From bd63e6801d1071cf4a85e100f1140abf2a263eb5 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Thu, 29 Jan 2026 09:36:55 +0800 Subject: [PATCH] [CK] Add FP8 KV_BLOCKSCALE support for batch prefill Implement per-page K/V quantization for paged attention: - Add KV_BLOCKSCALE enum to BlockAttentionQuantScaleEnum - Use exp2 shift trick to eliminate explicit P scaling overhead - Prefetch physical pages offset for KV cache, overlaps with computations --- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 2 + .../01_fmha/codegen/ops/fmha_batch_prefill.py | 5 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 20 +- example/ck_tile/01_fmha/quant.hpp | 13 +- .../block_attention_quant_scale_enum.hpp | 12 +- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 106 +++- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 504 +++++++++++++++--- 7 files changed, 558 insertions(+), 104 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index cac6671ca5f..995fc8c9659 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -78,12 +78,14 @@ def get_mask_cpp_check_expr(mask: str) -> str: "no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", "blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE", + "kv_blockscale": "ck_tile::BlockAttentionQuantScaleEnum::KV_BLOCKSCALE", } QSCALE_CHECK_MAP = { "no": "quant_scale_enum::no_scale", "pertensor": "quant_scale_enum::pertensor", "blockscale": "quant_scale_enum::blockscale", + "kv_blockscale": "quant_scale_enum::kv_blockscale", } BIAS_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 42f686e0c00..b575adc7d05 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -677,7 +677,7 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]: kv_lookup_table, ) in itertools.product( ["t", "f"], - ["pertensor"], + ["pertensor", "kv_blockscale"], get_mask_map(mask_impl).keys(), ["no"], SUPPORTED_KV_MEMORY_LAYOUT, @@ -740,6 +740,9 @@ def get_fwd_blobs( for page_size in SUPPORTED_PAGE_SIZE: if page_size == 1 and pipeline.F_kv_memory_layout != "linear": continue + # kv_blockscale only supports page_size=1024 + if pipeline.F_qscale == "kv_blockscale" and page_size != 1024: + continue k = FmhaFwdKernel( F_idx=0, F_hdim=hdim, diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index aedbb0e17c2..1fe14982a15 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -602,6 +602,14 @@ struct fmha_batch_prefill_args std::variant, std::pair> drop_seed_offset; + + // KV_BLOCKSCALE: per-page K/V descales (Q per-tensor, K/V per-page) + // Layout: [num_block, num_kv_head, 2] where 2 = (k_descale, v_descale) + // Mutually exclusive with per-tensor k_descale_ptr/v_descale_ptr + const void* kv_block_descale_ptr = nullptr; + ck_tile::index_t kv_block_descale_stride_block = 0; // Stride along num_block dimension + ck_tile::index_t kv_block_descale_stride_head = 0; // Stride along num_kv_head dimension + ck_tile::index_t kv_block_descale_stride_kv = 1; // Stride for K/V index (last dim) }; template @@ -1225,7 +1233,11 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.p_drop, args.s_randval, args.drop_seed_offset, - args.sink_ptr); + args.sink_ptr, + args.kv_block_descale_ptr, + args.kv_block_descale_stride_block, + args.kv_block_descale_stride_head, + args.kv_block_descale_stride_kv); } else { // create batch mode kernel arguments @@ -1278,7 +1290,11 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.p_drop, args.s_randval, args.drop_seed_offset, - args.sink_ptr); + args.sink_ptr, + args.kv_block_descale_ptr, + args.kv_block_descale_stride_block, + args.kv_block_descale_stride_head, + args.kv_block_descale_stride_kv); } }(); diff --git a/example/ck_tile/01_fmha/quant.hpp b/example/ck_tile/01_fmha/quant.hpp index feb28cba24e..9221f9a0a6c 100644 --- a/example/ck_tile/01_fmha/quant.hpp +++ b/example/ck_tile/01_fmha/quant.hpp @@ -11,9 +11,10 @@ // keep sync with BlockAttentionQuantScaleEnum enum class quant_scale_enum { - no_scale = 0, - pertensor = 1, - blockscale, + no_scale = 0, + pertensor = 1, + blockscale = 2, + kv_blockscale = 3, // Q per-tensor, K/V per-page block scale }; struct quant_scale_info @@ -28,6 +29,8 @@ struct quant_scale_info os << "pt"; else if(type == quant_scale_enum::blockscale) os << "bs"; + else if(type == quant_scale_enum::kv_blockscale) + os << "kvbs"; } static quant_scale_info decode(std::string str) @@ -45,6 +48,10 @@ struct quant_scale_info { info.type = quant_scale_enum::blockscale; } + else if(str == "kvbs" || str == "3") + { + info.type = quant_scale_enum::kv_blockscale; + } else { throw std::invalid_argument("invalid quant scale value: " + str); diff --git a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp index 7e0f704bef8..84a2321708d 100644 --- a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp +++ b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp @@ -10,9 +10,10 @@ namespace ck_tile { // This class is used for codegen pattern matching enum class BlockAttentionQuantScaleEnum { - NO_SCALE = 0, - PERTENSOR = 1, - BLOCKSCALE, + NO_SCALE = 0, + PERTENSOR = 1, + BLOCKSCALE = 2, + KV_BLOCKSCALE = 3, // Q per-tensor, K/V per-page block scale }; template @@ -33,5 +34,10 @@ struct BlockAttentionQuantScaleEnumToStr +struct BlockAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = "kv_blockscale"; +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 86e1de3e9fd..03303a0683d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -185,13 +185,44 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t batch_stride_lse = 0; }; - struct FmhaFwdCommonQScaleKargs + // PERTENSOR: Q/K/V all use per-tensor descales + struct FmhaFwdPerTensorQScaleKargs { const void* q_descale_ptr = nullptr; const void* k_descale_ptr = nullptr; const void* v_descale_ptr = nullptr; }; + // KV_BLOCKSCALE: Q per-tensor, K/V per-page descales + struct FmhaFwdKVBlockScaleKargs + { + const void* q_descale_ptr = nullptr; // Per-tensor Q descale + const void* kv_block_descale_ptr = nullptr; // [num_block, num_kv_head, 2] + ck_tile::index_t kv_block_descale_stride_block = 0; // Stride along num_block dimension + ck_tile::index_t kv_block_descale_stride_head = 0; // Stride along num_kv_head dimension + ck_tile::index_t kv_block_descale_stride_kv = 1; // Stride for K/V index + }; + + // Helper template to select QScale Kargs type based on QScaleEnum + // EmptyType: type to use when QScaleEnum is NO_SCALE (e.g., FmhaFwdEmptyKargs<3>) + template + struct QScaleKargsSelector + { + using type = EmptyType; + }; + + template + struct QScaleKargsSelector + { + using type = FmhaFwdPerTensorQScaleKargs; + }; + + template + struct QScaleKargsSelector + { + using type = FmhaFwdKVBlockScaleKargs; + }; + struct FmhaFwdDropoutSeedOffset { template @@ -255,9 +286,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + QScaleKargsSelector>::type, std::conditional_t>, std::conditional_t> { @@ -276,9 +305,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + QScaleKargsSelector>::type, std::conditional_t>, std::conditional_t> { @@ -348,7 +375,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - const void* sink_ptr = nullptr) + const void* sink_ptr = nullptr, + const void* kv_block_descale_ptr = nullptr, + ck_tile::index_t kv_block_descale_stride_block = 0, + ck_tile::index_t kv_block_descale_stride_head = 0, + ck_tile::index_t kv_block_descale_stride_kv = 1) { Kargs kargs{{q_ptr, k_ptr, @@ -419,6 +450,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.kv_block_descale_ptr = kv_block_descale_ptr; + kargs.kv_block_descale_stride_block = kv_block_descale_stride_block; + kargs.kv_block_descale_stride_head = kv_block_descale_stride_head; + kargs.kv_block_descale_stride_kv = kv_block_descale_stride_kv; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -495,7 +534,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - const void* sink_ptr = nullptr) + const void* sink_ptr = nullptr, + const void* kv_block_descale_ptr = nullptr, + ck_tile::index_t kv_block_descale_stride_block = 0, + ck_tile::index_t kv_block_descale_stride_head = 0, + ck_tile::index_t kv_block_descale_stride_kv = 1) { Kargs kargs{{q_ptr, k_ptr, @@ -563,6 +606,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.kv_block_descale_ptr = kv_block_descale_ptr; + kargs.kv_block_descale_stride_block = kv_block_descale_stride_block; + kargs.kv_block_descale_stride_head = kv_block_descale_stride_head; + kargs.kv_block_descale_stride_kv = kv_block_descale_stride_kv; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -1162,6 +1213,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel return kargs.scale_s * q_descale * k_descale; } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + // Q is per-tensor, K is per-page (handled in pipeline) + float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); + return kargs.scale_s * q_descale; + } else { return kargs.scale_s; @@ -1237,6 +1294,37 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel dropout, sink_value); } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + // KV_BLOCKSCALE: K/V descale is per-page, handled in pipeline + const float* kv_block_descale_ptr = + reinterpret_cast(kargs.kv_block_descale_ptr); + + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + variant_params.sm_scale, + variant, + variant_params, + block_indices, + smem_ptr, + page_idx, + stride_k_for_pipeline, + stride_v_for_pipeline, + kargs.batch_stride_k, + kargs.batch_stride_v, + dropout, + sink_value, + kv_block_descale_ptr, + kargs.kv_block_descale_stride_block, + kargs.kv_block_descale_stride_head, + kargs.kv_block_descale_stride_kv); + } else { return FmhaPipeline{}(q_dram_window, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 48e8f75ae7e..7622778c89e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -7,14 +7,21 @@ #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { -template -CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, - const index_t& stride_token, - const index_t& stride_page_block, - const CoordVecType& coord_vec, - OffsetVecType& kv_offset_vec, - index_t global_seq_offset = 0) + index_t kN0> +CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx, + const CoordVecType& coord_vec, + index_t global_seq_offset, + index_t (&physical_pages)[kLoopCount]) { static constexpr index_t kLog2PageSize = [] { index_t shift = 0; @@ -42,18 +46,16 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, return shift; }(); - const index_t& thread_coord_start = coord_vec[kCoordAxis]; - constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1; + const index_t& thread_coord_start = coord_vec[kCoordAxis]; + if constexpr(kIsKcache) { - // for k offsets + // K cache: per-token lookup (all tokens may be on different pages) static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t page_id = global_token_idx >> kLog2PageSize; - const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; - kv_offset_vec[k0] = static_cast(page_idx[page_id]) * stride_page_block + - static_cast(token_idx_in_page) * stride_token; + const index_t page_id = global_token_idx >> kLog2PageSize; + physical_pages[k0.value] = page_idx[page_id]; }); } else @@ -71,11 +73,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - - const long_index_t page_base_offset = - static_cast(page_idx[global_token_idx]) * stride_page_block; - - kv_offset_vec[k0] = page_base_offset; + physical_pages[k0.value] = page_idx[global_token_idx]; }); } else if constexpr(kVTileCrossesPages) @@ -85,70 +83,131 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx, static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t page_id = global_token_idx >> kLog2PageSize; - const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; - - const long_index_t page_base_offset = - static_cast(page_idx[page_id]) * stride_page_block; - - if constexpr(kKVMemoryLayout == - BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) - { - // Vectorized layout uses a packed [token/kVectorSize, head_dim, kVectorSize] - // address pattern. - const long_index_t token_offset = - static_cast((token_idx_in_page / kVectorSize) * - (stride_token * kVectorSize)) + - (token_idx_in_page % kVectorSize); - - kv_offset_vec[k0] = page_base_offset + token_offset; - } - else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT - { - kv_offset_vec[k0] = page_base_offset + - static_cast(token_idx_in_page) * stride_token; - } + const index_t page_id = global_token_idx >> kLog2PageSize; + physical_pages[k0.value] = page_idx[page_id]; }); } - else // !kVTileCrossesPages + else { - // V tile is fully contained in one page, so page_id is shared. - // Use lane0 to compute page_id once and broadcast page_base_offset. + // V tile fully contained in one page: lane0 lookup, broadcast to all const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start); const index_t lane0_page_id = (global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize; + const index_t shared_physical_page = page_idx[lane0_page_id]; - const long_index_t page_base_offset = - static_cast(page_idx[lane0_page_id]) * stride_page_block; + static_for<0, kLoopCount, 1>{}( + [&](auto k0) { physical_pages[k0.value] = shared_physical_page; }); + } + } +} - static_for<0, kLoopCount, 1>{}([&](auto k0) { - // kLoopStride allows non-unit token spacing in the tile distribution. - const index_t token_idx_in_page = - (global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value) & - kInPageOffsetMask; +// kv_offset_array_transform: Converts logical token indices to physical memory offsets +// for paged KV cache access. +// +// This version uses pre-loaded physical_pages array from load_physical_pages(). +// Benefits: +// - page_idx is read only once (by load_physical_pages) +// - physical_pages can be prefetched before GEMM to hide memory latency +// - physical_pages can be reused for descale lookup (KV_BLOCKSCALE) +// +// Template parameters: +// - kCoordAxis: Which axis of coord_vec contains the thread's token coordinate +// - kPageBlockSize: Number of tokens per page (must be power of 2) +// - kLoopStart/kLoopCount/kLoopStride: Loop iteration parameters for static_for +// - kKVMemoryLayout: VECTORIZED_LAYOUT or LINEAR_LAYOUT +// - kIsKcache: true for K cache, false for V cache +// - kN0: Tile size in N dimension (used for page crossing detection) +// - kVectorSize: Vector size for vectorized layout (e.g., 8 for fp8) +// +// Memory layout for V cache: +// LINEAR_LAYOUT: [page, token_in_page, head_dim] +// VECTORIZED_LAYOUT: [page, token_in_page/kVectorSize, head_dim, kVectorSize] +// +template +CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t (&physical_pages)[kLoopCount], + const index_t& stride_token, + const index_t& stride_page_block, + const CoordVecType& coord_vec, + OffsetVecType& kv_offset_vec, + index_t global_seq_offset = 0) +{ + static constexpr index_t kLog2PageSize = [] { + index_t shift = 0; + index_t val = kPageBlockSize; + while(val > 1) + { + val >>= 1; + shift++; + } + return shift; + }(); - if constexpr(kKVMemoryLayout == - BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) - { - // Vectorized layout offset - // Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize] - // Offset = (token_idx_in_page / kVectorSize) * (HeadDim * kVectorSize) + - // (token_idx_in_page % kVectorSize) + const index_t& thread_coord_start = coord_vec[kCoordAxis]; + constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1; - const long_index_t token_offset = - static_cast((token_idx_in_page / kVectorSize) * - (stride_token * kVectorSize)) + - (token_idx_in_page % kVectorSize); + if constexpr(kIsKcache) + { + // K cache: per-token lookup + // Each token may be on a different page, so we use physical_pages[k0] for each. + // Offset = physical_page * stride_page_block + token_idx_in_page * stride_token + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t global_token_idx = + global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; + const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + const index_t physical_page = physical_pages[k0.value]; - kv_offset_vec[k0] = page_base_offset + token_offset; - } - else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT - { - kv_offset_vec[k0] = page_base_offset + - static_cast(token_idx_in_page) * stride_token; - } - }); - } + kv_offset_vec[k0] = static_cast(physical_page) * stride_page_block + + static_cast(token_idx_in_page) * stride_token; + }); + } + else // !kVTileCrossesPages + { + // V cache: use physical_pages[k0] for each token + // physical_pages was already populated correctly by load_physical_pages(), handling: + // - page_size=1: page_idx maps token_idx -> physical_page directly + // - V tile crosses pages: per-token page lookup + // - V tile in single page: lane0 lookup with broadcast to all lanes + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t global_token_idx = + global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; + const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + const index_t physical_page = physical_pages[k0.value]; + + const long_index_t page_base_offset = + static_cast(physical_page) * stride_page_block; + + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // Vectorized layout offset calculation: + // Layout: [page, token_in_page/kVectorSize, head_dim, kVectorSize] + // Offset = page_base + (token/kVectorSize) * (head_dim * kVectorSize) + + // (token % kVectorSize) + const long_index_t token_offset = + static_cast((token_idx_in_page / kVectorSize) * + (stride_token * kVectorSize)) + + (token_idx_in_page % kVectorSize); + + kv_offset_vec[k0] = page_base_offset + token_offset; + } + else // LINEAR_LAYOUT + { + // Linear layout: [page, token_in_page, head_dim] + // Offset = page_base + token_idx_in_page * stride_token + kv_offset_vec[k0] = + page_base_offset + static_cast(token_idx_in_page) * stride_token; + } + }); } } @@ -209,6 +268,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout; + static constexpr auto QScaleEnum = Problem::QScaleEnum; + + // For KV_BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] + // This avoids explicit P *= scale_p and v_descale /= scale_p operations + static constexpr float OCP_FP8_SHIFT = 8.0f; + static constexpr float FNUZ_FP8_SHIFT = 7.0f; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -341,8 +406,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t page_stride_k, const index_t page_stride_v, DropoutType& dropout, - const float sink_v) const + const float sink_v, + // KV_BLOCKSCALE parameters (only used when QScaleEnum == KV_BLOCKSCALE) + const float* kv_block_descale_ptr = nullptr, + index_t kv_block_descale_stride_block = 0, + index_t kv_block_descale_stride_head = 0, + index_t kv_block_descale_stride_kv = 1) const { + // KV_BLOCKSCALE requires page_block_size >= kN0 to ensure + // all tokens in a main loop iteration belong to the same page + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + static_assert(kPageBlockSize >= kN0, "KV_BLOCKSCALE requires kPageBlockSize >= kN0"); + } + static_assert( std::is_same_v> && std::is_same_v> && @@ -494,6 +571,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0]; statically_indexed_array k_offsets; index_t current_seq_k = seqlen_k_start; + + // Load physical pages first, then compute offsets. + // k_physical_pages can be reused for descale lookup later. + index_t k_physical_pages[NRepeat] = {}; + load_physical_pages(page_idx, k_coord, current_seq_k, k_physical_pages); + kv_offset_array_transform, decltype(k_coord), 0, @@ -505,7 +596,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync true, kN0, kVectorSize>( - page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); + k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window.get_window_lengths(), @@ -644,6 +735,50 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync "V page-index Y dim must be valid"); statically_indexed_array v_offsets; + // V physical pages array for use with kv_offset_array_transform + // For V_KIterOuter > 1, we need V_PageIdxRepeat elements; otherwise V_KIterInner + index_t v_physical_pages[V_PageIdxRepeat] = {}; + + // Prefetch V physical pages - can be called early to hide buffer load latency + auto prefetch_v_physical_pages = [&](auto k_loop_start) { + constexpr index_t kLoopStart = decltype(k_loop_start)::value; + if constexpr(V_KIterOuter > 1) + { + static_for<0, V_KIterOuter, 1>{}([&](auto k2) { + // Load physical pages for this k2 slice into the appropriate portion of array + index_t v_physical_pages_k2[V_KIterInner] = {}; + load_physical_pages(page_idx, v_coord, current_seq_k, v_physical_pages_k2); + + // Copy to merged array + static_for<0, V_KIterInner, 1>{}([&](auto k1) { + constexpr auto idx = k1.value + k2.value * V_KIterInner; + v_physical_pages[idx] = v_physical_pages_k2[k1.value]; + }); + }); + } + else + { + load_physical_pages(page_idx, v_coord, current_seq_k, v_physical_pages); + } + }; + + // Update V offsets using pre-loaded physical pages auto update_v_offsets = [&](auto k_loop_start) { constexpr index_t kLoopStart = decltype(k_loop_start)::value; // For 3D K decomposition (K2, K0, K1), compute offsets for each K2 slice @@ -653,6 +788,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { static_for<0, V_KIterOuter, 1>{}([&](auto k2) { statically_indexed_array v_offsets_k2; + // Extract physical pages for this k2 slice + index_t v_physical_pages_k2[V_KIterInner]; + static_for<0, V_KIterInner, 1>{}([&](auto k1) { + constexpr auto idx = k1.value + k2.value * V_KIterInner; + v_physical_pages_k2[k1.value] = v_physical_pages[idx]; + }); + kv_offset_array_transform, decltype(v_coord), I1, @@ -663,8 +805,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, false, kN0, - kVectorSize>( - page_idx, stride_v, page_stride_v, v_coord, v_offsets_k2, current_seq_k); + kVectorSize>(v_physical_pages_k2, + stride_v, + page_stride_v, + v_coord, + v_offsets_k2, + current_seq_k); + static_for<0, V_KIterInner, 1>{}([&](auto k1) { constexpr auto idx = number{}; v_offsets[idx] = v_offsets_k2[k1]; @@ -684,9 +831,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync false, kN0, kVectorSize>( - page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); } }; + + // Prefetch V physical pages early to hide buffer load latency + prefetch_v_physical_pages(number<0>{}); update_v_offsets(number<0>{}); auto v_dram_window = make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(), @@ -717,6 +867,41 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // main loop do { + // KV_BLOCKSCALE: load per-page K/V descale factors + // Uses k_physical_pages[0] from load_physical_pages to avoid redundant page_idx reads. + // Assumes kPageBlockSize >= kN0, so all tokens in one main loop iteration belong to + // the same page (single scale pair). + // + // TODO: Cross-page KV_BLOCKSCALE support + // Currently only supports kPageBlockSize >= kN0 (all tokens in tile on same page). + // To support smaller page sizes (cross-page tiles), need: + // + // 1. K descale: Load per-token k_descale_vec[NRepeat] based on k_physical_pages[k0] + // - After GEMM0 (S = Q × K^T), apply column-wise scaling: S[:,j] *= k_descale[j] + // - Requires modifying s_acc_element_func to accept column index + // + // 2. V descale: Load per-token v_descale_vec[V_PageIdxRepeat] based on + // v_physical_pages[k0] + // - Before GEMM1 (O = P × V), apply row-wise scaling to P: P[i,j] *= v_descale[j] + // - Or pre-scale V in LDS (more complex) + // + // 3. K and V may be on different pages for the same token index, so need separate + // lookups + // + [[maybe_unused]] float k_descale = 1.0f; + [[maybe_unused]] float v_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + const index_t scale_offset = + k_physical_pages[0] * kv_block_descale_stride_block + + block_indices.kv_head_idx * kv_block_descale_stride_head; + k_descale = kv_block_descale_ptr[scale_offset + 0 * kv_block_descale_stride_kv]; + v_descale = kv_block_descale_ptr[scale_offset + 1 * kv_block_descale_stride_kv]; + } + + // Prefetch V physical pages early - overlaps with GEMM0 computation + prefetch_v_physical_pages(number{}); + // STAGE 1, QK gemm clear_tile(s_acc); // initialize C if constexpr(k0_loops > 1) @@ -763,9 +948,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync __builtin_amdgcn_sched_barrier(1); auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + // V physical pages already prefetched before GEMM0 update_v_offsets(number{}); v_dram_window.update_page_idx(v_offsets); + // KV_BLOCKSCALE: apply k_descale to s_acc (dequantize QK result) + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + tile_elementwise_inout([&k_descale](auto& x) { x *= k_descale; }, s_acc); + } + const auto p = [&]() { const auto bias_tile = load_tile(bias_dram_window); // load bias tile @@ -875,6 +1067,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } const auto s = cast_tile(s_acc); // S{j} + + // Prefetch V physical pages early - overlaps with softmax computation + if constexpr(k1_loops > 1) + { + prefetch_v_physical_pages(number<2 * kK1>{}); + } + auto m_local = block_tile_reduce( s, sequence<1>{}, @@ -953,7 +1152,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); + // For KV_BLOCKSCALE: precompute (m - shift) once per row + // exp2(s - (m - shift)) = exp2(s - m + shift) = exp2(s - m) * 2^shift + // This scales P by 2^shift (≈448 for fp8_e4m3) without explicit multiply + auto validated_m = get_validated_m(m[i_idx]); + auto row_max = scale_s * validated_m; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { +#if CK_TILE_USE_OCP_FP8 + validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap + row_max -= OCP_FP8_SHIFT; // for else branch +#else + validated_m -= FNUZ_FP8_SHIFT; + row_max -= FNUZ_FP8_SHIFT; +#endif + } #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -961,13 +1174,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { if constexpr(kHasLogitsSoftCap) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { @@ -1049,6 +1262,22 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync }(); // STAGE 3, KV gemm + // KV_BLOCKSCALE: accumulate P*V into temporary tile before applying v_descale + auto o_acc_unscaled = decltype(o_acc){}; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + clear_tile(o_acc_unscaled); + } + + // Select GEMM1 target: o_acc_unscaled for KV_BLOCKSCALE (needs v_descale), o_acc + // otherwise + auto& gemm1_acc = [&]() -> auto& { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + return o_acc_unscaled; + else + return o_acc; + }(); + if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { @@ -1056,11 +1285,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { v_buf = load_tile( v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + // Update V offsets using previously prefetched physical pages update_v_offsets(number<(2 + i_k1.value) * kK1>{}); v_dram_window.update_page_idx(v_offsets); } + + // Prefetch V physical pages for NEXT iteration - overlaps with GEMM1 + if constexpr(i_k1 + 1 < k1_loops - 1) + { + prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{}); + } + block_sync_lds(); - gemm_1(o_acc, + gemm_1(gemm1_acc, get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence{}), get_slice_tile( @@ -1104,6 +1341,17 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync move_tile_window(k_dram_block_window, {kN0, 0}); k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + // KV_BLOCKSCALE: reload physical pages for the new tile + load_physical_pages(page_idx, k_coord, current_seq_k, k_physical_pages); + kv_offset_array_transform, decltype(k_coord), 0, @@ -1115,7 +1363,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync true, kN0, kVectorSize>( - page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); + k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) @@ -1131,13 +1379,26 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { block_sync_lds(); gemm_1( - o_acc, + gemm1_acc, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), get_slice_tile( v_lds_window, sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); } + + // KV_BLOCKSCALE: apply v_descale and accumulate o_acc_unscaled into o_acc + // Note: No division by scale_p needed because: + // 1. P was scaled by 2^shift through exp2 shift trick + // 2. rowsum l was also scaled by 2^shift + // 3. Final O = sum(P*V) / l, so the 2^shift cancels out + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) + { + tile_elementwise_inout( + [&v_descale](auto& o, auto& o_unscaled) { o += o_unscaled * v_descale; }, + o_acc, + o_acc_unscaled); + } } while(i_total_loops < num_total_loop); // store lse @@ -1257,6 +1518,77 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync dropout, sink_v); } + + // Overload for KV_BLOCKSCALE: K/V descale is per-page + // This is a convenience overload that forwards to the main operator() with kv_scale parameters + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + const index_t* page_idx, + const index_t stride_k, + const index_t stride_v, + const index_t page_stride_k, + const index_t page_stride_v, + DropoutType& dropout, + float sink_v, + const float* kv_block_descale_ptr, + index_t kv_block_descale_stride_block, + index_t kv_block_descale_stride_head, + index_t kv_block_descale_stride_kv) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + position_encoding, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + page_idx, + stride_k, + stride_v, + page_stride_k, + page_stride_v, + dropout, + sink_v, + kv_block_descale_ptr, + kv_block_descale_stride_block, + kv_block_descale_stride_head, + kv_block_descale_stride_kv); + } }; } // namespace ck_tile