From 48b0594c1fe743258d1baa8a6fc994f2efaff95f Mon Sep 17 00:00:00 2001 From: adarshxs Date: Tue, 5 May 2026 22:42:59 +0530 Subject: [PATCH 1/7] Port upstream DeepGEMM updates --- deep-gemm/CMakeLists.txt | 9 +- deep-gemm/README.md | 138 +- deep-gemm/build.toml | 79 +- deep-gemm/csrc/apis/attention.hpp | 440 ++++-- deep-gemm/csrc/apis/einsum.hpp | 73 +- deep-gemm/csrc/apis/gemm.hpp | 135 +- deep-gemm/csrc/apis/hyperconnection.hpp | 16 +- deep-gemm/csrc/apis/layout.hpp | 71 +- deep-gemm/csrc/apis/mega.hpp | 239 +++ deep-gemm/csrc/apis/runtime.hpp | 45 +- deep-gemm/csrc/impl.cu | 171 +- deep-gemm/csrc/indexing/main.cu | 7 +- deep-gemm/csrc/jit/cache.hpp | 2 +- deep-gemm/csrc/jit/compiler.hpp | 264 ++-- deep-gemm/csrc/jit/device_runtime.hpp | 53 +- deep-gemm/csrc/jit/handle.hpp | 90 +- deep-gemm/csrc/jit/include_parser.hpp | 80 + deep-gemm/csrc/jit/kernel_runtime.hpp | 118 +- .../csrc/jit_kernels/heuristics/common.hpp | 353 +---- .../csrc/jit_kernels/heuristics/config.hpp | 171 ++ .../csrc/jit_kernels/heuristics/mega_moe.hpp | 240 +++ .../csrc/jit_kernels/heuristics/runtime.hpp | 62 + .../csrc/jit_kernels/heuristics/sm100.hpp | 334 ++-- .../csrc/jit_kernels/heuristics/sm90.hpp | 336 ++-- .../csrc/jit_kernels/heuristics/utils.hpp | 23 + deep-gemm/csrc/jit_kernels/impls/epilogue.hpp | 2 +- .../csrc/jit_kernels/impls/runtime_utils.hpp | 74 +- .../jit_kernels/impls/sm100_bf16_gemm.hpp | 438 +++--- .../jit_kernels/impls/sm100_bmk_bnk_mn.hpp | 24 +- .../impls/sm100_fp8_fp4_gemm_1d1d.hpp | 459 ++++++ .../impls/sm100_fp8_fp4_mega_moe.hpp | 220 +++ .../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp | 2 +- .../impls/sm100_tf32_hc_prenorm_gemm.hpp | 38 +- .../csrc/jit_kernels/impls/sm90_bf16_gemm.hpp | 424 ++--- .../jit_kernels/impls/sm90_bmk_bnk_mn.hpp | 16 +- .../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp | 187 +-- .../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp | 356 +++-- .../impls/sm90_tf32_hc_prenorm_gemm.hpp | 38 +- .../jit_kernels/impls/smxx_clean_logits.hpp | 14 +- .../csrc/jit_kernels/impls/smxx_cublaslt.hpp | 75 +- .../impls/smxx_fp8_fp4_mqa_logits.hpp | 328 ++++ .../impls/smxx_fp8_fp4_paged_mqa_logits.hpp | 463 ++++++ .../csrc/jit_kernels/impls/smxx_layout.hpp | 77 +- deep-gemm/csrc/python_api.cpp | 2 + deep-gemm/csrc/utils/exception.hpp | 8 +- deep-gemm/csrc/utils/format.hpp | 13 + deep-gemm/csrc/utils/hash.hpp | 21 +- deep-gemm/csrc/utils/layout.hpp | 7 +- deep-gemm/csrc/utils/math.hpp | 7 +- deep-gemm/csrc/utils/system.hpp | 53 +- deep-gemm/csrc/utils/torch_compat.hpp | 36 + deep-gemm/deep_gemm/__init__.py | 20 +- .../include/deep_gemm/comm/barrier.cuh | 83 + .../include/deep_gemm/common/compile.cuh | 18 + .../include/deep_gemm/common/cute_tie.cuh | 2 + .../include/deep_gemm/common/exception.cuh | 43 + .../include/deep_gemm/common/math.cuh | 149 ++ .../include/deep_gemm/common/tma_copy.cuh | 92 ++ .../include/deep_gemm/common/types.cuh | 43 + .../include/deep_gemm/common/utils.cuh | 165 +- .../deep_gemm/epilogue/sm100_store_cd.cuh | 137 ++ .../epilogue/sm100_store_cd_swap_ab.cuh | 144 ++ .../include/deep_gemm/epilogue/transform.cuh | 24 + .../deep_gemm/impls/sm100_bf16_gemm.cuh | 345 ++--- .../deep_gemm/impls/sm100_bmk_bnk_mn.cuh | 56 +- .../deep_gemm/impls/sm100_fp4_mqa_logits.cuh | 457 ++++++ .../impls/sm100_fp4_paged_mqa_logits.cuh | 510 ++++++ .../impls/sm100_fp8_fp4_gemm_1d1d.cuh | 514 ++++++ .../impls/sm100_fp8_fp4_mega_moe.cuh | 1380 +++++++++++++++++ .../deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh | 14 +- .../deep_gemm/impls/sm100_fp8_mqa_logits.cuh | 251 ++- .../impls/sm100_fp8_paged_mqa_logits.cuh | 369 +++-- .../impls/sm100_tf32_hc_prenorm_gemm.cuh | 65 +- .../deep_gemm/impls/sm90_bf16_gemm.cuh | 87 +- .../deep_gemm/impls/sm90_bmk_bnk_mn.cuh | 65 +- .../deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh | 161 +- .../deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh | 101 +- .../deep_gemm/impls/sm90_fp8_mqa_logits.cuh | 127 +- .../impls/sm90_fp8_paged_mqa_logits.cuh | 231 +-- .../impls/sm90_tf32_hc_prenorm_gemm.cuh | 65 +- .../deep_gemm/impls/smxx_clean_logits.cuh | 53 +- .../include/deep_gemm/impls/smxx_layout.cuh | 55 +- .../include/deep_gemm/layout/mega_moe.cuh | 260 ++++ .../include/deep_gemm/layout/sym_buffer.cuh | 41 + .../deep_gemm/include/deep_gemm/mma/sm100.cuh | 151 ++ .../deep_gemm/include/deep_gemm/mma/sm90.cuh | 293 ++++ .../deep_gemm/include/deep_gemm/ptx/ld_st.cuh | 251 +++ .../include/deep_gemm/ptx/tcgen05.cuh | 168 ++ .../deep_gemm/include/deep_gemm/ptx/tma.cuh | 112 ++ .../deep_gemm/include/deep_gemm/ptx/utils.cuh | 53 + .../deep_gemm/include/deep_gemm/ptx/wgmma.cuh | 25 + .../include/deep_gemm/scheduler/gemm.cuh | 300 ++++ .../include/deep_gemm/scheduler/mega_moe.cuh | 221 +++ .../deep_gemm/scheduler/paged_mqa_logits.cuh | 239 +++ .../legacy/a_fused_m_grouped_gemm.py | 2 +- .../legacy/b_fused_k_grouped_gemm.py | 2 +- deep-gemm/deep_gemm/mega/__init__.py | 130 ++ deep-gemm/deep_gemm/testing/bench.py | 17 +- deep-gemm/deep_gemm/utils/__init__.py | 1 + deep-gemm/deep_gemm/utils/dist.py | 74 + deep-gemm/deep_gemm/utils/layout.py | 6 +- deep-gemm/deep_gemm/utils/math.py | 66 +- deep-gemm/scripts/quick_plot_pm.py | 448 ++++++ deep-gemm/scripts/readme_example.py | 49 - deep-gemm/scripts/run_ncu_mega_moe.sh | 89 ++ deep-gemm/setup.py | 3 +- deep-gemm/tests/generators.py | 58 +- deep-gemm/tests/test_attention.py | 436 ++++-- deep-gemm/tests/test_bf16.py | 27 +- deep-gemm/tests/test_cublaslt.py | 20 - deep-gemm/tests/test_einsum.py | 6 +- deep-gemm/tests/test_fp8_fp4.py | 51 +- deep-gemm/tests/test_layout.py | 24 +- deep-gemm/tests/test_lazy_init.py | 7 +- deep-gemm/tests/test_mega_moe.py | 295 ++++ deep-gemm/tests/test_sanitizer.py | 5 +- deep-gemm/torch-ext/deep_gemm/_C.py | 194 +++ deep-gemm/torch-ext/deep_gemm/__init__.py | 157 +- .../include/deep_gemm/comm/barrier.cuh | 83 + .../include/deep_gemm/common/compile.cuh | 18 + .../include/deep_gemm/common/cute_tie.cuh | 2 + .../include/deep_gemm/common/exception.cuh | 43 + .../include/deep_gemm/common/math.cuh | 149 ++ .../include/deep_gemm/common/tma_copy.cuh | 92 ++ .../include/deep_gemm/common/types.cuh | 43 + .../include/deep_gemm/common/utils.cuh | 165 +- .../deep_gemm/epilogue/sm100_store_cd.cuh | 137 ++ .../epilogue/sm100_store_cd_swap_ab.cuh | 144 ++ .../include/deep_gemm/epilogue/transform.cuh | 24 + .../deep_gemm/impls/sm100_bf16_gemm.cuh | 345 ++--- .../deep_gemm/impls/sm100_bmk_bnk_mn.cuh | 56 +- .../deep_gemm/impls/sm100_fp4_mqa_logits.cuh | 457 ++++++ .../impls/sm100_fp4_paged_mqa_logits.cuh | 510 ++++++ .../impls/sm100_fp8_fp4_gemm_1d1d.cuh | 514 ++++++ .../impls/sm100_fp8_fp4_mega_moe.cuh | 1380 +++++++++++++++++ .../deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh | 14 +- .../deep_gemm/impls/sm100_fp8_mqa_logits.cuh | 251 ++- .../impls/sm100_fp8_paged_mqa_logits.cuh | 369 +++-- .../impls/sm100_tf32_hc_prenorm_gemm.cuh | 65 +- .../deep_gemm/impls/sm90_bf16_gemm.cuh | 87 +- .../deep_gemm/impls/sm90_bmk_bnk_mn.cuh | 65 +- .../deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh | 161 +- .../deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh | 101 +- .../deep_gemm/impls/sm90_fp8_mqa_logits.cuh | 127 +- .../impls/sm90_fp8_paged_mqa_logits.cuh | 231 +-- .../impls/sm90_tf32_hc_prenorm_gemm.cuh | 65 +- .../deep_gemm/impls/smxx_clean_logits.cuh | 53 +- .../include/deep_gemm/impls/smxx_layout.cuh | 55 +- .../include/deep_gemm/layout/mega_moe.cuh | 260 ++++ .../include/deep_gemm/layout/sym_buffer.cuh | 41 + .../deep_gemm/include/deep_gemm/mma/sm100.cuh | 151 ++ .../deep_gemm/include/deep_gemm/mma/sm90.cuh | 293 ++++ .../deep_gemm/include/deep_gemm/ptx/ld_st.cuh | 251 +++ .../include/deep_gemm/ptx/tcgen05.cuh | 168 ++ .../deep_gemm/include/deep_gemm/ptx/tma.cuh | 112 ++ .../deep_gemm/include/deep_gemm/ptx/utils.cuh | 53 + .../deep_gemm/include/deep_gemm/ptx/wgmma.cuh | 25 + .../include/deep_gemm/scheduler/gemm.cuh | 300 ++++ .../include/deep_gemm/scheduler/mega_moe.cuh | 221 +++ .../deep_gemm/scheduler/paged_mqa_logits.cuh | 239 +++ .../torch-ext/deep_gemm/legacy/__init__.py | 5 + .../legacy/a_fused_k_grouped_gemm.py | 88 ++ .../legacy/a_fused_m_grouped_gemm.py | 92 ++ .../legacy/b_fused_k_grouped_gemm.py | 86 + .../deep_gemm/legacy/m_grouped_gemm.py | 84 + .../deep_gemm/legacy/tune_options.py | 28 + .../torch-ext/deep_gemm/mega/__init__.py | 130 ++ .../torch-ext/deep_gemm/testing/bench.py | 17 +- .../torch-ext/deep_gemm/utils/__init__.py | 1 + deep-gemm/torch-ext/deep_gemm/utils/dist.py | 74 + deep-gemm/torch-ext/deep_gemm/utils/layout.py | 42 +- deep-gemm/torch-ext/deep_gemm/utils/math.py | 66 +- deep-gemm/torch-ext/torch_binding.cpp | 88 +- deep-gemm/torch-ext/torch_binding.h | 59 +- 174 files changed, 21026 insertions(+), 4792 deletions(-) create mode 100644 deep-gemm/csrc/apis/mega.hpp create mode 100644 deep-gemm/csrc/jit/include_parser.hpp create mode 100644 deep-gemm/csrc/jit_kernels/heuristics/config.hpp create mode 100644 deep-gemm/csrc/jit_kernels/heuristics/mega_moe.hpp create mode 100644 deep-gemm/csrc/jit_kernels/heuristics/runtime.hpp create mode 100644 deep-gemm/csrc/jit_kernels/heuristics/utils.hpp create mode 100644 deep-gemm/csrc/jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp create mode 100644 deep-gemm/csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp create mode 100644 deep-gemm/csrc/jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp create mode 100644 deep-gemm/csrc/jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp create mode 100644 deep-gemm/csrc/utils/torch_compat.hpp create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/comm/barrier.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/common/compile.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/common/exception.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/common/math.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/common/tma_copy.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/common/types.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/epilogue/transform.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/layout/mega_moe.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/layout/sym_buffer.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/mma/sm100.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/mma/sm90.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/ptx/ld_st.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/ptx/tma.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/ptx/utils.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/ptx/wgmma.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/scheduler/gemm.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh create mode 100644 deep-gemm/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh create mode 100644 deep-gemm/deep_gemm/mega/__init__.py create mode 100644 deep-gemm/deep_gemm/utils/dist.py create mode 100644 deep-gemm/scripts/quick_plot_pm.py delete mode 100644 deep-gemm/scripts/readme_example.py create mode 100755 deep-gemm/scripts/run_ncu_mega_moe.sh delete mode 100644 deep-gemm/tests/test_cublaslt.py create mode 100644 deep-gemm/tests/test_mega_moe.py create mode 100644 deep-gemm/torch-ext/deep_gemm/_C.py create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/comm/barrier.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/compile.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/exception.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/math.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/tma_copy.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/types.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/epilogue/transform.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/layout/mega_moe.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/layout/sym_buffer.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/mma/sm100.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/mma/sm90.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/ld_st.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/tma.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/utils.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/wgmma.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/scheduler/gemm.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh create mode 100644 deep-gemm/torch-ext/deep_gemm/legacy/__init__.py create mode 100644 deep-gemm/torch-ext/deep_gemm/legacy/a_fused_k_grouped_gemm.py create mode 100644 deep-gemm/torch-ext/deep_gemm/legacy/a_fused_m_grouped_gemm.py create mode 100644 deep-gemm/torch-ext/deep_gemm/legacy/b_fused_k_grouped_gemm.py create mode 100644 deep-gemm/torch-ext/deep_gemm/legacy/m_grouped_gemm.py create mode 100644 deep-gemm/torch-ext/deep_gemm/legacy/tune_options.py create mode 100644 deep-gemm/torch-ext/deep_gemm/mega/__init__.py create mode 100644 deep-gemm/torch-ext/deep_gemm/utils/dist.py diff --git a/deep-gemm/CMakeLists.txt b/deep-gemm/CMakeLists.txt index 79f1964d..bbf625d3 100644 --- a/deep-gemm/CMakeLists.txt +++ b/deep-gemm/CMakeLists.txt @@ -3,8 +3,7 @@ cmake_minimum_required(VERSION 3.10) project(deep_gemm LANGUAGES CXX CUDA) set(CMAKE_VERBOSE_MAKEFILE ON) -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi -Wno-deprecated-declarations") set(CUDA_SEPARABLE_COMPILATION ON) list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG") list(APPEND CUDA_NVCC_FLAGS "-O3") @@ -18,11 +17,11 @@ find_package(CUDAToolkit REQUIRED) find_package(pybind11 REQUIRED) find_package(Torch REQUIRED) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CUDA_STANDARD 20) include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include) -include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) +include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include/cccl ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs) # The main Python API entrance diff --git a/deep-gemm/README.md b/deep-gemm/README.md index c81bf46b..03f0e0bc 100644 --- a/deep-gemm/README.md +++ b/deep-gemm/README.md @@ -1,45 +1,27 @@ # DeepGEMM -DeepGEMM is a library designed for clean and efficient General Matrix Multiplications (GEMMs). It supports FP8 and BF16 (working in progress) for both normal and Mix-of-Experts (MoE) grouped scenarios. Written in CUDA, the library has no kernel compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module. +DeepGEMM is a unified, high-performance tensor core kernel library that brings together the key computation primitives of modern large language models — GEMMs (FP8, FP4, BF16), fused MoE with overlapped communication (Mega MoE), MQA scoring for the lightning indexer, HyperConnection (HC), and more — into a single, cohesive CUDA codebase. All kernels are compiled at runtime via a lightweight Just-In-Time (JIT) module, requiring no CUDA compilation during installation. -DeepGEMM leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only a limited number of core kernel functions. This makes it a clean and accessible resource for learning NVIDIA GPU kernel optimization techniques. +DeepGEMM leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), but avoids heavy reliance on their templates or algebras. The library is designed for simplicity, with only a limited number of core kernel functions, making it a clean and accessible resource for learning NVIDIA GPU kernel optimization techniques. Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes. ## News +- 2026.04.16: Mega MoE, FP8xFP4 GEMM, FP4 Indexer, PDL, faster JIT compilation and more. + - Please see [#304](https://github.com/deepseek-ai/DeepGEMM/pull/304) for more details. + - For Mega MoE benchmarks, refer to [#316](https://github.com/deepseek-ai/DeepGEMM/pull/316). - 2025.09.28: DeepGEMM now supports scoring kernels (weighted ReLU MQA logits) for the lightning indexer for DeepSeek v3.2. - - Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details. + - Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details. - 2025.07.20: DeepGEMM now supports both SM90/SM100, and has a full refactor with a low-CPU-overhead JIT CPP module. - - NVRTC and post-compilation SASS optimization are all disabled. - - NVRTC will be supported later. - - As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported. - - Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details. + - NVRTC and post-compilation SASS optimization are all disabled. + - NVRTC will be supported later. + - As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported. + - Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details. - 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details. - 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases). - 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details. -## Roadmap - -- [x] More correctness tests for grouped-contiguous layout -- [x] Shared memory swizzling for output -- [x] MoE scheduler with TMA multicast compatibility -- [x] Fix TMA multicast compatibility for indivisible shapes -- [x] Skip useless computation on M -- [x] NVRTC as a faster compiler -- [x] Sanitizer for testing -- [x] Weight gradient kernels for dense models -- [x] Weight gradient kernels for MoE models -- [ ] Better `get_best_configs` modeling -- [ ] CUDA PDL support -- [ ] Larger TMA multicast size for some shapes -- [x] MMA template refactor with CUTLASS -- [x] Remove shape limitations on N and K -- [x] BF16 kernels -- [ ] Split/stream-k optimizations -- [ ] Ampere kernels -- [ ] Polish docs - ## Quick start ### Requirements @@ -48,9 +30,9 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - Python 3.8 or higher - Compilers with C++20 support - CUDA Toolkit: - - CUDA 12.3 or higher for SM90 - - **We highly recommend 12.9 or higher for the best performance** - - CUDA 12.9 or higher for SM100 + - CUDA 12.3 or higher for SM90 + - **We highly recommend 12.9 or higher for the best performance** + - CUDA 12.9 or higher for SM100 - PyTorch 2.1 or higher - CUTLASS 4.0 or higher (could be cloned by Git submodule) - `{fmt}` library (could be cloned by Git submodule) @@ -65,11 +47,6 @@ cd DeepGEMM # Link some essential includes and build the CPP JIT module cat develop.sh ./develop.sh - -# Test all GEMM implements -python tests/test_layout.py -python tests/test_attention.py -python tests/test_core.py ``` ### Installation @@ -134,17 +111,47 @@ out_ij = out_ij.sum() # Scalar For more details and the paged version `fp8_paged_mqa_logits`, please refer to `tests/test_attention.py`. +#### Mega MoE + +Mega MoE fuses and overlaps EP dispatch, linear 1 (FP8xFP4), SwiGLU, linear 2 (FP8xFP4), and EP combine into a single mega-kernel, overlapping NVLink communication and tensor core computation. It requires multi-process launch with symmetric memory. Usage: + +```python +# Allocate symmetric memory buffer +# NOTES: requires PyTorch >= 2.9 +buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden +) + +# Transform weights (FP4 with UE8M0 SF) into the required layout +transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights) + +# Copy inputs into the buffer before each call +# You may fuse these into previous kernels +buffer.x[:num_tokens].copy_(x_fp8) +buffer.x_sf[:num_tokens].copy_(x_sf) +buffer.topk_idx[:num_tokens].copy_(topk_idx) +buffer.topk_weights[:num_tokens].copy_(topk_weights) + +# Run the fused mega MoE kernel +y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') +deep_gemm.fp8_fp4_mega_moe(y, transformed_l1, transformed_l2, buffer) +``` + +For the full example with multi-process setup and benchmarking, please refer to `tests/test_mega_moe.py`. + #### Utilities The library provides some utility functions besides the above kernels: -- `deep_gemm.set_num_sms`: set the maximum SM count to use -- `deep_gemm.get_num_sms`: get the current SM maximum count (return the device SM count if not set) -- `deep_gemm.set_tc_util`: set an approximated tensor core utilization ratio -- `deep_gemm.get_tc_util`: get the current tensor core utilization ratio -- `deep_gemm.transform_sf_into_required_layout`: transform scaling factors into required layout +- `deep_gemm.set_num_sms` / `get_num_sms`: set/get the maximum SM count to use +- `deep_gemm.set_tc_util` / `get_tc_util`: set/get an approximated tensor core utilization ratio +- `deep_gemm.set_pdl` / `get_pdl`: enable/disable Programmatic Dependent Launch (PDL) +- `deep_gemm.set_mk_alignment_for_contiguous_layout` / `get_mk_alignment_for_contiguous_layout`: set/get the group-level M/K alignment for contiguous layout +- `deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout`: get the theoretical minimum M/K alignment +- `deep_gemm.set_ignore_compile_dims`: configure dimensions to ignore during JIT compilation +- `deep_gemm.set_block_size_multiple_of`: constrain block sizes to be multiples of a given value +- `deep_gemm.transform_sf_into_required_layout`: transform scaling factors into the required layout - `deep_gemm.get_tma_aligned_size`: get the required TMA alignment size -- `deep_gemm.get_mk_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout - `deep_gemm.get_mn_major_tma_aligned_tensor`: get a MN-major TMA-aligned tensor - `deep_gemm.get_mn_major_tma_aligned_packed_ue8m0_tensor`: get a MN-major TMA-aligned tensor (with packing FP32 into UE8M0) - `deep_gemm.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor`: K-grouped GEMM packing kernel @@ -152,17 +159,30 @@ The library provides some utility functions besides the above kernels: The library also provides some environment variables, which may be useful: - General - - `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default -- JIT cache related - - `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default -- NVCC/NVRTC selections - - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default - - `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default -- Compiler options - - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default - - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default -- Heuristic selection - - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default + - `DG_JIT_DEBUG`: `0` or `1`, print JIT debugging information, `0` by default + - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default +- JIT cache + - `DG_JIT_CACHE_DIR`: string, cache directory for compiled kernels, `$HOME/.deep_gemm` by default +- Compiler selection + - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC (faster compilation, may have lower performance for some cases), `0` by default + - `DG_JIT_NVCC_COMPILER`: string, NVCC compiler path; defaults to `torch.utils.cpp_extension.CUDA_HOME` + - `DG_JIT_CPP_STANDARD`: integer, C++ standard version, `20` by default +- Compiler output + - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print compilation commands, `0` by default + - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS output, `0` by default + - `DG_JIT_PTXAS_CHECK`: `0` or `1`, assert no local memory usage in compiled kernels, `0` by default + - `DG_JIT_PRINT_LOAD_TIME`: `0` or `1`, print kernel load time, `0` by default +- Debug and profiling + - `DG_JIT_WITH_LINEINFO`: `0` or `1`, embed source line info for profiling tools, `0` by default + - `DG_JIT_DUMP_ASM`: `0` or `1`, dump both PTX and SASS, `0` by default + - `DG_JIT_DUMP_PTX`: `0` or `1`, dump PTX output, `0` by default + - `DG_JIT_DUMP_SASS`: `0` or `1`, dump SASS output, `0` by default + - `DG_COMM_KERNEL_DEBUG`: `0` or `1`, zero symmetric buffer before each Mega MoE call for debugging, `0` by default + - `DG_USE_NVIDIA_TOOLS`: `0` or `1`, skip internal profiling when running under external NVIDIA tools, `0` by default +- Build options + - `DG_SKIP_CUDA_BUILD`: `0` or `1`, skip CUDA extension build during installation, `0` by default + - `DG_FORCE_BUILD`: `0` or `1`, force local build instead of downloading pre-built wheels, `0` by default + - `DG_JIT_USE_RUNTIME_API`: `0` or `1`, use CUDA Runtime API for kernel loading (requires CUDA runtime >= 12.8), `0` by default For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation. @@ -174,6 +194,14 @@ DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project This code repository is released under [the MIT License](LICENSE). --- +## Citation -vendored at commit 477618cd51baffca09c4b0b87e97c03fe827ef03 \ No newline at end of file +```bibtex +@misc{deepgemm2025, + title={DeepGEMM: clean and efficient BLAS kernel library on GPU}, + author={Chenggang Zhao and Zhean Xu and Liang Zhao and Jiashi Li and Chenhao Xu and Anyi Xu and Shengyu Liu and Kexing Zhou and Kuai Yu}, + year={2025}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/deepseek-ai/DeepGEMM}}, +} +``` diff --git a/deep-gemm/build.toml b/deep-gemm/build.toml index 84c36b1f..5f4ab1f8 100644 --- a/deep-gemm/build.toml +++ b/deep-gemm/build.toml @@ -2,10 +2,10 @@ name = "deep-gemm" license = "MIT" backends = ["cuda"] -version = 1 +version = 2 [general.cuda] -minver = "12.3" +minver = "12.8" [torch] src = [ @@ -16,7 +16,7 @@ pyext = ["py", "cuh", "hpp", "h"] [kernel.deep_gemm] backend = "cuda" -cuda-minver = "12.3" +cuda-minver = "12.8" cuda-capabilities = ["9.0a"] depends = ["torch", "cutlass_4_0"] include = [ @@ -24,7 +24,7 @@ include = [ "deep_gemm/include", ] cuda-flags = [ - "-std=c++17", + "-std=c++20", "-O3", "-DNDEBUG", "--expt-relaxed-constexpr", @@ -32,30 +32,33 @@ cuda-flags = [ "-Wno-deprecated-declarations", ] src = [ - # Compiled source "csrc/impl.cu", - # API headers "csrc/apis/attention.hpp", "csrc/apis/einsum.hpp", "csrc/apis/gemm.hpp", "csrc/apis/hyperconnection.hpp", "csrc/apis/layout.hpp", + "csrc/apis/mega.hpp", "csrc/apis/runtime.hpp", - # JIT infrastructure headers "csrc/jit/cache.hpp", "csrc/jit/compiler.hpp", "csrc/jit/device_runtime.hpp", "csrc/jit/handle.hpp", + "csrc/jit/include_parser.hpp", "csrc/jit/kernel_runtime.hpp", - # JIT kernel heuristics "csrc/jit_kernels/heuristics/common.hpp", + "csrc/jit_kernels/heuristics/config.hpp", + "csrc/jit_kernels/heuristics/mega_moe.hpp", + "csrc/jit_kernels/heuristics/runtime.hpp", "csrc/jit_kernels/heuristics/sm100.hpp", "csrc/jit_kernels/heuristics/sm90.hpp", - # JIT kernel implementations + "csrc/jit_kernels/heuristics/utils.hpp", "csrc/jit_kernels/impls/epilogue.hpp", "csrc/jit_kernels/impls/runtime_utils.hpp", "csrc/jit_kernels/impls/sm100_bf16_gemm.hpp", "csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp", + "csrc/jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp", + "csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp", "csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp", "csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp", "csrc/jit_kernels/impls/sm90_bf16_gemm.hpp", @@ -65,10 +68,11 @@ src = [ "csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp", "csrc/jit_kernels/impls/smxx_clean_logits.hpp", "csrc/jit_kernels/impls/smxx_cublaslt.hpp", + "csrc/jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp", + "csrc/jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp", "csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp", "csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp", "csrc/jit_kernels/impls/smxx_layout.hpp", - # Utility headers "csrc/utils/compatibility.hpp", "csrc/utils/exception.hpp", "csrc/utils/format.hpp", @@ -77,31 +81,56 @@ src = [ "csrc/utils/lazy_init.hpp", "csrc/utils/math.hpp", "csrc/utils/system.hpp", - # Runtime JIT headers (deep_gemm/include) + "csrc/utils/torch_compat.hpp", + "deep_gemm/include/deep_gemm/comm/barrier.cuh", + "deep_gemm/include/deep_gemm/common/compile.cuh", "deep_gemm/include/deep_gemm/common/cute_tie.cuh", "deep_gemm/include/deep_gemm/common/epilogue_utils.cuh", - "deep_gemm/include/deep_gemm/common/types.hpp", - "deep_gemm/include/deep_gemm/common/sm90_utils.cuh", + "deep_gemm/include/deep_gemm/common/exception.cuh", + "deep_gemm/include/deep_gemm/common/math.cuh", "deep_gemm/include/deep_gemm/common/reduction.cuh", - "deep_gemm/include/deep_gemm/common/utils.cuh", - "deep_gemm/include/deep_gemm/common/tma_utils.cuh", - "deep_gemm/include/deep_gemm/common/sm100_utils.cuh", "deep_gemm/include/deep_gemm/common/scheduler.cuh", - "deep_gemm/include/deep_gemm/impls/smxx_layout.cuh", - "deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh", - "deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh", - "deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh", + "deep_gemm/include/deep_gemm/common/sm100_utils.cuh", + "deep_gemm/include/deep_gemm/common/sm90_utils.cuh", + "deep_gemm/include/deep_gemm/common/tma_copy.cuh", + "deep_gemm/include/deep_gemm/common/tma_utils.cuh", + "deep_gemm/include/deep_gemm/common/types.cuh", + "deep_gemm/include/deep_gemm/common/types.hpp", + "deep_gemm/include/deep_gemm/common/utils.cuh", + "deep_gemm/include/deep_gemm/epilogue/sm100_store_cd.cuh", + "deep_gemm/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh", + "deep_gemm/include/deep_gemm/epilogue/transform.cuh", + "deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh", + "deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh", + "deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh", + "deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh", + "deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh", + "deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh", "deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh", + "deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh", + "deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh", "deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh", "deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh", + "deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh", "deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh", - "deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh", - "deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh", - "deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh", + "deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh", "deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh", - "deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh", + "deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh", "deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh", - "deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh", + "deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh", + "deep_gemm/include/deep_gemm/impls/smxx_layout.cuh", + "deep_gemm/include/deep_gemm/layout/mega_moe.cuh", + "deep_gemm/include/deep_gemm/layout/sym_buffer.cuh", + "deep_gemm/include/deep_gemm/mma/sm100.cuh", + "deep_gemm/include/deep_gemm/mma/sm90.cuh", + "deep_gemm/include/deep_gemm/ptx/ld_st.cuh", + "deep_gemm/include/deep_gemm/ptx/tcgen05.cuh", + "deep_gemm/include/deep_gemm/ptx/tma.cuh", + "deep_gemm/include/deep_gemm/ptx/utils.cuh", + "deep_gemm/include/deep_gemm/ptx/wgmma.cuh", + "deep_gemm/include/deep_gemm/scheduler/gemm.cuh", + "deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh", + "deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh", ] [general.hub] diff --git a/deep-gemm/csrc/apis/attention.hpp b/deep-gemm/csrc/apis/attention.hpp index c83233d0..0b628884 100644 --- a/deep-gemm/csrc/apis/attention.hpp +++ b/deep-gemm/csrc/apis/attention.hpp @@ -5,9 +5,9 @@ #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE #include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp" #include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" -#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp" -#include "../jit_kernels/impls/smxx_fp8_mqa_logits.hpp" -#include "../jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp" +#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp" +#include "../jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp" +#include "../jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp" #include "../jit_kernels/impls/smxx_clean_logits.hpp" #endif @@ -24,8 +24,8 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair(a.first); - const auto& [n , k_] = get_shape<2>(b.first); - const auto& [m_, n_] = get_shape<2>(d); + const auto [m , k ] = get_shape<2>(a.first); + const auto [n , k_] = get_shape<2>(b.first); + const auto [m_, n_] = get_shape<2>(d); DG_HOST_ASSERT(m == m_ and k == k_); DG_HOST_ASSERT(n > 0 and k > 0); DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); @@ -45,7 +45,7 @@ static void fp8_gemm_nt_skip_head_mid(const std::pairget_arch_major(); - const auto& epilogue_type = fmt::format("EpilogueHeadSplits<{}, {}, {}>", left, mid, right); + const auto arch_major = device_runtime->get_arch_major(); + const auto epilogue_type = fmt::format("epilogue::transform::EpilogueHeadSplits<{}, {}, {}>", left, mid, right); if (arch_major == 9 and sfa.scalar_type() == torch::kFloat and std::get<1>(recipe.value()) != 1) { - const auto& major_sfb = get_major_type_ab(sfb); + const auto major_sfb = get_major_type_ab(sfb); sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, major_sfb, compiled_dims, epilogue_type); } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { // NOTES: Only granularity 128 and FP8 are exposed in the API @@ -73,59 +73,113 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair& kv, - const torch::Tensor& weights, - const torch::Tensor& cu_seq_len_k_start, - const torch::Tensor& cu_seq_len_k_end, - const bool& clean_logits, - const int& max_seqlen_k) { - const auto& [seq_len, num_heads, head_dim] = get_shape<3>(q); - const auto& [seq_len_kv, head_dim_] = get_shape<2>(kv.first); - const auto& [seq_len_, num_heads_] = get_shape<2>(weights); - const auto& [seq_len_kv_] = get_shape<1>(kv.second); - - DG_HOST_ASSERT(seq_len == seq_len_); - DG_HOST_ASSERT(num_heads == num_heads_ and head_dim == head_dim_); - DG_HOST_ASSERT(seq_len_kv == seq_len_kv_); - DG_HOST_ASSERT(cu_seq_len_k_start.size(0) == seq_len); - DG_HOST_ASSERT(cu_seq_len_k_end.size(0) == seq_len); - - DG_HOST_ASSERT(q.is_contiguous() and kv.first.is_contiguous()); - DG_HOST_ASSERT(kv.second.is_contiguous()); - DG_HOST_ASSERT(weights.is_contiguous()); - DG_HOST_ASSERT(cu_seq_len_k_start.is_contiguous()); - DG_HOST_ASSERT(cu_seq_len_k_end.is_contiguous()); +static torch::Tensor fp8_fp4_mqa_logits(const std::tuple>& q, + const std::tuple& kv, + const torch::Tensor& weights, + const torch::Tensor& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const bool& clean_logits, + const int& max_seqlen_k, + const at::ScalarType& logits_dtype) { + const auto [q_fp, q_sf] = q; + const auto [kv_fp, kv_sf] = kv; + const bool is_fp4 = q_sf.has_value(); + int seq_len, seq_len_kv, num_heads, head_dim; + + if (is_fp4) { + // Check FP4 Q + std::tie(seq_len, num_heads, head_dim) = get_shape<3>(q_fp); + head_dim *= 2; + DG_HOST_ASSERT(num_heads == 32 or num_heads == 64); + DG_HOST_ASSERT(head_dim == 128); + DG_HOST_ASSERT(q_fp.is_contiguous()); + DG_HOST_ASSERT(q_fp.scalar_type() == kPackedFP4); + + // Check SF Q + auto [_seq_len, _num_heads] = get_shape<2>(q_sf.value()); + DG_HOST_ASSERT(seq_len == _seq_len and num_heads == _num_heads); + DG_HOST_ASSERT(q_sf.value().is_contiguous()); + DG_HOST_ASSERT(q_sf.value().scalar_type() == torch::kInt32); + + // Check FP4 KV + int _head_dim; + std::tie(seq_len_kv, _head_dim) = get_shape<2>(kv_fp); + _head_dim *= 2; + DG_HOST_ASSERT(head_dim == _head_dim); + DG_HOST_ASSERT(kv_fp.is_contiguous()); + DG_HOST_ASSERT(kv_fp.scalar_type() == kPackedFP4); + + // Check SF KV + auto [_seq_len_kv] = get_shape<1>(kv_sf); + DG_HOST_ASSERT(seq_len_kv == _seq_len_kv); + DG_HOST_ASSERT(kv_sf.is_contiguous()); + DG_HOST_ASSERT(kv_sf.scalar_type() == torch::kInt32); + } else { + // Check FP8 Q + std::tie(seq_len, num_heads, head_dim) = get_shape<3>(q_fp); + DG_HOST_ASSERT(num_heads == 32 or num_heads == 64); + DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); + DG_HOST_ASSERT(q_fp.is_contiguous()); + DG_HOST_ASSERT(q_fp.scalar_type() == torch::kFloat8_e4m3fn); + + // Check FP4 KV + int _head_dim; + std::tie(seq_len_kv, _head_dim) = get_shape<2>(kv_fp); + DG_HOST_ASSERT(head_dim == _head_dim); + DG_HOST_ASSERT(kv_fp.is_contiguous()); + DG_HOST_ASSERT(kv_fp.scalar_type() == torch::kFloat8_e4m3fn); + + // Check SF KV + auto [_seq_len_kv] = get_shape<1>(kv_sf); + DG_HOST_ASSERT(seq_len_kv == _seq_len_kv); + DG_HOST_ASSERT(kv_sf.is_contiguous()); + DG_HOST_ASSERT(kv_sf.scalar_type() == torch::kFloat); + } - DG_HOST_ASSERT(q.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(kv.first.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(kv.second.scalar_type() == torch::kFloat); + // Check weights + auto [_seq_len, _num_heads] = get_shape<2>(weights); + DG_HOST_ASSERT(seq_len == _seq_len and num_heads == _num_heads); + DG_HOST_ASSERT(weights.stride(1) == 1); DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat); + + // Check cu_seq_len_k_start + DG_HOST_ASSERT(cu_seq_len_k_start.size(0) == seq_len); + DG_HOST_ASSERT(cu_seq_len_k_start.is_contiguous()); DG_HOST_ASSERT(cu_seq_len_k_start.scalar_type() == torch::kInt); + + // Check cu_seq_len_k_end + DG_HOST_ASSERT(cu_seq_len_k_end.size(0) == seq_len); + DG_HOST_ASSERT(cu_seq_len_k_end.is_contiguous()); DG_HOST_ASSERT(cu_seq_len_k_end.scalar_type() == torch::kInt); - constexpr int seq_len_alignment = 4; + // Allocate output + constexpr int block_qh = 128; constexpr int block_kv = 256; - const auto aligned_seq_len = align(seq_len, seq_len_alignment); - + const int block_q = block_qh / num_heads; + DG_HOST_ASSERT(block_qh % num_heads == 0); + torch::Tensor logits; - int stride_logits; + int aligned_seq_len = align(seq_len, block_q), stride_logits; if (max_seqlen_k == 0) { - stride_logits = align(seq_len_kv + block_kv, 4); - logits = torch::empty({aligned_seq_len, stride_logits}, q.options().dtype(torch::kFloat)); + // Logits stride must be 16-byte aligned + stride_logits = align(seq_len_kv + block_kv, 8); + logits = torch::empty({aligned_seq_len, stride_logits}, q_fp.options().dtype(logits_dtype)); logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, seq_len_kv)}); } else { stride_logits = align(max_seqlen_k, block_kv); - logits = torch::empty({aligned_seq_len, stride_logits}, q.options().dtype(torch::kFloat)); + logits = torch::empty({aligned_seq_len, stride_logits}, q_fp.options().dtype(logits_dtype)); logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, max_seqlen_k)}); DG_HOST_ASSERT(not clean_logits); } // Dispatch implementation - const auto& arch_major = device_runtime->get_arch_major(); - if (arch_major == 9 or arch_major == 10) { - smxx_fp8_mqa_logits(q, kv.first, kv.second, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, - seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, seq_len_alignment); + const auto arch_major = device_runtime->get_arch_major(); + if (is_fp4 and arch_major == 10) { + sm100_fp4_mqa_logits(q_fp, q_sf.value(), kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype, + seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv); + } else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) { + smxx_fp8_mqa_logits(q_fp, kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype, + seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } @@ -136,25 +190,31 @@ static torch::Tensor fp8_mqa_logits(const torch::Tensor& q, return logits; } -static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms) { - const bool is_context_lens_2d = context_lens.dim() == 2; - int batch_size = 0, next_n = 0; - if (is_context_lens_2d) { - batch_size = context_lens.size(0); - next_n = context_lens.size(1); - } else { - DG_HOST_ASSERT(context_lens.dim() == 1); - batch_size = context_lens.size(0); - } +static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms, const std::optional& indices) { + // NOTES: Only 2D context lens is supported for now + DG_HOST_ASSERT(context_lens.dim() == 2); + const bool is_context_lens_2d = true; + const int batch_size = context_lens.size(0); + const int next_n = context_lens.size(1); + const bool is_varlen = indices.has_value(); DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt); DG_HOST_ASSERT(context_lens.is_contiguous()); + // Create metadata tensor auto schedule_metadata = torch::empty({num_sms + 1, 2}, context_lens.options()); // Dispatch implementation - const auto& arch_major = device_runtime->get_arch_major(); - if (arch_major == 9 or arch_major == 10) { - smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d); + const auto arch_major = device_runtime->get_arch_major(); + if (is_varlen) { + const auto& indices_tensor = indices.value(); + DG_HOST_ASSERT(arch_major == 10 and next_n == 1 and (block_kv == 64 or block_kv == 32)); + DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size); + DG_HOST_ASSERT(indices_tensor.is_contiguous()); + DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt); + smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d, true, indices_tensor.data_ptr()); + } else if (arch_major == 9 or arch_major == 10) { + DG_HOST_ASSERT(block_kv == 64 or (arch_major == 10 and block_kv == 32)); + smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d, false, nullptr); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } @@ -162,85 +222,156 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_ return schedule_metadata; } -static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q, - const torch::Tensor& fused_kv_cache, - const torch::Tensor& weights, - const torch::Tensor& context_lens, - const torch::Tensor& block_table, - const torch::Tensor& schedule_meta, - const int& max_context_len, - const bool& clean_logits) { - const auto& [batch_size, next_n, num_heads, head_dim] = get_shape<4>(q); - const auto& [num_kv_blocks, block_kv, num_heads_kv, head_dim_with_sf] = get_shape<4>(fused_kv_cache); - const auto& [batch_size_next_n, num_heads_] = get_shape<2>(weights); - const auto& [batch_size_, max_block_len] = get_shape<2>(block_table); - const auto& [schedule_meta_size, meta_info_size] = get_shape<2>(schedule_meta); - const auto& num_sms = device_runtime->get_num_sms(); - const auto& kv_cache_stride_bytes = fused_kv_cache.stride(0); - const auto& block_table_stride = block_table.stride(0); - - const bool is_context_lens_2d = context_lens.dim() == 2; - if (is_context_lens_2d) { - const auto& [batch_size__, next_n_] = get_shape<2>(context_lens); - DG_HOST_ASSERT(batch_size == batch_size__ and next_n == next_n_); +static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple>& q, + const torch::Tensor& fused_kv_cache, + const torch::Tensor& weights, + const torch::Tensor& context_lens, + const torch::Tensor& block_table, + const torch::Tensor& schedule_meta, + const int& max_context_len, + const bool& clean_logits, + const at::ScalarType& logits_dtype, + const std::optional& indices) { + const auto [q_fp, q_sf] = q; + const bool is_fp4 = q_sf.has_value(); + + torch::Tensor kv_cache, kv_cache_sf; + int batch_size, next_n, num_heads, head_dim; + int num_kv_blocks, block_kv; + int kv_cache_stride_bytes; + int block_table_stride = block_table.stride(0); + int num_sms = device_runtime->get_num_sms(); + + if (is_fp4) { + // Check FP4 Q + std::tie(batch_size, next_n, num_heads, head_dim) = get_shape<4>(q_fp); + head_dim *= 2; + DG_HOST_ASSERT(next_n >= 1); + DG_HOST_ASSERT(num_heads == 32 or num_heads == 64); + DG_HOST_ASSERT(head_dim == 128); + DG_HOST_ASSERT(q_fp.is_contiguous()); + DG_HOST_ASSERT(q_fp.scalar_type() == kPackedFP4); + + // Check SF Q + auto [_batch_size, _next_n, _num_heads] = get_shape<3>(q_sf.value()); + DG_HOST_ASSERT(batch_size == _batch_size and next_n == _next_n and num_heads == _num_heads); + DG_HOST_ASSERT(q_sf.value().is_contiguous()); + DG_HOST_ASSERT(q_sf.value().scalar_type() == torch::kInt32); + + // Check fused KV cache + int num_heads_kv, fp4_with_sf_bytes; + std::tie(num_kv_blocks, block_kv, num_heads_kv, fp4_with_sf_bytes) = get_shape<4>(fused_kv_cache); + DG_HOST_ASSERT(block_kv == 32 or block_kv == 64); + DG_HOST_ASSERT(num_heads_kv == 1 and fp4_with_sf_bytes == head_dim / 2 + static_cast(sizeof(int))); + DG_HOST_ASSERT(fused_kv_cache.stride(1) == fp4_with_sf_bytes and fused_kv_cache.stride(3) == 1); + DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte); + + // Derive FP4 values and SF tensor + kv_cache_stride_bytes = fused_kv_cache.stride(0); + DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(int) == 0); + kv_cache = torch::from_blob( + fused_kv_cache.data_ptr(), + {num_kv_blocks, block_kv, head_dim / 2}, + {kv_cache_stride_bytes, head_dim / 2, 1}, + torch::TensorOptions().dtype(kPackedFP4) + ); + kv_cache_sf = torch::from_blob( + fused_kv_cache.data_ptr() + block_kv * head_dim / 2, + {num_kv_blocks, block_kv}, + {kv_cache_stride_bytes / static_cast(sizeof(int)), 1}, + torch::TensorOptions().dtype(torch::kInt32) + ); } else { - DG_HOST_ASSERT(context_lens.dim() == 1); - const auto& [batch_size__] = get_shape<1>(context_lens); - DG_HOST_ASSERT(batch_size == batch_size__); + // Check FP8 Q + std::tie(batch_size, next_n, num_heads, head_dim) = get_shape<4>(q_fp); + DG_HOST_ASSERT(next_n >= 1); + DG_HOST_ASSERT(num_heads == 32 or num_heads == 64); + DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); + DG_HOST_ASSERT(q_fp.is_contiguous()); + DG_HOST_ASSERT(q_fp.scalar_type() == torch::kFloat8_e4m3fn); + + // Check fused KV cache + int num_heads_kv, head_dim_with_sf; + std::tie(num_kv_blocks, block_kv, num_heads_kv, head_dim_with_sf) = get_shape<4>(fused_kv_cache); + DG_HOST_ASSERT(block_kv == 32 or block_kv == 64); + DG_HOST_ASSERT(num_heads_kv == 1 and head_dim_with_sf == head_dim + static_cast(sizeof(float))); + DG_HOST_ASSERT(fused_kv_cache.stride(1) == head_dim_with_sf and fused_kv_cache.stride(3) == 1); + DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte); + + // Derive FP8 values and SF tensor + kv_cache_stride_bytes = fused_kv_cache.stride(0); + DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(float) == 0); + kv_cache = torch::from_blob( + fused_kv_cache.data_ptr(), + {num_kv_blocks, block_kv, head_dim}, + {kv_cache_stride_bytes, head_dim, 1}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn) + ); + kv_cache_sf = torch::from_blob( + fused_kv_cache.data_ptr() + block_kv * head_dim, + {num_kv_blocks, block_kv}, + {kv_cache_stride_bytes / static_cast(sizeof(float)), 1}, + torch::TensorOptions().dtype(torch::kFloat32) + ); + + // Weights must be contiguous for FP8 + DG_HOST_ASSERT(weights.is_contiguous()); } - DG_HOST_ASSERT(batch_size == batch_size_); - DG_HOST_ASSERT(batch_size_next_n == batch_size * next_n); - DG_HOST_ASSERT(num_heads == num_heads_ and num_heads_kv == 1); - DG_HOST_ASSERT(head_dim_with_sf == head_dim + static_cast(sizeof(float))); - DG_HOST_ASSERT(schedule_meta_size == num_sms + 1 and meta_info_size == 2); - - DG_HOST_ASSERT(next_n == 1 or next_n == 2); - DG_HOST_ASSERT(block_kv == 64); - - DG_HOST_ASSERT(q.is_contiguous()); - DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(float) == 0); - DG_HOST_ASSERT(fused_kv_cache.stride(1) == head_dim_with_sf); - DG_HOST_ASSERT(fused_kv_cache.stride(2) == head_dim_with_sf); - DG_HOST_ASSERT(fused_kv_cache.stride(3) == 1); - DG_HOST_ASSERT(weights.is_contiguous()); - DG_HOST_ASSERT(context_lens.is_contiguous()); - DG_HOST_ASSERT(block_table.stride(1) == 1); - DG_HOST_ASSERT(schedule_meta.is_contiguous()); - - DG_HOST_ASSERT(q.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte); + // Check weights + auto [_batch_size_next_n, _num_heads] = get_shape<2>(weights); + DG_HOST_ASSERT(_batch_size_next_n == batch_size * next_n and _num_heads == num_heads); + DG_HOST_ASSERT(weights.stride(1) == 1); DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat); - DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt); + + // Check block table + auto [_batch_size, _max_block_len] = get_shape<2>(block_table); + DG_HOST_ASSERT(_batch_size == batch_size); + DG_HOST_ASSERT(block_table.stride(1) == 1); DG_HOST_ASSERT(block_table.scalar_type() == torch::kInt); + + // Check indices + const bool is_varlen = indices.has_value(); + const auto arch_major = device_runtime->get_arch_major(); + const auto indices_tensor = indices.value_or(torch::Tensor()); + if (is_varlen) { + DG_HOST_ASSERT(arch_major == 10 and next_n == 1); + DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size); + DG_HOST_ASSERT(indices_tensor.is_contiguous()); + DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt); + } + + // Check schedule metadata + auto [_schedule_meta_size, _meta_info_size] = get_shape<2>(schedule_meta); + DG_HOST_ASSERT(_schedule_meta_size == num_sms + 1 and _meta_info_size == 2); + DG_HOST_ASSERT(schedule_meta.is_contiguous()); DG_HOST_ASSERT(schedule_meta.scalar_type() == torch::kInt); - // Derive FP8 values and SF tensor from KV cache - const auto& kv_cache = torch::from_blob( - fused_kv_cache.data_ptr(), - {num_kv_blocks, block_kv, head_dim}, - {kv_cache_stride_bytes, head_dim, 1}, - torch::TensorOptions().dtype(torch::kFloat8_e4m3fn) - ); - const auto& kv_cache_scales = torch::from_blob( - fused_kv_cache.data_ptr() + block_kv * head_dim, - {num_kv_blocks, block_kv}, - {kv_cache_stride_bytes / static_cast(sizeof(float)), 1}, - torch::TensorOptions().dtype(torch::kFloat32) - ); + // Check context lengths + // NOTES: Only 2D context lens is supported for now + DG_HOST_ASSERT(context_lens.dim() == 2); + const bool is_context_lens_2d = true; + const auto [__batch_size, _next_n] = get_shape<2>(context_lens); + DG_HOST_ASSERT(batch_size == __batch_size and next_n == _next_n); + DG_HOST_ASSERT(context_lens.is_contiguous()); + DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt); // Allocate output constexpr int split_kv = 256; - const auto& aligned_max_context_len = align(max_context_len, split_kv); - auto logits = torch::empty({batch_size * next_n, aligned_max_context_len}, q.options().dtype(torch::kFloat)); + const auto aligned_max_context_len = align(max_context_len, split_kv); + auto logits = torch::empty({batch_size * next_n, aligned_max_context_len}, q_fp.options().dtype(logits_dtype)); logits = logits.slice(-1, 0, max_context_len); + DG_HOST_ASSERT(logits_dtype == torch::kFloat32 or logits_dtype == torch::kBFloat16); // Dispatch implementation - const auto& arch_major = device_runtime->get_arch_major(); - if (arch_major == 9 or arch_major == 10) { - smxx_fp8_paged_mqa_logits(q, kv_cache, kv_cache_scales, weights, context_lens, logits, block_table, schedule_meta, - batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d, - kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, split_kv); + if (is_fp4 and arch_major == 10) { + sm100_fp4_paged_mqa_logits(q_fp, q_sf.value(), kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta, + logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d, + is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv); + } else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) { + smxx_fp8_paged_mqa_logits(q_fp, kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta, + logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d, + is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } @@ -253,9 +384,36 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q, return logits; } + +// Legacy API wrappers +static torch::Tensor fp8_mqa_logits(const torch::Tensor& q, + const std::tuple& kv, + const torch::Tensor& weights, + const torch::Tensor& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const bool& clean_logits, + const int& max_seqlen_k) { + return fp8_fp4_mqa_logits(std::make_tuple(q, std::nullopt), kv, weights, + cu_seq_len_k_start, cu_seq_len_k_end, + clean_logits, max_seqlen_k, torch::kFloat); +} + +static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q, + const torch::Tensor& fused_kv_cache, + const torch::Tensor& weights, + const torch::Tensor& context_lens, + const torch::Tensor& block_table, + const torch::Tensor& schedule_meta, + const int& max_context_len, + const bool& clean_logits, + const std::optional& indices) { + return fp8_fp4_paged_mqa_logits(std::make_tuple(q, std::nullopt), fused_kv_cache, weights, + context_lens, block_table, schedule_meta, + max_context_len, clean_logits, torch::kFloat, indices); +} #endif -#ifdef DG_USE_PYBIND11 +#if !defined(__CUDACC__) static void register_apis(pybind11::module_& m) { #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid, @@ -263,17 +421,33 @@ static void register_apis(pybind11::module_& m) { py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); + m.def("fp8_fp4_mqa_logits", &fp8_fp4_mqa_logits, + py::arg("q"), py::arg("kv"), py::arg("weights"), + py::arg("cu_seq_len_k_start"), py::arg("cu_seq_len_k_end"), + py::arg("clean_logits") = true, + py::arg("max_seqlen_k") = 0, + py::arg("logits_dtype") = torch::kFloat32); + m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata, + py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms"), + py::arg("indices") = std::nullopt); + m.def("fp8_fp4_paged_mqa_logits", &fp8_fp4_paged_mqa_logits, + py::arg("q"), py::arg("kv_cache"), py::arg("weights"), + py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"), + py::arg("max_context_len"), + py::arg("clean_logits") = false, + py::arg("logits_dtype") = torch::kFloat32, + py::arg("indices") = std::nullopt); + // Legacy API m.def("fp8_mqa_logits", &fp8_mqa_logits, py::arg("q"), py::arg("kv"), py::arg("weights"), py::arg("cu_seq_len_k_start"), py::arg("cu_seq_len_k_end"), py::arg("clean_logits") = true, py::arg("max_seqlen_k") = 0); - m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata, - py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms")); m.def("fp8_paged_mqa_logits", &fp8_paged_mqa_logits, py::arg("q"), py::arg("kv_cache"), py::arg("weights"), py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"), - py::arg("max_context_len"), py::arg("clean_logits") = false); + py::arg("max_context_len"), py::arg("clean_logits") = false, + py::arg("indices") = std::nullopt); #endif } #endif diff --git a/deep-gemm/csrc/apis/einsum.hpp b/deep-gemm/csrc/apis/einsum.hpp index ad489923..5154747c 100644 --- a/deep-gemm/csrc/apis/einsum.hpp +++ b/deep-gemm/csrc/apis/einsum.hpp @@ -1,6 +1,6 @@ #pragma once -#ifdef DG_USE_PYBIND11 +#if !defined(__CUDACC__) #include #include #endif @@ -31,7 +31,7 @@ static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const tor DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(not c.has_value()); - const auto& workspace = torch::empty_like(d, d.options().dtype(torch::kFloat32)); + const auto workspace = torch::empty_like(d, d.options().dtype(torch::kFloat32)); DG_CUDA_RUNTIME_CHECK(cudaMemsetAsync(workspace.data_ptr(), 0, workspace.nbytes(), c10::cuda::getCurrentCUDAStream())); bmk_bnk_mn(a, b, workspace, workspace); @@ -45,12 +45,12 @@ static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const tor DG_HOST_ASSERT(b.is_contiguous()); DG_HOST_ASSERT(d.is_contiguous()); - const auto& [s , m, k ] = get_shape<3>(a); - const auto& [s_, n, k_] = get_shape<3>(b); + const auto [s , m, k ] = get_shape<3>(a); + const auto [s_, n, k_] = get_shape<3>(b); DG_HOST_ASSERT(s == s_ and k == k_); // Dispatch implementation - const auto& arch_major = device_runtime->get_arch_major(); + const auto arch_major = device_runtime->get_arch_major(); if (arch_major == 9) { sm90_bmn_bnk_mn_gemm(a, b, d, s, m, n, k); } else if (arch_major == 10) { @@ -61,9 +61,9 @@ static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const tor } static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D, const bool& use_cublaslt) { - const auto& [b , h , r ] = get_shape<3>(A); - const auto& [h_, d , r_] = get_shape<3>(B); - const auto& [b_, h__, d_] = get_shape<3>(D); + const auto [b , h , r ] = get_shape<3>(A); + const auto [h_, d , r_] = get_shape<3>(B); + const auto [b_, h__, d_] = get_shape<3>(D); DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__); DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1); @@ -71,7 +71,7 @@ static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const to DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1); // Dispatch implementation - const auto& arch_major = device_runtime->get_arch_major(); + const auto arch_major = device_runtime->get_arch_major(); if (use_cublaslt) { cublaslt_bhr_hdr_bhd(A, B, D, b, h, r, d); } else if (arch_major == 9) { @@ -84,9 +84,9 @@ static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const to } static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D, const bool& use_cublaslt) { - const auto& [b , h , d ] = get_shape<3>(A); - const auto& [h_, d_ , r ] = get_shape<3>(B); - const auto& [b_, h__, r_] = get_shape<3>(D); + const auto [b , h , d ] = get_shape<3>(A); + const auto [h_, d_ , r ] = get_shape<3>(B); + const auto [b_, h__, r_] = get_shape<3>(D); DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__); DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1); @@ -94,7 +94,7 @@ static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const to DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1); // Dispatch implementation - const auto& arch_major = device_runtime->get_arch_major(); + const auto arch_major = device_runtime->get_arch_major(); if (use_cublaslt) { cublaslt_bhd_hdr_bhr(A, B, D, b, h, r, d); } else if (arch_major == 9) { @@ -144,16 +144,16 @@ static void fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, std::optional> recipe, const std::string& compiled_dims) { // Shape must be `[B, M, K] @ [B, N, K].T` - const auto& major_a = a.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; - const auto& major_b = b.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; + const auto major_a = a.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; + const auto major_b = b.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; DG_HOST_ASSERT(a.stride(-1) == 1 or a.stride(-2) == 1); DG_HOST_ASSERT(b.stride(-1) == 1 or b.stride(-2) == 1); DG_HOST_ASSERT(d.stride(-1) == 1); // Type and shape checks - const auto& [batch_size , m , k ] = get_shape<3>(a); - const auto& [batch_size_ , n , k_] = get_shape<3>(b); - const auto& [batch_size__, m_, n_] = get_shape<3>(d); + const auto [batch_size , m , k ] = get_shape<3>(a); + const auto [batch_size_ , n , k_] = get_shape<3>(b); + const auto [batch_size__, m_, n_] = get_shape<3>(d); DG_HOST_ASSERT(batch_size == batch_size_ and batch_size == batch_size_); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); DG_HOST_ASSERT(a.scalar_type() == torch::kFloat8_e4m3fn); @@ -165,15 +165,16 @@ static void fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, return; // Transform scaling factors - const auto& [transformed_sfa, transformed_sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + const auto [transformed_sfa, transformed_sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( sfa, sfb, m, n, k, recipe, std::nullopt, std::nullopt, batch_size, batch_size, false); // Dispatch implementation const auto arch_major = device_runtime->get_arch_major(); if (arch_major == 10) { - sm100_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, compiled_dims); + sm100_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, gran_k_a, gran_k_b, major_a, major_b, compiled_dims); } else { - const auto& major_sfb = get_major_type_ab(sfb); + const auto major_sfb = get_major_type_ab(sfb); + DG_HOST_ASSERT(gran_k_a == 128 and gran_k_b == 128); sm90_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, major_sfb, compiled_dims); } } @@ -189,26 +190,26 @@ static void fp8_einsum(const std::string& expr, if (expr == "bhr,hdr->bhd") { // Permute dims to satisfy the order of (batch_size, m, n, k) // (batch_size, m, n, k): (h, b, d, r) - const auto& perm_a = a.first.permute({1, 0, 2}); - const auto& perm_sfa = a.second.permute({1, 0, 2}); - const auto& perm_d = d.permute({1, 0, 2}); - const auto& perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt; + const auto perm_a = a.first.permute({1, 0, 2}); + const auto perm_sfa = a.second.permute({1, 0, 2}); + const auto perm_d = d.permute({1, 0, 2}); + const auto perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt; fp8_bmm(perm_a, perm_sfa, b.first, b.second, perm_d, perm_c, recipe, "nk"); } else if (expr == "bhd,hdr->bhr" and arch_major == 10) { // (batch_size, m, n, k): (h, b, r, d) - const auto& perm_a = a.first.permute({1, 0, 2}); - const auto& perm_sfa = a.second.permute({1, 0, 2}); - const auto& perm_b = b.first.permute({0, 2, 1}); - const auto& perm_sfb = b.second.permute({0, 2, 1}); - const auto& perm_d = d.permute({1, 0, 2}); - const auto& perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt; + const auto perm_a = a.first.permute({1, 0, 2}); + const auto perm_sfa = a.second.permute({1, 0, 2}); + const auto perm_b = b.first.permute({0, 2, 1}); + const auto perm_sfb = b.second.permute({0, 2, 1}); + const auto perm_d = d.permute({1, 0, 2}); + const auto perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt; fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, perm_d, perm_c, recipe, "nk"); } else if (expr == "bhd,bhr->hdr" and arch_major == 10) { // (batch_size, m, n, k): (h, d, r, b) - const auto& perm_a = a.first.permute({1, 2, 0}); - const auto& perm_sfa = a.second.permute({1, 2, 0}); - const auto& perm_b = b.first.permute({1, 2, 0}); - const auto& perm_sfb = b.second.permute({1, 2, 0}); + const auto perm_a = a.first.permute({1, 2, 0}); + const auto perm_sfa = a.second.permute({1, 2, 0}); + const auto perm_b = b.first.permute({1, 2, 0}); + const auto perm_sfb = b.second.permute({1, 2, 0}); fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, d, c, recipe, "mn"); } else { DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr)); @@ -216,7 +217,7 @@ static void fp8_einsum(const std::string& expr, } #endif -#ifdef DG_USE_PYBIND11 +#if !defined(__CUDACC__) static void register_apis(pybind11::module_& m) { #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE m.def("einsum", &einsum, diff --git a/deep-gemm/csrc/apis/gemm.hpp b/deep-gemm/csrc/apis/gemm.hpp index f12517cf..924820b5 100644 --- a/deep-gemm/csrc/apis/gemm.hpp +++ b/deep-gemm/csrc/apis/gemm.hpp @@ -6,7 +6,7 @@ #include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp" #include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" #include "../jit_kernels/impls/sm90_bf16_gemm.hpp" -#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp" +#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp" #include "../jit_kernels/impls/sm100_bf16_gemm.hpp" #endif @@ -23,7 +23,7 @@ static bool early_return(const int& m, const int &n, const int& k, return true; // Checks - const bool& is_cd_same = c.has_value() and c->data_ptr() == d.data_ptr(); + const bool is_cd_same = c.has_value() and c->data_ptr() == d.data_ptr(); if (is_cd_same) DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); if (c.has_value()) { @@ -57,8 +57,8 @@ static void fp8_fp4_gemm_nt(const std::pair& a, const std::string& compiled_dims, const bool& disable_ue8m0_cast) { // Shape must be `[M, K] @ [N, K].T` - const auto& major_a = get_major_type_ab(a.first); - const auto& major_b = get_major_type_ab(b.first); + const auto major_a = get_major_type_ab(a.first); + const auto major_b = get_major_type_ab(b.first); if (fp8_requires_k_major()) { DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); @@ -89,7 +89,7 @@ static void fp8_fp4_gemm_nt(const std::pair& a, if (gran_n == 1) { sm90_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); } else { - const auto& major_sfb = get_major_type_ab(sfb); + const auto major_sfb = get_major_type_ab(sfb); sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims); } } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { @@ -152,8 +152,8 @@ static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair& expected_m_for_psum_layout) { // Shape must be `[M, K] @ [G, N, K].mT` - const auto& major_a = get_major_type_ab(a.first); - const auto& major_b = get_major_type_ab(b.first); + const auto major_a = get_major_type_ab(a.first); + const auto major_b = get_major_type_ab(b.first); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); if (fp8_requires_k_major()) DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); @@ -171,10 +171,10 @@ static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair(grouped_layout); + const auto [num_groups_] = get_shape<1>(grouped_layout); DG_HOST_ASSERT(num_groups == num_groups_); } else { - const auto& [m__] = get_shape<1>(grouped_layout); + const auto [m__] = get_shape<1>(grouped_layout); DG_HOST_ASSERT(m == m__); DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value()); } @@ -192,10 +192,10 @@ static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair& recipe, const std::string& compiled_dims) { // Must be 1D1D kernel - DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128)); + DG_HOST_ASSERT(std::get<0>(recipe) == 1 and std::get<1>(recipe) == 1); + + const int gran_k = std::get<2>(recipe); + DG_HOST_ASSERT(gran_k == 32 or gran_k == 128); // Shape checks - const auto& [num_groups, m, n] = get_shape<3>(d); - const auto& [sum_k_ , m_] = get_shape<2>(a.first); - const auto& [sum_k__, n_] = get_shape<2>(b.first); + const auto [num_groups, m, n] = get_shape<3>(d); + const auto [sum_k_ , m_] = get_shape<2>(a.first); + const auto [sum_k__, n_] = get_shape<2>(b.first); const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__); @@ -297,13 +300,13 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pairget_arch_major(); + const auto arch_major = device_runtime->get_arch_major(); if (arch_major == 10) { - sm100_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, + sm100_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, gran_k, cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); @@ -322,9 +325,9 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair(d); - const auto& sum_mk = a.first.numel(); - const auto& sum_nk = b.first.numel(); + const auto [num_groups, m, n] = get_shape<3>(d); + const auto sum_mk = a.first.numel(); + const auto sum_nk = b.first.numel(); const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); DG_HOST_ASSERT(sum_mk == static_cast(sum_k) * m); DG_HOST_ASSERT(sum_nk == static_cast(sum_k) * n); @@ -340,17 +343,17 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pairget_num_sms(); - const auto& tensor_map_buffer = torch::empty({num_sms * 4 * static_cast(sizeof(CUtensorMap))}, - a.first.options().dtype(torch::kByte)); + const auto num_sms = device_runtime->get_num_sms(); + const auto tensor_map_buffer = torch::empty({num_sms * 4 * static_cast(sizeof(CUtensorMap))}, + a.first.options().dtype(torch::kByte)); // Dispatch implementation - const auto& arch_major = device_runtime->get_arch_major(); + const auto arch_major = device_runtime->get_arch_major(); if (arch_major == 9) { sm90_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer, cute::UMMA::Major::K, cute::UMMA::Major::K, compiled_dims); @@ -367,16 +370,16 @@ static void bf16_gemm_nt(const torch::Tensor& a, const std::optional& c, const std::string& compiled_dims) { // Shape must be `[M, K] @ [N, K].T` - const auto& major_a = get_major_type_ab(a); - const auto& major_b = get_major_type_ab(b); + const auto major_a = get_major_type_ab(a); + const auto major_b = get_major_type_ab(b); // C/D must be N-major check_major_type_cd(d); // Type and shape checks - const auto& [m , k ] = get_shape<2>(a); - const auto& [n , k_] = get_shape<2>(b); - const auto& [m_, n_] = get_shape<2>(d); + const auto [m , k ] = get_shape<2>(a); + const auto [n , k_] = get_shape<2>(b); + const auto [m_, n_] = get_shape<2>(d); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); @@ -387,7 +390,7 @@ static void bf16_gemm_nt(const torch::Tensor& a, return; // Dispatch into different implements - const auto& arch_major = device_runtime->get_arch_major(); + const auto arch_major = device_runtime->get_arch_major(); if (arch_major == 9) { sm90_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims); } else if (arch_major == 10) { @@ -427,15 +430,15 @@ static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torc const bool& use_psum_layout, const std::optional& expected_m_for_psum_layout) { // Shape must be `[M, K] @ [G, N, K].mT` - const auto& major_a = get_major_type_ab(a); - const auto& major_b = get_major_type_ab(b); + const auto major_a = get_major_type_ab(a); + const auto major_b = get_major_type_ab(b); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); DG_HOST_ASSERT(grouped_layout.is_contiguous()); // Type and shape checks - const auto& [m, k] = get_shape<2>(a); - const auto& [num_groups, n, k_] = get_shape<3>(b); - const auto& [m_, n_] = get_shape<2>(d); + const auto [m, k] = get_shape<2>(a); + const auto [num_groups, n, k_] = get_shape<3>(b); + const auto [m_, n_] = get_shape<2>(d); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); @@ -445,10 +448,10 @@ static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torc // Layout checks if (use_psum_layout) { - const auto& [num_groups_] = get_shape<1>(grouped_layout); + const auto [num_groups_] = get_shape<1>(grouped_layout); DG_HOST_ASSERT(num_groups == num_groups_); } else { - const auto& [m__] = get_shape<1>(grouped_layout); + const auto [m__] = get_shape<1>(grouped_layout); DG_HOST_ASSERT(m == m__); DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value()); } @@ -461,11 +464,11 @@ static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torc return; // Dispatch implementation - const auto& arch_major = device_runtime->get_arch_major(); + const auto arch_major = device_runtime->get_arch_major(); if (arch_major == 9) { - DG_HOST_ASSERT(not use_psum_layout); sm90_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout, - num_groups, m, n, k, major_a, major_b, compiled_dims); + num_groups, m, n, k, major_a, major_b, compiled_dims, + use_psum_layout, expected_m_for_psum_layout); } else if (arch_major == 10) { sm100_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout, num_groups, m, n, k, major_a, major_b, compiled_dims, @@ -487,16 +490,16 @@ static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::T const torch::Tensor& d, const torch::Tensor& masked_m, const int& expected_m, const std::string& compiled_dims) { // Shape must be `[G, M, K] @ [G, N, K].mT` - const auto& major_a = get_major_type_ab(a); - const auto& major_b = get_major_type_ab(b); + const auto major_a = get_major_type_ab(a); + const auto major_b = get_major_type_ab(b); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); DG_HOST_ASSERT(masked_m.is_contiguous()); // Type and shape checks - const auto& [num_groups, m, k] = get_shape<3>(a); - const auto& [num_groups_, n, k_] = get_shape<3>(b); - const auto& [num_groups__, m_, n_] = get_shape<3>(d); - const auto& num_groups___ = static_cast(masked_m.numel()); + const auto [num_groups, m, k] = get_shape<3>(a); + const auto [num_groups_, n, k_] = get_shape<3>(b); + const auto [num_groups__, m_, n_] = get_shape<3>(d); + const auto num_groups___ = static_cast(masked_m.numel()); DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); @@ -509,7 +512,7 @@ static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::T check_major_type_cd(d); // Dispatch implementation - const auto& arch_major = device_runtime->get_arch_major(); + const auto arch_major = device_runtime->get_arch_major(); if (arch_major == 9) { sm90_bf16_m_grouped_gemm_masked(a, b, d, masked_m, num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); @@ -529,9 +532,9 @@ static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a, const std::optional& c, const std::string& compiled_dims) { // Shape checks - const auto& [num_groups, m, n] = get_shape<3>(d); - const auto& [sum_k_ , m_] = get_shape<2>(a); - const auto& [sum_k__, n_] = get_shape<2>(b); + const auto [num_groups, m, n] = get_shape<3>(d); + const auto [sum_k_ , m_] = get_shape<2>(a); + const auto [sum_k__, n_] = get_shape<2>(b); const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__); @@ -546,7 +549,7 @@ static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a, return; // Dispatch implementation - const auto& arch_major = device_runtime->get_arch_major(); + const auto arch_major = device_runtime->get_arch_major(); if (arch_major == 9) { sm90_bf16_k_grouped_gemm(a, b, c, d, m, n, ks, ks_tensor, cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); @@ -562,20 +565,20 @@ static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a, static void cublaslt_gemm_nt(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d, const std::optional& c) { // Shape must be `[M, K] @ [N, K].T` - const auto& major_a = get_major_type_ab(a); - const auto& major_b = get_major_type_ab(b); + const auto major_a = get_major_type_ab(a); + const auto major_b = get_major_type_ab(b); // Type and shape checks - const auto& [m , k ] = get_shape<2>(a); - const auto& [n , k_] = get_shape<2>(b); - const auto& [m_, n_] = get_shape<2>(d); + const auto [m , k ] = get_shape<2>(a); + const auto [n , k_] = get_shape<2>(b); + const auto [m_, n_] = get_shape<2>(d); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); // Early return for trivial cases if (early_return(m, n, k, d, c)) return; - cublaslt_gemm(a, b, c, d, m, n, k, major_a, major_b); + cublaslt_gemm(a, b, d, m, n, k, major_a, major_b, c.has_value()); } static void cublaslt_gemm_nn(const torch::Tensor& a, const torch::Tensor& b, @@ -593,7 +596,7 @@ static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b, cublaslt_gemm_nt(a.transpose(0, 1), b, d, c); } -#ifdef DG_USE_PYBIND11 +#if !defined(__CUDACC__) static void register_apis(pybind11::module_& m) { #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE diff --git a/deep-gemm/csrc/apis/hyperconnection.hpp b/deep-gemm/csrc/apis/hyperconnection.hpp index 713de4a3..d834c35e 100644 --- a/deep-gemm/csrc/apis/hyperconnection.hpp +++ b/deep-gemm/csrc/apis/hyperconnection.hpp @@ -24,16 +24,16 @@ static void tf32_hc_prenorm_gemm(const torch::Tensor& a, DG_HOST_ASSERT(sqr_sum.is_contiguous()); // Type and shape checks - const auto& [m, k ] = get_shape<2>(a); - const auto& [n, k_] = get_shape<2>(b); + const auto [m, k ] = get_shape<2>(a); + const auto [n, k_] = get_shape<2>(b); if (num_splits.has_value()) { - const auto& [num_splits_, m_, n_] = get_shape<3>(d); - const auto& [num_splits__, m__] = get_shape<2>(sqr_sum); + const auto [num_splits_, m_, n_] = get_shape<3>(d); + const auto [num_splits__, m__] = get_shape<2>(sqr_sum); DG_HOST_ASSERT(num_splits.value() == num_splits_ and num_splits.value() == num_splits__ and num_splits.value() >= 1); DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); } else { - const auto& [m_, n_] = get_shape<2>(d); - const auto& [m__] = get_shape<1>(sqr_sum); + const auto [m_, n_] = get_shape<2>(d); + const auto [m__] = get_shape<1>(sqr_sum); DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); } DG_HOST_ASSERT(n > 0 and k > 0); @@ -47,7 +47,7 @@ static void tf32_hc_prenorm_gemm(const torch::Tensor& a, return; // Dispatch into different implements - const auto& arch_major = device_runtime->get_arch_major(); + const auto arch_major = device_runtime->get_arch_major(); if (arch_major == 9) { sm90_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1); } else if (arch_major == 10) { @@ -59,7 +59,7 @@ static void tf32_hc_prenorm_gemm(const torch::Tensor& a, #endif -#ifdef DG_USE_PYBIND11 +#if !defined(__CUDACC__) static void register_apis(pybind11::module_& m) { #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE m.def("tf32_hc_prenorm_gemm", &tf32_hc_prenorm_gemm, diff --git a/deep-gemm/csrc/apis/layout.hpp b/deep-gemm/csrc/apis/layout.hpp index 3ec1c6a6..947212a2 100644 --- a/deep-gemm/csrc/apis/layout.hpp +++ b/deep-gemm/csrc/apis/layout.hpp @@ -1,5 +1,6 @@ #pragma once +#include "../jit_kernels/heuristics/runtime.hpp" #include "../utils/layout.hpp" #include "../utils/compatibility.hpp" @@ -12,21 +13,24 @@ namespace deep_gemm::layout { #if DG_TENSORMAP_COMPATIBLE static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, const int& mn, const int& k, - const std::optional>& recipe, - const std::optional>& recipe_ab, + const std::variant, + std::tuple>& recipe, const std::optional& num_groups, - const bool& is_sfa, + const std::optional& is_sfa, const bool& disable_ue8m0_cast) { - const auto& arch_major = device_runtime->get_arch_major(); + const auto arch_major = device_runtime->get_arch_major(); + // Get granularity MN/K from recipe int gran_mn, gran_k; - if (recipe.has_value()) { - DG_HOST_ASSERT(not recipe_ab.has_value()); - gran_mn = is_sfa ? std::get<0>(recipe.value()) : std::get<1>(recipe.value()); - gran_k = std::get<2>(recipe.value()); + if (auto p = std::get_if>(&recipe)) { + DG_HOST_ASSERT(is_sfa.has_value()); + gran_mn = is_sfa.value() ? std::get<0>(*p) : std::get<1>(*p); + gran_k = std::get<2>(*p); + } else if (auto p = std::get_if>(&recipe)) { + DG_HOST_ASSERT(not is_sfa.has_value()); + std::tie(gran_mn, gran_k) = *p; } else { - DG_HOST_ASSERT(recipe_ab.has_value()); - std::tie(gran_mn, gran_k) = recipe_ab.value(); + DG_HOST_UNREACHABLE("Invalid recipe"); } // Pre-transform checks @@ -43,8 +47,8 @@ static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, // (FP32, x, gran_k) on SM100: transform to (INT, 1, gran_k), TMA-aligned and MN-major if (sf.scalar_type() == torch::kFloat and (gran_k == 32 or gran_k == 128) and arch_major == 10) { DG_HOST_ASSERT(not disable_ue8m0_cast); - const auto& broadcasted = gran_mn == 1 ? sf : - sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(gran_mn)); + const auto broadcasted = gran_mn == 1 ? sf : + sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(gran_mn)); return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted); } @@ -64,11 +68,19 @@ static std::tuple transform_sf_pair_into const std::optional& num_groups_a, const std::optional& num_groups_b, const bool& disable_ue8m0_cast = false) { - DG_HOST_ASSERT(recipe_a.has_value() == recipe_b.has_value()); + // Use default recipe, if none is specified if (not recipe_a.has_value() and not recipe.has_value()) recipe = get_default_recipe(sfa.scalar_type(), sfb.scalar_type()); - const auto transformed_sfa = transform_sf_into_required_layout(sfa, m, k, recipe, recipe_a, num_groups_a, true, disable_ue8m0_cast); - const auto transformed_sfb = transform_sf_into_required_layout(sfb, n, k, recipe, recipe_b, num_groups_b, false, disable_ue8m0_cast); + + // Must be either 'recipe' or the 'recipe_a' + 'recipe_b' pair. + DG_HOST_ASSERT(recipe_a.has_value() == recipe_b.has_value()); + DG_HOST_ASSERT(recipe_a.has_value() != recipe.has_value()); + + // Transform SFA and SFB layout + const auto transformed_sfa = recipe.has_value() ? transform_sf_into_required_layout(sfa, m, k, recipe.value(), num_groups_a, true, disable_ue8m0_cast) + : transform_sf_into_required_layout(sfa, m, k, recipe_a.value(), num_groups_a, std::nullopt, disable_ue8m0_cast); + const auto transformed_sfb = recipe.has_value() ? transform_sf_into_required_layout(sfb, n, k, recipe.value(), num_groups_b, false, disable_ue8m0_cast) + : transform_sf_into_required_layout(sfb, n, k, recipe_b.value(), num_groups_b, std::nullopt, disable_ue8m0_cast); const int gran_k_a = recipe_a.has_value() ? std::get<1>(recipe_a.value()) : std::get<2>(recipe.value()); const int gran_k_b = recipe_b.has_value() ? std::get<1>(recipe_b.value()) : std::get<2>(recipe.value()); return std::make_tuple(transformed_sfa, transformed_sfb, gran_k_a, gran_k_b); @@ -79,8 +91,12 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te const torch::Tensor& ks_tensor, const std::tuple& recipe) { DG_HOST_ASSERT(sf.dim() == 2); - DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128)); - const auto& arch_major = device_runtime->get_arch_major(); + DG_HOST_ASSERT(std::get<0>(recipe) == 1 and std::get<1>(recipe) == 1); + + const int gran_k = std::get<2>(recipe); + DG_HOST_ASSERT(gran_k == 32 or gran_k == 128); + + const auto arch_major = device_runtime->get_arch_major(); // FP32 on SM90 if (sf.scalar_type() == torch::kFloat and arch_major == 9) @@ -88,7 +104,7 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te // FP32 on SM100 if (sf.scalar_type() == torch::kFloat and arch_major == 10) - return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks); + return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks, gran_k); // INT on SM100 if (sf.scalar_type() == torch::kInt and arch_major == 10) @@ -99,14 +115,13 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te #endif -#ifdef DG_USE_PYBIND11 +#if !defined(__CUDACC__) static void register_apis(pybind11::module_& m) { - #if DG_TENSORMAP_COMPATIBLE m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout, - py::arg("sf"), py::arg("mn"), py::arg("k"), - py::arg("recipe") = std::nullopt, py::arg("recipe_ab") = std::nullopt, - py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false, + py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"), + py::arg("num_groups") = std::nullopt, + py::arg("is_sfa") = std::nullopt, py::arg("disable_ue8m0_cast") = false); m.def("get_tma_aligned_size", &get_tma_aligned_size); @@ -115,7 +130,15 @@ static void register_apis(pybind11::module_& m) { m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor); #endif - m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout); + m.def("set_mk_alignment_for_contiguous_layout", [&](const int& new_value) { + heuristics_runtime->set_mk_alignment_for_contiguous_layout(new_value); + }); + m.def("get_mk_alignment_for_contiguous_layout", [&]() { + return heuristics_runtime->get_mk_alignment_for_contiguous_layout(); + }); + m.def("get_theoretical_mk_alignment_for_contiguous_layout", [&](const std::optional& expected_m) { + return heuristics_runtime->get_theoretical_mk_alignment_for_contiguous_layout(expected_m); + }, py::arg("expected_m") = std::nullopt); } #endif diff --git a/deep-gemm/csrc/apis/mega.hpp b/deep-gemm/csrc/apis/mega.hpp new file mode 100644 index 00000000..76711b50 --- /dev/null +++ b/deep-gemm/csrc/apis/mega.hpp @@ -0,0 +1,239 @@ +#pragma once + +#include +#if !defined(__CUDACC__) +#include +#endif + +#if DG_TENSORMAP_COMPATIBLE +#include "../jit/compiler.hpp" +#endif +#include "../jit/device_runtime.hpp" +#include "../jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp" + +namespace deep_gemm::mega { + +static int get_token_alignment_for_mega_moe() { + return layout::kLCMCandidateBlockM; +} + +static std::tuple(const torch::Tensor&)>> +get_symm_buffer_size_for_mega_moe( + const int& num_ranks, const int& num_experts, + const int& num_max_tokens_per_rank, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const bool& use_fp8_dispatch, const std::string& activation) { + DG_HOST_ASSERT(num_experts % num_ranks == 0); + + // Workspace bytes + const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk); + + // Layouts + const auto fp8_token_layout = layout::Data(hidden); + const auto bf16_token_layout = layout::Data(hidden * 2); + const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden); + const auto fp8_sf_layout = layout::Data(hidden / 32); + const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden / 32); + const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false); + const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false); + const auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + // Input buffers + const auto input_token_buffer = layout::Buffer( + fp8_token_layout, 1, num_max_tokens_per_rank, + workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, num_max_tokens_per_rank, + input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer( + input_topk_idx_layout, 1, num_max_tokens_per_rank, + input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer( + input_topk_weights_layout, 1, num_max_tokens_per_rank, + input_topk_idx_buffer.get_end_ptr()); + + // Buffer configs + const auto num_max_pool_tokens = static_cast(workspace.num_max_pool_tokens); + int num_max_padded_sf_pool_tokens = 0; + for (int block_m: layout::kCandidateBlockM) { + num_max_padded_sf_pool_tokens = std::max( + num_max_padded_sf_pool_tokens, + layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m) + ); + } + + // L1 input buffer + const auto l1_token_buffer = layout::Buffer( + fp8_token_layout, 1, num_max_pool_tokens, + input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, num_max_padded_sf_pool_tokens, + l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer( + l1_topk_weights_layout, 1, num_max_pool_tokens, + l1_sf_buffer.get_end_ptr()); + + // L2 input buffer + const auto l2_token_buffer = layout::Buffer( + fp8_intermediate_token_layout, 1, num_max_pool_tokens, + l1_topk_weights_buffer.get_end_ptr()); + const auto l2_sf_buffer = layout::Buffer( + fp8_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens, + l2_token_buffer.get_end_ptr()); + + // Combine input buffer: BF16 tokens for cross-rank combine + const auto combine_token_buffer = layout::Buffer( + bf16_token_layout, num_topk, num_max_tokens_per_rank, + l2_sf_buffer.get_end_ptr()); + + // Check SF buffer requirements + DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0); + DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0); + + // Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf)` tensor views from the raw buffer + // NOTES: `x_sf` is K-major, while `l1_acts_sf` and `l2_acts_sf` are M-major + auto slice_input_buffers = [=](const torch::Tensor& buffer) { + auto x = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_token_buffer.base)), + {num_max_tokens_per_rank, hidden}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); + auto x_sf = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_sf_buffer.base)), + {num_max_tokens_per_rank, hidden / 128}, + torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + auto topk_idx = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_idx_buffer.base)), + {num_max_tokens_per_rank, num_topk}, + torch::TensorOptions().dtype(torch::kInt64).device(buffer.device())); + auto topk_weights = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_weights_buffer.base)), + {num_max_tokens_per_rank, num_topk}, + torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device())); + auto l1_acts = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_token_buffer.base)), + {num_max_pool_tokens, hidden}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); + auto l1_acts_sf = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_sf_buffer.base)), + {num_max_padded_sf_pool_tokens, hidden / 128}, + {1, num_max_padded_sf_pool_tokens}, + torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + auto l2_acts = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_token_buffer.base)), + {num_max_pool_tokens, intermediate_hidden}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); + auto l2_acts_sf = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_sf_buffer.base)), + {num_max_padded_sf_pool_tokens, intermediate_hidden / 128}, + {1, num_max_padded_sf_pool_tokens}, + torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf); + }; + return {reinterpret_cast(combine_token_buffer.get_end_ptr()), slice_input_buffers}; +} + +static void fp8_fp4_mega_moe( + const torch::Tensor& y, + const std::tuple& l1_weights_tuple, + const std::tuple& l2_weights_tuple, + const std::optional& cumulative_local_expert_recv_stats, + const torch::Tensor& sym_buffer, + const std::vector& sym_buffer_ptrs, const int& rank_idx, + const int& num_max_tokens_per_rank, + const int& num_experts, const int& num_topk, + const std::tuple& recipe, + const std::string& activation, + const std::optional& activation_clamp_opt, + const bool& fast_math +) { + const auto [l1_weights, l1_weights_sf] = l1_weights_tuple; + const auto [l2_weights, l2_weights_sf] = l2_weights_tuple; + + // Config checks + const auto num_tokens = static_cast(y.size(0)); + const auto [rm, rn, rk] = recipe; + DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 32); + DG_HOST_ASSERT(activation == "swiglu"); + + // Activation checks + const auto activation_clamp = + activation_clamp_opt.value_or(std::numeric_limits::infinity()); + DG_HOST_ASSERT(activation_clamp >= 0); + + // Tensor checks + DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K); + DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K); + const auto arch_major = device_runtime->get_arch_major(); + const auto [num_experts_per_rank, intermediate_hidden_2, hidden] = + check_grouped_ab_fp8_fp4(l1_weights, cute::UMMA::Major::K, arch_major); + const auto [num_experts_per_rank_, hidden_, intermediate_hidden] = + check_grouped_ab_fp8_fp4(l2_weights, cute::UMMA::Major::K, arch_major); + DG_HOST_ASSERT(num_tokens <= num_max_tokens_per_rank); + DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_); + DG_HOST_ASSERT(hidden == hidden_); + DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden); + DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous()); + + // Check weight SF layout for UE8M0 packing, MN-major, and TMA alignment + constexpr int kGranMN = 1, kGranK = 32; + check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK, + num_experts_per_rank, true, false, torch::kInt); + check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK, + num_experts_per_rank, true, false, torch::kInt); + + // Check stats counter + if (cumulative_local_expert_recv_stats.has_value()) { + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank); + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous()); + } + + // Check buffer bytes + const auto num_ranks = static_cast(sym_buffer_ptrs.size()); + const auto num_experts_ = num_experts_per_rank * num_ranks; + const auto [num_required_bytes, slice] = get_symm_buffer_size_for_mega_moe( + num_ranks, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + true, activation); + DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast(num_required_bytes)); + DG_HOST_ASSERT(num_experts == num_experts_); + + // Already registered tensors + const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer); + + // Dispatch into different architectures + if (arch_major == 10) { + sm100_fp8_fp4_mega_moe(y, + l1_acts, l1_acts_sf, + l2_acts, l2_acts_sf, + l1_weights, l2_weights, + l1_weights_sf, l2_weights_sf, + cumulative_local_expert_recv_stats, + sym_buffer_ptrs, + rank_idx, num_max_tokens_per_rank, + num_experts_per_rank, + num_tokens, num_topk, + hidden, intermediate_hidden, + activation_clamp, fast_math); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } + + // Zero the entire symmetric buffer for debug mode + // NOTES: caller must re-copy inputs into the buffer before each kernel call + if (get_env("DG_COMM_KERNEL_DEBUG")) + sym_buffer.zero_(); +} + +#if !defined(__CUDACC__) +static void register_apis(pybind11::module_& m) { +#if DG_TENSORMAP_COMPATIBLE + m.def("get_token_alignment_for_mega_moe", &get_token_alignment_for_mega_moe); + m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe); + m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe); +#endif +} +#endif + +} // namespace deep_gemm::mega diff --git a/deep-gemm/csrc/apis/runtime.hpp b/deep-gemm/csrc/apis/runtime.hpp index 725cc09c..29476d87 100644 --- a/deep-gemm/csrc/apis/runtime.hpp +++ b/deep-gemm/csrc/apis/runtime.hpp @@ -4,33 +4,11 @@ #include "../jit/compiler.hpp" #endif #include "../jit/device_runtime.hpp" +#include "../jit_kernels/heuristics/runtime.hpp" namespace deep_gemm::runtime { -static void deep_gemm_set_num_sms(int64_t new_num_sms) { - device_runtime->set_num_sms(static_cast(new_num_sms)); -} - -static int64_t deep_gemm_get_num_sms() { - return device_runtime->get_num_sms(); -} - -static void deep_gemm_set_tc_util(int64_t new_tc_util) { - device_runtime->set_tc_util(static_cast(new_tc_util)); -} - -static int64_t deep_gemm_get_tc_util() { - return device_runtime->get_tc_util(); -} - -static void deep_gemm_init(const std::string& library_root_path, const std::string& cuda_home_path_by_python) { -#if DG_TENSORMAP_COMPATIBLE - Compiler::prepare_init(library_root_path, cuda_home_path_by_python); - KernelRuntime::prepare_init(cuda_home_path_by_python); -#endif -} - -#ifdef DG_USE_PYBIND11 +#if !defined(__CUDACC__) static void register_apis(pybind11::module_& m) { m.def("set_num_sms", [&](const int& new_num_sms) { device_runtime->set_num_sms(new_num_sms); @@ -44,10 +22,29 @@ static void register_apis(pybind11::module_& m) { m.def("get_tc_util", [&]() { return device_runtime->get_tc_util(); }); + m.def("set_pdl", [&](const bool& new_enable_pdl) { + device_runtime->set_pdl(new_enable_pdl); + }); + m.def("get_pdl", [&]() { + return device_runtime->get_pdl(); + }); + m.def("set_ignore_compile_dims", [&](const bool& new_value) { + heuristics_runtime->set_ignore_compile_dims(new_value); + }); + m.def("set_block_size_multiple_of", [&](const std::variant>& new_value) { + if (std::holds_alternative(new_value)) { + auto x = std::get(new_value); + heuristics_runtime->set_block_size_multiple_of(x, x); + } else { + auto [x, y] = std::get>(new_value); + heuristics_runtime->set_block_size_multiple_of(x, y); + } + }); m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) { #if DG_TENSORMAP_COMPATIBLE Compiler::prepare_init(library_root_path, cuda_home_path_by_python); KernelRuntime::prepare_init(cuda_home_path_by_python); + IncludeParser::prepare_init(library_root_path); #endif }); } diff --git a/deep-gemm/csrc/impl.cu b/deep-gemm/csrc/impl.cu index ac2458db..4b7611df 100644 --- a/deep-gemm/csrc/impl.cu +++ b/deep-gemm/csrc/impl.cu @@ -1,13 +1,27 @@ -#include +#include "utils/torch_compat.hpp" +#include "utils/exception.hpp" #include +#include +#include #include +// Upstream's DG_UNIFIED_ASSERT maps to device `trap` whenever NVCC is compiling. +// In Kernel Hub builds, that also affects NVCC's host pass and breaks x86 asm. +#ifndef DG_UNIFIED_ASSERT +#if defined(__CUDA_ARCH__) +#define DG_UNIFIED_ASSERT(cond) DG_DEVICE_ASSERT(cond) +#else +#define DG_UNIFIED_ASSERT(cond) DG_HOST_ASSERT(cond) +#endif +#endif + #include "../torch-ext/torch_binding.h" #include "apis/attention.hpp" #include "apis/einsum.hpp" #include "apis/hyperconnection.hpp" #include "apis/gemm.hpp" #include "apis/layout.hpp" +#include "apis/mega.hpp" #include "apis/runtime.hpp" using Tensor = at::Tensor; @@ -32,32 +46,71 @@ static std::optional> make_recipe2( return std::make_tuple(static_cast(r0), static_cast(r1)); } +static std::variant, std::tuple> make_layout_recipe( + int64_t r0, int64_t r1, int64_t r2, int64_t recipe_len) { + if (recipe_len == 3) { + return std::make_tuple(static_cast(r0), static_cast(r1), static_cast(r2)); + } + return std::make_tuple(static_cast(r0), static_cast(r1)); +} + // Runtime ops void deep_gemm_init(const std::string& path, const std::string& cuda_home) { - deep_gemm::runtime::deep_gemm_init(path, cuda_home); +#if DG_TENSORMAP_COMPATIBLE + deep_gemm::Compiler::prepare_init(path, cuda_home); + deep_gemm::KernelRuntime::prepare_init(cuda_home); + deep_gemm::IncludeParser::prepare_init(path); +#endif } void deep_gemm_set_num_sms(int64_t num_sms) { - deep_gemm::runtime::deep_gemm_set_num_sms(num_sms); + deep_gemm::device_runtime->set_num_sms(static_cast(num_sms)); } int64_t deep_gemm_get_num_sms() { - return deep_gemm::runtime::deep_gemm_get_num_sms(); + return deep_gemm::device_runtime->get_num_sms(); } void deep_gemm_set_tc_util(int64_t tc_util) { - deep_gemm::runtime::deep_gemm_set_tc_util(tc_util); + deep_gemm::device_runtime->set_tc_util(static_cast(tc_util)); } int64_t deep_gemm_get_tc_util() { - return deep_gemm::runtime::deep_gemm_get_tc_util(); + return deep_gemm::device_runtime->get_tc_util(); +} + +void deep_gemm_set_pdl(bool enable_pdl) { + deep_gemm::device_runtime->set_pdl(enable_pdl); +} + +bool deep_gemm_get_pdl() { + return deep_gemm::device_runtime->get_pdl(); +} + +void deep_gemm_set_ignore_compile_dims(bool ignore_compile_dims) { + deep_gemm::heuristics_runtime->set_ignore_compile_dims(ignore_compile_dims); +} + +void deep_gemm_set_block_size_multiple_of(int64_t block_m, int64_t block_n) { + deep_gemm::heuristics_runtime->set_block_size_multiple_of( + static_cast(block_m), static_cast(block_n)); } // Layout ops +void deep_gemm_set_mk_alignment_for_contiguous_layout(int64_t alignment) { + deep_gemm::heuristics_runtime->set_mk_alignment_for_contiguous_layout(static_cast(alignment)); +} + int64_t deep_gemm_get_mk_alignment_for_contiguous_layout() { - return deep_gemm::get_mk_alignment_for_contiguous_layout(); + return deep_gemm::heuristics_runtime->get_mk_alignment_for_contiguous_layout(); +} + +int64_t deep_gemm_get_theoretical_mk_alignment_for_contiguous_layout( + int64_t expected_m, bool has_expected_m) { + auto value = has_expected_m ? std::make_optional(static_cast(expected_m)) : std::nullopt; + return deep_gemm::heuristics_runtime->get_theoretical_mk_alignment_for_contiguous_layout(value); } Tensor deep_gemm_get_tma_aligned_size(int64_t mn, int64_t element_size) { @@ -75,23 +128,23 @@ Tensor deep_gemm_get_mn_major_tma_aligned_packed_ue8m0_tensor(const Tensor& sf) } Tensor deep_gemm_get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( - const Tensor& sf, const Tensor& ks_tensor, const Tensor& ks_int_tensor) { + const Tensor& sf, const Tensor& ks_tensor, const Tensor& ks_int_tensor, int64_t gran_k) { auto ks = tensor_to_vec_int(ks_int_tensor); - return deep_gemm::get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks); + return deep_gemm::get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( + sf, ks_tensor, ks, static_cast(gran_k)); } Tensor deep_gemm_transform_sf_into_required_layout( const Tensor& sf, int64_t mn, int64_t k, - int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, - int64_t recipe_ab_0, int64_t recipe_ab_1, bool has_recipe_ab, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, int64_t recipe_len, int64_t num_groups, bool has_num_groups, - bool is_sfa, bool disable_ue8m0_cast) { - auto recipe = make_recipe3(recipe_0, recipe_1, recipe_2, has_recipe); - auto recipe_ab = make_recipe2(recipe_ab_0, recipe_ab_1, has_recipe_ab); + bool is_sfa, bool has_is_sfa, bool disable_ue8m0_cast) { + auto recipe = make_layout_recipe(recipe_0, recipe_1, recipe_2, recipe_len); auto ng = has_num_groups ? std::make_optional(static_cast(num_groups)) : std::nullopt; + auto sfa = has_is_sfa ? std::make_optional(is_sfa) : std::nullopt; return deep_gemm::layout::transform_sf_into_required_layout( sf, static_cast(mn), static_cast(k), - recipe, recipe_ab, ng, is_sfa, disable_ue8m0_cast); + recipe, ng, sfa, disable_ue8m0_cast); } // GEMM ops - FP8/FP4 @@ -372,32 +425,110 @@ void deep_gemm_fp8_gemm_nt_skip_head_mid( a, b, d, head_splits, recipe, compiled_dims, disable_ue8m0_cast); } +Tensor deep_gemm_fp8_fp4_mqa_logits( + const Tensor& q_data, const std::optional& q_sf, + const Tensor& kv_data, const Tensor& kv_sf, + const Tensor& weights, + const Tensor& cu_seq_len_k_start, const Tensor& cu_seq_len_k_end, + bool clean_logits, int64_t max_seqlen_k, at::ScalarType logits_dtype) { + auto q = std::make_tuple(q_data, q_sf); + auto kv = std::make_tuple(kv_data, kv_sf); + return deep_gemm::attention::fp8_fp4_mqa_logits( + q, kv, weights, cu_seq_len_k_start, cu_seq_len_k_end, + clean_logits, static_cast(max_seqlen_k), logits_dtype); +} + Tensor deep_gemm_fp8_mqa_logits( const Tensor& q, const Tensor& kv_data, const Tensor& kv_sf, const Tensor& weights, const Tensor& cu_seq_len_k_start, const Tensor& cu_seq_len_k_end, bool clean_logits, int64_t max_seqlen_k) { - auto kv = std::make_pair(kv_data, kv_sf); + auto kv = std::make_tuple(kv_data, kv_sf); return deep_gemm::attention::fp8_mqa_logits( q, kv, weights, cu_seq_len_k_start, cu_seq_len_k_end, clean_logits, static_cast(max_seqlen_k)); } Tensor deep_gemm_get_paged_mqa_logits_metadata( - const Tensor& context_lens, int64_t block_kv, int64_t num_sms) { + const Tensor& context_lens, int64_t block_kv, int64_t num_sms, + const std::optional& indices) { return deep_gemm::attention::get_paged_mqa_logits_metadata( - context_lens, static_cast(block_kv), static_cast(num_sms)); + context_lens, static_cast(block_kv), static_cast(num_sms), indices); +} + +Tensor deep_gemm_fp8_fp4_paged_mqa_logits( + const Tensor& q_data, const std::optional& q_sf, + const Tensor& fused_kv_cache, + const Tensor& weights, const Tensor& context_lens, + const Tensor& block_table, const Tensor& schedule_meta, + int64_t max_context_len, bool clean_logits, at::ScalarType logits_dtype, + const std::optional& indices) { + auto q = std::make_tuple(q_data, q_sf); + return deep_gemm::attention::fp8_fp4_paged_mqa_logits( + q, fused_kv_cache, weights, context_lens, block_table, schedule_meta, + static_cast(max_context_len), clean_logits, logits_dtype, indices); } Tensor deep_gemm_fp8_paged_mqa_logits( const Tensor& q, const Tensor& fused_kv_cache, const Tensor& weights, const Tensor& context_lens, const Tensor& block_table, const Tensor& schedule_meta, - int64_t max_context_len, bool clean_logits) { + int64_t max_context_len, bool clean_logits, const std::optional& indices) { return deep_gemm::attention::fp8_paged_mqa_logits( q, fused_kv_cache, weights, context_lens, block_table, schedule_meta, - static_cast(max_context_len), clean_logits); + static_cast(max_context_len), clean_logits, indices); +} + +int64_t deep_gemm_get_token_alignment_for_mega_moe() { + return deep_gemm::mega::get_token_alignment_for_mega_moe(); +} + +int64_t deep_gemm_get_symm_buffer_size_for_mega_moe( + int64_t num_ranks, int64_t num_experts, int64_t num_max_tokens_per_rank, + int64_t num_topk, int64_t hidden, int64_t intermediate_hidden, + bool use_fp8_dispatch, const std::string& activation) { + auto [num_bytes, slice] = deep_gemm::mega::get_symm_buffer_size_for_mega_moe( + static_cast(num_ranks), static_cast(num_experts), + static_cast(num_max_tokens_per_rank), static_cast(num_topk), + static_cast(hidden), static_cast(intermediate_hidden), + use_fp8_dispatch, activation); + return num_bytes; +} + +std::vector deep_gemm_get_symm_buffer_views_for_mega_moe( + const Tensor& buffer, int64_t num_ranks, int64_t num_experts, + int64_t num_max_tokens_per_rank, int64_t num_topk, int64_t hidden, + int64_t intermediate_hidden, bool use_fp8_dispatch, const std::string& activation) { + auto [num_bytes, slice] = deep_gemm::mega::get_symm_buffer_size_for_mega_moe( + static_cast(num_ranks), static_cast(num_experts), + static_cast(num_max_tokens_per_rank), static_cast(num_topk), + static_cast(hidden), static_cast(intermediate_hidden), + use_fp8_dispatch, activation); + auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(buffer); + return {x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf}; +} + +void deep_gemm_fp8_fp4_mega_moe( + const Tensor& y, + const Tensor& l1_weights, const Tensor& l1_weights_sf, + const Tensor& l2_weights, const Tensor& l2_weights_sf, + const std::optional& cumulative_local_expert_recv_stats, + const Tensor& sym_buffer, c10::List sym_buffer_ptrs, int64_t rank_idx, + int64_t num_max_tokens_per_rank, int64_t num_experts, int64_t num_topk, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, + const std::string& activation, const std::optional& activation_clamp, + bool fast_math) { + auto ptrs = std::vector(sym_buffer_ptrs.begin(), sym_buffer_ptrs.end()); + auto recipe = std::make_tuple(static_cast(recipe_0), static_cast(recipe_1), static_cast(recipe_2)); + auto clamp = activation_clamp.has_value() + ? std::make_optional(static_cast(activation_clamp.value())) + : std::nullopt; + deep_gemm::mega::fp8_fp4_mega_moe( + y, std::make_tuple(l1_weights, l1_weights_sf), std::make_tuple(l2_weights, l2_weights_sf), + cumulative_local_expert_recv_stats, sym_buffer, ptrs, static_cast(rank_idx), + static_cast(num_max_tokens_per_rank), static_cast(num_experts), + static_cast(num_topk), recipe, activation, clamp, fast_math); } // Einsum ops diff --git a/deep-gemm/csrc/indexing/main.cu b/deep-gemm/csrc/indexing/main.cu index 1b96da2f..a42b66f9 100644 --- a/deep-gemm/csrc/indexing/main.cu +++ b/deep-gemm/csrc/indexing/main.cu @@ -3,12 +3,14 @@ #include #include #include -#include +#include // Attention kernels #include #include +#include #include +#include #include // Einsum kernels @@ -23,6 +25,9 @@ #include #include +// Mega kernels +#include + using namespace deep_gemm; int main() { diff --git a/deep-gemm/csrc/jit/cache.hpp b/deep-gemm/csrc/jit/cache.hpp index 1e8659fd..ddc763d0 100644 --- a/deep-gemm/csrc/jit/cache.hpp +++ b/deep-gemm/csrc/jit/cache.hpp @@ -17,7 +17,7 @@ class KernelRuntimeCache { std::shared_ptr get(const std::filesystem::path& dir_path) { // Hit the runtime cache - if (const auto& iterator = cache.find(dir_path); iterator != cache.end()) + if (const auto iterator = cache.find(dir_path); iterator != cache.end()) return iterator->second; if (KernelRuntime::check_validity(dir_path)) diff --git a/deep-gemm/csrc/jit/compiler.hpp b/deep-gemm/csrc/jit/compiler.hpp index 38d090e7..265b787d 100644 --- a/deep-gemm/csrc/jit/compiler.hpp +++ b/deep-gemm/csrc/jit/compiler.hpp @@ -2,10 +2,13 @@ #include #include +#include #include #include +#ifdef DG_ENABLE_NVRTC_COMPILER #include -#include +#endif +#include #include #include "../utils/exception.hpp" @@ -15,92 +18,22 @@ #include "../utils/system.hpp" #include "cache.hpp" #include "device_runtime.hpp" +#include "include_parser.hpp" namespace deep_gemm { -// Lazy-load NVRTC to avoid link-time dependency on libnvrtc.so. -// kernel-builder doesn't support linking extra CUDA libs yet, so we dlopen -// at runtime — same pattern as the CUDA driver API in jit/handle.hpp. -static void* get_nvrtc_handle() { - static void* handle = nullptr; - if (handle == nullptr) { - handle = dlopen("libnvrtc.so", RTLD_LAZY | RTLD_LOCAL); - if (handle == nullptr) - handle = dlopen("libnvrtc.so.12", RTLD_LAZY | RTLD_LOCAL); - DG_HOST_ASSERT(handle != nullptr and "Failed to load NVRTC library"); - } - return handle; -} - -#define DECL_LAZY_NVRTC_FUNCTION(name) \ -template \ -static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \ - using FuncType = decltype(&name); \ - static FuncType func = nullptr; \ - if (func == nullptr) { \ - func = reinterpret_cast(dlsym(get_nvrtc_handle(), #name)); \ - DG_HOST_ASSERT(func != nullptr and "Failed to load NVRTC function"); \ - } \ - return func(std::forward(args)...); \ -} - -DECL_LAZY_NVRTC_FUNCTION(nvrtcVersion); -DECL_LAZY_NVRTC_FUNCTION(nvrtcCreateProgram); -DECL_LAZY_NVRTC_FUNCTION(nvrtcCompileProgram); -DECL_LAZY_NVRTC_FUNCTION(nvrtcGetProgramLogSize); -DECL_LAZY_NVRTC_FUNCTION(nvrtcGetProgramLog); -DECL_LAZY_NVRTC_FUNCTION(nvrtcGetPTXSize); -DECL_LAZY_NVRTC_FUNCTION(nvrtcGetPTX); -DECL_LAZY_NVRTC_FUNCTION(nvrtcGetCUBINSize); -DECL_LAZY_NVRTC_FUNCTION(nvrtcGetCUBIN); -DECL_LAZY_NVRTC_FUNCTION(nvrtcDestroyProgram); -DECL_LAZY_NVRTC_FUNCTION(nvrtcGetErrorString); - -// Redirect nvrtc calls to lazy-loaded versions so NVRTCCompiler is unchanged -#define nvrtcVersion lazy_nvrtcVersion -#define nvrtcCreateProgram lazy_nvrtcCreateProgram -#define nvrtcCompileProgram lazy_nvrtcCompileProgram -#define nvrtcGetProgramLogSize lazy_nvrtcGetProgramLogSize -#define nvrtcGetProgramLog lazy_nvrtcGetProgramLog -#define nvrtcGetPTXSize lazy_nvrtcGetPTXSize -#define nvrtcGetPTX lazy_nvrtcGetPTX -#define nvrtcGetCUBINSize lazy_nvrtcGetCUBINSize -#define nvrtcGetCUBIN lazy_nvrtcGetCUBIN -#define nvrtcDestroyProgram lazy_nvrtcDestroyProgram -#define nvrtcGetErrorString lazy_nvrtcGetErrorString - class Compiler { public: static std::filesystem::path library_root_path; static std::filesystem::path library_include_path; static std::filesystem::path cuda_home; - static std::string library_version; static std::filesystem::path cuobjdump_path; - static std::string get_library_version() { - const auto dg_include = library_include_path / "deep_gemm"; - if (not std::filesystem::exists(dg_include)) { - // Fallback: hash the root path itself - std::string fallback(library_root_path.string()); - return get_hex_digest(std::vector(fallback.begin(), fallback.end())); - } - std::vector buffer; - for (const auto& f: collect_files(dg_include)) { - std::ifstream in(f, std::ios::binary); - DG_HOST_ASSERT(in.is_open()); - buffer.insert(buffer.end(), - std::istreambuf_iterator(in), - std::istreambuf_iterator()); - } - return get_hex_digest(buffer); - } - static void prepare_init(const std::string& library_root_path, const std::string& cuda_home_path_by_python) { Compiler::library_root_path = library_root_path; Compiler::library_include_path = Compiler::library_root_path / "include"; Compiler::cuda_home = cuda_home_path_by_python; - Compiler::library_version = get_library_version(); Compiler::cuobjdump_path = Compiler::cuda_home / "bin" / "cuobjdump"; } @@ -112,12 +45,11 @@ class Compiler { DG_HOST_ASSERT(not library_root_path.empty()); DG_HOST_ASSERT(not library_include_path.empty()); DG_HOST_ASSERT(not cuda_home.empty()); - DG_HOST_ASSERT(not library_version.empty()); DG_HOST_ASSERT(not cuobjdump_path.empty()); // Cache settings cache_dir_path = std::filesystem::path(get_env("HOME")) / ".deep_gemm"; - if (const auto& env_cache_dir_path = get_env("DG_JIT_CACHE_DIR"); not env_cache_dir_path.empty()) + if (const auto env_cache_dir_path = get_env("DG_JIT_CACHE_DIR"); not env_cache_dir_path.empty()) cache_dir_path = env_cache_dir_path; // The compiler flags applied to all derived compilers @@ -137,58 +69,79 @@ class Compiler { return make_dirs(cache_dir_path / "tmp"); } - std::filesystem::path get_tmp_file_path() const { - return make_tmp_dir() / get_uuid(); + static void fsync_path(const std::filesystem::path& path) { + const auto fd = ::open(path.c_str(), O_RDONLY); + if (fd >= 0) { + ::fsync(fd); + ::close(fd); + } } - void put(const std::filesystem::path& path, const std::string& data) const { - const auto tmp_file_path = get_tmp_file_path(); + // Recursively fsync a directory: files and subdirectories first (bottom-up), then the directory itself + // NOTES: ensures data and directory entries are visible on other nodes in distributed filesystems + static void fsync_dir(const std::filesystem::path& dir_path) { // NOLINT(*-no-recursion) + for (const auto& entry: std::filesystem::directory_iterator(dir_path)) { + if (entry.is_directory()) + fsync_dir(entry.path()); + else if (entry.is_regular_file()) + fsync_path(entry.path()); + } + fsync_path(dir_path); + } - // Write into the temporary file - std::ofstream out(tmp_file_path, std::ios::binary); + static void put(const std::filesystem::path& path, const std::string& data) { + std::ofstream out(path, std::ios::binary); DG_HOST_ASSERT(out.write(data.data(), data.size())); out.close(); - // Atomically replace - std::filesystem::rename(tmp_file_path, path); + // NOTES: fsync to ensure the data is visible to other processes (e.g., NVCC) + // on distributed filesystems, where `close()` alone does not guarantee persistence + fsync_path(path); } std::shared_ptr build(const std::string& name, const std::string& code) const { - const auto kernel_signature = fmt::format("{}$${}$${}$${}$${}", name, library_version, signature, flags, code); + const auto kernel_signature = fmt::format("{}$${}$${}$${}", name, signature, flags, code); const auto dir_path = cache_dir_path / "cache" / fmt::format("kernel.{}.{}", name, get_hex_digest(kernel_signature)); // Hit the runtime cache - if (const auto& runtime = kernel_runtime_cache->get(dir_path); runtime != nullptr) + if (const auto runtime = kernel_runtime_cache->get(dir_path); runtime != nullptr) return runtime; - // Create the kernel directory - make_dirs(dir_path); + // Compile into a temporary directory, then atomically rename the whole directory + // NOTES: renaming a directory is atomic on both local and distributed filesystems, + // avoiding the stale inode issue that occurs when renaming individual files + const auto tmp_dir_path = make_tmp_dir() / get_uuid(); + make_dirs(tmp_dir_path); - // Compile into a temporary CUBIN - const auto tmp_cubin_path = get_tmp_file_path(); + // Compile into the temporary directory + const auto tmp_cubin_path = tmp_dir_path / "kernel.cubin"; if (get_env("DG_JIT_DUMP_ASM") or get_env("DG_JIT_DUMP_PTX")) { - // Dump PTX if needed - const auto tmp_ptx_path = get_tmp_file_path(); - compile(code, dir_path, tmp_cubin_path, tmp_ptx_path); - - // Replace into the cache directory - std::filesystem::rename(tmp_ptx_path, dir_path / "kernel.ptx"); + const auto tmp_ptx_path = tmp_dir_path / "kernel.ptx"; + compile(code, tmp_dir_path, tmp_cubin_path, tmp_ptx_path); } else { - compile(code, dir_path, tmp_cubin_path); + compile(code, tmp_dir_path, tmp_cubin_path); } - // Replace into the cache directory - const auto cubin_path = dir_path / "kernel.cubin"; - std::filesystem::rename(tmp_cubin_path, cubin_path); - // Disassemble if needed if (get_env("DG_JIT_DUMP_ASM") or get_env("DG_JIT_DUMP_SASS")) { - // Dump into a temporary SASS - const auto tmp_sass_path = get_tmp_file_path(); - disassemble(cubin_path, tmp_sass_path); + const auto tmp_sass_path = tmp_dir_path / "kernel.sass"; + disassemble(tmp_cubin_path, tmp_sass_path); + } - // Replace into the current directory - std::filesystem::rename(tmp_sass_path, dir_path / "kernel.sass"); + // Fsync before rename to ensure visibility on distributed filesystems + fsync_dir(tmp_dir_path); + + // Atomically rename the temporary directory to the final cache path + // NOTES: if another rank already created dir_path, rename will fail — that's fine + make_dirs(dir_path.parent_path()); + std::error_code error_code; + std::filesystem::rename(tmp_dir_path, dir_path, error_code); + if (error_code) { + // Another rank beat us, then clean up our dir and use the existing one + // NOTES: avoid `std::filesystem::remove_all` here — it can segfault on + // distributed filesystems, when concurrent processes operate + // on the same parent directory, causing stale directory entries + safe_remove_all(tmp_dir_path); } // Put into the runtime cache @@ -201,10 +154,10 @@ class Compiler { // Disassemble the CUBIN file to SASS const auto command = fmt::format("{} --dump-sass {} > {}", cuobjdump_path.c_str(), cubin_path.c_str(), sass_path.c_str()); if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) - fprintf(stderr, "Running cuobjdump command: %s\n", command.c_str()); + printf("Running cuobjdump command: %s\n", command.c_str()); const auto [return_code, output] = call_external_command(command); if (return_code != 0) { - fprintf(stderr, "cuobjdump failed: %s\n", output.c_str()); + printf("cuobjdump failed: %s\n", output.c_str()); DG_HOST_ASSERT(false and "cuobjdump failed"); } } @@ -215,7 +168,6 @@ class Compiler { DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_root_path); DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_include_path); DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuda_home); -DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_version); DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuobjdump_path); class NVCCCompiler final: public Compiler { @@ -225,18 +177,18 @@ class NVCCCompiler final: public Compiler { DG_HOST_ASSERT(std::filesystem::exists(nvcc_path)); // Call the version command - const auto& command = std::string(nvcc_path) + " --version"; - const auto& [return_code, output] = call_external_command(command); + const auto command = std::string(nvcc_path) + " --version"; + const auto [return_code, output] = call_external_command(command); DG_HOST_ASSERT(return_code == 0); - // Parse "release X.Y" without std::regex - int major = 0, minor = 0; - const char* release_pos = std::strstr(output.c_str(), "release "); - DG_HOST_ASSERT(release_pos != nullptr and "Could not find 'release' in nvcc --version output"); - std::sscanf(release_pos + 8, "%d.%d", &major, &minor); + // The version should be at least 12.3, for the best performance with 12.9 + int major, minor; + std::smatch match; + DG_HOST_ASSERT(std::regex_search(output, match, std::regex(R"(release (\d+\.\d+))"))); + std::sscanf(match[1].str().c_str(), "%d.%d", &major, &minor); DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVCC version should be >= 12.3"); if (major == 12 and minor < 9) - fprintf(stderr, "Warning: please use at least NVCC 12.9 for the best DeepGEMM performance\n"); + printf("Warning: please use at least NVCC 12.9 for the best DeepGEMM performance\n"); return {major, minor}; } @@ -244,66 +196,64 @@ class NVCCCompiler final: public Compiler { NVCCCompiler() { // Override the compiler signature nvcc_path = cuda_home / "bin" / "nvcc"; - if (const auto& env_nvcc_path = get_env("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty()) + if (const auto env_nvcc_path = get_env("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty()) nvcc_path = env_nvcc_path; - const auto& [nvcc_major, nvcc_minor] = get_nvcc_version(); + const auto [nvcc_major, nvcc_minor] = get_nvcc_version(); signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor); // The override the compiler flags // Only NVCC >= 12.9 supports arch-specific family suffix - const auto& arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9); - // DG_CUTLASS_INCLUDE is set by Python _find_cutlass_include() before ops.init() - const auto& cutlass_include = get_env("DG_CUTLASS_INCLUDE"); - std::string cutlass_flag = cutlass_include.empty() ? "" : fmt::format(" -I{}", cutlass_include); - flags = fmt::format("{} -I{}{} --gpu-architecture=sm_{} " + const auto arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9); + flags = fmt::format("{} -I{} --gpu-architecture=sm_{} " "--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi " "-O3 --expt-relaxed-constexpr --expt-extended-lambda", - flags, library_include_path.c_str(), cutlass_flag, arch); - - // print flags if ENV is set - if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_FLAGS", 0)) - fprintf(stderr, "NVCC compiler flags: %s\n", flags.c_str()); + flags, library_include_path.c_str(), arch); } void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path, const std::optional &ptx_path) const override { // Write the code into the cache directory - const auto& code_path = dir_path / "kernel.cu"; + const auto code_path = dir_path / "kernel.cu"; put(code_path, code); // Compile - const auto& command = fmt::format("{} {} -cubin -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags); + // Avoid cwd files shadowing C++ standard library headers + const auto compile_dir = make_tmp_dir(); + const auto command = fmt::format("cd {} && {} {} -cubin -o {} {}", + compile_dir.c_str(), nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags); if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) - fprintf(stderr, "Running NVCC command: %s\n", command.c_str()); - const auto& [return_code, output] = call_external_command(command); + printf("Running NVCC command: %s\n", command.c_str()); + const auto [return_code, output] = call_external_command(command); if (return_code != 0) { - fprintf(stderr, "NVCC compilation failed: %s\n", output.c_str()); + printf("NVCC compilation failed: %s\n", output.c_str()); DG_HOST_ASSERT(false and "NVCC compilation failed"); } // Compile to PTX if needed if (ptx_path.has_value()) { - const auto ptx_command = fmt::format("{} {} -ptx -o {} {}", nvcc_path.c_str(), code_path.c_str(), ptx_path->c_str(), flags); + const auto ptx_command = fmt::format("cd {} && {} {} -ptx -o {} {}", + compile_dir.c_str(), nvcc_path.c_str(), code_path.c_str(), ptx_path->c_str(), flags); if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) - fprintf(stderr, "Running NVCC PTX command: %s\n", ptx_command.c_str()); + printf("Running NVCC PTX command: %s\n", ptx_command.c_str()); const auto [ptx_return_code, ptx_output] = call_external_command(ptx_command); if (ptx_return_code != 0) { - fprintf(stderr, "NVCC PTX compilation failed: %s\n", ptx_output.c_str()); + printf("NVCC PTX compilation failed: %s\n", ptx_output.c_str()); DG_HOST_ASSERT(false and "NVCC PTX compilation failed"); } } - // Check local memory usage (without std::regex — avoids ABI issues) + // Check local memory usage if (get_env("DG_JIT_PTXAS_CHECK", 0)) - DG_HOST_ASSERT(output.find("Local memory used") == std::string::npos); + DG_HOST_ASSERT(not std::regex_search(output, std::regex(R"(Local memory used)"))); // Print PTXAS log if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0)) - fprintf(stderr, "%s", output.c_str()); + printf("%s", output.c_str()); } }; +#ifdef DG_ENABLE_NVRTC_COMPILER class NVRTCCompiler final: public Compiler { public: NVRTCCompiler() { @@ -317,9 +267,6 @@ class NVRTCCompiler final: public Compiler { std::string include_dirs; include_dirs += fmt::format("-I{} ", library_include_path.string()); include_dirs += fmt::format("-I{} ", (cuda_home / "include").string()); - // DG_CUTLASS_INCLUDE is set by Python _find_cutlass_include() before ops.init() - if (const auto& cutlass_include = get_env("DG_CUTLASS_INCLUDE"); not cutlass_include.empty()) - include_dirs += fmt::format("-I{} ", cutlass_include); // Add PCH support for version 12.8 and above // NOTES: PCH is vital for compilation speed @@ -332,7 +279,7 @@ class NVRTCCompiler final: public Compiler { // Override the compiler flags // Only NVRTC >= 12.9 supports arch-specific family suffix - const auto& arch = device_runtime->get_arch(false, major > 12 or minor >= 9); + const auto arch = device_runtime->get_arch(false, major > 12 or minor >= 9); flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {} --device-int128", flags, include_dirs, arch, pch_flags); } @@ -341,21 +288,15 @@ class NVRTCCompiler final: public Compiler { const std::filesystem::path &cubin_path, const std::optional &ptx_path) const override { // Write the code into the cache directory - const auto& code_path = dir_path / "kernel.cu"; + const auto code_path = dir_path / "kernel.cu"; put(code_path, code); - // Split flags by whitespace (without std::istringstream — avoids ABI issues) + // Parse compilation options + std::istringstream iss(flags); std::vector options; - { - size_t i = 0; - while (i < flags.size()) { - while (i < flags.size() && (flags[i] == ' ' || flags[i] == '\t')) ++i; - if (i >= flags.size()) break; - size_t start = i; - while (i < flags.size() && flags[i] != ' ' && flags[i] != '\t') ++i; - options.push_back(flags.substr(start, i - start)); - } - } + std::string option; + while (iss >> option) + options.push_back(option); // Convert to C-style string array for NVRTC std::vector option_cstrs; @@ -364,16 +305,16 @@ class NVRTCCompiler final: public Compiler { // Print compiler command if requested if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) { - fprintf(stderr, "Compiling JIT runtime with NVRTC options: "); + printf("Compiling JIT runtime with NVRTC options: "); for (const auto& opt: options) - fprintf(stderr, "%s ", opt.c_str()); - fprintf(stderr, "\n"); + printf("%s ", opt.c_str()); + printf("\n"); } // Create NVRTC program and compile nvrtcProgram program; DG_NVRTC_CHECK(nvrtcCreateProgram(&program, code.c_str(), "kernel.cu", 0, nullptr, nullptr)); - const auto& compile_result = nvrtcCompileProgram(program, static_cast(option_cstrs.size()), option_cstrs.data()); + const auto compile_result = nvrtcCompileProgram(program, static_cast(option_cstrs.size()), option_cstrs.data()); // Get and print compiler log size_t log_size; @@ -384,7 +325,7 @@ class NVRTCCompiler final: public Compiler { if (log_size > 1) { std::string compilation_log(log_size, '\0'); DG_NVRTC_CHECK(nvrtcGetProgramLog(program, compilation_log.data())); - fprintf(stderr, "NVRTC log: %s\n", compilation_log.c_str()); + printf("NVRTC log: %s\n", compilation_log.c_str()); } } @@ -412,11 +353,16 @@ class NVRTCCompiler final: public Compiler { DG_NVRTC_CHECK(nvrtcDestroyProgram(&program)); } }; +#endif static auto compiler = LazyInit([]() -> std::shared_ptr { +#ifdef DG_ENABLE_NVRTC_COMPILER if (get_env("DG_JIT_USE_NVRTC", 0)) { return std::make_shared(); } +#endif + if (get_env("DG_JIT_USE_NVRTC", 0)) + printf("Warning: DG_JIT_USE_NVRTC is ignored in Kernel Hub builds; using NVCC\n"); return std::make_shared(); }); diff --git a/deep-gemm/csrc/jit/device_runtime.hpp b/deep-gemm/csrc/jit/device_runtime.hpp index d33743ef..2321aded 100644 --- a/deep-gemm/csrc/jit/device_runtime.hpp +++ b/deep-gemm/csrc/jit/device_runtime.hpp @@ -7,10 +7,13 @@ #include "../utils/exception.hpp" #include "../utils/lazy_init.hpp" +#define PYTORCH_SUPPORTS_GET_CUBLASLT_HANDLE (TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 3)) + namespace deep_gemm { class DeviceRuntime { int num_sms = 0, tc_util = 0; + bool enable_pdl = false; std::shared_ptr cached_prop; // cuBLASLt utils @@ -18,24 +21,52 @@ class DeviceRuntime { public: // Create the cuBLASLt handle ourselves - cublasLtHandle_t cublaslt_handle{}; - std::shared_ptr cublaslt_workspace; + cublasLtHandle_t cublaslt_handle; + torch::Tensor cublaslt_workspace; + bool use_pytorch_managed_cublaslt_handle; + bool use_temp_cublaslt_workspace; explicit DeviceRuntime() { - cublaslt_workspace = std::make_shared(torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA))); - DG_CUBLASLT_CHECK(cublasLtCreate(&cublaslt_handle)); + + // Whether to use PyTorch cuBLASLt + // By default, we don't use it, + // as `at::cuda::getCurrentCUDABlasLtHandle` has large CPU overhead with some PyTorch versions + use_pytorch_managed_cublaslt_handle = get_env("DG_USE_PYTORCH_CUBLASLT_HANDLE", 0) > 0; +#if not PYTORCH_SUPPORTS_GET_CUBLASLT_HANDLE + DG_HOST_ASSERT(not use_pytorch_managed_cublaslt_handle and "PyTorch does not support to get cuBLASLt handle"); +#endif + + // Whether to create workspace tensor on each call instead of holding one. + // Enabled by compute-sanitizer tests, which trigger `cudaErrorCudartUnloading` + // when the workspace tensor is destructed after CUDA driver shutdown. + use_temp_cublaslt_workspace = get_env("DG_USE_TEMP_CUBLASLT_WORKSPACE", 0) > 0; + + if (not use_pytorch_managed_cublaslt_handle) + DG_CUBLASLT_CHECK(cublasLtCreate(&cublaslt_handle)); + + if (not use_temp_cublaslt_workspace) + cublaslt_workspace = torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA)); } ~DeviceRuntime() noexcept(false) { - DG_CUBLASLT_CHECK(cublasLtDestroy(cublaslt_handle)); + if (not use_pytorch_managed_cublaslt_handle) + DG_CUBLASLT_CHECK(cublasLtDestroy(cublaslt_handle)); } cublasLtHandle_t get_cublaslt_handle() const { +#if PYTORCH_SUPPORTS_GET_CUBLASLT_HANDLE + if (use_pytorch_managed_cublaslt_handle) + return at::cuda::getCurrentCUDABlasLtHandle(); +#endif + + // Self-managed handle return cublaslt_handle; } torch::Tensor get_cublaslt_workspace() const { - return *cublaslt_workspace; + if (use_temp_cublaslt_workspace) + return torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA)); + return cublaslt_workspace; } std::shared_ptr get_prop() { @@ -56,7 +87,7 @@ class DeviceRuntime { std::string get_arch(const bool& number_only = false, const bool& support_arch_family = false) { - const auto& [major, minor] = get_arch_pair(); + const auto [major, minor] = get_arch_pair(); if (major == 10 and minor != 1) { if (number_only) return "100"; @@ -92,6 +123,14 @@ class DeviceRuntime { int get_tc_util() const { return tc_util == 0 ? 100 : tc_util; } + + void set_pdl(const bool& new_enable_pdl) { + enable_pdl = new_enable_pdl; + } + + bool get_pdl() const { + return enable_pdl; + } }; static auto device_runtime = LazyInit([](){ return std::make_shared(); }); diff --git a/deep-gemm/csrc/jit/handle.hpp b/deep-gemm/csrc/jit/handle.hpp index 34447f91..8e1ee90e 100644 --- a/deep-gemm/csrc/jit/handle.hpp +++ b/deep-gemm/csrc/jit/handle.hpp @@ -24,7 +24,7 @@ static void* get_driver_handle() { #define DECL_LAZY_CUDA_DRIVER_FUNCTION(name) \ template \ static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \ - using FuncType = decltype(&name); \ + using FuncType = decltype(&(name)); \ static FuncType func = nullptr; \ if (func == nullptr) { \ func = reinterpret_cast(dlsym(get_driver_handle(), #name)); \ @@ -39,6 +39,9 @@ DECL_LAZY_CUDA_DRIVER_FUNCTION(cuFuncSetAttribute); DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleLoad); DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleUnload); DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleGetFunction); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryLoadFromFile); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryUnload); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuKernelGetFunction); DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLaunchKernelEx); DECL_LAZY_CUDA_DRIVER_FUNCTION(cuTensorMapEncodeTiled); @@ -65,13 +68,13 @@ static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const s } static void unload_library(const LibraryHandle& library) { - const auto& error = cudaLibraryUnload(library); + const auto error = cudaLibraryUnload(library); DG_HOST_ASSERT(error == cudaSuccess or error == cudaErrorCudartUnloading); } static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, const cudaStream_t& stream, const int& smem_size, - const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) { + const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim, const bool& enable_pdl) { if (smem_size > 0) DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -80,17 +83,27 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, config.blockDim = block_dim; config.dynamicSmemBytes = smem_size; config.stream = stream; - config.numAttrs = 0; - config.attrs = nullptr; + // Create attributes // NOTES: must use `static` or the `attr` will be deconstructed - static LaunchAttrHandle attr; + static LaunchAttrHandle attrs[2]; + config.numAttrs = 0; + config.attrs = attrs; + + // Cluster size if (cluster_dim > 1) { + auto& attr = attrs[config.numAttrs ++]; attr.id = cudaLaunchAttributeClusterDimension; attr.val.clusterDim = {static_cast(cluster_dim), 1, 1}; - config.attrs = &attr; - config.numAttrs = 1; } + + // Dependent kernel launch + if (enable_pdl) { + auto& attr = attrs[config.numAttrs ++]; + attr.id = cudaLaunchAttributeProgrammaticStreamSerialization; + attr.val.programmaticStreamSerializationAllowed = 1; + } + return config; } @@ -103,19 +116,46 @@ static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& #else // Use CUDA driver API -using LibraryHandle = CUmodule; using KernelHandle = CUfunction; using LaunchConfigHandle = CUlaunchConfig; using LaunchAttrHandle = CUlaunchAttribute; +// `cuLibraryEnumerateKernels` is supported since CUDA Driver API 12.4 +#if CUDA_VERSION >= 12040 + #define DG_JIT_USE_LIBRARY_ENUM_KERNELS + DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryGetKernelCount); + DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryEnumerateKernels); + using LibraryHandle = CUlibrary; +#else + using LibraryHandle = CUmodule; +#endif + #define DG_CUDA_UNIFIED_CHECK DG_CUDA_DRIVER_CHECK static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const std::string& func_name, - LibraryHandle *library_opt = nullptr) { + LibraryHandle *library_opt = nullptr) { LibraryHandle library; KernelHandle kernel; + +#ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS + DG_CUDA_DRIVER_CHECK(lazy_cuLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + unsigned int num_kernels; + DG_CUDA_DRIVER_CHECK(lazy_cuLibraryGetKernelCount(&num_kernels, library)); + if (num_kernels != 1) { + const auto dir_path = cubin_path.parent_path(); + printf("Corrupted JIT cache directory (expected 1 kernel, found %u): %s, " + "please run `rm -rf %s` and restart your task.\n", + num_kernels, dir_path.c_str(), dir_path.c_str()); + DG_HOST_ASSERT(false and "Corrupted JIT cache directory"); + } + + CUkernel cu_kernel; + DG_CUDA_DRIVER_CHECK(lazy_cuLibraryEnumerateKernels(&cu_kernel, 1, library)); + DG_CUDA_DRIVER_CHECK(lazy_cuKernelGetFunction(&kernel, cu_kernel)); +#else DG_CUDA_DRIVER_CHECK(lazy_cuModuleLoad(&library, cubin_path.c_str())); DG_CUDA_DRIVER_CHECK(lazy_cuModuleGetFunction(&kernel, library, func_name.c_str())); +#endif if (library_opt != nullptr) *library_opt = library; @@ -123,13 +163,17 @@ static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const s } static void unload_library(const LibraryHandle& library) { - const auto& error = lazy_cuModuleUnload(library); +#ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS + const auto error = lazy_cuLibraryUnload(library); +#else + const auto error = lazy_cuModuleUnload(library); +#endif DG_HOST_ASSERT(error == CUDA_SUCCESS or error == CUDA_ERROR_DEINITIALIZED); } static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, const cudaStream_t& stream, const int& smem_size, - const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) { + const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim, const bool& enable_pdl) { if (smem_size > 0) DG_CUDA_DRIVER_CHECK(lazy_cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size)); @@ -142,19 +186,29 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, config.blockDimZ = block_dim.z; config.sharedMemBytes = smem_size; config.hStream = stream; - config.numAttrs = 0; - config.attrs = nullptr; + // Create attributes // NOTES: must use `static` or the `attr` will be deconstructed - static LaunchAttrHandle attr; + static LaunchAttrHandle attrs[2]; + config.numAttrs = 0; + config.attrs = attrs; + + // Cluster size if (cluster_dim > 1) { + auto& attr = attrs[config.numAttrs ++]; attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - attr.value.clusterDim.x = cluster_dim; + attr.value.clusterDim.x = static_cast(cluster_dim); attr.value.clusterDim.y = 1; attr.value.clusterDim.z = 1; - config.attrs = &attr; - config.numAttrs = 1; } + + // Dependent kernel launch + if (enable_pdl) { + auto& attr = attrs[config.numAttrs ++]; + attr.id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; + attr.value.programmaticStreamSerializationAllowed = 1; + } + return config; } diff --git a/deep-gemm/csrc/jit/include_parser.hpp b/deep-gemm/csrc/jit/include_parser.hpp new file mode 100644 index 00000000..99f2663c --- /dev/null +++ b/deep-gemm/csrc/jit/include_parser.hpp @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include +#include + +#include "../utils/format.hpp" +#include "../utils/system.hpp" + +namespace deep_gemm { + +class IncludeParser { + std::unordered_map> cache; + + static std::vector get_includes(const std::string& code, const std::filesystem::path& file_path = "") { + std::vector includes; + const std::regex pattern(R"(#\s*include\s*[<"][^>"]+[>"])"); + std::sregex_iterator iter(code.begin(), code.end(), pattern); + const std::sregex_iterator end; + + // TODO: parse relative paths as well + for (; iter != end; ++ iter) { + const auto include_str = iter->str(); + const int len = include_str.length(); + if (include_str.substr(0, 10) == "#include <" and include_str[len - 1] == '>' and include_str[10] != ' ' and include_str[len - 2] != ' ') { + std::string filename = include_str.substr(10, len - 11); + if (filename.substr(0, 9) == "deep_gemm") // We only parse `` + includes.push_back(filename); + } else { + std::string error_info = fmt::format("Non-standard include: {}", include_str); + if (file_path != "") + error_info += fmt::format(" ({})", file_path.string()); + DG_HOST_UNREACHABLE(error_info); + } + } + return includes; + } + +public: + static std::filesystem::path library_include_path; + + static void prepare_init(const std::string& library_root_path) { + library_include_path = std::filesystem::path(library_root_path) / "include"; + } + + std::string get_hash_value(const std::string& code, const bool& exclude_code = true) { + std::stringstream ss; + for (const auto& i: get_includes(code)) + ss << get_hash_value_by_path(library_include_path / i) << "$"; + if (not exclude_code) + ss << "#" << get_hex_digest(code); + return get_hex_digest(ss.str()); + } + + std::string get_hash_value_by_path(const std::filesystem::path& path) { + // Check whether hit in cache + // ReSharper disable once CppUseAssociativeContains + if (cache.count(path) > 0) { + const auto opt = cache[path]; + if (not opt.has_value()) + DG_HOST_UNREACHABLE(fmt::format("Circular include may occur: {}", path.string())); + return opt.value(); + } + + // Read file and calculate hash recursively + std::ifstream in(path); + if (not in.is_open()) + DG_HOST_UNREACHABLE(fmt::format("Failed to open: {}", path.string())); + std::string code((std::istreambuf_iterator(in)), std::istreambuf_iterator()); + cache[path] = std::nullopt; + return (cache[path] = get_hash_value(code, false)).value(); + } +}; + +DG_DECLARE_STATIC_VAR_IN_CLASS(IncludeParser, library_include_path); + +static auto include_parser = std::make_shared(); + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit/kernel_runtime.hpp b/deep-gemm/csrc/jit/kernel_runtime.hpp index 60563a1d..40597fb4 100644 --- a/deep-gemm/csrc/jit/kernel_runtime.hpp +++ b/deep-gemm/csrc/jit/kernel_runtime.hpp @@ -1,10 +1,13 @@ #pragma once +#include + #include "../utils/exception.hpp" #include "../utils/format.hpp" #include "../utils/system.hpp" #include "device_runtime.hpp" #include "handle.hpp" +#include "include_parser.hpp" namespace deep_gemm { @@ -13,12 +16,13 @@ struct LaunchArgs { int num_threads; int smem_size; int cluster_dim; + bool enable_pdl; - LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): - grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} + LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1, const bool& enable_pdl = true): + grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim), enable_pdl(enable_pdl) {} - LaunchArgs(const std::pair& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): - grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} + LaunchArgs(const std::pair& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1, const bool& enable_pdl = true): + grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim), enable_pdl(enable_pdl) {} }; class KernelRuntime final { @@ -33,41 +37,56 @@ class KernelRuntime final { DG_HOST_ASSERT(not cuda_home.empty()); // NOLINT(*-pro-type-member-init) - const auto& cuobjdump_path = cuda_home / "bin" / "cuobjdump"; - const auto& cubin_path = dir_path / "kernel.cubin"; + const auto cuobjdump_path = cuda_home / "bin" / "cuobjdump"; + const auto cubin_path = dir_path / "kernel.cubin"; if (get_env("DG_JIT_DEBUG")) - fprintf(stderr, "Loading CUBIN: %s\n", cubin_path.c_str()); + printf("Loading CUBIN: %s\n", cubin_path.c_str()); + + // Record start time + std::chrono::high_resolution_clock::time_point start_time; + if (get_env("DG_JIT_DEBUG") or get_env("DG_JIT_PRINT_LOAD_TIME")) + start_time = std::chrono::high_resolution_clock::now(); +#ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS + // Load from the library + kernel = load_kernel(cubin_path, {}, &library); +#else // Find the only symbol // TODO: use kernel enumeration for newer drivers const std::vector illegal_names = {"vprintf", "__instantiate_kernel", "__internal", "__assertfail"}; - const auto& [exit_code, symbols] = call_external_command(fmt::format("{} -symbols {}", cuobjdump_path.c_str(), cubin_path.c_str())); + const auto [exit_code, symbols] = call_external_command(fmt::format("{} -symbols {}", cuobjdump_path.c_str(), cubin_path.c_str())); DG_HOST_ASSERT(exit_code == 0); - // Parse line-by-line without std::istringstream + std::istringstream iss(symbols); std::vector symbol_names; - size_t pos = 0; - while (pos < symbols.size()) { - size_t eol = symbols.find('\n', pos); - if (eol == std::string::npos) eol = symbols.size(); - std::string line = symbols.substr(pos, eol - pos); - pos = eol + 1; + for (std::string line; std::getline(iss, line); ) { if (line.find("STT_FUNC") == 0 and line.find("STO_ENTRY") != std::string::npos and std::none_of(illegal_names.begin(), illegal_names.end(), - [&](const auto& name) { return line.find(name) != std::string::npos; })) { - const auto& last_space = line.rfind(' '); + [&](const auto name) { return line.find(name) != std::string::npos; })) { + const auto last_space = line.rfind(' '); symbol_names.push_back(line.substr(last_space + 1)); } } - if (get_env("DG_JIT_DEBUG")) { - fprintf(stderr, "Symbol names: "); + + // Print symbols + if (symbol_names.size() != 1 or get_env("DG_JIT_DEBUG")) { + printf("Symbols: "); + printf(" > CUBIN: %s\n", cubin_path.c_str()); + printf(" > Raw symbols: %s\n", symbols.c_str()); + printf(" > Parsed symbols:\n"); for (const auto& symbol: symbol_names) - fprintf(stderr, "%s, ", symbol.c_str()); - fprintf(stderr, "\n"); + printf(" > %s, ", symbol.c_str()); } + DG_HOST_ASSERT(symbol_names.size() == 1); // Load from the library - DG_HOST_ASSERT(symbol_names.size() == 1); kernel = load_kernel(cubin_path, symbol_names[0], &library); +#endif + + // Print load time + if (get_env("DG_JIT_DEBUG") or get_env("DG_JIT_PRINT_LOAD_TIME")) { + std::chrono::duration load_time = std::chrono::high_resolution_clock::now() - start_time; + printf("Load time (%s): %.2lf ms\n", dir_path.c_str(), load_time.count()); + } } static void prepare_init(const std::string& cuda_home_path_by_python) { @@ -75,8 +94,19 @@ class KernelRuntime final { } static bool check_validity(const std::filesystem::path& dir_path) { - return std::filesystem::exists(dir_path / "kernel.cu") and - std::filesystem::exists(dir_path / "kernel.cubin"); + if (not std::filesystem::exists(dir_path)) + return false; + + // NOTES: if the directory exists, `kernel.cu` and `kernel.cubin` must both exist, + // because the directory is created atomically via rename + if (not std::filesystem::exists(dir_path / "kernel.cu") or + not std::filesystem::exists(dir_path / "kernel.cubin")) { + printf("Corrupted JIT cache directory (missing kernel.cu or kernel.cubin): %s, " + "please run `rm -rf %s` and restart your task.\n", + dir_path.c_str(), dir_path.c_str()); + DG_HOST_ASSERT(false and "Corrupted JIT cache directory"); + } + return true; } ~KernelRuntime() noexcept(false) { @@ -91,30 +121,42 @@ class LaunchRuntime { public: template static std::string generate(const Args& args) { - const auto& code = Derived::generate_impl(args); - if (get_env("DG_JIT_DEBUG", 0)) - fprintf(stderr, "Generated kernel code: %s\n", code.c_str()); + auto code = Derived::generate_impl(args); + + // NOTES: we require that `generate_impl`'s includes never change + static std::string include_hash; + if (include_hash.empty()) + include_hash = include_parser->get_hash_value(code); + + // TODO: optimize string concat performance + code = fmt::format("// Includes' hash value: {}\n{}", include_hash, code); + if (get_env("DG_JIT_DEBUG")) + printf("Generated kernel code:\n%s\n", code.c_str()); return code; } template static void launch(const std::shared_ptr& kernel_runtime, const Args& args) { - const auto& kernel = kernel_runtime->kernel; - const auto& stream = at::cuda::getCurrentCUDAStream(); - const LaunchArgs& launch_args = args.launch_args; - - const dim3& grid_dim = {static_cast(launch_args.grid_dim.first), - static_cast(launch_args.grid_dim.second), - 1}; - const dim3& block_dim = {static_cast(launch_args.num_threads), 1, 1}; + const auto kernel = kernel_runtime->kernel; + const auto stream = at::cuda::getCurrentCUDAStream(); + LaunchArgs launch_args = args.launch_args; + + // Allow runtime override from Python. + // NOTES: the default is enabled. + launch_args.enable_pdl = device_runtime->get_pdl(); + + const dim3 grid_dim = {static_cast(launch_args.grid_dim.first), + static_cast(launch_args.grid_dim.second), + 1}; + const dim3 block_dim = {static_cast(launch_args.num_threads), 1, 1}; auto config = construct_launch_config(kernel, stream, launch_args.smem_size, - grid_dim, block_dim, launch_args.cluster_dim); + grid_dim, block_dim, launch_args.cluster_dim, launch_args.enable_pdl); // Launch in the derived class if (get_env("DG_JIT_DEBUG")) { - fprintf(stderr, "Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, stream: %ld\n", + printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, pdl: %d, stream: %ld\n", launch_args.grid_dim.first, launch_args.grid_dim.second, launch_args.num_threads, - launch_args.smem_size, launch_args.cluster_dim, stream.id()); + launch_args.smem_size, launch_args.cluster_dim, launch_args.enable_pdl, stream.id()); } Derived::launch_impl(kernel, config, args); } diff --git a/deep-gemm/csrc/jit_kernels/heuristics/common.hpp b/deep-gemm/csrc/jit_kernels/heuristics/common.hpp index a49584f4..2b79a8b7 100644 --- a/deep-gemm/csrc/jit_kernels/heuristics/common.hpp +++ b/deep-gemm/csrc/jit_kernels/heuristics/common.hpp @@ -1,339 +1,54 @@ #pragma once -#include +#include +#include -#include "../../utils/math.hpp" +#include "config.hpp" +#include "runtime.hpp" #include "../../utils/layout.hpp" #include "../../utils/system.hpp" namespace deep_gemm { -struct MulticastConfig { - int num_multicast; - bool is_multicast_on_a; - - MulticastConfig(const int& num_multicast, const bool& is_multicast_on_a): - num_multicast(num_multicast), is_multicast_on_a(is_multicast_on_a) { - DG_HOST_ASSERT(1 <= num_multicast and num_multicast <= 2); - } -}; - -struct SharedMemoryConfig { - int smem_size; - int swizzle_a_mode; - int swizzle_b_mode; - int swizzle_cd_mode; -}; - -struct ThreadConfig { - int num_threads; - - // SM90 - int num_tma_threads; - int num_math_threads; - - // SM100 - int num_non_epilogue_threads; - int num_epilogue_threads; - - static ThreadConfig sm90(const int& num_tma_threads, - const int& num_math_threads) { - auto config = ThreadConfig(); - config.num_threads = num_tma_threads + num_math_threads; - config.num_tma_threads = num_tma_threads; - config.num_math_threads = num_math_threads; - return config; - } - - static ThreadConfig sm100(const int& num_non_epilogue_threads, - const int& num_epilogue_threads) { - auto config = ThreadConfig(); - config.num_threads = num_non_epilogue_threads + num_epilogue_threads; - config.num_non_epilogue_threads = num_non_epilogue_threads; - config.num_epilogue_threads = num_epilogue_threads; - return config; - } -}; - -struct GemmConfig { - // Templated configs - GemmType gemm_type; - KernelType kernel_type; - MmaKind mma_kind; - at::ScalarType a_dtype, b_dtype, cd_dtype; - cute::UMMA::Major major_a; - cute::UMMA::Major major_b; - bool with_accumulation; - int block_m, block_n, block_k; - int num_stages, num_last_stages; - - // Templated device configs - int num_sms; - int tc_util; - - // Structured configs - MulticastConfig multicast_config; - SharedMemoryConfig smem_config; - ThreadConfig thread_config; -}; - -static bool is_multicast_legal(const int& shape_dim, const int& block_dim, - const int& num_multicast, const int& num_sms, - const bool& require_divisible) { - const bool& divisible = ceil_div(shape_dim, block_dim) % num_multicast == 0 or not require_divisible; - return divisible and num_sms % num_multicast == 0; -} - -template -static int get_swizzle_mode(const int& block_size, const size_type_t& elem_size) { - // `> 0` means interleaving - // 16B actually means non-swizzling (but interleaving) - for (const int& mode: {128, 64, 32, 16}) { - if ((block_size * static_cast(elem_size)) % mode == 0) - return mode; - } - DG_HOST_UNREACHABLE("Unreachable"); -} - template -static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const KernelType& kernel_type, - const int& m, const int& n, const int& k, - const int& block_m, const int& block_n, const int& block_k, - const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const MmaKind& mma_kind, const at::ScalarType& cd_dtype, - const int& num_stages, const MulticastConfig& multicast_config) { - const int& ab_elem_size = static_cast(get_element_size(mma_kind)); - const int& cd_elem_size = static_cast(c10::elementSize(cd_dtype)); - - const int& load_block_m = ArchSpec::get_ab_load_block_m(multicast_config, block_m); - const int& load_block_n = ArchSpec::get_ab_load_block_n(multicast_config, block_n); - const int& swizzle_a_mode = get_swizzle_mode(major_a == cute::UMMA::Major::K ? block_k : load_block_m, ab_elem_size); - const int& swizzle_b_mode = get_swizzle_mode(major_b == cute::UMMA::Major::K ? block_k : load_block_n, ab_elem_size); - const int& swizzle_cd_mode = ArchSpec::enable_cd_swizzle(cd_dtype) ? get_swizzle_mode(block_n, cd_elem_size) : 0; - - // Different archs have different epilogue pipelines - const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype); - - // A/B shared memory - const int& smem_a_per_stage = load_block_m * block_k * ab_elem_size; - const int& smem_b_per_stage = load_block_n * block_k * ab_elem_size; - - // SF shared memory - const auto& [smem_sfa_per_stage, smem_sfb_per_stage] = - ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, mma_kind, cd_dtype); - const int& smem_extra_sfb = ArchSpec::get_extra_sfb_smem_size(m, n, k, block_m, block_n, block_k); - - // M-barriers and tensor memory pointers - const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages); - const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size(); - const int& smem_tensor_map = ArchSpec::get_tensormap_smem_size(gemm_type); - - // Sum them up - int smem_size = 0; - smem_size += smem_tensor_map; - smem_size += smem_cd; - smem_size += num_stages * smem_a_per_stage; - smem_size += num_stages * smem_b_per_stage; - smem_size += num_stages * smem_sfa_per_stage; - smem_size += num_stages * smem_sfb_per_stage; - smem_size += smem_extra_sfb; - smem_size += smem_barrier; - smem_size += smem_tmem_ptr; - - return SharedMemoryConfig { - .smem_size = smem_size, - .swizzle_a_mode = swizzle_a_mode, - .swizzle_b_mode = swizzle_b_mode, - .swizzle_cd_mode = swizzle_cd_mode, +static GemmConfig get_best_config(const GemmDesc& desc) { + desc.check_validity(); + + // Choose the best layout + const auto layout_candidates = ArchSpec::get_layout_candidates(desc); + DG_HOST_ASSERT(not layout_candidates.empty()); + auto layout = layout_candidates[0]; + auto layout_info = ArchSpec::get_layout_info(desc, layout); + for (int i = 1; i < static_cast(layout_candidates.size()); ++ i) { + const auto candidate_info = ArchSpec::get_layout_info(desc, layout_candidates[i]); + if (ArchSpec::compare(candidate_info, layout_info)) + layout = layout_candidates[i], layout_info = candidate_info; + } + + // Infer other configs + const auto storage_config = ArchSpec::get_storage_config(desc, layout); + const auto pipeline_config = ArchSpec::get_pipeline_config(desc, layout, storage_config); + const auto launch_config = ArchSpec::get_launch_config(desc, layout); + const auto gemm_config = GemmConfig { + .layout = layout, + .storage_config = storage_config, + .pipeline_config = pipeline_config, + .launch_config = launch_config }; -} - -template -static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type, - const int& m, const int& n, const int& k, const int& num_groups, - const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const at::ScalarType& a_dtype, const at::ScalarType& b_dtype, - const at::ScalarType& cd_dtype, - const bool& with_accumulation, const int& num_sms) { - const auto mma_kind = (a_dtype == torch::kBFloat16 ? MmaKind::BF16 : MmaKind::MXFP8FP4); - if (mma_kind == MmaKind::BF16) { - DG_HOST_ASSERT(a_dtype == torch::kBFloat16 and b_dtype == torch::kBFloat16); - } else { - DG_HOST_ASSERT(a_dtype == torch::kFloat8_e4m3fn or a_dtype == kPackedFP4); - DG_HOST_ASSERT(b_dtype == torch::kFloat8_e4m3fn or b_dtype == kPackedFP4); - } - DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); - - // Select M/N block sizes - auto block_ms = ArchSpec::get_block_m_candidates(kernel_type, major_a, m); - if (gemm_type == GemmType::MGroupedContiguous) - block_ms = std::vector{get_mk_alignment_for_contiguous_layout()}; - if (gemm_type == GemmType::MGroupedMasked or gemm_type == GemmType::MGroupedContiguousWithPsumLayout) - block_ms = std::vector{64, 128}; // Exclude 256 for performance - auto block_ns = ArchSpec::get_block_n_candidates(kernel_type, cd_dtype); - - // NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B - // TODO: Optimize it - if (a_dtype == kPackedFP4 and major_a == cute::UMMA::Major::MN) - block_ms = std::vector{128}; - if (b_dtype == kPackedFP4 and major_b == cute::UMMA::Major::MN) - block_ns = std::vector{128}; - - // K block size is selected in a fixed manner - const auto& block_k = (mma_kind == MmaKind::BF16 ? 64 : 128); - - // Some util functions - const auto& get_num_blocks = [=](const int& block_m, const int& block_n) { - return ceil_div(m, block_m) * ceil_div(n, block_n) * num_groups; - }; - const auto& get_num_waves = [=](const int& block_m, const int& block_n) { - return ceil_div(get_num_blocks(block_m, block_n), num_sms); - }; - const auto& get_last_wave_util = [=](const int& block_m, const int& block_n) { - const auto& num_last_blocks = get_num_blocks(block_m, block_n) % num_sms; - return num_last_blocks == 0 ? num_sms : num_last_blocks; - }; - - // Decide block sizes by waves - int best_block_m = 0, best_block_n = 0; - int best_num_waves = 0, best_last_util = 0; - for (const auto& block_m: block_ms) { - for (const auto& block_n: block_ns) { - const int& num_waves = get_num_waves(block_m, block_n); - const auto& last_util = get_last_wave_util(block_m, block_n); - if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, mma_kind, cd_dtype, m, n, k, block_m, block_n, block_k)) - continue; - - bool success = false; - if (best_block_m == 0 or best_block_n == 0 or num_waves < best_num_waves) { - success = true; - } else if (num_waves == best_num_waves) { - // Check last wave utilization - success = last_util > best_last_util; - if (last_util == best_last_util) { - // Case 1: same `block_m`, smaller `block_n` (wasted) - success |= block_m == best_block_m and block_n < best_block_n; - // Case 2: same `block_n`, smaller `block_m` (wasted) - success |= block_n == best_block_n and block_m < best_block_m; - // Case 3: different for both `block_m` and `block_n`, larger `block_n` is better - // NOTES: don't pick `block_m/block_n` larger than shape `m/n` in this case - success |= block_m != best_block_m and block_n > best_block_n - and block_n <= n and block_m <= m; - } - } - - // Replace with the new config if successful - if (success) { - best_block_m = block_m, best_block_n = block_n; - best_num_waves = num_waves, best_last_util = last_util; - } - } - } - DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0); - - // Decide the number of TMA multicasts and whether broadcast on A - MulticastConfig best_multicast_config = {1, false}; - auto [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality( - gemm_type, num_groups, m, n, best_block_m, best_block_n, num_sms); - - // NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B - // TODO: Optimize it - if (a_dtype == kPackedFP4 and major_a == cute::UMMA::Major::MN) - is_legal_on_a = false; - if (b_dtype == kPackedFP4 and major_b == cute::UMMA::Major::MN) - is_legal_on_b = false; - - const bool is_legal[2] = {is_legal_on_b, is_legal_on_a}; - bool order[2] = {false, true}; - if (best_block_m > best_block_n) - std::swap(order[0], order[1]); - for (const bool& is_multicast_on_a: order) { - if (m >= 512 and is_legal[static_cast(is_multicast_on_a)]) { - best_multicast_config = {2, is_multicast_on_a}; - break; - } - } - - // Always pick the largest number of stage - constexpr int smem_capacity = ArchSpec::smem_capacity; - int best_num_stages = 0; - SharedMemoryConfig best_smem_config; - for (int num_stages = 32; num_stages > 0; -- num_stages) { - if (not ArchSpec::is_num_stages_legal(mma_kind, cd_dtype, num_stages, best_block_m, best_block_n, block_k)) - continue; - - best_smem_config = get_smem_config(gemm_type, kernel_type, - m, n, k, - best_block_m, best_block_n, block_k, - major_a, major_b, - mma_kind, cd_dtype, - num_stages, best_multicast_config); - if (best_smem_config.smem_size <= smem_capacity) { - best_num_stages = num_stages; - break; - } - } - DG_HOST_ASSERT(best_num_stages != 0); - - // Recompute the minimal number of SMs required - // NOTES: less L2 cache usage and less GPU frequency drop - int num_min_sms = num_sms; - if (get_env("DG_JIT_MINIMIZE_NUM_SMS", 0)) { - num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves); - num_min_sms = align(num_min_sms, best_multicast_config.num_multicast); - DG_HOST_ASSERT(num_min_sms <= num_sms); - } - - const auto& config = GemmConfig { - .gemm_type = gemm_type, - .kernel_type = kernel_type, - .mma_kind = mma_kind, - .a_dtype = a_dtype, - .b_dtype = b_dtype, - .cd_dtype = cd_dtype, - .major_a = major_a, - .major_b = major_b, - .with_accumulation = with_accumulation, - .block_m = best_block_m, - .block_n = best_block_n, - .block_k = block_k, - .num_stages = best_num_stages, - .num_last_stages = ceil_div(k, block_k) % best_num_stages, - .num_sms = num_min_sms, - .tc_util = device_runtime->get_tc_util(), - .multicast_config = best_multicast_config, - // ReSharper disable once CppLocalVariableMightNotBeInitialized - .smem_config = best_smem_config, - .thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n) - }; - - // Only SM100 BF16 kernels support tensor core control - if (config.tc_util < 100) - DG_HOST_ASSERT(device_runtime->get_arch_major() == 10 and mma_kind == MmaKind::BF16); // Print configs for the first time if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { - auto key = std::make_tuple(gemm_type, kernel_type, m, n, k, num_groups, major_a, major_b, - mma_kind, a_dtype, b_dtype, cd_dtype, with_accumulation, num_sms); - static std::set printed; + std::stringstream ss; + ss << desc; + const auto key = ss.str(); + + static std::unordered_set printed; if (printed.count(key) == 0) { - printf("GEMM type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, " - "A major: %d, B major: %d, MMA kind: %d, A dtype: %s, B dtype: %s, CD dtype: %s, accumulation: %d, " - "SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, " - "SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, " - "swizzle B: %d, swizzle CD: %d, SMs: %d, threads: %d, TC util: %d%%\n", - static_cast(gemm_type), static_cast(kernel_type), m, n, k, num_groups, - static_cast(major_a), static_cast(major_b), static_cast(mma_kind), - c10::toString(a_dtype), c10::toString(b_dtype), c10::toString(cd_dtype), - static_cast(with_accumulation), num_sms, best_block_m, best_block_n, block_k, - best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast, - static_cast(best_multicast_config.is_multicast_on_a), - best_smem_config.smem_size, best_smem_config.swizzle_a_mode, best_smem_config.swizzle_b_mode, - best_smem_config.swizzle_cd_mode, config.num_sms, config.thread_config.num_threads, config.tc_util); + std::cout << desc << ": " << gemm_config << ", " << layout_info << std::endl; printed.insert(key); } } - return config; + return gemm_config; } } // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/heuristics/config.hpp b/deep-gemm/csrc/jit_kernels/heuristics/config.hpp new file mode 100644 index 00000000..c06f2f16 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/heuristics/config.hpp @@ -0,0 +1,171 @@ +#pragma once + +#include +#include +#include + +#include "../../utils/math.hpp" + +namespace deep_gemm { + +/// GEMM descriptors +struct GemmDesc { + GemmType gemm_type; + KernelType kernel_type; + int m, n, k, num_groups; + at::ScalarType a_dtype, b_dtype, cd_dtype; + cute::UMMA::Major major_a; + cute::UMMA::Major major_b; + bool with_accumulation; + + // Requirements from users + int num_sms, tc_util; + std::string compiled_dims; + + // Shape for heuristic generation + int expected_m = 0, expected_n = 0, expected_k = 0, expected_num_groups = 0; + int get_expected_m() const { return expected_m > 0 ? expected_m : m; } + int get_expected_n() const { return expected_n > 0 ? expected_n : n; } + int get_expected_k() const { return expected_k > 0 ? expected_k : k; } + int get_expected_num_groups() const { return expected_num_groups > 0 ? expected_num_groups : num_groups; } + + MmaKind get_mma_kind() const { + return a_dtype == torch::kBFloat16 ? MmaKind::BF16 : MmaKind::MXFP8FP4; + } + + void check_validity() const { + if (get_mma_kind() == MmaKind::BF16) { + DG_HOST_ASSERT(a_dtype == torch::kBFloat16 and b_dtype == torch::kBFloat16); + } else { + DG_HOST_ASSERT(a_dtype == torch::kFloat8_e4m3fn or a_dtype == kPackedFP4); + DG_HOST_ASSERT(b_dtype == torch::kFloat8_e4m3fn or b_dtype == kPackedFP4); + } + DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); + DG_HOST_ASSERT(num_sms % 2 == 0); + } + + friend std::ostream& operator << (std::ostream& os, const GemmDesc& desc) { + MmaKind mma_kind = desc.get_mma_kind(); + os << "GemmDesc(gemm_type=" << static_cast(desc.gemm_type) + << ", kernel_type=" << static_cast(desc.kernel_type) + << ", m=" << desc.m << ", n=" << desc.n << ", k=" << desc.k + << ", num_groups=" << desc.num_groups + << ", major_a=" << static_cast(desc.major_a) + << ", major_b=" << static_cast(desc.major_b) + << ", mma_kind=" << static_cast(mma_kind) + << ", a_dtype=" << c10::toString(desc.a_dtype) + << ", b_dtype=" << c10::toString(desc.b_dtype) + << ", cd_dtype=" << c10::toString(desc.cd_dtype) + << ", with_accumulation=" << static_cast(desc.with_accumulation) + << ", num_sms=" << desc.num_sms + << ", tc_util=" << desc.tc_util + << ", compiled_dims=" << desc.compiled_dims + << ", expected_m=" << desc.expected_m + << ", expected_n=" << desc.expected_n + << ", expected_k=" << desc.expected_k + << ", expected_num_groups=" << desc.expected_num_groups << ")"; + return os; + } +}; + +/// GEMM configs +struct Layout { + int swap_ab; + int block_m, block_n, block_k; + int cluster_m, cluster_n; + + int get_cluster_size() const { + return cluster_m * cluster_n; + } + + friend std::ostream& operator << (std::ostream& os, const Layout& layout) { + os << "Layout(swap_ab=" << layout.swap_ab + << ", block_m=" << layout.block_m << ", block_n=" << layout.block_n << ", block_k=" << layout.block_k + << ", cluster_m=" << layout.cluster_m << ", cluster_n=" << layout.cluster_n << ")"; + return os; + } +}; + +struct StorageConfig { + int load_block_m, load_block_n; + int store_block_m, store_block_n; + + int swizzle_a_mode, swizzle_b_mode; + int swizzle_cd_mode; + + friend std::ostream& operator << (std::ostream& os, const StorageConfig& config) { + os << "StorageConfig(" + << "load_block_m=" << config.load_block_m << ", load_block_n=" << config.load_block_n + << ", store_block_m=" << config.store_block_m << ", store_block_n=" << config.store_block_n + << ", swizzle_a_mode=" << config.swizzle_a_mode << ", swizzle_b_mode=" << config.swizzle_b_mode + << ", swizzle_cd_mode=" << config.swizzle_cd_mode << ")"; + return os; + } +}; + +struct PipelineConfig { + int smem_size; + int num_stages; + + friend std::ostream& operator << (std::ostream& os, const PipelineConfig& config) { + os << "PipelineConfig(" + << "smem_size=" << config.smem_size + << ", num_stages=" << config.num_stages << ")"; + return os; + } +}; + +struct LaunchConfig { + int num_sms; + int num_sms_per_cluster; + int num_threads; + + int num_tma_threads; + int num_math_threads; + int num_non_epilogue_threads; + int num_epilogue_threads; + + friend std::ostream& operator << (std::ostream& os, const LaunchConfig& config) { + os << "LaunchConfig(" + << "num_sms=" << config.num_sms << ", num_sms_per_cluster=" << config.num_sms_per_cluster + << ", num_threads=" << config.num_threads + << ", num_tma_threads=" << config.num_tma_threads << ", num_math_threads=" << config.num_math_threads + << ", num_non_epilogue_threads=" << config.num_non_epilogue_threads + << ", num_epilogue_threads=" << config.num_epilogue_threads << ")"; + return os; + } +}; + +struct GemmConfig { + Layout layout; + StorageConfig storage_config; + PipelineConfig pipeline_config; + LaunchConfig launch_config; + + friend std::ostream& operator << (std::ostream& os, const GemmConfig& config) { + os << "GemmConfig(" + << "layout=" << config.layout + << ", storage_config=" << config.storage_config + << ", pipeline_config=" << config.pipeline_config + << ", launch_config=" << config.launch_config << ")"; + return os; + } +}; + +/// Config comparators +struct LayoutInfo { + int num_waves; + int last_wave_util; + int64_t num_cycles; + Layout layout; + + friend std::ostream& operator << (std::ostream& os, const LayoutInfo& config) { + os << "LayoutInfo(" + << "num_waves=" << config.num_waves + << ", last_wave_util=" << config.last_wave_util + << ", num_cycles=" << config.num_cycles << ")"; + return os; + } +}; + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/heuristics/mega_moe.hpp b/deep-gemm/csrc/jit_kernels/heuristics/mega_moe.hpp new file mode 100644 index 00000000..b1ba6bd7 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -0,0 +1,240 @@ +#pragma once + +#include +#include + +#include + +#include "../../utils/exception.hpp" +#include "../../utils/math.hpp" +#include "../../utils/system.hpp" +#include "sm100.hpp" + +namespace deep_gemm { + +struct MegaMoEConfig { + // Block tiling + int block_m, block_n, block_k; + int load_block_m, load_block_n; + int store_block_m; + + // SF block sizes (UTCCP 128-aligned) + int sf_block_m, sf_block_n; + + // Pool capacity and SF-padded token count + int num_max_pool_tokens; + int num_padded_sf_pool_tokens; + + // Swizzle modes for TMA descriptors + int swizzle_acts_mode, swizzle_weights_mode; + + // Number of experts to process per wave + int num_experts_per_wave; + + // Pipeline stages and shared memory + int num_stages, smem_size; + + // Thread layout + int num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads; + + friend std::ostream& operator << (std::ostream& os, const MegaMoEConfig& config) { + os << "MegaMoEConfig(" + << "block_m=" << config.block_m << ", block_n=" << config.block_n << ", block_k=" << config.block_k + << ", load_block_m=" << config.load_block_m << ", load_block_n=" << config.load_block_n + << ", store_block_m=" << config.store_block_m + << ", sf_block_m=" << config.sf_block_m << ", sf_block_n=" << config.sf_block_n + << ", num_max_pool_tokens=" << config.num_max_pool_tokens + << ", num_padded_sf_pool_tokens=" << config.num_padded_sf_pool_tokens + << ", swizzle_acts_mode=" << config.swizzle_acts_mode << ", swizzle_weights_mode=" << config.swizzle_weights_mode + << ", num_experts_per_wave=" << config.num_experts_per_wave + << ", num_stages=" << config.num_stages << ", smem_size=" << config.smem_size + << ", num_dispatch_threads=" << config.num_dispatch_threads + << ", num_non_epilogue_threads=" << config.num_non_epilogue_threads + << ", num_epilogue_threads=" << config.num_epilogue_threads << ")"; + return os; + } +}; + +static std::tuple get_block_config_for_mega_moe( + const int& num_ranks, const int& num_experts, + const int& num_max_tokens_per_rank, const int& num_topk, + const int& num_tokens) { + const auto& [cluster_size, block_m, store_block_m, num_epilogue_warpgroups] = [&]() -> std::tuple { + float num_expected_tokens_per_expert = static_cast(num_tokens) * num_ranks * num_topk / num_experts; + if (num_expected_tokens_per_expert <= 8.5) { + // Really small token-per-expert (e.g. RL long-tail rollout), use the smallest block_m + return {2, 16, 8, 2}; + } else if (num_expected_tokens_per_expert <= 16.5) { + // Small batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 128 + return {2, 32, 16, 2}; + } else if (num_expected_tokens_per_expert <= 32.5) { + // Medium batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 256 + return {2, 64, 32, 1}; + } else if (num_expected_tokens_per_expert <= 64.5) { + // Large batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 512 + return {2, 96, 16, 2}; + } else if (num_expected_tokens_per_expert <= 96.5) { + // Medium batch size, Medium EP, decoding, e.g. 6/384 experts, EP16, bsz 256, or EP32, bsz128 + return {2, 128, 32, 2}; + } else { + // Prefill, or large EP decoding + return {2, 192, 32, 2}; + } + }(); + + // Check whether our `block_m` lies in `kCandidateBlockM` + DG_HOST_ASSERT(std::any_of( + layout::kCandidateBlockM, layout::kCandidateBlockM + layout::kNumCandidateBlockMs, + [=](const auto& candidate) { return candidate == block_m; }) + ); + + // Return configs + return {cluster_size, block_m, store_block_m, num_epilogue_warpgroups * 128}; +} + +static int get_num_experts_per_wave_for_mega_moe( + const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, + const int& intermediate_hidden, const int& block_m, const int& block_n, const int& num_sms) { + + float expected_tokens_per_expert = static_cast(num_tokens) * num_topk / num_experts_per_rank; + if (expected_tokens_per_expert < 1) { + // Most experts don't have tokens, calculate all experts at once + return num_experts_per_rank; + } + + // Reduce per-expert block count by this factor since uneven routing leaves some experts with fewer tokens + constexpr int kImbalanceFactor = 2; + + // Count L1 blocks per expert assuming tokens are evenly spread across experts + const int num_m_blocks = ceil_div(static_cast(std::ceil(expected_tokens_per_expert)), block_m); + const int num_n_blocks = (2 * intermediate_hidden) / block_n; + const int num_l1_blocks_per_expert = num_m_blocks * num_n_blocks; + + // Pick the smallest value whose total blocks (after imbalance reduction) can keep all SMs busy + int num_experts_per_wave = num_l1_blocks_per_expert > 0 + ? ceil_div(kImbalanceFactor * num_sms, num_l1_blocks_per_expert) : 1; + num_experts_per_wave = std::min(num_experts_per_wave, num_experts_per_rank); + + // Round up to the nearest divisor of num_experts_per_rank so every wave processes the same count + while (num_experts_per_wave < num_experts_per_rank and num_experts_per_rank % num_experts_per_wave != 0) + ++ num_experts_per_wave; + + return num_experts_per_wave; +} + +static std::pair get_pipeline_config_for_mega_moe( + const int& smem_capacity, + const int& num_experts, const int& hidden, + const int& block_m, const int& block_n, const int& block_k, const int& store_block_m, + const int& sf_block_m, const int& sf_block_n, + const int& num_dispatch_warps, const int& num_epilogue_warps) { + constexpr int kSmemAlignment = 1024; + constexpr int kNumEpilogueStages = 2; + constexpr int kNumTMAStoreStages = 2; + + // Always multicast on A + const int load_block_m = block_m / 2; + + // Dispatch region + const int smem_expert_count_size = align( + num_experts * static_cast(sizeof(uint32_t)), kSmemAlignment); + const int smem_send_buffers_size = align( + static_cast(layout::Buffer(layout::Data(hidden), num_dispatch_warps, 1).get_num_bytes()), + kSmemAlignment); + const int smem_dispatch_size = smem_expert_count_size + smem_send_buffers_size; + + // C/D output region: max of L1 FP8 (2 TMA stages, BLOCK_N/2 post-SwiGLU) and L2 BF16 (1 stage) + const auto num_epilogue_warpgroups = num_epilogue_warps / 4; + const int smem_cd_l1 = num_epilogue_warpgroups * store_block_m * (block_n / 2) * kNumTMAStoreStages; + const int smem_cd_l2 = num_epilogue_warpgroups * store_block_m * block_n * static_cast(sizeof(nv_bfloat16)); + const int smem_cd = std::max(smem_cd_l1, smem_cd_l2); + + // Barriers (stage-independent): dispatch + tensor memory full/empty + combine (2 per epilogue warp) + const int smem_barriers = (num_dispatch_warps + kNumEpilogueStages * 2 + num_epilogue_warps * 2) * 8; + + // Amax reduction + const int smem_amax_reduction = store_block_m * num_epilogue_warps * static_cast(sizeof(float)); + + // Tensor memory pointer + const int smem_tmem_ptr = 4; + + // SF is aligned to UTCCP 128-element granularity + const int smem_sfa_per_stage = sf_block_m * 4; + const int smem_sfb_per_stage = sf_block_n * 4; + + // Per-stage: A tile + B tile + SFA tile + SFB tile + full/empty barriers + const int smem_per_stage = load_block_m * block_k + block_n * block_k + smem_sfa_per_stage + smem_sfb_per_stage + 2 * 8; + + // Fixed total + const int smem_fixed = smem_dispatch_size + smem_cd + smem_amax_reduction + smem_barriers + smem_tmem_ptr; + + // Select maximum num_stages + const int num_stages = (smem_capacity - smem_fixed) / smem_per_stage; + DG_HOST_ASSERT(num_stages >= 2); + + return {num_stages, smem_fixed + num_stages * smem_per_stage}; +} + +static MegaMoEConfig get_mega_moe_config( + const int& num_ranks, const int& num_experts, const int& num_experts_per_rank, + const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const int& num_padded_sf_pool_tokens) { + // Block config + const auto [cluster_size, block_m, store_block_m, num_epilogue_threads] = + get_block_config_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens); + const int block_n = 128; + const int block_k = 128; + const int load_block_m = block_m / 2; + const int load_block_n = block_n; + const auto [sf_block_m, sf_block_n] = SM100ArchSpec::get_sf_uttcp_aligned_block_sizes(block_m, block_n, MmaKind::MXFP8FP4); + const int num_max_pool_tokens = layout::get_num_max_pool_tokens( + num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank); + // NOTES: FP8 activations and FP4 weights (unpacked to 8-bit in smem) both use 128B swizzle + const int swizzle_acts_mode = 128; + const int swizzle_weights_mode = 128; + + // Waves + const int num_sms = device_runtime->get_num_sms(); + const int num_experts_per_wave = get_num_experts_per_wave_for_mega_moe( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, block_m, block_n, num_sms); + + // Thread layout + const int num_dispatch_threads = 128; + const int num_non_epilogue_threads = 128; + + // Pipeline + const auto [num_stages, smem_size] = get_pipeline_config_for_mega_moe( + SM100ArchSpec::smem_capacity, + num_experts, hidden, + block_m, block_n, block_k, store_block_m, + sf_block_m, sf_block_n, + num_dispatch_threads / 32, num_epilogue_threads / 32); + + const auto config = MegaMoEConfig { + block_m, block_n, block_k, + load_block_m, load_block_n, store_block_m, + sf_block_m, sf_block_n, + num_max_pool_tokens, num_padded_sf_pool_tokens, + swizzle_acts_mode, swizzle_weights_mode, + num_experts_per_wave, + num_stages, smem_size, + num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads + }; + + // Print configs for the first time + if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { + const auto key = fmt::format( + "MegaMoEConfig(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})", + num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk); + static std::unordered_set printed; + if (printed.count(key) == 0) { + std::cout << key << ": " << config << std::endl; + printed.insert(key); + } + } + return config; +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/heuristics/runtime.hpp b/deep-gemm/csrc/jit_kernels/heuristics/runtime.hpp new file mode 100644 index 00000000..93f2a23a --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/heuristics/runtime.hpp @@ -0,0 +1,62 @@ +#pragma once + +#include "../../jit/device_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/lazy_init.hpp" + +namespace deep_gemm { + +class HeuristicsRuntime { + static constexpr int kLegacyMKAlignmentForContiguousLayout = 128; + + bool ignore_compile_dims = false; + int block_m_multiple_of = 1; + int block_n_multiple_of = 1; + int mk_alignment_for_contiguous_layout = kLegacyMKAlignmentForContiguousLayout; + +public: + void set_ignore_compile_dims(const bool& new_value) { + ignore_compile_dims = new_value; + } + + bool get_ignore_compile_dims() const { + return ignore_compile_dims; + } + + void set_block_size_multiple_of(const int& new_block_m_multiple_of, const int& new_block_n_multiple_of) { + block_m_multiple_of = new_block_m_multiple_of; + block_n_multiple_of = new_block_n_multiple_of; + } + + int get_block_m_multiple_of() const { + return block_m_multiple_of; + } + + int get_block_n_multiple_of() const { + return block_n_multiple_of; + } + + void set_mk_alignment_for_contiguous_layout(const int& new_value) { + mk_alignment_for_contiguous_layout = new_value; + } + + int get_mk_alignment_for_contiguous_layout() const { + return mk_alignment_for_contiguous_layout; + } + + static int get_theoretical_mk_alignment_for_contiguous_layout(const std::optional& expected_m) { + if (device_runtime->get_arch_major() != 10) + return kLegacyMKAlignmentForContiguousLayout; + + int block_m = 240, mma_step = 16; + if (expected_m.has_value()) { + // Reduce `block_m` while ensuring it covers `m` + for (; block_m > 32 and block_m - mma_step >= expected_m.value(); block_m -= mma_step); + } + return block_m; + } +}; + +static auto heuristics_runtime = LazyInit([](){ return std::make_shared(); }); + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/heuristics/sm100.hpp b/deep-gemm/csrc/jit_kernels/heuristics/sm100.hpp index dd1e6024..7c46686b 100644 --- a/deep-gemm/csrc/jit_kernels/heuristics/sm100.hpp +++ b/deep-gemm/csrc/jit_kernels/heuristics/sm100.hpp @@ -2,9 +2,11 @@ #include // Reuse some types in the JIT modules -#include +#include #include "common.hpp" +#include "runtime.hpp" +#include "utils.hpp" #include "../../utils/exception.hpp" namespace deep_gemm { @@ -12,155 +14,255 @@ namespace deep_gemm { struct SM100ArchSpec { static constexpr int smem_capacity = 232448; - static std::vector get_block_m_candidates(const KernelType& kernel_type, const cute::UMMA::Major& major_a, const int& m) { - std::vector candidates{128, 256}; - if ((kernel_type == KernelType::Kernel1D1D or kernel_type == KernelType::KernelNoSF) and major_a == cute::UMMA::Major::K) { - // NOTES: `block_m = 32/64` is smaller than `LAYOUT_AD_M`, should be careful in handling this - if (m <= 32) candidates.push_back(32); - if (m <= 64) candidates.push_back(64); + static std::pair get_sf_uttcp_aligned_block_sizes( + const int& block_m, const int& block_n, const MmaKind& mma_kind) { + constexpr int num_utccp_aligned_elems = 128; + switch (mma_kind) { + case MmaKind::BF16: return {0, 0}; + case MmaKind::MXFP8FP4: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)}; + default: DG_HOST_UNREACHABLE("Unknown dtype"); } - return candidates; } - static std::vector get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype) { - // 16 is for better SM usage - // Stride 32 is due to low-performance swizzle-16/32B - std::vector candidates = {16}; - for (int i = 32; i <= 256; i += 32) - candidates.push_back(i); - return candidates; - } + static std::vector get_layout_candidates(const GemmDesc& desc) { + // Block K is always in a fixed manner + const int block_k = 128 / get_element_size(desc.get_mma_kind()); - static int get_ab_load_block_m(const MulticastConfig& config, const int& block_m) { - return block_m / (config.is_multicast_on_a ? config.num_multicast : 1); - } + // Always enable swap A/B (and multicasting if possible) for m-grouped GEMMs + if (desc.gemm_type == GemmType::MGroupedContiguous or + desc.gemm_type == GemmType::MGroupedContiguousWithPsumLayout or + desc.gemm_type == GemmType::MGroupedMasked) { + const bool swap_ab = true; + const auto block_n = 128; + const auto block_m = heuristics_runtime->get_mk_alignment_for_contiguous_layout(); + const auto cluster_m = 1; + const auto cluster_n = ceil_div(desc.n, block_n) % 2 == 0 and desc.num_sms % 2 == 0 ? 2 : 1; + const auto layout = Layout{swap_ab, block_m, block_n, block_k, cluster_m, cluster_n}; + std::vector candidates = {layout}; + return candidates; + } - static int get_ab_load_block_n(const MulticastConfig& config, const int& block_n) { - return block_n / (config.is_multicast_on_a ? 1 : config.num_multicast); - } + // Enumerate all candidates + std::vector candidates; + for (int swap_ab = 0; swap_ab < 2; ++ swap_ab) { + // Block M/N candidates + std::vector block_m_candidates; + std::vector block_n_candidates; + if (swap_ab) { + int step = std::lcm(16, heuristics_runtime->get_block_m_multiple_of()); + int end = 256; + for (int i = step; i <= end; i += step) + block_m_candidates.push_back(i); - static int get_cd_store_block_m(const int& block_m) { - constexpr int layout_ad_m = 128; - return std::min(block_m, layout_ad_m); - } + // TODO: consider other block N + block_n_candidates = {128}; + } else { + // NOTES: smaller block M can avoid TMA L2 OOB bound + // TODO: consider block M = 256 + if (desc.m <= 32) block_m_candidates = {32}; + else if (desc.m <= 64) block_m_candidates = {64}; + else block_m_candidates = {128}; - static int get_cd_store_block_n(const int& block_n) { - return block_n; - } + // Small block size for small shape + if (16 % heuristics_runtime->get_block_n_multiple_of() == 0) + block_n_candidates.push_back(16); + int step = std::lcm(32, heuristics_runtime->get_block_n_multiple_of()); + // For small K, fewer store blocks improve store/compute overlap and reduce epilogue bottleneck + int end = desc.k <= 256 ? 128 : 256; + for (int i = step; i <= end; i += step) + block_n_candidates.push_back(i); + } - static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) { - return true; - } + for (int cluster_m = 1; cluster_m <= 2; ++ cluster_m) { + // After swapping, layout A/D can only do on cluster N + if (swap_ab == 1 and cluster_m > 1) + continue; - static std::pair get_sf_uttcp_aligned_block_sizes( - const int& block_m, const int& block_n, const MmaKind& mma_kind) { - constexpr int num_utccp_aligned_elems = 128; - switch (mma_kind) { - case MmaKind::BF16: return {0, 0}; - case MmaKind::MXFP8FP4: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)}; - default: DG_HOST_UNREACHABLE("Unknown dtype"); - } - } + for (int cluster_n = 1; cluster_n <= 2; ++ cluster_n) { + // We only support cluster 2 + if (cluster_m * cluster_n > 2) + continue; + + // Only support layout A/D + if (swap_ab == 0 and cluster_n > 1) + continue; + + // SM count must be divisible + if (desc.num_sms % (cluster_m * cluster_n) != 0) + continue; + + for (int block_m: block_m_candidates) { + // Ensure large swizzle sizes (32B swizzle yields poor performance) + const auto swizzle_a_requirement = desc.a_dtype == kPackedFP4 ? 128 : 64; + // Enforce swizzle alignment for MN major; otherwise check base MMA shape + const auto load_block_m_requirement = desc.major_a == cute::UMMA::Major::MN ? swizzle_a_requirement : 8; + if ((block_m / cluster_n) % load_block_m_requirement != 0) + continue; + + // Shape must be divisible for multicast + if (ceil_div(desc.m, block_m) % cluster_m != 0) + continue; + + for (int block_n: block_n_candidates) { + // Ensure large swizzle sizes (32B swizzle yields poor performance) + const auto swizzle_b_requirement = desc.b_dtype == kPackedFP4 ? 128 : 64; + // Enforce swizzle alignment for MN major; otherwise check base MMA shape + const auto load_block_n_requirement = desc.major_b == cute::UMMA::Major::MN ? swizzle_b_requirement : 8; + if ((block_n / cluster_m) % load_block_n_requirement != 0) + continue; + + // Shape must be divisible for multicast + if (ceil_div(desc.n, block_n) % cluster_n != 0) + continue; + + // SwapAB requires block N is layout A/D' UMMA M + constexpr int layout_ad_m = 128; + if (swap_ab and block_n != layout_ad_m) + continue; - static bool is_block_size_legal(const KernelType& kernel_type, - const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const MmaKind& mma_kind, const at::ScalarType& cd_dtype, - const int& m, const int& n, const int& k, - const int& block_m, const int& block_n, const int& block_k) { - // Layout A/D does not support `block_n % 16 != 0` - if (block_n % 16 != 0) - return false; - - // Performance is lower with 1D1D and `block_m == 256` - if (kernel_type == KernelType::Kernel1D1D and major_b == cute::UMMA::Major::K and block_m > 128) - return false; - - // For small K, fewer store blocks improve store/compute overlap and reduce epilogue bottleneck - if (k <= 256 and (block_n > 128 or block_m > 128)) - return false; - - // Check tensor memory validity - int sf_block_m = 0, sf_block_n = 0; - if (kernel_type == KernelType::Kernel1D1D) { - const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind); - sf_block_m = sf_block_m_, sf_block_n = sf_block_n_; + // Check tensor memory capacity + const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, desc.get_mma_kind()); + const auto tmem_sf_cols = desc.get_mma_kind() == MmaKind::MXFP8FP4 ? sf_block_m / 32 + sf_block_n / 32 : 0; + const auto umma_n = swap_ab ? block_m : block_n; + if (2 * umma_n + tmem_sf_cols > 512) + continue; + + const auto layout = Layout{swap_ab, block_m, block_n, block_k, cluster_m, cluster_n}; + + // When neither A nor B is MN major, 128B swizzle is always feasible + if (desc.major_a == cute::UMMA::Major::K or desc.major_b == cute::UMMA::Major::K) { + const auto storage_config = get_storage_config(desc, layout); + if (storage_config.swizzle_a_mode != 128 or storage_config.swizzle_b_mode != 128) + continue; + } + + candidates.push_back(layout); + } + } + } + } } - if (((2 * block_n) + (sf_block_m / 32) + (sf_block_n / 32)) > 512) - return false; - // NOTES: when B is MN-major, we restrict `block_n` to multiples of 64, - // since TMA performance degrades when `swizzle_b <= 32B` (i.e., when `block_ns % 64 != 0`), even with 3D TMA - return major_b == cute::UMMA::Major::K or (block_n * get_element_size(mma_kind)) % 64 == 0; + DG_HOST_ASSERT(not candidates.empty()); + return candidates; } - static bool is_num_stages_legal(const MmaKind& mma_kind, const at::ScalarType& cd_dtype, - const int& num_stages, - const int& block_m, const int& block_n, const int& block_k) { - return true; - } + static StorageConfig get_storage_config(const GemmDesc& desc, const Layout& layout) { + constexpr int layout_ad_m = 128; + constexpr int umma_step_n = 16; + + // Load/store block sizes (w/o consideration of swizzling atoms, w/ consideration of loop atoms) + const auto load_block_m = layout.block_m / layout.cluster_n; + const auto load_block_n = layout.block_n / layout.cluster_m; + const auto store_block_m = layout.swap_ab ? umma_step_n : std::min(layout_ad_m, layout.block_m); + const auto store_block_n = layout.block_n; + + // Decide swizzling by the inner dim + // TODO: support FP4 sub-byte + const auto swizzle_mode_a = get_swizzle_mode( + desc.major_a == cute::UMMA::Major::K ? layout.block_k : load_block_m, c10::elementSize(desc.a_dtype)); + const auto swizzle_mode_b = get_swizzle_mode( + desc.major_b == cute::UMMA::Major::K ? layout.block_k : load_block_n, c10::elementSize(desc.b_dtype)); + const auto swizzle_mode_cd = get_swizzle_mode( + store_block_n, c10::elementSize(desc.cd_dtype)); - static std::pair get_multicast_legality(const GemmType& gemm_type, const int& num_groups, - const int& m, const int& n, const int& block_m, const int& block_n, - const int& num_sms) { - // TODO: support other layouts return { - false, - is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous - or (gemm_type == GemmType::Batched and num_groups <= 32)), + load_block_m, load_block_n, + store_block_m, store_block_n, + swizzle_mode_a, swizzle_mode_b, swizzle_mode_cd }; } - static ThreadConfig get_thread_config(const KernelType& kernel_type, - const int& block_m, const int& block_n) { - return ThreadConfig::sm100(128, 128); - } + static PipelineConfig get_pipeline_config(const GemmDesc& desc, const Layout& layout, const StorageConfig& storage_config) { + constexpr int kNumMaxStages = 32; - static int get_smem_cd_size(const KernelType& kernel_type, - const int& block_m, const int& block_n, - const int& swizzle_cd_mode, - const at::ScalarType& cd_dtype) { - constexpr static int layout_ad_m = 128; - return std::min(block_m, layout_ad_m) * swizzle_cd_mode * 2; - } + // C/D for TMA stores + const int smem_cd = layout.swap_ab ? storage_config.store_block_m * storage_config.store_block_n * c10::elementSize(desc.cd_dtype) * 2 + : storage_config.store_block_m * storage_config.swizzle_cd_mode * 2; + + // TODO: remove SF barriers for BF16 GEMMs + // TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers + // NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages + // NOTES: the last barrier is for tensor core utilization control + const int smem_barriers = kNumMaxStages * 8 * 3 + 2 * 8 * 2 + 8; - static std::pair get_sf_smem_size_per_stage(const KernelType& kernel_type, - const int& block_m, const int& block_n, const int& block_k, - const MmaKind& mma_kind, const at::ScalarType& cd_dtype) { - if (mma_kind == MmaKind::BF16) - return {0, 0}; + // Tensor memory pointer + const int smem_tmem_ptr = 4; + // Calculate A/B per stages + // TODO: consider FP4 + const int smem_a_per_stage = storage_config.load_block_m * layout.block_k * c10::elementSize(desc.a_dtype); + const int smem_b_per_stage = storage_config.load_block_n * layout.block_k * c10::elementSize(desc.b_dtype); + + // Calculate SF A/B per stages int smem_sfa_per_stage = 0; int smem_sfb_per_stage = 0; - if (kernel_type == KernelType::Kernel1D1D) { - const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind); + if (desc.kernel_type == KernelType::Kernel1D1D) { + const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes( + layout.block_m, layout.block_n, desc.get_mma_kind()); smem_sfa_per_stage = sf_block_m * 4; smem_sfb_per_stage = sf_block_n * 4; - } else { - smem_sfa_per_stage = block_m * 4; - smem_sfb_per_stage = 0; } - return {smem_sfa_per_stage, smem_sfb_per_stage}; - } - static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k, - const int& block_m, const int& block_n, const int& block_k) { - return 0; + // Calculate stages + int smem_extra = smem_cd + smem_barriers + smem_tmem_ptr; + int smem_per_stage = smem_a_per_stage + smem_b_per_stage + smem_sfa_per_stage + smem_sfb_per_stage; + int num_stages = std::min( + (smem_capacity - smem_extra) / smem_per_stage, + kNumMaxStages); + return { + smem_extra + num_stages * smem_per_stage, + num_stages + }; } - static int get_barrier_smem_size(const int& num_stages) { - // TODO: remove SF barriers for BF16 GEMMs - // TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers - // NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages - // NOTES: the last barrier is for tensor core utilization control - return num_stages * 8 * 3 + 2 * 8 * 2 + 8; + static LaunchConfig get_launch_config(const GemmDesc& desc, const Layout& layout) { + return { + desc.num_sms, + layout.get_cluster_size(), + 256, + 32, 128, 128, 128 + }; } - static int get_tmem_ptr_smem_size() { - return 4; + static LayoutInfo get_layout_info(const GemmDesc& desc, const Layout& layout) { + const auto num_blocks = + ceil_div(desc.get_expected_m(), layout.block_m) * + ceil_div(desc.get_expected_n(), layout.block_n) * + desc.get_expected_num_groups(); + const auto num_waves = ceil_div(num_blocks, desc.num_sms); + const auto num_last_blocks = num_blocks % desc.num_sms; + const auto last_wave_util = num_last_blocks == 0 ? desc.num_sms : num_last_blocks; + // TODO: calculate expected cycles + return {num_waves, last_wave_util, 0, layout}; } - static int get_tensormap_smem_size(const GemmType& gemm_type) { - return 0; + // A regular comparator + static bool compare(const LayoutInfo& a, const LayoutInfo& b) { + // Single wave is always better + if ((a.num_waves == 1 or b.num_waves == 1) and a.num_waves != b.num_waves) + return a.num_waves < b.num_waves; + + // Doing multicast is better + if (a.layout.get_cluster_size() != b.layout.get_cluster_size()) + return a.layout.get_cluster_size() > b.layout.get_cluster_size(); + + // Smaller number of waves is better + if (a.num_waves != b.num_waves) + return a.num_waves < b.num_waves; + + // Larger last wave utilization is better + if (a.last_wave_util != b.last_wave_util) + return a.last_wave_util > b.last_wave_util; + + // More stages is better + // Same block M, smaller block N is better + // Same block N, smaller block M is better + if (a.layout.block_m + a.layout.block_n != b.layout.block_m + b.layout.block_n) + return a.layout.block_m + a.layout.block_n < b.layout.block_m + b.layout.block_n; + + // Less shared memory C/D, more stages is better + return a.layout.block_m * a.layout.block_n < b.layout.block_m * b.layout.block_n; } }; diff --git a/deep-gemm/csrc/jit_kernels/heuristics/sm90.hpp b/deep-gemm/csrc/jit_kernels/heuristics/sm90.hpp index 2fd2e9ec..19f802f9 100644 --- a/deep-gemm/csrc/jit_kernels/heuristics/sm90.hpp +++ b/deep-gemm/csrc/jit_kernels/heuristics/sm90.hpp @@ -2,162 +2,244 @@ #include // Reuse some types in the JIT modules -#include +#include #include "common.hpp" +#include "utils.hpp" +#include "../../utils/exception.hpp" namespace deep_gemm { struct SM90ArchSpec { static constexpr int smem_capacity = 232448; - - static std::vector get_block_m_candidates(const KernelType& kernel_type, const cute::UMMA::Major& major_a, const int& m) { - std::vector candidates{64, 128, 256}; - if ((kernel_type == KernelType::Kernel1D2D or kernel_type == KernelType::KernelNoSF) and major_a == cute::UMMA::Major::K) { - // NOTES: `block_m = 16/32` is smaller than MMA M size, should be careful in handling this - if (m <= 16) candidates.push_back(16); - if (m <= 32) candidates.push_back(32); - } - return candidates; - } - static std::vector get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype) { - int start = 16; + static std::vector get_layout_candidates(const GemmDesc& desc) { + // Block M candidates + std::vector block_m_candidates; + if (desc.gemm_type == GemmType::Normal or + desc.gemm_type == GemmType::Batched or + desc.gemm_type == GemmType::KGroupedContiguous) { + // TODO: check 256's performance + block_m_candidates = {64, 128}; + // NOTES: smaller block M can avoid TMA L2 OOB bound + if (desc.m <= 16) block_m_candidates.push_back(16); + if (desc.m <= 32) block_m_candidates.push_back(32); + + // BF16 output GEMM supports 256 + if (desc.cd_dtype != torch::kFloat) + block_m_candidates.push_back(256); + } else if (desc.gemm_type == GemmType::MGroupedContiguous or + desc.gemm_type == GemmType::MGroupedContiguousWithPsumLayout) { + block_m_candidates = std::vector{heuristics_runtime->get_mk_alignment_for_contiguous_layout()}; + } else if (desc.gemm_type == GemmType::MGroupedMasked) { + block_m_candidates = {64, 128}; + } + // Block N candidates + std::vector block_n_candidates; + int step = std::lcm(16, heuristics_runtime->get_block_n_multiple_of()); + int start = step; // Avoid bank conflicts for 1D1D kernel FP32 output - std::vector candidates; - if (kernel_type == KernelType::Kernel1D1D and cd_dtype == torch::kFloat) { - candidates.push_back(16); + if (desc.kernel_type == KernelType::Kernel1D1D and desc.cd_dtype == torch::kFloat) { + DG_HOST_ASSERT(desc.major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(desc.major_b == cute::UMMA::Major::K); start = 24; + block_n_candidates.push_back(16); + } + // Register spills + int end = 256; + if (desc.kernel_type == KernelType::Kernel1D2D) + end = 192; + if (desc.kernel_type == KernelType::Kernel1D1D) + end = 160; + // Enumerate + for (int i = start; i <= end; i += step) + block_n_candidates.push_back(i); + + // Block K is always in a fixed manner + const int block_k = 128 / get_element_size(desc.get_mma_kind()); + + // Disable multicast for performance + const bool disable_multicast = + // The number of k-groups is large (a heuristic) + (desc.gemm_type == GemmType::KGroupedContiguous and desc.num_groups > 4) or + // Not supported + (desc.gemm_type == GemmType::Batched); + + // Enumerate all candidates + std::vector candidates; + for (int cluster_m = 1; cluster_m <= (disable_multicast ? 1 : 2); ++ cluster_m) { + for (int cluster_n = 1; cluster_n <= (disable_multicast ? 1 : 2); ++ cluster_n) { + // We only support cluster 2 + if (cluster_m * cluster_n > 2) + continue; + + // SM count must be divisible + if (desc.num_sms % (cluster_m * cluster_n) != 0) + continue; + + for (int block_m: block_m_candidates) { + for (int block_n: block_n_candidates) { + // 1D2D kernel unroll requirement + if (desc.kernel_type == KernelType::Kernel1D2D and block_n > block_k and (block_n % (block_n - block_k) != 0 and block_k % (block_n - block_k) != 0)) + continue; + + // Multicast legality for masked layout + // TODO: add some comments about it + if ((desc.gemm_type == GemmType::MGroupedMasked or desc.gemm_type == GemmType::MGroupedContiguousWithPsumLayout) and + ceil_div(desc.n, block_n) % (cluster_m * cluster_n) != 0) + continue; + + // The block sizes cannot be too large (for enough registers), so at least one dim less than 128 + if (block_m > 128 and block_n > 128) + continue; + + // Calculate swizzling + const auto layout = Layout{0, block_m, block_n, block_k, cluster_m, cluster_n}; + const auto storage_config = get_storage_config(desc, layout); + + // Make sure swizzling is large enough (32B's performance is low) + if (storage_config.swizzle_a_mode % 64 != 0 or storage_config.swizzle_b_mode % 64 != 0) + continue; + + // To hide TMA latency, the stage count should be at least 3; for small matrices, at least 4 + int num_stages = get_pipeline_config(desc, layout, storage_config).num_stages; + if (num_stages < 3 or (block_m * block_n < 128 * 192 and num_stages < 4)) + continue; + + candidates.push_back(layout); + } + } + } } - // Push the strided options - for (int i = start; i <= 256; i += 16) - candidates.push_back(i); + DG_HOST_ASSERT(not candidates.empty()); return candidates; } - static int get_ab_load_block_m(const MulticastConfig& multicast_config, const int& block_m) { - return block_m; - } - - static int get_ab_load_block_n(const MulticastConfig& multicast_config, const int& block_n) { - return block_n; - } - - static int get_cd_store_block_m(const int& block_m, const bool& single_warpgroup_sync = false) { + static StorageConfig get_storage_config(const GemmDesc& desc, const Layout& layout) { constexpr int wgmma_m = 64; - return single_warpgroup_sync ? wgmma_m : block_m; - } - - static int get_cd_store_block_n(const int& block_n) { - return block_n; - } - - static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) { - return cd_dtype != torch::kFloat; - } - - static bool is_block_size_legal(const KernelType& kernel_type, - const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const MmaKind& mma_kind, const at::ScalarType& cd_dtype, - const int& m, const int& n, const int& k, - const int& block_m, const int& block_n, const int& block_k) { - // SM90 FP32 output does not support `block_m == 256` - if (cd_dtype == at::kFloat and block_m == 256) - return false; - - // Avoid large C/D shared memory for FP32 output - // Ensure `num_stages >= 4` (for 1D1D Kernel), `num_stages >= 3` (for No SF kernel) - if (block_n > 128 and cd_dtype == torch::kFloat) { - if (kernel_type == KernelType::Kernel1D1D and block_n > 152) - return false; - if (kernel_type == KernelType::KernelNoSF and block_n > 200) - return false; - } - - // When B is N Major, use swizzle 128B for better performance; only affects SM90 BF16 GEMM - if (major_b == cute::UMMA::Major::MN and block_n >= 128 and block_n % 64 != 0) - return false; - // Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k` - // Or too many register spills - if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192)) - return false; - - // The block sizes cannot be too large (for enough registers), so at least one dim less than 128 - return block_m <= 128 or block_n <= 128; - } - - static bool is_num_stages_legal(const MmaKind& mma_kind, const at::ScalarType& cd_dtype, - const int& num_stages, - const int& block_m, const int& block_n, const int& block_k) { - // Unrolling both stages and `num_former_iters` will cause large code size - if (mma_kind == MmaKind::MXFP8FP4 and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4) - return num_stages <= 4; - return true; - } - - static std::pair get_multicast_legality(const GemmType& gemm_type, const int& num_groups, - const int& m, const int& n, const int& block_m, const int& block_n, - const int& num_sms) { - // Disable multicast when the number of k-groups is large (a heuristic) - if (gemm_type == GemmType::KGroupedContiguous and num_groups > 4) - return {false, false}; - - if (gemm_type == GemmType::Batched) - return {false, false}; + // Load/store block sizes (w/o consideration of swizzling atoms, w/ consideration of loop atoms) + // TODO: support swap AB + DG_HOST_ASSERT(layout.swap_ab == 0); + const auto load_block_m = layout.block_m; + const auto load_block_n = layout.block_n; + // 1D1D kernel will do single warp-group stores + const auto store_block_m = desc.kernel_type == KernelType::Kernel1D1D ? wgmma_m : layout.block_m; + const auto store_block_n = layout.block_n; + + // Decide swizzling by the inner dim + const auto swizzle_mode_a = get_swizzle_mode( + desc.major_a == cute::UMMA::Major::K ? layout.block_k : load_block_m, c10::elementSize(desc.a_dtype)); + const auto swizzle_mode_b = get_swizzle_mode( + desc.major_b == cute::UMMA::Major::K ? layout.block_k : load_block_n, c10::elementSize(desc.b_dtype)); + // We only enable swizzling for non-FP32 outputs + const auto swizzle_mode_cd = desc.cd_dtype != torch::kFloat ? + get_swizzle_mode(store_block_n, c10::elementSize(desc.cd_dtype)) : 0; return { - is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked), - // For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even - is_multicast_legal(m, block_m, 2, num_sms, false) - and (gemm_type != GemmType::MGroupedMasked or is_multicast_legal(n, block_n, 2, num_sms, true)) + load_block_m, load_block_n, + store_block_m, store_block_n, + swizzle_mode_a, swizzle_mode_b, swizzle_mode_cd }; } - static ThreadConfig get_thread_config(const KernelType& kernel_type, - const int& block_m, const int& block_n) { - return ThreadConfig::sm90(128, (block_m <= 64 ? 1 : 2) * 128); - } + static PipelineConfig get_pipeline_config(const GemmDesc& desc, const Layout& layout, const StorageConfig& storage_config) { + constexpr int kNumMaxStages = 16; - static int get_smem_cd_size(const KernelType& kernel_type, - const int& block_m, const int& block_n, - const int& swizzle_cd_mode, const at::ScalarType& cd_dtype) { + // TODO: consider swap AB + // C/D for TMA stores // NOTES: 1024 is for TMA swizzling alignment requirement - return align(block_m * block_n * static_cast(c10::elementSize(cd_dtype)), 1024); - } - - static std::pair get_sf_smem_size_per_stage(const KernelType& kernel_type, - const int& block_m, const int& block_n, const int& block_k, - const MmaKind& mma_kind, const at::ScalarType& cd_dtype) { - if (mma_kind == MmaKind::BF16) - return {0, 0}; - - // NOTES: 128 is for 2D TMA alignment requirement - int smem_sfa_per_stage = align(block_m * static_cast(sizeof(float)), 128); - int smem_sfb_per_stage = 0; - if (kernel_type == KernelType::Kernel1D1D) - smem_sfb_per_stage = align(block_n * 4, 128); - return {smem_sfa_per_stage, smem_sfb_per_stage}; - } - - static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k, - const int& block_m, const int& block_n, const int& block_k) { - const auto& use_uniform_sfb = block_k % block_n == 0 ? 1 : 2; - return align(ceil_div(k, block_k) * static_cast(sizeof(float)) * use_uniform_sfb, 8); - } - - static int get_barrier_smem_size(const int& num_stages) { - return num_stages * 8 * 2; + const int smem_cd = + align(layout.block_m * layout.block_n * static_cast(c10::elementSize(desc.cd_dtype)), 1024); + const int smem_barriers = kNumMaxStages * 8 * 2; + + // Calculate A/B per stages + const int smem_a_per_stage = storage_config.load_block_m * layout.block_k * c10::elementSize(desc.a_dtype); + const int smem_b_per_stage = storage_config.load_block_n * layout.block_k * c10::elementSize(desc.b_dtype); + + // Calculate SF A/B per stages + const int smem_sfa_per_stage = desc.kernel_type == KernelType::KernelNoSF ? + 0 : align(layout.block_m * static_cast(sizeof(float)), 128); + const int smem_sfb_per_stage = desc.kernel_type != KernelType::Kernel1D1D ? + 0 : align(layout.block_n * static_cast(sizeof(float)), 128); + + // Extra SFB sizes for 1D2D kernels + const int use_uniform_sfb = layout.block_k % layout.block_n == 0 ? 1 : 2; + const int smem_extra_sfb = desc.kernel_type != KernelType::Kernel1D2D ? + 0 : align(ceil_div(desc.k, layout.block_k) * static_cast(sizeof(float)) * use_uniform_sfb, 8); + + // Extra tensormap for 1D1D kernels + const int smem_tensormap = + desc.gemm_type == GemmType::KGroupedContiguous ? 4 * static_cast(sizeof(CUtensorMap)) : 0; + + // Calculate stages + const int smem_extra = smem_cd + smem_barriers + smem_extra_sfb + smem_tensormap; + const int smem_per_stage = smem_a_per_stage + smem_b_per_stage + smem_sfa_per_stage + smem_sfb_per_stage; + const int num_stages = std::min( + (smem_capacity - smem_extra) / smem_per_stage, + kNumMaxStages); + return { + smem_extra + num_stages * smem_per_stage, + num_stages + }; } - static int get_tmem_ptr_smem_size() { - return 0; + static LaunchConfig get_launch_config(const GemmDesc& desc, const Layout& layout) { + const int num_tma_threads = 128; + const int num_math_threads = layout.block_m <= 64 ? 128 : 256; + return { + desc.num_sms, + layout.get_cluster_size(), + num_tma_threads + num_math_threads, + num_tma_threads, num_math_threads, + 0, 0 // Meaningless for SM90 + }; } - static int get_tensormap_smem_size(const GemmType& gemm_type) { - return gemm_type == GemmType::KGroupedContiguous ? 4 * static_cast(sizeof(CUtensorMap)) : 0; + static LayoutInfo get_layout_info(const GemmDesc& desc, const Layout& layout) { + const auto num_blocks = + ceil_div(desc.get_expected_m(), layout.block_m) * + ceil_div(desc.get_expected_n(), layout.block_n) * + desc.get_expected_num_groups(); + const auto num_waves = ceil_div(num_blocks, desc.num_sms); + const auto num_last_blocks = num_blocks % desc.num_sms; + const auto last_wave_util = num_last_blocks == 0 ? desc.num_sms : num_last_blocks; + + // Utils + const int l2_bandwidth_per_cycle = std::min(64. * desc.num_sms, 8e6 / (1.3e3)); // B/cycle + const int l1_bandwidth_per_cycle = 128 * desc.num_sms; // B/cycle + const int wgmma_m = 64; + const int elem_size_ab = c10::elementSize(desc.a_dtype); + const int elem_size_cd = c10::elementSize(desc.cd_dtype); + DG_HOST_ASSERT(desc.a_dtype == desc.b_dtype); + + // Data movement per block + int64_t expected_k = desc.get_expected_k(); + int64_t num_bytes_l2_ab = expected_k * (layout.block_m / layout.cluster_n + layout.block_n / layout.cluster_m) * elem_size_ab; + int64_t num_bytes_l1_ab = expected_k * (layout.block_m + layout.block_n) * elem_size_ab; + int64_t num_bytes_l1_tc = expected_k * (std::max(wgmma_m, layout.block_m) + layout.block_n) * elem_size_ab + + layout.block_m * layout.block_n * elem_size_cd; + int64_t num_bytes_l1_l2_cd = layout.block_m * layout.block_n * elem_size_cd * (desc.with_accumulation ? 2 : 1); + + // HBM bandwidth and total compute (Tensor/CUDA cores) are constant across configs + // We only model L1/L2 cycles as they are the primary variables between configs + int64_t num_l2_cycles = (num_bytes_l2_ab + num_bytes_l1_l2_cd) * num_blocks / l2_bandwidth_per_cycle; + int64_t num_l1_cycles = (num_bytes_l1_ab + num_bytes_l1_tc + num_bytes_l1_l2_cd) * num_blocks / l1_bandwidth_per_cycle; + float wave_efficiency = static_cast(num_blocks) / (num_waves * desc.num_sms); + int64_t num_cycles = std::max(num_l1_cycles, num_l2_cycles) / wave_efficiency; + + // Disable multicasting if only one wave exists + if (layout.cluster_n * layout.cluster_m > 1 and num_waves <= 1) + num_cycles = std::numeric_limits::max(); + + return {num_waves, last_wave_util, num_cycles, layout}; + } + + // A regular comparator + static bool compare(const LayoutInfo& a, const LayoutInfo& b) { + return a.num_cycles < b.num_cycles; } }; diff --git a/deep-gemm/csrc/jit_kernels/heuristics/utils.hpp b/deep-gemm/csrc/jit_kernels/heuristics/utils.hpp new file mode 100644 index 00000000..17d2ae07 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/heuristics/utils.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include +// Reuse some types in the JIT modules +#include + +#include "common.hpp" +#include "../../utils/exception.hpp" + +namespace deep_gemm { + +template +static int get_swizzle_mode(const int& block_size, const size_type_t& elem_size) { + // `> 0` means interleaving + // 16B actually means non-swizzling (but interleaving) + for (const int& mode: {128, 64, 32, 16}) { + if ((block_size * static_cast(elem_size)) % mode == 0) + return mode; + } + DG_HOST_UNREACHABLE("Unreachable"); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/epilogue.hpp b/deep-gemm/csrc/jit_kernels/impls/epilogue.hpp index bd21de10..1003df4c 100644 --- a/deep-gemm/csrc/jit_kernels/impls/epilogue.hpp +++ b/deep-gemm/csrc/jit_kernels/impls/epilogue.hpp @@ -6,7 +6,7 @@ namespace deep_gemm { static std::string get_default_epilogue_type(const std::optional& epilogue_type) { - return epilogue_type.value_or("EpilogueIdentity"); + return epilogue_type.value_or("epilogue::transform::EpilogueIdentity"); } } // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/runtime_utils.hpp b/deep-gemm/csrc/jit_kernels/impls/runtime_utils.hpp index 677a89ba..7aa87526 100644 --- a/deep-gemm/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/deep-gemm/csrc/jit_kernels/impls/runtime_utils.hpp @@ -1,7 +1,7 @@ #pragma once #include -#include +#include "../../utils/torch_compat.hpp" #include "../heuristics/sm90.hpp" #include "../../jit/handle.hpp" @@ -20,6 +20,9 @@ static int get_non_contiguous_dim(const cute::UMMA::Major& major) { } static int get_compiled_dim(const int& dim, const char& name, const std::string& compiled_dims) { + if (heuristics_runtime->get_ignore_compile_dims()) + return 0; + for (const char& c: compiled_dims) { if (name == c) return dim; @@ -58,8 +61,19 @@ static std::string to_string(const at::ScalarType& dtype) { } } +static std::string to_string(const float& v) { + if (std::isfinite(v)) { + return fmt::format(R"({:a}f)", v); + } else if (std::isinf(v)) { + return v > 0 ? "cute::numeric_limits::infinity()" + : "-cute::numeric_limits::infinity()"; + } + DG_HOST_UNREACHABLE("NaN input is not supported"); +} + static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype, - const bool& allow_tf32) { + const bool& allow_tf32, + const bool& fp4_unpacked_smem) { if (allow_tf32 and dtype == torch::kFloat) return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; @@ -68,15 +82,16 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8; -#if CUDART_VERSION >= 12080 - case kPackedFP4: return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B; +#if CUDA_VERSION >= 12080 + case kPackedFP4: return fp4_unpacked_smem ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B + : CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; #endif default: DG_HOST_UNREACHABLE("Unsupported dtype"); } } static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) { -#if CUDART_VERSION >= 12080 +#if CUDA_VERSION >= 12080 if (base != 0) { DG_HOST_ASSERT(base == 32 and mode == 128); return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B; @@ -99,14 +114,20 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, int smem_inner_dim, int smem_outer_dim, const int& gmem_outer_stride, const int& swizzle_mode, const int& swizzle_base = 0, - const bool& allow_tf32 = false) { - const auto& elem_size = static_cast(t.element_size()); + const bool& allow_tf32 = false, + const bool& fp4_unpacked_smem = true) { + const auto elem_size = static_cast(t.element_size()); if (swizzle_mode != 0) smem_inner_dim = swizzle_mode / elem_size; - // Inner dim must be a multiple of 64B for .b4x16_p64 - if (t.scalar_type() == kPackedFP4) - DG_HOST_ASSERT(gmem_inner_dim % 128 == 0); + if (t.scalar_type() == kPackedFP4) { + // Inner dim must be a multiple of 64B for .b4x16_p64 + DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_inner_dim % 128 == 0); + + // Fix FP4 packed smem + if (not fp4_unpacked_smem and swizzle_mode != 0) + smem_inner_dim = swizzle_mode * 2; + } CUtensorMap tensor_map; const cuuint64_t gmem_dims[2] = {static_cast(gmem_inner_dim), static_cast(gmem_outer_dim)}; @@ -114,12 +135,13 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, const cuuint64_t gmem_strides[1] = {static_cast(gmem_outer_stride * elem_size), }; const cuuint32_t elem_strides[2] = {1, 1}; if (get_env("DG_JIT_DEBUG")) { - printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d (base: %d), elem size: %d\n", + printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d (base: %d), elem size: %d, pointer: %llu\n", gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim, - gmem_outer_stride, swizzle_mode, swizzle_base, elem_size); + gmem_outer_stride, swizzle_mode, swizzle_base, elem_size, + reinterpret_cast(t.data_ptr())); } DG_CUDA_DRIVER_CHECK(lazy_cuTensorMapEncodeTiled( - &tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32), + &tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32, fp4_unpacked_smem), 2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides, CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base), CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); @@ -131,14 +153,20 @@ static CUtensorMap make_tma_3d_desc(const torch::Tensor& t, int smem_dim_0, int smem_dim_1, int smem_dim_2, const int& gmem_stride_0, const int& gmem_stride_1, const int& swizzle_mode, const int& swizzle_base = 0, - const bool& allow_tf32 = false) { - const auto& elem_size = static_cast(t.element_size()); + const bool& allow_tf32 = false, + const bool& fp4_unpacked_smem = true) { + const auto elem_size = static_cast(t.element_size()); if (swizzle_mode != 0) smem_dim_0 = swizzle_mode / elem_size; - // Inner dim must be a multiple of 64B for .b4x16_p64 - if (t.scalar_type() == kPackedFP4) - DG_HOST_ASSERT(gmem_dim_0 % 128 == 0); + if (t.scalar_type() == kPackedFP4) { + // Inner dim must be a multiple of 64B for .b4x16_p64 + DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_dim_0 % 128 == 0); + + // Fix fp4 packed smem + if (not fp4_unpacked_smem and swizzle_mode != 0) + smem_dim_0 = swizzle_mode * 2; + } CUtensorMap tensor_map; const cuuint64_t gmem_dims[3] = {static_cast(gmem_dim_0), static_cast(gmem_dim_1), static_cast(gmem_dim_2),}; @@ -151,7 +179,7 @@ static CUtensorMap make_tma_3d_desc(const torch::Tensor& t, gmem_stride_0, gmem_stride_1, swizzle_mode, elem_size); } DG_CUDA_DRIVER_CHECK(lazy_cuTensorMapEncodeTiled( - &tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32), + &tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32, fp4_unpacked_smem), 3, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides, CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base), CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); @@ -168,8 +196,8 @@ static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major, const bool& allow_tf32 = false) { if (num_groups > 1) DG_HOST_ASSERT(major == cute::UMMA::Major::K); - const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups); - const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_m); + const auto [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups); + const auto [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_m); return make_tma_2d_desc(t, gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim, @@ -186,8 +214,8 @@ static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major, const int& num_groups, const int& swizzle_mode, const int& swizzle_base = 0, const bool& allow_tf32 = false) { - const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n); - const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n); + const auto [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n); + const auto [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n); // `num_groups` is always applied into the outer dimensions return make_tma_2d_desc(t, diff --git a/deep-gemm/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp b/deep-gemm/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp index bca47a3a..b9224d7c 100644 --- a/deep-gemm/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp +++ b/deep-gemm/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include "../../utils/torch_compat.hpp" #include "../../jit/compiler.hpp" #include "../../jit/device_runtime.hpp" @@ -16,9 +16,7 @@ namespace deep_gemm { class SM100BF16GemmRuntime final: public LaunchRuntime { public: struct Args { - int m, n, k, num_groups; - const std::string& compiled_dims; - + GemmDesc gemm_desc; GemmConfig gemm_config; LaunchArgs launch_args; @@ -45,28 +43,32 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, + {}, {}, {}, {}, {} >); }}; )", - to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b), - get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), - args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, - args.num_groups, - args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, - args.gemm_config.num_stages, - args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, - args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, - args.gemm_config.num_sms, - to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype), - args.gemm_config.tc_util); + to_string(args.gemm_desc.major_a), to_string(args.gemm_desc.major_b), + get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims), + args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k, + args.gemm_desc.num_groups, + args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode, args.gemm_config.storage_config.swizzle_cd_mode, + args.gemm_config.pipeline_config.num_stages, + args.gemm_config.launch_config.num_non_epilogue_threads, args.gemm_config.launch_config.num_epilogue_threads, + args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1, + args.gemm_config.launch_config.num_sms, + args.gemm_config.layout.swap_ab, + to_string(args.gemm_desc.gemm_type), args.gemm_desc.with_accumulation, to_string(args.gemm_desc.cd_dtype), + args.gemm_desc.tc_util); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { // TODO: optimize `args` copy DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, - args.grouped_layout, args.m, args.n, args.k, + args.grouped_layout, args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k, args.tensor_map_a, args.tensor_map_b, args.tensor_map_cd)); } @@ -79,45 +81,49 @@ static void sm100_bf16_gemm(const torch::Tensor& a, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const std::string& compiled_dims) { - const auto& config = get_best_config( - GemmType::Normal, KernelType::KernelNoSF, - m, n, k, 1, major_a, major_b, - a.scalar_type(), b.scalar_type(), - d.scalar_type(), c.has_value(), - device_runtime->get_num_sms()); - - const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, - SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), - config.block_k, - static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, - config.smem_config.swizzle_a_mode); - const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, - SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), - config.block_k, - static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, - config.smem_config.swizzle_b_mode); - const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, - SM100ArchSpec::get_cd_store_block_m(config.block_m), - SM100ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(-2)), 1, - config.smem_config.swizzle_cd_mode); + const auto desc = GemmDesc { + .gemm_type = GemmType::Normal, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = k, .num_groups = 1, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); // Launch const SM100BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = 1, - .compiled_dims = compiled_dims, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), .grouped_layout = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_cd = tensor_map_cd }; - const auto& code = SM100BF16GemmRuntime::generate(args); - const auto& runtime = compiler->build("sm100_bf16_gemm", code); + const auto code = SM100BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm100_bf16_gemm", code); SM100BF16GemmRuntime::launch(runtime, args); } @@ -130,53 +136,61 @@ static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, const std::string& compiled_dims, const bool& use_psum_layout, const std::optional& expected_m_for_psum_layout) { - const auto& gemm_type = use_psum_layout ? GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + const auto gemm_type = use_psum_layout ? + GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + + // Only psum layout can use expected m + if (expected_m_for_psum_layout) + DG_HOST_ASSERT(use_psum_layout); // NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`. // Otherwise, treat the contiguous layout as a whole. - const auto& m_for_config = expected_m_for_psum_layout.has_value() ? expected_m_for_psum_layout.value() : m; - const auto& num_groups_for_config = expected_m_for_psum_layout.has_value() ? num_groups : 1; - - const auto& config = get_best_config( - gemm_type, KernelType::KernelNoSF, - // NOTES: `num_groups` is 1, since the contiguous layout is seen as a whole - m_for_config, n, k, num_groups_for_config, major_a, major_b, - a.scalar_type(), b.scalar_type(), - d.scalar_type(), false, - device_runtime->get_num_sms()); - - const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, - SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), - config.block_k, - static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, - config.smem_config.swizzle_a_mode); - const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, - SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), - config.block_k, - static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, - config.smem_config.swizzle_b_mode); - const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, - SM100ArchSpec::get_cd_store_block_m(config.block_m), - SM100ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(-2)), 1, - config.smem_config.swizzle_cd_mode); + const auto desc = GemmDesc { + .gemm_type = gemm_type, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = expected_m_for_psum_layout.value_or(m), + .expected_n = n, .expected_k = k, + .expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1 + }; + const auto config = get_best_config(desc); + + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); // Launch - const SM100BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, + const SM100BF16GemmRuntime::Args args = { + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), .grouped_layout = grouped_layout.data_ptr(), .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_cd = tensor_map_cd }; - const auto& code = SM100BF16GemmRuntime::generate(args); - const auto& runtime = compiler->build("sm100_bf16_m_grouped_gemm_contiguous", code); + const auto code = SM100BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm100_bf16_m_grouped_gemm_contiguous", code); SM100BF16GemmRuntime::launch(runtime, args); } @@ -188,45 +202,50 @@ static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a, const int& expected_m, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const std::string& compiled_dims) { - const auto& config = get_best_config( - GemmType::MGroupedMasked, KernelType::KernelNoSF, - expected_m, n, k, num_groups, major_a, major_b, - a.scalar_type(), b.scalar_type(), - d.scalar_type(), false, - device_runtime->get_num_sms()); - - const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, - SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), - config.block_k, - static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, - config.smem_config.swizzle_a_mode); - const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, - SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), - config.block_k, - static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, - config.smem_config.swizzle_b_mode); - const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, - SM100ArchSpec::get_cd_store_block_m(config.block_m), - SM100ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(-2)), num_groups, - config.smem_config.swizzle_cd_mode); + const auto desc = GemmDesc { + .gemm_type = GemmType::MGroupedMasked, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = expected_m, .expected_n = n, .expected_k = k, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); + + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), num_groups, + config.storage_config.swizzle_cd_mode); // Launch - const SM100BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, + const SM100BF16GemmRuntime::Args args = { + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), .grouped_layout = masked_m.data_ptr(), .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_cd = tensor_map_cd }; - const auto& code = SM100BF16GemmRuntime::generate(args); - const auto& runtime = compiler->build("sm100_bf16_m_grouped_gemm_masked", code); + const auto code = SM100BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm100_bf16_m_grouped_gemm_masked", code); SM100BF16GemmRuntime::launch(runtime, args); } @@ -241,54 +260,59 @@ static void sm100_bf16_k_grouped_gemm(const torch::Tensor& a, DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); int sum_k = 0; - for (const auto& k: ks) { + for (const auto k: ks) { sum_k += k; DG_HOST_ASSERT(k % 128 == 0); } - const auto& num_groups = static_cast(ks.size()); + const auto num_groups = static_cast(ks.size()); // Get config using max K for better performance - const auto& max_k = *std::max_element(ks.begin(), ks.end()); - const auto& config = get_best_config( - GemmType::KGroupedContiguous, KernelType::KernelNoSF, - m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN, - a.scalar_type(), b.scalar_type(), - d.scalar_type(), c.has_value(), - device_runtime->get_num_sms()); + const auto max_k = *std::max_element(ks.begin(), ks.end()); + const auto desc = GemmDesc { + .gemm_type = GemmType::KGroupedContiguous, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = sum_k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); // Create tensor descriptors - const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, - SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), - config.block_k, - static_cast(a.stride(0)), 1, - config.smem_config.swizzle_a_mode); - const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, - SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), - config.block_k, - static_cast(b.stride(0)), 1, - config.smem_config.swizzle_b_mode); - const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, - SM100ArchSpec::get_cd_store_block_m(config.block_m), - SM100ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(1)), num_groups, - config.smem_config.swizzle_cd_mode); + const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(0)), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(0)), 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(1)), num_groups, + config.storage_config.swizzle_cd_mode); // Launch kernel const SM100BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = sum_k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), .grouped_layout = ks_tensor.data_ptr(), .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_cd = tensor_map_cd }; - const auto& code = SM100BF16GemmRuntime::generate(args); - const auto& runtime = compiler->build("sm100_bf16_k_grouped_gemm", code); + const auto code = SM100BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm100_bf16_k_grouped_gemm", code); SM100BF16GemmRuntime::launch(runtime, args); } @@ -297,46 +321,46 @@ static void sm100_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, const torch::Tensor& tensor_d, const int& b, const int& h, const int& r, const int& d, const std::string& compiled_dims = "nk") { - const auto& config = get_best_config( - GemmType::Batched, KernelType::KernelNoSF, - b, d, r, h, cute::UMMA::Major::K, cute::UMMA::Major::K, - tensor_a.scalar_type(), tensor_b.scalar_type(), - tensor_d.scalar_type(), false, - device_runtime->get_num_sms()); - - const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); - const auto& tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h, - config.block_k, load_block_m, 1, - tensor_a.stride(0), tensor_a.stride(1), - config.smem_config.swizzle_a_mode); - const int& load_block_n = SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); - const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, - config.block_k, load_block_n, 1, - tensor_b.stride(1), tensor_b.stride(0), - config.smem_config.swizzle_b_mode); - const int& store_block_m = SM100ArchSpec::get_cd_store_block_m(config.block_m); - const int& store_block_n = SM100ArchSpec::get_cd_store_block_n(config.block_n); - const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h, - store_block_n, store_block_m, 1, - tensor_d.stride(0), tensor_d.stride(1), - config.smem_config.swizzle_cd_mode); + const auto desc = GemmDesc { + .gemm_type = GemmType::Batched, + .kernel_type = KernelType::KernelNoSF, + .m = b, .n = d, .k = r, .num_groups = h, + .a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(), + .cd_dtype = tensor_d.scalar_type(), + .major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::K, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + const auto tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h, + config.layout.block_k, config.storage_config.load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + config.layout.block_k, config.storage_config.load_block_n, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h, + config.storage_config.store_block_n, config.storage_config.store_block_m, 1, + tensor_d.stride(0), tensor_d.stride(1), + config.storage_config.swizzle_cd_mode); // Launch const SM100BF16GemmRuntime::Args& args = { - .m = b, .n = d, .k = r, - .num_groups = h, - .compiled_dims = compiled_dims, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), .grouped_layout = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_cd = tensor_map_cd }; - const auto& code = SM100BF16GemmRuntime::generate(args); - const auto& runtime = compiler->build("sm100_bf16_bhr_hdr_bhd", code); + const auto code = SM100BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm100_bf16_bhr_hdr_bhd", code); SM100BF16GemmRuntime::launch(runtime, args); } @@ -345,46 +369,46 @@ static void sm100_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, const torch::Tensor& tensor_d, const int& b, const int& h, const int& r, const int& d, const std::string& compiled_dims = "nk") { - const auto& config = get_best_config( - GemmType::Batched, KernelType::KernelNoSF, - b, r, d, h, cute::UMMA::Major::K, cute::UMMA::Major::MN, - tensor_a.scalar_type(), tensor_b.scalar_type(), - tensor_d.scalar_type(), false, - device_runtime->get_num_sms()); - - const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); - const auto& tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h, - config.block_k, load_block_m, 1, - tensor_a.stride(0), tensor_a.stride(1), - config.smem_config.swizzle_a_mode); - const int& load_block_n = SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); - const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, - load_block_n, config.block_k, 1, - tensor_b.stride(1), tensor_b.stride(0), - config.smem_config.swizzle_b_mode); - const int& store_block_m = SM100ArchSpec::get_cd_store_block_m(config.block_m); - const int& store_block_n = SM100ArchSpec::get_cd_store_block_n(config.block_n); - const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h, - store_block_n, store_block_m, 1, - tensor_d.stride(0), tensor_d.stride(1), - config.smem_config.swizzle_cd_mode); + const auto desc = GemmDesc { + .gemm_type = GemmType::Batched, + .kernel_type = KernelType::KernelNoSF, + .m = b, .n = r, .k = d, .num_groups = h, + .a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(), + .cd_dtype = tensor_d.scalar_type(), + .major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::MN, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + const auto tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h, + config.layout.block_k, config.storage_config.load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + config.storage_config.load_block_n, config.layout.block_k, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h, + config.storage_config.store_block_n, config.storage_config.store_block_m, 1, + tensor_d.stride(0), tensor_d.stride(1), + config.storage_config.swizzle_cd_mode); // Launch const SM100BF16GemmRuntime::Args& args = { - .m = b, .n = r, .k = d, - .num_groups = h, - .compiled_dims = compiled_dims, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), .grouped_layout = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_cd = tensor_map_cd }; - const auto& code = SM100BF16GemmRuntime::generate(args); - const auto& runtime = compiler->build("sm100_bf16_bhd_hdr_bhr", code); + const auto code = SM100BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm100_bf16_bhd_hdr_bhr", code); SM100BF16GemmRuntime::launch(runtime, args); } diff --git a/deep-gemm/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp b/deep-gemm/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp index dc8766cc..2ec5d12a 100644 --- a/deep-gemm/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp +++ b/deep-gemm/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include "../../utils/torch_compat.hpp" #include "../../jit/compiler.hpp" #include "../../jit/device_runtime.hpp" @@ -85,11 +85,11 @@ static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a, // NOTES: we select 4 as start, as it is tested to be faster than values > 4 int num_stages = 4, smem_size = 0; while (true) { - const int& smem_cd = block_m * swizzle_cd_mode * 2; - const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t); - const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t); - const int& smem_barrier = SM100ArchSpec::get_barrier_smem_size(num_stages); - const int& smem_tmem_ptr = SM100ArchSpec::get_tmem_ptr_smem_size(); + const int smem_cd = block_m * swizzle_cd_mode * 2; + const int smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t); + const int smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t); + const int smem_barrier = num_stages * 8 * 3 + 2 * 8 * 2 + 8; + const int smem_tmem_ptr = 4; smem_size = 0; smem_size += smem_cd; @@ -112,11 +112,11 @@ static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a, num_stages, smem_size, swizzle_ab_mode, swizzle_cd_mode); } - const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode); - const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode); - const auto& tensor_map_d = make_tma_2d_desc(d, n, m, block_n, block_m, n, swizzle_cd_mode); + const auto tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode); + const auto tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode); + const auto tensor_map_d = make_tma_2d_desc(d, n, m, block_n, block_m, n, swizzle_cd_mode); - const SM100BmkBnkMnRuntime::Args& args = { + const SM100BmkBnkMnRuntime::Args args = { .s = s, .m = m, .n = n, .k = k, .block_m = block_m, .block_n = block_n, .block_k = block_k, .split_factor = split_factor, @@ -129,8 +129,8 @@ static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d }; - const auto& code = SM100BmkBnkMnRuntime::generate(args); - const auto& runtime = compiler->build("sm100_bmn_bnk_mn_gemm", code); + const auto code = SM100BmkBnkMnRuntime::generate(args); + const auto runtime = compiler->build("sm100_bmn_bnk_mn_gemm", code); SM100BmkBnkMnRuntime::launch(runtime, args); } diff --git a/deep-gemm/csrc/jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp b/deep-gemm/csrc/jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp new file mode 100644 index 00000000..efd6d555 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp @@ -0,0 +1,459 @@ +#pragma once + +#include "../../utils/torch_compat.hpp" + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" + +#include "epilogue.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100FP8FP4Gemm1D1DRuntime final: public LaunchRuntime { +public: + struct Args { + GemmDesc gemm_desc; + GemmConfig gemm_config; + LaunchArgs launch_args; + // TODO: move into descriptor + const std::optional epilogue_type; + + // TODO: move into descriptor + int gran_k_a, gran_k_b; + + void* grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_sfa; + CUtensorMap tensor_map_sfb; + CUtensorMap tensor_map_cd; + }; + + static std::string generate_impl(const Args& args) { + // TODO: rename files + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp8_fp4_gemm_1d1d_impl< + {}, {}, + {}, {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, + {}, {}, + {}, {}, {}, + {} + >); +}}; +)", + to_string(args.gemm_desc.major_a), to_string(args.gemm_desc.major_b), + args.gran_k_a, args.gran_k_b, + get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims), + args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k, + args.gemm_desc.num_groups, + args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode, args.gemm_config.storage_config.swizzle_cd_mode, + args.gemm_config.pipeline_config.num_stages, + args.gemm_config.launch_config.num_non_epilogue_threads, args.gemm_config.launch_config.num_epilogue_threads, + args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1, + args.gemm_config.launch_config.num_sms, + args.gemm_config.layout.swap_ab, + to_string(args.gemm_desc.gemm_type), args.gemm_desc.with_accumulation, + to_string(args.gemm_desc.a_dtype), to_string(args.gemm_desc.b_dtype), to_string(args.gemm_desc.cd_dtype), + get_default_epilogue_type(args.epilogue_type)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.grouped_layout, args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_sfa, args.tensor_map_sfb, + args.tensor_map_cd)); + } +}; + +static void sm100_fp8_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const std::optional& epilogue_type = std::nullopt) { + const auto desc = GemmDesc { + .gemm_type = GemmType::Normal, + .kernel_type = KernelType::Kernel1D1D, + .m = m, .n = n, .k = k, .num_groups = 1, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), + .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + const auto cd = c.value_or(d); + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, static_cast(d.size(-1)), + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, gran_k_a, 1, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.layout.block_n, gran_k_b, 1, 0); + + // Launch + const SM100FP8FP4Gemm1D1DRuntime::Args args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = epilogue_type, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm100_fp8_fp4_gemm_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + const int& num_groups, const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + const auto gemm_type = use_psum_layout ? + GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + + // Only psum layout can use expected m + if (expected_m_for_psum_layout) + DG_HOST_ASSERT(use_psum_layout); + + // NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`. + // Otherwise, treat the contiguous layout as a whole. + const auto desc = GemmDesc { + .gemm_type = gemm_type, + .kernel_type = KernelType::Kernel1D1D, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), + .compiled_dims = compiled_dims, + .expected_m = expected_m_for_psum_layout.value_or(m), + .expected_n = n, .expected_k = k, + .expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1 + }; + const auto config = get_best_config(desc); + + // Create tensor descriptors + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, gran_k_a, 1, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.layout.block_n, gran_k_b, num_groups, 0); + + // Launch kernel + const SM100FP8FP4Gemm1D1DRuntime::Args args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = std::nullopt, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .grouped_layout = grouped_layout.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto desc = GemmDesc { + .gemm_type = GemmType::MGroupedMasked, + .kernel_type = KernelType::Kernel1D1D, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), + .compiled_dims = compiled_dims, + .expected_m = expected_m, .expected_n = n, .expected_k = k, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); + + // Create tensor descriptors + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), num_groups, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, gran_k_a, num_groups, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.layout.block_n, gran_k_b, num_groups, 0); + + // Launch kernel + const SM100FP8FP4Gemm1D1DRuntime::Args args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = std::nullopt, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const int& gran_k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); + DG_HOST_ASSERT(gran_k == 32 or gran_k == 128); + const int gran_k_a = gran_k; + const int gran_k_b = gran_k; + + int sum_k = 0, sum_sf_k = 0; + for (const auto k: ks) { + sum_k += k, sum_sf_k += ceil_div(k, gran_k * 4); + DG_HOST_ASSERT(k % gran_k == 0); + } + const auto num_groups = static_cast(ks.size()); + + // Get config using max K for better performance + const auto max_k = *std::max_element(ks.begin(), ks.end()); + const auto desc = GemmDesc { + .gemm_type = GemmType::KGroupedContiguous, + .kernel_type = KernelType::Kernel1D1D, + .m = m, .n = n, .k = sum_k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), + .compiled_dims = compiled_dims, + .expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); + + // Create tensor descriptors + const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(0)), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(0)), 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(1)), num_groups, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * gran_k_a * 4, + config.layout.block_m, gran_k_a, 1, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * gran_k_b * 4, + config.layout.block_n, gran_k_b, 1, 0); + + // Launch kernel + const SM100FP8FP4Gemm1D1DRuntime::Args args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = std::nullopt, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .grouped_layout = ks_tensor.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& batch_size, const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto desc = GemmDesc { + .gemm_type = GemmType::Batched, + .kernel_type = KernelType::Kernel1D1D, + .m = m, .n = n, .k = k, .num_groups = batch_size, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), + .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + const int load_block_m = config.storage_config.load_block_m; + const auto [inner_dim_a, outer_dim_a] = get_inner_outer_dims(major_a, k, m); + const auto [inner_block_a, outer_block_a] = get_inner_outer_dims(major_a, config.layout.block_k, load_block_m); + const auto tensor_map_a = make_tma_3d_desc(a, inner_dim_a, outer_dim_a, batch_size, + inner_block_a, outer_block_a, 1, + a.stride(major_a == cute::UMMA::Major::K ? 1 : 2), + a.stride(0), + config.storage_config.swizzle_a_mode); + + const int load_block_n = config.storage_config.load_block_n; + const auto [inner_dim_b, outer_dim_b] = get_inner_outer_dims(major_b, k, n); + const auto [inner_block_b, outer_block_b] = get_inner_outer_dims(major_b, config.layout.block_k, load_block_n); + const auto tensor_map_b = make_tma_3d_desc(b, inner_dim_b, outer_dim_b, batch_size, + inner_block_b, outer_block_b, 1, + b.stride(major_b == cute::UMMA::Major::K ? 1 : 2), + b.stride(0), + config.storage_config.swizzle_b_mode); + + const int store_block_m = config.storage_config.store_block_m; + const int store_block_n = config.storage_config.store_block_n; + const auto tensor_map_cd = make_tma_3d_desc(d, n, m, batch_size, + store_block_n, store_block_m, 1, + d.stride(1), d.stride(0), + config.storage_config.swizzle_cd_mode); + + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, gran_k_a, batch_size, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.layout.block_n, gran_k_b, batch_size, 0); + + // Launch + const SM100FP8FP4Gemm1D1DRuntime::Args args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = std::nullopt, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm100_fp8_gemm_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp b/deep-gemm/csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp new file mode 100644 index 00000000..64d5cda1 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp @@ -0,0 +1,220 @@ +#pragma once + +#include "../../utils/torch_compat.hpp" + +#include "../../jit/compiler.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "runtime_utils.hpp" + +#include +#include + +#include "../heuristics/mega_moe.hpp" + +namespace deep_gemm { + +class SM100FP8FP4MegaMoERuntime final : public LaunchRuntime { +public: + struct Args { + // Templated arguments + int num_max_tokens_per_rank; + int hidden, intermediate_hidden; + int num_experts, num_topk; + int num_ranks; + float activation_clamp; + bool fast_math; + MegaMoEConfig config; + + // Runtime arguments + void* y; + int* cumulative_local_expert_recv_stats; + int num_tokens; + layout::SymBuffer<> sym_buffer_ptrs; + + // Tensormap + CUtensorMap tensor_map_l1_acts; + CUtensorMap tensor_map_l1_acts_sf; + CUtensorMap tensor_map_l1_weights; + CUtensorMap tensor_map_l1_weights_sf; + CUtensorMap tensor_map_l1_output; + CUtensorMap tensor_map_l2_acts; + CUtensorMap tensor_map_l2_acts_sf; + CUtensorMap tensor_map_l2_weights; + CUtensorMap tensor_map_l2_weights_sf; + + // Launch configs + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp8_fp4_mega_moe_impl< + {}, + {}, {}, + {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, + {}, + {}, + {}, {}, {}, + {}, {}, + {}, + {} + >); +}}; +)", args.num_max_tokens_per_rank, + args.hidden, args.intermediate_hidden, + args.num_experts, args.num_topk, + args.config.num_experts_per_wave, + args.config.block_m, args.config.block_n, args.config.block_k, + args.config.store_block_m, + args.config.sf_block_m, args.config.sf_block_n, + args.config.num_max_pool_tokens, + args.config.num_padded_sf_pool_tokens, + args.config.num_stages, + args.config.num_dispatch_threads, args.config.num_non_epilogue_threads, args.config.num_epilogue_threads, + args.launch_args.grid_dim.first, args.num_ranks, + to_string(args.activation_clamp), + args.fast_math ? "true" : "false"); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.y, + args.cumulative_local_expert_recv_stats, + args.num_tokens, + args.sym_buffer_ptrs, + args.tensor_map_l1_acts, + args.tensor_map_l1_acts_sf, + args.tensor_map_l1_weights, + args.tensor_map_l1_weights_sf, + args.tensor_map_l1_output, + args.tensor_map_l2_acts, + args.tensor_map_l2_acts_sf, + args.tensor_map_l2_weights, + args.tensor_map_l2_weights_sf + )); + } +}; + +static void sm100_fp8_fp4_mega_moe( + const torch::Tensor& y, + const torch::Tensor& l1_acts, const torch::Tensor& l1_acts_sf, + const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf, + const torch::Tensor& l1_weights, const torch::Tensor& l2_weights, + const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf, + const std::optional cumulative_local_expert_recv_stats, + const std::vector& sym_buffer_ptrs, + const int& rank_idx, const int& num_max_tokens_per_rank, + const int& num_experts_per_rank, + const int& num_tokens, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const float& activation_clamp, + const bool& fast_math +) { + const auto num_ranks = static_cast(sym_buffer_ptrs.size()); + const auto num_experts = num_experts_per_rank * num_ranks; + const auto num_padded_sf_pool_tokens = static_cast(l1_acts_sf.size(0)); + + // Heuristics + const auto config = get_mega_moe_config( + num_ranks, num_experts, num_experts_per_rank, + num_max_tokens_per_rank, num_tokens, num_topk, hidden, intermediate_hidden, num_padded_sf_pool_tokens); + + // Make tensormap + constexpr int kGranK = 32; + const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts, + hidden, config.num_max_pool_tokens, + config.block_k, config.load_block_m, + static_cast(l1_acts.stride(-2)), + config.swizzle_acts_mode); + const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf, + config.num_padded_sf_pool_tokens, hidden, + config.sf_block_m, kGranK, + 1, 0); + const auto tensor_map_l1_weights = make_tma_2d_desc(l1_weights, + hidden, num_experts_per_rank * intermediate_hidden * 2, + config.block_k, config.load_block_n, + static_cast(l1_weights.stride(-2)), + config.swizzle_weights_mode); + const auto tensor_map_l1_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_weights_sf, + intermediate_hidden * 2, hidden, + config.block_n, kGranK, + num_experts_per_rank, 0); + // NOTES: L1 output and L2 activations are essentially the same tensor. + // Post-SwiGLU output has half the N width (`BLOCK_N / 2` per input tile), + // so the swizzle mode is also halved (128 -> 64). + const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts, + intermediate_hidden, config.num_max_pool_tokens, + config.block_n / 2, config.store_block_m, + static_cast(l2_acts.stride(-2)), + config.swizzle_acts_mode / 2); + const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts, + intermediate_hidden, config.num_max_pool_tokens, + config.block_k, config.load_block_m, + static_cast(l2_acts.stride(-2)), + config.swizzle_acts_mode); + const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf, + config.num_padded_sf_pool_tokens, intermediate_hidden, + config.sf_block_m, kGranK, + 1, 0); + const auto tensor_map_l2_weights = make_tma_2d_desc(l2_weights, + intermediate_hidden, num_experts_per_rank * hidden, + config.block_k, config.load_block_n, + static_cast(l2_weights.stride(-2)), + config.swizzle_weights_mode); + const auto tensor_map_l2_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_weights_sf, + hidden, intermediate_hidden, + config.block_n, kGranK, + num_experts_per_rank, 0); + + // Stats can be optional + int* cumulative_local_expert_recv_stats_ptr = nullptr; + if (cumulative_local_expert_recv_stats.has_value()) + cumulative_local_expert_recv_stats_ptr = cumulative_local_expert_recv_stats->data_ptr(); + + // Launch + const auto num_sms = device_runtime->get_num_sms(); + const SM100FP8FP4MegaMoERuntime::Args args = { + .num_max_tokens_per_rank = num_max_tokens_per_rank, + .hidden = hidden, .intermediate_hidden = intermediate_hidden, + .num_experts = num_experts, .num_topk = num_topk, + .num_ranks = num_ranks, + .activation_clamp = activation_clamp, + .fast_math = fast_math, + .config = config, + .y = y.data_ptr(), + .cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr, + .num_tokens = num_tokens, + .sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx), + .tensor_map_l1_acts = tensor_map_l1_acts, + .tensor_map_l1_acts_sf = tensor_map_l1_acts_sf, + .tensor_map_l1_weights = tensor_map_l1_weights, + .tensor_map_l1_weights_sf = tensor_map_l1_weights_sf, + .tensor_map_l1_output = tensor_map_l1_output, + .tensor_map_l2_acts = tensor_map_l2_acts, + .tensor_map_l2_acts_sf = tensor_map_l2_acts_sf, + .tensor_map_l2_weights = tensor_map_l2_weights, + .tensor_map_l2_weights_sf = tensor_map_l2_weights_sf, + .launch_args = LaunchArgs(num_sms, + config.num_dispatch_threads + config.num_non_epilogue_threads + config.num_epilogue_threads, + config.smem_size, 2) + }; + + const auto code = SM100FP8FP4MegaMoERuntime::generate(args); + const auto runtime = compiler->build("sm100_fp8_fp4_mega_moe", code); + SM100FP8FP4MegaMoERuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/deep-gemm/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index 404369a4..41681cb6 100644 --- a/deep-gemm/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/deep-gemm/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include "../../utils/torch_compat.hpp" #include "../../jit/compiler.hpp" #include "../../jit/device_runtime.hpp" diff --git a/deep-gemm/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp b/deep-gemm/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp index bdb5b11d..e91a5d41 100644 --- a/deep-gemm/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp +++ b/deep-gemm/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include "../../utils/torch_compat.hpp" #include "../../jit/compiler.hpp" #include "../../jit/device_runtime.hpp" @@ -79,21 +79,21 @@ static void sm100_tf32_hc_prenorm_gemm(const torch::Tensor& a, DG_HOST_ASSERT(n <= 128 and n % 8 == 0); DG_HOST_ASSERT(k % block_k == 0); - const auto& swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float)); - const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k, - block_m, block_k, - static_cast(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, - get_swizzle_mode(block_k, a.element_size()), 0, - true); - const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k, - block_n, block_k, - static_cast(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, - get_swizzle_mode(block_k, b.element_size()), 0, - true); - const auto& tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n, - block_m, block_n, - static_cast(d.stride(-2)), 1, - swizzle_cd_mode) + const auto swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float)); + const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k, + block_m, block_k, + static_cast(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, a.element_size()), 0, + true); + const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k, + block_n, block_k, + static_cast(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, b.element_size()), 0, + true); + const auto tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n, + block_m, block_n, + static_cast(d.stride(-2)), 1, + swizzle_cd_mode) : make_tma_3d_desc(d, n, m, num_splits, block_n, block_m, 1, static_cast(d.stride(-2)), @@ -135,14 +135,14 @@ static void sm100_tf32_hc_prenorm_gemm(const torch::Tensor& a, .num_stages = num_stages, .num_mma_threads = num_mma_threads, .num_cast_and_reduce_threads = num_cast_and_reduce_threads, - .launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_mma_threads + num_cast_and_reduce_threads, smem_size, 1), + .launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_mma_threads + num_cast_and_reduce_threads, smem_size), .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d, .sqr_sum = sqr_sum.data_ptr() }; - const auto& code = SM100BF16HCPrenormGemmRuntime::generate(args); - const auto& runtime = compiler->build("sm100_tf32_hc_prenorm_gemm", code); + const auto code = SM100BF16HCPrenormGemmRuntime::generate(args); + const auto runtime = compiler->build("sm100_tf32_hc_prenorm_gemm", code); SM100BF16HCPrenormGemmRuntime::launch(runtime, args); } diff --git a/deep-gemm/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp b/deep-gemm/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp index 6291d0d9..43fa2913 100644 --- a/deep-gemm/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp +++ b/deep-gemm/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include "../../utils/torch_compat.hpp" #include "../../jit/compiler.hpp" #include "../../jit/kernel_runtime.hpp" @@ -14,9 +14,7 @@ namespace deep_gemm { class SM90BF16GemmRuntime final: public LaunchRuntime { public: struct Args { - int m, n, k, num_groups; - const std::string& compiled_dims; - + GemmDesc gemm_desc; GemmConfig gemm_config; LaunchArgs launch_args; @@ -49,24 +47,29 @@ static void __instantiate_kernel() {{ }}; )", // TODO: add CD dtype - to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b), - get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), - args.num_groups, - args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, - args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, - args.gemm_config.num_stages, - args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, - args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, - args.gemm_config.num_sms, - to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, - to_string(args.gemm_config.cd_dtype)); + to_string(args.gemm_desc.major_a), to_string(args.gemm_desc.major_b), + get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims), + args.gemm_desc.num_groups, + args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k, + args.gemm_config.storage_config.swizzle_a_mode, + args.gemm_config.storage_config.swizzle_b_mode, + args.gemm_config.storage_config.swizzle_cd_mode, + args.gemm_config.pipeline_config.num_stages, + args.gemm_config.launch_config.num_tma_threads, args.gemm_config.launch_config.num_math_threads, + // TODO: refactor with cluster M/N + args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1, + args.gemm_config.launch_config.num_sms, + to_string(args.gemm_desc.gemm_type), args.gemm_desc.with_accumulation, + to_string(args.gemm_desc.cd_dtype)); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { // TODO: optimize `args` copy DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.grouped_layout, - args.m, args.n, args.k, + args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k, args.tensor_map_a, args.tensor_map_b, args.tensor_map_cd)); } @@ -79,46 +82,50 @@ static void sm90_bf16_gemm(const torch::Tensor& a, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const std::string& compiled_dims) { - const auto& config = get_best_config( - GemmType::Normal, KernelType::KernelNoSF, - m, n, k, 1, major_a, major_b, - a.scalar_type(), b.scalar_type(), - d.scalar_type(), c.has_value(), - device_runtime->get_num_sms()); + const auto desc = GemmDesc { + .gemm_type = GemmType::Normal, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = k, .num_groups = 1, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); // Requires no TMA splits - const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, - SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), - config.block_k, - static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, - config.smem_config.swizzle_a_mode); - const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, - SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), - config.block_k, - static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, - config.smem_config.swizzle_b_mode); - const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, - SM90ArchSpec::get_cd_store_block_m(config.block_m), - SM90ArchSpec::get_cd_store_block_n(config.block_n), + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, static_cast(d.stride(-2)), 1, - config.smem_config.swizzle_cd_mode); + config.storage_config.swizzle_cd_mode); // Launch const SM90BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = 1, - .compiled_dims = compiled_dims, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), .grouped_layout = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_cd = tensor_map_cd, }; - const auto& code = SM90BF16GemmRuntime::generate(args); - const auto& runtime = compiler->build("sm90_bf16_gemm", code); + const auto code = SM90BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm90_bf16_gemm", code); SM90BF16GemmRuntime::launch(runtime, args); } @@ -128,51 +135,67 @@ static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, const torch::Tensor& m_indices, const int& num_groups, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); DG_HOST_ASSERT(k % 64 == 0); - const auto& config = get_best_config( - GemmType::MGroupedContiguous, KernelType::KernelNoSF, - m, n, k, 1, major_a, major_b, - a.scalar_type(), b.scalar_type(), - d.scalar_type(), false, - device_runtime->get_num_sms()); + const auto gemm_type = use_psum_layout ? + GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + + // Only psum layout can use expected m + if (expected_m_for_psum_layout) + DG_HOST_ASSERT(use_psum_layout); + + const auto desc = GemmDesc { + .gemm_type = gemm_type, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = expected_m_for_psum_layout.value_or(m), + .expected_n = n, .expected_k = k, + .expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1 + }; + const auto config = get_best_config(desc); // Requires no TMA splits - const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, - SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), - config.block_k, - static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, - config.smem_config.swizzle_a_mode); - const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, - SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), - config.block_k, - static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, - config.smem_config.swizzle_b_mode); - const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, - SM90ArchSpec::get_cd_store_block_m(config.block_m), - SM90ArchSpec::get_cd_store_block_n(config.block_n), + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, static_cast(d.stride(-2)), 1, - config.smem_config.swizzle_cd_mode); + config.storage_config.swizzle_cd_mode); // Launch const SM90BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), .grouped_layout = m_indices.data_ptr(), .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_cd = tensor_map_cd, }; - const auto& code = SM90BF16GemmRuntime::generate(args); - const auto& runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code); + const auto code = SM90BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code); SM90BF16GemmRuntime::launch(runtime, args); } @@ -188,46 +211,51 @@ static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a, DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); DG_HOST_ASSERT(k % 64 == 0); - const auto& config = get_best_config( - GemmType::MGroupedMasked, KernelType::KernelNoSF, - expected_m, n, k, num_groups, major_a, major_b, - a.scalar_type(), b.scalar_type(), - d.scalar_type(), false, - device_runtime->get_num_sms()); + const auto desc = GemmDesc { + .gemm_type = GemmType::MGroupedMasked, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = expected_m, .expected_n = 0, .expected_k = 0, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); // Requires no TMA splits - const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, - SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), - config.block_k, - static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, - config.smem_config.swizzle_a_mode); - const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, - SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), - config.block_k, - static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, - config.smem_config.swizzle_b_mode); - const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, - SM90ArchSpec::get_cd_store_block_m(config.block_m), - SM90ArchSpec::get_cd_store_block_n(config.block_n), + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, static_cast(d.stride(-2)), num_groups, - config.smem_config.swizzle_cd_mode); + config.storage_config.swizzle_cd_mode); // Launch const SM90BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), .grouped_layout = masked_m.data_ptr(), .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_cd = tensor_map_cd, }; - const auto& code = SM90BF16GemmRuntime::generate(args); - const auto& runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code); + const auto code = SM90BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code); SM90BF16GemmRuntime::launch(runtime, args); } @@ -242,54 +270,59 @@ static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a, DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); int sum_k = 0; - for (const auto& k: ks) { + for (const auto k: ks) { sum_k += k; DG_HOST_ASSERT(k % 128 == 0); } - const auto& num_groups = static_cast(ks.size()); + const auto num_groups = static_cast(ks.size()); // Get config using max K for better performance - const auto& max_k = *std::max_element(ks.begin(), ks.end()); - const auto& config = get_best_config( - GemmType::KGroupedContiguous, KernelType::KernelNoSF, - m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN, - a.scalar_type(), b.scalar_type(), - d.scalar_type(), c.has_value(), - device_runtime->get_num_sms()); + const auto max_k = *std::max_element(ks.begin(), ks.end()); + const auto desc = GemmDesc { + .gemm_type = GemmType::KGroupedContiguous, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = sum_k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); // Create tensor descriptors - const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, - SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), - config.block_k, - static_cast(a.stride(0)), 1, - config.smem_config.swizzle_a_mode); - const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, - SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), - config.block_k, - static_cast(b.stride(0)), 1, - config.smem_config.swizzle_b_mode); - const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, - SM90ArchSpec::get_cd_store_block_m(config.block_m), - SM90ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(1)), num_groups, - config.smem_config.swizzle_cd_mode); + const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(0)), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(0)), 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(1)), num_groups, + config.storage_config.swizzle_cd_mode); // Launch kernel const SM90BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = sum_k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), .grouped_layout = ks_tensor.data_ptr(), .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_cd = tensor_map_cd, }; - const auto& code = SM90BF16GemmRuntime::generate(args); - const auto& runtime = compiler->build("sm90_bf16_k_grouped_gemm", code); + const auto code = SM90BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm90_bf16_k_grouped_gemm", code); SM90BF16GemmRuntime::launch(runtime, args); } @@ -298,45 +331,50 @@ static void sm90_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, const torch::Tensor& tensor_d, const int& b, const int& h, const int& r, const int& d, const std::string& compiled_dims = "nk") { - const auto& config = get_best_config( - GemmType::Batched, KernelType::KernelNoSF, - b, d, r, h, cute::UMMA::Major::K, cute::UMMA::Major::K, - tensor_a.scalar_type(), tensor_b.scalar_type(), - tensor_d.scalar_type(), false, - device_runtime->get_num_sms()); - - const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); - const auto& tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h, - config.block_k, load_block_m, 1, - tensor_a.stride(0), tensor_a.stride(1), - config.smem_config.swizzle_a_mode); - const int& load_block_n = SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); - const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, - config.block_k, load_block_n, 1, - tensor_b.stride(1), tensor_b.stride(0), - config.smem_config.swizzle_b_mode); - const int& store_block_m = SM90ArchSpec::get_cd_store_block_m(config.block_m); - const int& store_block_n = SM90ArchSpec::get_cd_store_block_n(config.block_n); - const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h, + const auto desc = GemmDesc { + .gemm_type = GemmType::Batched, + .kernel_type = KernelType::KernelNoSF, + .m = b, .n = d, .k = r, .num_groups = h, + .a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(), + .cd_dtype = tensor_d.scalar_type(), + .major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::K, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + const int load_block_m = config.storage_config.load_block_m; + const auto tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h, + config.layout.block_k, load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.storage_config.swizzle_a_mode); + const int load_block_n = config.storage_config.load_block_n; + const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + config.layout.block_k, load_block_n, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.storage_config.swizzle_b_mode); + const int store_block_m = config.storage_config.store_block_m; + const int store_block_n = config.storage_config.store_block_n; + const auto tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h, store_block_n, store_block_m, 1, tensor_d.stride(0), tensor_d.stride(1), - config.smem_config.swizzle_cd_mode); + config.storage_config.swizzle_cd_mode); + // Launch const SM90BF16GemmRuntime::Args& args = { - .m = b, .n = d, .k = r, - .num_groups = h, - .compiled_dims = compiled_dims, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), .grouped_layout = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_cd = tensor_map_cd, }; - const auto& code = SM90BF16GemmRuntime::generate(args); - const auto& runtime = compiler->build("sm90_bf16_bhr_hdr_bhd", code); + const auto code = SM90BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm90_bf16_bhr_hdr_bhd", code); SM90BF16GemmRuntime::launch(runtime, args); } @@ -345,45 +383,49 @@ static void sm90_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, const torch::Tensor& tensor_d, const int& b, const int& h, const int& r, const int& d, const std::string& compiled_dims = "nk") { - const auto& config = get_best_config( - GemmType::Batched, KernelType::KernelNoSF, - b, r, d, h, cute::UMMA::Major::K, cute::UMMA::Major::MN, - tensor_a.scalar_type(), tensor_b.scalar_type(), - tensor_d.scalar_type(), false, - device_runtime->get_num_sms()); - - const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); - const auto& tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h, - config.block_k, load_block_m, 1, - tensor_a.stride(0), tensor_a.stride(1), - config.smem_config.swizzle_a_mode); - const int& load_block_n = SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); - const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, - load_block_n, config.block_k, 1, - tensor_b.stride(1), tensor_b.stride(0), - config.smem_config.swizzle_b_mode); - const int& store_block_m = SM90ArchSpec::get_cd_store_block_m(config.block_m); - const int& store_block_n = SM90ArchSpec::get_cd_store_block_n(config.block_n); - const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h, + const auto desc = GemmDesc { + .gemm_type = GemmType::Batched, + .kernel_type = KernelType::KernelNoSF, + .m = b, .n = r, .k = d, .num_groups = h, + .a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(), + .cd_dtype = tensor_d.scalar_type(), + .major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::MN, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + const int load_block_m = config.storage_config.load_block_m; + const auto tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h, + config.layout.block_k, load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.storage_config.swizzle_a_mode); + const int load_block_n = config.storage_config.load_block_n; + const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + load_block_n, config.layout.block_k, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.storage_config.swizzle_b_mode); + const int store_block_m = config.storage_config.store_block_m; + const int store_block_n = config.storage_config.store_block_n; + const auto tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h, store_block_n, store_block_m, 1, tensor_d.stride(0), tensor_d.stride(1), - config.smem_config.swizzle_cd_mode); + config.storage_config.swizzle_cd_mode); // Launch const SM90BF16GemmRuntime::Args& args = { - .m = b, .n = r, .k = d, - .num_groups = h, - .compiled_dims = compiled_dims, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), .grouped_layout = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_cd = tensor_map_cd, }; - const auto& code = SM90BF16GemmRuntime::generate(args); - const auto& runtime = compiler->build("sm90_bf16_bhd_hdr_bhr", code); + const auto code = SM90BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm90_bf16_bhd_hdr_bhr", code); SM90BF16GemmRuntime::launch(runtime, args); } diff --git a/deep-gemm/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp b/deep-gemm/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp index 8441e997..19a1556e 100644 --- a/deep-gemm/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp +++ b/deep-gemm/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include "../../utils/torch_compat.hpp" #include "../../jit/compiler.hpp" #include "../../jit/device_runtime.hpp" @@ -84,9 +84,9 @@ static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a, // Select best number of stages int num_stages = 4, smem_size = 0; while (true) { - const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t); - const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t); - const int& smem_barrier = SM90ArchSpec::get_barrier_smem_size(num_stages); + const int smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t); + const int smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t); + const int smem_barrier = num_stages * 8 * 2; smem_size = 0; smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages; @@ -108,8 +108,8 @@ static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a, num_stages, smem_size, swizzle_ab_mode); } - const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode); - const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode); + const auto tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode); + const auto tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode); const SM90BmkBnkMnRuntime::Args& args = { .s = s, .m = m, .n = n, .k = k, @@ -123,8 +123,8 @@ static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a, .tensor_map_b = tensor_map_b, .d = d.data_ptr() }; - const auto& code = SM90BmkBnkMnRuntime::generate(args); - const auto& runtime = compiler->build("sm90_bmn_bnk_mn_gemm", code); + const auto code = SM90BmkBnkMnRuntime::generate(args); + const auto runtime = compiler->build("sm90_bmn_bnk_mn_gemm", code); SM90BmkBnkMnRuntime::launch(runtime, args); } diff --git a/deep-gemm/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp b/deep-gemm/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp index 002b3873..7d6cb5c9 100644 --- a/deep-gemm/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp +++ b/deep-gemm/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include "../../utils/torch_compat.hpp" #include "../../jit/compiler.hpp" #include "../../jit/device_runtime.hpp" @@ -15,9 +15,7 @@ namespace deep_gemm { class SM90FP8Gemm1D1DRuntime final: public LaunchRuntime { public: struct Args { - int m, n, k, num_groups; - const std::string& compiled_dims; - + GemmDesc gemm_desc; GemmConfig gemm_config; LaunchArgs launch_args; @@ -52,15 +50,17 @@ static void __instantiate_kernel() {{ >); }}; )", - get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), - args.num_groups, - args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, - args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, - args.gemm_config.num_stages, - args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, - args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, - args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type), - to_string(args.gemm_config.cd_dtype)); + get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims), + args.gemm_desc.num_groups, + args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k, + args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode, + args.gemm_config.pipeline_config.num_stages, + args.gemm_config.launch_config.num_tma_threads, args.gemm_config.launch_config.num_math_threads, + args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1, + args.gemm_config.launch_config.num_sms, to_string(args.gemm_desc.gemm_type), + to_string(args.gemm_desc.cd_dtype)); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { @@ -68,7 +68,7 @@ static void __instantiate_kernel() {{ args.gmem_a_ptr, args.gmem_b_ptr, args.grouped_layout, args.tensor_map_buffer, - args.m, args.n, args.k, + args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k, args.tensor_map_a_base, args.tensor_map_b_base, args.tensor_map_sfa, args.tensor_map_sfb, args.tensor_map_cd)); @@ -85,44 +85,48 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); - const auto& config = get_best_config( - GemmType::Normal, KernelType::Kernel1D1D, - m, n, k, 1, major_a, major_b, - a.scalar_type(), b.scalar_type(), - d.scalar_type(), c.has_value(), - device_runtime->get_num_sms()); + const auto desc = GemmDesc { + .gemm_type = GemmType::Normal, + .kernel_type = KernelType::Kernel1D1D, + .m = m, .n = n, .k = k, .num_groups = 1, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); // Requires no TMA splits - DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); - DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); - - const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, - SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), - config.block_k, k, 1, - config.smem_config.swizzle_a_mode); - const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, - SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), - config.block_k, k, 1, - config.smem_config.swizzle_b_mode); - const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, - config.block_m, config.block_k, 1, 0); - const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, - config.block_n, config.block_k, 1, 0); - const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, - SM90ArchSpec::get_cd_store_block_m(config.block_m, true), - SM90ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(-2)), 1, - 0); + DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k); + DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k); + + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, k, 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, k, 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, config.layout.block_k, 1, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.layout.block_n, config.layout.block_k, 1, 0); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + 0); // Launch const SM90FP8Gemm1D1DRuntime::Args& args = { - .m = m, .n = n, .k = k, - .num_groups = 1, - .compiled_dims = compiled_dims, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), .gmem_a_ptr = nullptr, .gmem_b_ptr = nullptr, .grouped_layout = nullptr, @@ -133,8 +137,8 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, .tensor_map_sfb = tensor_map_sfb, .tensor_map_cd = tensor_map_cd, }; - const auto& code = SM90FP8Gemm1D1DRuntime::generate(args); - const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code); + const auto code = SM90FP8Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm90_fp8_gemm_1d1d", code); SM90FP8Gemm1D1DRuntime::launch(runtime, args); } @@ -151,54 +155,61 @@ static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Te DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); - // Get config using max K for better performance - const auto& num_groups = static_cast(ks.size()); - const auto& max_k = *std::max_element(ks.begin(), ks.end()); - const auto& config = get_best_config( - GemmType::KGroupedContiguous, KernelType::Kernel1D1D, - m, n, max_k, num_groups, major_a, major_b, - a.scalar_type(), b.scalar_type(), - d.scalar_type(), c.has_value(), - device_runtime->get_num_sms()); - - // Requires no TMA splits - DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); - DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); - - int first_k = 0, sum_k = 0, sum_sf_k = 0; + // TODO: refactor with the mk alignment function + const auto num_groups = static_cast(ks.size()); + int first_k = 0, sum_k = 0, sum_sf_k = 0, max_k = 0; for (int i = 0; i < num_groups; ++ i) { if (first_k == 0 and ks[i] != 0) first_k = ks[i]; sum_k += ks[i], sum_sf_k += ceil_div(ks[i], 128); + max_k = std::max(max_k, ks[i]); DG_HOST_ASSERT(ks[i] % 128 == 0); } - const auto& tensor_map_a_base = make_tma_a_desc(major_a, a, m, first_k, - SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), - config.block_k, first_k, 1, - config.smem_config.swizzle_a_mode); - const auto& tensor_map_b_base = make_tma_b_desc(major_b, b, n, first_k, - SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), - config.block_k, first_k, 1, - config.smem_config.swizzle_b_mode); - const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 128, - config.block_m, config.block_k, 1, 0); - const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 128, - config.block_n, config.block_k, 1, 0); - const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, - SM90ArchSpec::get_cd_store_block_m(config.block_m, true), - SM90ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(-2)), num_groups, - config.smem_config.swizzle_cd_mode); + + // Get config using max K for better performance + const auto desc = GemmDesc { + .gemm_type = GemmType::KGroupedContiguous, + .kernel_type = KernelType::Kernel1D1D, + .m = m, .n = n, .k = sum_k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); + + // Requires no TMA splits + DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k); + DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k); + + const auto tensor_map_a_base = make_tma_a_desc(major_a, a, m, first_k, + config.storage_config.load_block_m, + config.layout.block_k, first_k, 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b_base = make_tma_b_desc(major_b, b, n, first_k, + config.storage_config.load_block_n, + config.layout.block_k, first_k, 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 128, + config.layout.block_m, config.layout.block_k, 1, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 128, + config.layout.block_n, config.layout.block_k, 1, 0); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), num_groups, + config.storage_config.swizzle_cd_mode); // Launch const SM90FP8Gemm1D1DRuntime::Args& args = { - .m = m, .n = n, .k = sum_k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), .gmem_a_ptr = a.data_ptr(), .gmem_b_ptr = b.data_ptr(), .grouped_layout = ks_tensor.data_ptr(), @@ -209,8 +220,8 @@ static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Te .tensor_map_sfb = tensor_map_sfb, .tensor_map_cd = tensor_map_cd, }; - const auto& code = SM90FP8Gemm1D1DRuntime::generate(args); - const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code); + const auto code = SM90FP8Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm90_fp8_gemm_1d1d", code); SM90FP8Gemm1D1DRuntime::launch(runtime, args); } diff --git a/deep-gemm/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/deep-gemm/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index b29017f8..b2fa08af 100644 --- a/deep-gemm/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/deep-gemm/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include "../../utils/torch_compat.hpp" #include "../../jit/compiler.hpp" #include "../../jit/device_runtime.hpp" @@ -17,14 +17,13 @@ namespace deep_gemm { class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime { public: struct Args { - cute::UMMA::Major major_sfb; - int m, n, k, num_groups; - const std::string& compiled_dims; - const std::optional& epilogue_type; - + GemmDesc gemm_desc; GemmConfig gemm_config; LaunchArgs launch_args; + // TODO: move this into `gemm_desc` + const std::optional& epilogue_type; + cute::UMMA::Major major_sfb; void *sfb, *grouped_layout; CUtensorMap tensor_map_a; CUtensorMap tensor_map_b; @@ -45,7 +44,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, {}, {}, - {}, {}, + {}, {}, {}, {}, {}, {}, {}, @@ -55,14 +54,16 @@ static void __instantiate_kernel() {{ )", // TODO: add CD dtype to_string(args.major_sfb), - get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), - args.num_groups, - args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, - args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, - args.gemm_config.num_stages, args.gemm_config.num_last_stages, - args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, - args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, - args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type), + get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims), + args.gemm_desc.num_groups, + args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k, + args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode, args.gemm_config.storage_config.swizzle_cd_mode, + args.gemm_config.pipeline_config.num_stages, + args.gemm_config.launch_config.num_tma_threads, args.gemm_config.launch_config.num_math_threads, + args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1, + args.gemm_config.launch_config.num_sms, to_string(args.gemm_desc.gemm_type), get_default_epilogue_type(args.epilogue_type)); } @@ -70,7 +71,7 @@ static void __instantiate_kernel() {{ // TODO: optimize `args` copy DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sfb, args.grouped_layout, - args.m, args.n, args.k, + args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.tensor_map_sfa)); } @@ -87,45 +88,49 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); - const auto& config = get_best_config( - GemmType::Normal, KernelType::Kernel1D2D, - m, n, k, 1, major_a, major_b, - a.scalar_type(), b.scalar_type(), - d.scalar_type(), c.has_value(), - device_runtime->get_num_sms()); + const auto desc = GemmDesc { + .gemm_type = GemmType::Normal, + .kernel_type = KernelType::Kernel1D2D, + .m = m, .n = n, .k = k, .num_groups = 1, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); // Requires no TMA splits - DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); - DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); - const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, - SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), - config.block_k, - static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, - config.smem_config.swizzle_a_mode); - const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, - SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), - config.block_k, - static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, - config.smem_config.swizzle_b_mode); - const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast(d.size(-1)), - SM90ArchSpec::get_cd_store_block_m(config.block_m), - SM90ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(-2)), 1, - config.smem_config.swizzle_cd_mode); - const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, - config.block_m, config.block_k, 1, 0); + DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k); + DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k); + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_d = make_tma_cd_desc(d, m, static_cast(d.size(-1)), + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, config.layout.block_k, 1, 0); // Launch const SM90FP8Gemm1D2DRuntime::Args& args = { - .major_sfb = major_sfb, - .m = m, .n = n, .k = k, - .num_groups = 1, - .compiled_dims = compiled_dims, - .epilogue_type = epilogue_type, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = epilogue_type, + .major_sfb = major_sfb, .sfb = sfb.data_ptr(), .grouped_layout = nullptr, .tensor_map_a = tensor_map_a, @@ -133,8 +138,8 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, .tensor_map_d = tensor_map_d, .tensor_map_sfa = tensor_map_sfa, }; - const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); - const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code); + const auto code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto runtime = compiler->build("sm90_fp8_gemm_1d2d", code); SM90FP8Gemm1D2DRuntime::launch(runtime, args); } @@ -144,49 +149,65 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons const torch::Tensor& m_indices, const int& num_groups, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb, - const std::string& compiled_dims) { + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); - const auto& config = get_best_config( - GemmType::MGroupedContiguous, KernelType::Kernel1D2D, - m, n, k, 1, major_a, major_b, - a.scalar_type(), b.scalar_type(), - d.scalar_type(), false, - device_runtime->get_num_sms()); + const auto gemm_type = use_psum_layout ? + GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + + // Only psum layout can use expected m + if (expected_m_for_psum_layout) + DG_HOST_ASSERT(use_psum_layout); + + const auto desc = GemmDesc { + .gemm_type = gemm_type, + .kernel_type = KernelType::Kernel1D2D, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = expected_m_for_psum_layout.value_or(m), + .expected_n = n, .expected_k = k, + .expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1 + }; + const auto config = get_best_config(desc); // Requires no TMA splits - DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); - DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); - const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, - SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), - config.block_k, - static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, - config.smem_config.swizzle_a_mode); - const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, - SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), - config.block_k, - static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, - config.smem_config.swizzle_b_mode); - const auto& tensor_map_d = make_tma_cd_desc(d, m, n, - SM90ArchSpec::get_cd_store_block_m(config.block_m), - SM90ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(-2)), 1, - config.smem_config.swizzle_cd_mode); - const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, - config.block_m, config.block_k, 1, 0); + DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k); + DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k); + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_d = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, config.layout.block_k, 1, 0); // Launch const SM90FP8Gemm1D2DRuntime::Args& args = { - .major_sfb = major_sfb, - .m = m, .n = n, .k = k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, - .epilogue_type = std::nullopt, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = std::nullopt, + .major_sfb = major_sfb, .sfb = sfb.data_ptr(), .grouped_layout = m_indices.data_ptr(), .tensor_map_a = tensor_map_a, @@ -194,8 +215,8 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons .tensor_map_d = tensor_map_d, .tensor_map_sfa = tensor_map_sfa, }; - const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); - const auto& runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code); + const auto code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code); SM90FP8Gemm1D2DRuntime::launch(runtime, args); } @@ -210,45 +231,50 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); - const auto& config = get_best_config( - GemmType::MGroupedMasked, KernelType::Kernel1D2D, - expected_m, n, k, num_groups, major_a, major_b, - a.scalar_type(), b.scalar_type(), - d.scalar_type(), false, - device_runtime->get_num_sms()); + const auto desc = GemmDesc { + .gemm_type = GemmType::MGroupedMasked, + .kernel_type = KernelType::Kernel1D2D, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = expected_m, .expected_n = n, .expected_k = k, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); // Requires no TMA splits - DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); - DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); - const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, - SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), - config.block_k, - static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, - config.smem_config.swizzle_a_mode); - const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, - SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), - config.block_k, - static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, - config.smem_config.swizzle_b_mode); - const auto& tensor_map_d = make_tma_cd_desc(d, m, n, - SM90ArchSpec::get_cd_store_block_m(config.block_m), - SM90ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(-2)), num_groups, - config.smem_config.swizzle_cd_mode); - const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, - config.block_m, config.block_k, num_groups, 0); + DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k); + DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k); + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_d = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), num_groups, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, config.layout.block_k, num_groups, 0); // Launch const SM90FP8Gemm1D2DRuntime::Args& args = { - .major_sfb = major_sfb, - .m = m, .n = n, .k = k, - .num_groups = num_groups, - .compiled_dims = compiled_dims, - .epilogue_type = std::nullopt, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = std::nullopt, + .major_sfb = major_sfb, .sfb = sfb.data_ptr(), .grouped_layout = masked_m.data_ptr(), .tensor_map_a = tensor_map_a, @@ -256,8 +282,8 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to .tensor_map_d = tensor_map_d, .tensor_map_sfa = tensor_map_sfa, }; - const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); - const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code); + const auto code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code); SM90FP8Gemm1D2DRuntime::launch(runtime, args); } @@ -271,51 +297,55 @@ static void sm90_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); - const auto& config = get_best_config( - GemmType::Batched, KernelType::Kernel1D2D, - m, n, k, batch_size, major_a, major_b, - a.scalar_type(), b.scalar_type(), - d.scalar_type(), c.has_value(), - device_runtime->get_num_sms()); + const auto desc = GemmDesc { + .gemm_type = GemmType::Batched, + .kernel_type = KernelType::Kernel1D2D, + .m = m, .n = n, .k = k, .num_groups = batch_size, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); // Requires no TMA splits - DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); - DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); - const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); - const auto& tensor_map_a = make_tma_3d_desc(a, k, m, batch_size, - config.block_k, load_block_m, 1, - a.stride(1), - a.stride(0), - config.smem_config.swizzle_a_mode); - - const int& load_block_n = SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); - const auto& tensor_map_b = make_tma_3d_desc(b, k, n, batch_size, - config.block_k, load_block_n, 1, - b.stride(1), - b.stride(0), - config.smem_config.swizzle_b_mode); - - const int& store_block_m = SM90ArchSpec::get_cd_store_block_m(config.block_m); - const int& store_block_n = SM90ArchSpec::get_cd_store_block_n(config.block_n); - const auto& tensor_map_d = make_tma_3d_desc(d, n, m, batch_size, - store_block_n, store_block_m, 1, - d.stride(1), d.stride(0), - config.smem_config.swizzle_cd_mode); - - const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, - config.block_m, config.block_k, batch_size, 0); + DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k); + DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k); + const int load_block_m = config.storage_config.load_block_m; + const auto tensor_map_a = make_tma_3d_desc(a, k, m, batch_size, + config.layout.block_k, load_block_m, 1, + a.stride(1), + a.stride(0), + config.storage_config.swizzle_a_mode); + + const int load_block_n = config.storage_config.load_block_n; + const auto tensor_map_b = make_tma_3d_desc(b, k, n, batch_size, + config.layout.block_k, load_block_n, 1, + b.stride(1), + b.stride(0), + config.storage_config.swizzle_b_mode); + + const int store_block_m = config.storage_config.store_block_m; + const int store_block_n = config.storage_config.store_block_n; + const auto tensor_map_d = make_tma_3d_desc(d, n, m, batch_size, + store_block_n, store_block_m, 1, + d.stride(1), d.stride(0), + config.storage_config.swizzle_cd_mode); + + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, config.layout.block_k, batch_size, 0); // Launch const SM90FP8Gemm1D2DRuntime::Args& args = { - .major_sfb = major_sfb, - .m = m, .n = n, .k = k, - .num_groups = batch_size, - .compiled_dims = compiled_dims, - .epilogue_type = std::nullopt, + .gemm_desc = desc, .gemm_config = config, - .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, - config.smem_config.smem_size, - config.multicast_config.num_multicast), + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = std::nullopt, + .major_sfb = major_sfb, .sfb = sfb.data_ptr(), .grouped_layout = nullptr, .tensor_map_a = tensor_map_a, @@ -323,8 +353,8 @@ static void sm90_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, .tensor_map_d = tensor_map_d, .tensor_map_sfa = tensor_map_sfa, }; - const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); - const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code); + const auto code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto runtime = compiler->build("sm90_fp8_gemm_1d2d", code); SM90FP8Gemm1D2DRuntime::launch(runtime, args); } diff --git a/deep-gemm/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp b/deep-gemm/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp index 63a47c32..4a10d697 100644 --- a/deep-gemm/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp +++ b/deep-gemm/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include "../../utils/torch_compat.hpp" #include "../../jit/compiler.hpp" #include "../../jit/device_runtime.hpp" @@ -81,21 +81,21 @@ static void sm90_tf32_hc_prenorm_gemm(const torch::Tensor& a, DG_HOST_ASSERT(n <= 32 and n % 8 == 0); DG_HOST_ASSERT(k % block_k == 0); - const auto& swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float)); - const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k, - block_m, block_k, - static_cast(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, - get_swizzle_mode(block_k, a.element_size()), 0, - true); - const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k, - block_n, block_k, - static_cast(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, - get_swizzle_mode(block_k, b.element_size()), 0, - true); - const auto& tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n, - block_m, block_n, - static_cast(d.stride(-2)), 1, - swizzle_cd_mode) + const auto swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float)); + const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k, + block_m, block_k, + static_cast(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, a.element_size()), 0, + true); + const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k, + block_n, block_k, + static_cast(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, b.element_size()), 0, + true); + const auto tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n, + block_m, block_n, + static_cast(d.stride(-2)), 1, + swizzle_cd_mode) : make_tma_3d_desc(d, n, m, num_splits, block_n, block_m, 1, static_cast(d.stride(-2)), @@ -138,14 +138,14 @@ static void sm90_tf32_hc_prenorm_gemm(const torch::Tensor& a, .num_stages = num_stages, .num_math_threads = num_math_threads, .num_tma_threads = num_tma_threads, - .launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_threads, smem_size, 1), + .launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_threads, smem_size), .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d, .sqr_sum = sqr_sum.data_ptr() }; - const auto& code = SM90BF16HCPrenormGemmRuntime::generate(args); - const auto& runtime = compiler->build("sm90_tf32_hc_prenorm_gemm", code); + const auto code = SM90BF16HCPrenormGemmRuntime::generate(args); + const auto runtime = compiler->build("sm90_tf32_hc_prenorm_gemm", code); SM90BF16HCPrenormGemmRuntime::launch(runtime, args); } diff --git a/deep-gemm/csrc/jit_kernels/impls/smxx_clean_logits.hpp b/deep-gemm/csrc/jit_kernels/impls/smxx_clean_logits.hpp index fdb91a03..ebe4c7a6 100644 --- a/deep-gemm/csrc/jit_kernels/impls/smxx_clean_logits.hpp +++ b/deep-gemm/csrc/jit_kernels/impls/smxx_clean_logits.hpp @@ -17,7 +17,8 @@ class SMXXCleanLogitsRuntime final: public LaunchRuntime int* cu_seq_len_k_start; int* cu_seq_len_k_end; - float* logits; + void* logits; + at::ScalarType logits_dtype; int block_kv; int num_warps; @@ -33,10 +34,10 @@ using namespace deep_gemm; static void __instantiate_kernel() {{ auto ptr = reinterpret_cast(&smxx_clean_logits< - {}, {}, {} + {}, {}, {}, {} >); }}; -)", args.next_n, args.block_kv, args.num_warps); +)", args.next_n, args.block_kv, args.num_warps, to_string(args.logits_dtype)); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { @@ -65,14 +66,15 @@ static void smxx_clean_logits(const torch::Tensor& logits, .stride_logits = stride_logits, .cu_seq_len_k_start = cu_seq_len_k_start.has_value() ? cu_seq_len_k_start.value().data_ptr() : nullptr, .cu_seq_len_k_end = cu_seq_len_k_end.data_ptr(), - .logits = logits.data_ptr(), + .logits = logits.data_ptr(), + .logits_dtype = logits.scalar_type(), .block_kv = block_kv, .num_warps = num_warps, .launch_args = LaunchArgs(device_runtime->get_num_sms(), num_warps * 32, smem_size) }; - const auto& code = SMXXCleanLogitsRuntime::generate(args); - const auto& runtime = compiler->build("smxx_clean_logits", code); + const auto code = SMXXCleanLogitsRuntime::generate(args); + const auto runtime = compiler->build("smxx_clean_logits", code); SMXXCleanLogitsRuntime::launch(runtime, args); } diff --git a/deep-gemm/csrc/jit_kernels/impls/smxx_cublaslt.hpp b/deep-gemm/csrc/jit_kernels/impls/smxx_cublaslt.hpp index dc20e334..7f29b0a5 100644 --- a/deep-gemm/csrc/jit_kernels/impls/smxx_cublaslt.hpp +++ b/deep-gemm/csrc/jit_kernels/impls/smxx_cublaslt.hpp @@ -46,7 +46,7 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a, DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); #if DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE - const int& math_sms = device_runtime->get_num_sms(); + const int math_sms = device_runtime->get_num_sms(); DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms))); #endif @@ -57,10 +57,10 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a, #endif // Get cuBLASLt handle, workspace, and stream - const auto& handle = device_runtime->get_cublaslt_handle(); - const auto& workspace = device_runtime->get_cublaslt_workspace(); - const auto& workspace_bytes = workspace.nbytes(); - const auto& stream = at::cuda::getCurrentCUDAStream(); + const auto handle = device_runtime->get_cublaslt_handle(); + const auto workspace = device_runtime->get_cublaslt_workspace(); + const auto workspace_bytes = workspace.nbytes(); + const auto stream = at::cuda::getCurrentCUDAStream(); // Algorithm selection cublasLtMatmulPreference_t pref; @@ -77,7 +77,7 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a, DG_HOST_ASSERT(num_heuristic_results == 1 and "Unable to find any algorithm for the GEMM"); // Call: D = alpha * (A @ B) + beta * C - const float& alpha = 1.0, beta = accumulate ? 1.0 : 0.0; + const float alpha = 1.0, beta = accumulate ? 1.0 : 0.0; DG_CUBLASLT_CHECK(cublasLtMatmul(handle, // Light handle desc, // Operation description &alpha, // Alpha @@ -99,47 +99,36 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a, } static void cublaslt_gemm(const torch::Tensor& lhs, const torch::Tensor& rhs, - const std::optional& acc, const torch::Tensor& out, const int& m, const int& n, const int& k, - const cute::UMMA::Major& a_major, const cute::UMMA::Major& b_major) { - const auto& trans_a = b_major == cute::UMMA::Major::K ? CUBLAS_OP_T : CUBLAS_OP_N; - const auto& trans_b = a_major == cute::UMMA::Major::K ? CUBLAS_OP_N : CUBLAS_OP_T; - - // Duplicate the accumulator if necessary - // TODO: remove this - if (acc.has_value()) { - if (acc->data_ptr() == out.data_ptr()) { - DG_HOST_ASSERT(acc->sizes() == out.sizes() and acc->strides() == out.strides()); - } else { - out.copy_(acc.value()); - } - } + const cute::UMMA::Major& a_major, const cute::UMMA::Major& b_major, + const bool& accumulate) { + const auto trans_a = b_major == cute::UMMA::Major::K ? CUBLAS_OP_T : CUBLAS_OP_N; + const auto trans_b = a_major == cute::UMMA::Major::K ? CUBLAS_OP_N : CUBLAS_OP_T; // Matrix layouts - const auto& cuda_type_a = at::cuda::ScalarTypeToCudaDataType(rhs.scalar_type()); - const auto& cuda_type_b = at::cuda::ScalarTypeToCudaDataType(lhs.scalar_type()); - const auto& cuda_type_d = at::cuda::ScalarTypeToCudaDataType(out.scalar_type()); - const auto& layout_a = b_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_a, k, n, rhs.stride(0)) - : get_cublaslt_layout(cuda_type_a, n, k, rhs.stride(1)); - const auto& layout_b = a_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_b, k, m, lhs.stride(0)) - : get_cublaslt_layout(cuda_type_b, m, k, lhs.stride(1)); - const auto& layout_d = get_cublaslt_layout(cuda_type_d, n, m, out.stride(0)); - - call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, acc.has_value()); + const auto cuda_type_a = at::cuda::ScalarTypeToCudaDataType(rhs.scalar_type()); + const auto cuda_type_b = at::cuda::ScalarTypeToCudaDataType(lhs.scalar_type()); + const auto cuda_type_d = at::cuda::ScalarTypeToCudaDataType(out.scalar_type()); + const auto layout_a = b_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_a, k, n, rhs.stride(0)) + : get_cublaslt_layout(cuda_type_a, n, k, rhs.stride(1)); + const auto layout_b = a_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_b, k, m, lhs.stride(0)) + : get_cublaslt_layout(cuda_type_b, m, k, lhs.stride(1)); + const auto layout_d = get_cublaslt_layout(cuda_type_d, n, m, out.stride(0)); + + call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, accumulate); } - static void cublaslt_bhr_hdr_bhd(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out, const int& b, const int& h, const int& r, const int& d) { - const auto& m = d, n = b, k = r; - const auto& trans_a = CUBLAS_OP_T; - const auto& trans_b = CUBLAS_OP_N; + const auto m = d, n = b, k = r; + const auto trans_a = CUBLAS_OP_T; + const auto trans_b = CUBLAS_OP_N; // Matrix layouts - const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, k, m, rhs.stride(1), h, rhs.stride(0)); - const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1)); - const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1)); + const auto layout_a = get_cublaslt_layout(CUDA_R_16BF, k, m, rhs.stride(1), h, rhs.stride(0)); + const auto layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1)); + const auto layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1)); call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false); } @@ -147,14 +136,14 @@ static void cublaslt_bhr_hdr_bhd(const torch::Tensor& lhs, const torch::Tensor& static void cublaslt_bhd_hdr_bhr(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out, const int& b, const int& h, const int& r, const int& d) { - const auto& m = r, n = b, k = d; - const auto& trans_a = CUBLAS_OP_N; - const auto& trans_b = CUBLAS_OP_N; + const auto m = r, n = b, k = d; + const auto trans_a = CUBLAS_OP_N; + const auto trans_b = CUBLAS_OP_N; // Matrix layouts - const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, m, k, rhs.stride(1), h, rhs.stride(0)); - const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1)); - const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1)); + const auto layout_a = get_cublaslt_layout(CUDA_R_16BF, m, k, rhs.stride(1), h, rhs.stride(0)); + const auto layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1)); + const auto layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1)); call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false); } diff --git a/deep-gemm/csrc/jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp b/deep-gemm/csrc/jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp new file mode 100644 index 00000000..3be10c98 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp @@ -0,0 +1,328 @@ +#pragma once + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../heuristics/sm90.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SMXXFP8MQALogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int seq_len; + int seq_len_kv; + int max_seqlen_k; + int stride_logits; + int num_heads, head_dim; + bool is_compressed_logits; + + int num_q_stages; + int num_kv_stages; + int block_q; + int block_kv; + + int* cu_seq_len_k_start; + int* cu_seq_len_k_end; + void* logits; + + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_kv_scales; + CUtensorMap tensor_map_weights; + at::ScalarType logits_dtype; + + int num_specialized_threads; + int num_math_threads; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + // TODO: optimize performance by tuning args + // Block sizes are fixed in this kernel + DG_HOST_ASSERT(128 % args.num_heads == 0); + const auto arch = device_runtime->get_arch(true); + + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm{}_fp8_mqa_logits< + {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, {}, + {} + >); +}}; +)", arch, arch, + args.num_heads, args.head_dim, + args.is_compressed_logits, + args.block_q, args.block_kv, + args.num_q_stages, args.num_kv_stages, + args.launch_args.grid_dim.first, + args.num_specialized_threads, args.num_math_threads, + to_string(args.logits_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.seq_len, args.seq_len_kv, + args.max_seqlen_k, args.stride_logits, + args.cu_seq_len_k_start, args.cu_seq_len_k_end, + args.logits, + args.tensor_map_q, args.tensor_map_kv, + args.tensor_map_kv_scales, args.tensor_map_weights + )); + } +}; + +static void smxx_fp8_mqa_logits(const torch::Tensor& q, + const torch::Tensor& kv, const torch::Tensor& kv_scales, + const torch::Tensor& weights, + const torch::Tensor& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const torch::Tensor& logits, + const at::ScalarType& logits_dtype, + const int& seq_len, const int& seq_len_kv, + const int& max_seqlen_k, const int& stride_logits, + const int& num_heads, const int& head_dim, + const int& block_q, const int& block_kv) { + constexpr int num_specialized_threads = 128; + constexpr int num_q_stages = 3, num_kv_stages = 3; + const int num_math_threads = (device_runtime->get_arch_major() == 10 ? 256 : 512); + + // Use compressed logits format when max_seqlen_k is specified + const bool is_compressed_logits = (max_seqlen_k > 0); + + // Construct TMAs + DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); + const auto tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads, + head_dim, block_q * num_heads, head_dim, head_dim); + const auto tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv, + head_dim, block_kv, head_dim, head_dim); + // According to the driver API, the minimal alignment is 256 bytes + // So it is safe for us to do a 16-byte OOB + const auto tensor_map_kv_scales = make_tma_2d_desc(kv_scales, + get_tma_aligned_size(seq_len_kv, static_cast(kv_scales.element_size())), + 1, block_kv, 1, 0, 0); + const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len, + num_heads, block_q, num_heads, 0); + + // Calculate shared memory size + int smem_size = 0; + const int smem_q_size_per_stage = block_q * num_heads * head_dim * static_cast(q.element_size()); + const int smem_weight_size_per_stage = block_q * num_heads * static_cast(weights.element_size()); + const int smem_kv_size_per_stage = block_kv * head_dim * static_cast(kv.element_size()); + const int kv_scale_size_per_stage = block_kv * static_cast(kv_scales.element_size()); + smem_size += num_q_stages * smem_q_size_per_stage; + smem_size += num_kv_stages * smem_kv_size_per_stage; + smem_size += num_q_stages * smem_weight_size_per_stage; + smem_size += num_kv_stages * kv_scale_size_per_stage; + smem_size += (num_q_stages * 2 + num_kv_stages * 2 + (num_math_threads / 128) * 2) * 8; + smem_size += 4; + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + + // Launch + const SMXXFP8MQALogitsRuntime::Args args = { + .seq_len = seq_len, + .seq_len_kv = seq_len_kv, + .max_seqlen_k = max_seqlen_k, + .stride_logits = stride_logits, + .num_heads = num_heads, .head_dim = head_dim, + .is_compressed_logits = is_compressed_logits, + .num_q_stages = num_q_stages, + .num_kv_stages = num_kv_stages, + .block_q = block_q, + .block_kv = block_kv, + .cu_seq_len_k_start = cu_seq_len_k_start.data_ptr(), + .cu_seq_len_k_end = cu_seq_len_k_end.data_ptr(), + .logits = logits.data_ptr(), + .tensor_map_q = tensor_map_q, + .tensor_map_kv = tensor_map_kv, + .tensor_map_kv_scales = tensor_map_kv_scales, + .tensor_map_weights = tensor_map_weights, + .logits_dtype = logits_dtype, + .num_specialized_threads = num_specialized_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(device_runtime->get_num_sms(), + num_specialized_threads + num_math_threads, + smem_size) + }; + const auto code = SMXXFP8MQALogitsRuntime::generate(args); + const auto runtime = compiler->build("smxx_fp8_mqa_logits", code); + SMXXFP8MQALogitsRuntime::launch(runtime, args); +} + +class SM100FP4MQALogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int seq_len; + int seq_len_kv; + int max_seqlen_k; + int stride_logits; + int num_heads, head_dim; + bool is_compressed_logits; + + int num_q_stages; + int num_kv_stages; + int block_q; + int block_kv; + + int* cu_seq_len_k_start; + int* cu_seq_len_k_end; + void* logits; + + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_sf_q; + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_sf_kv; + CUtensorMap tensor_map_weights; + at::ScalarType logits_dtype; + + int num_specialized_threads; + int num_math_threads; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + // TODO: optimize performance by tuning args + // Block sizes are fixed in this kernel + DG_HOST_ASSERT(128 % args.num_heads == 0); + const auto arch = device_runtime->get_arch(true); + + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp4_mqa_logits< + {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, {}, + {} + >); +}}; +)", args.num_heads, args.head_dim, + args.is_compressed_logits, + args.block_q, args.block_kv, + args.num_q_stages, args.num_kv_stages, + args.launch_args.grid_dim.first, + args.num_specialized_threads, args.num_math_threads, + to_string(args.logits_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.seq_len, args.seq_len_kv, + args.max_seqlen_k, args.stride_logits, + args.cu_seq_len_k_start, args.cu_seq_len_k_end, + args.logits, + args.tensor_map_q, args.tensor_map_sf_q, + args.tensor_map_kv, args.tensor_map_sf_kv, + args.tensor_map_weights + )); + } +}; + +static void sm100_fp4_mqa_logits(const torch::Tensor& q, const torch::Tensor& sf_q, + const torch::Tensor& kv, const torch::Tensor& sf_kv, + const torch::Tensor& weights, + const torch::Tensor& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const torch::Tensor& logits, + const at::ScalarType& logits_dtype, + const int& seq_len, const int& seq_len_kv, + const int& max_seqlen_k, const int& stride_logits, + const int& num_heads, const int& head_dim, + const int& block_q, const int& block_kv) { + constexpr int num_specialized_threads = 128; + const int num_math_threads = 2 * 128; + constexpr int num_q_stages = 3, num_kv_stages = 6, num_tmem_stages = 3; + + // Use compressed logits format when max_seqlen_k is specified + const bool is_compressed_logits = (max_seqlen_k > 0); + + // Construct TMAs + // `head_dim` must be 128 for 64B swizzling + DG_HOST_ASSERT(head_dim == 128); + const auto tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads, + head_dim, block_q * num_heads, + static_cast(q.stride(1)), + head_dim / 2, 0, false, false); + const auto tensor_map_sf_q = make_tma_2d_desc(sf_q, num_heads, seq_len, + num_heads, block_q, + static_cast(sf_q.stride(0)), 0); + const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len, + num_heads, block_q, + static_cast(weights.stride(0)), 0); + const auto tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv, + head_dim, block_kv, + static_cast(kv.stride(0)), + head_dim / 2, 0, false, false); + // According to the driver API, the minimal alignment is 256 bytes + // So it is safe for us to do a 16-byte OOB + const auto tensor_map_sf_kv = make_tma_2d_desc(sf_kv, + get_tma_aligned_size(seq_len_kv, static_cast(sf_kv.element_size())), 1, + block_kv, 1, 0, 0); + + // Calculate shared memory size + const int smem_q_size_per_stage = block_q * num_heads * head_dim / 2; + const int smem_sf_q_size_per_stage = align(block_q * num_heads, 128) * sizeof(int); + const int smem_kv_size_per_stage = block_kv * head_dim / 2; + const int smem_sf_kv_size_per_stage = align(block_kv, 128) * sizeof(int); + const int smem_weight_size_per_stage = block_q * num_heads * sizeof(float); + + const int smem_barriers = (num_q_stages + num_kv_stages + num_tmem_stages) * 2 * 8; + const int smem_tmem_ptr = 4; + const int smem_size = num_q_stages * (smem_q_size_per_stage + smem_sf_q_size_per_stage + smem_weight_size_per_stage) + + num_kv_stages * (smem_kv_size_per_stage + smem_sf_kv_size_per_stage) + + smem_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + + // Launch + const SM100FP4MQALogitsRuntime::Args args = { + .seq_len = seq_len, + .seq_len_kv = seq_len_kv, + .max_seqlen_k = max_seqlen_k, + .stride_logits = stride_logits, + .num_heads = num_heads, .head_dim = head_dim, + .is_compressed_logits = is_compressed_logits, + .num_q_stages = num_q_stages, + .num_kv_stages = num_kv_stages, + .block_q = block_q, + .block_kv = block_kv, + .cu_seq_len_k_start = cu_seq_len_k_start.data_ptr(), + .cu_seq_len_k_end = cu_seq_len_k_end.data_ptr(), + .logits = logits.data_ptr(), + .tensor_map_q = tensor_map_q, + .tensor_map_sf_q = tensor_map_sf_q, + .tensor_map_kv = tensor_map_kv, + .tensor_map_sf_kv = tensor_map_sf_kv, + .tensor_map_weights = tensor_map_weights, + .logits_dtype = logits_dtype, + .num_specialized_threads = num_specialized_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(device_runtime->get_num_sms(), + num_specialized_threads + num_math_threads, + smem_size) + }; + const auto code = SM100FP4MQALogitsRuntime::generate(args); + const auto runtime = compiler->build("sm100_fp4_mqa_logits", code); + SM100FP4MQALogitsRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp b/deep-gemm/csrc/jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp new file mode 100644 index 00000000..2a3288ee --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp @@ -0,0 +1,463 @@ +#pragma once + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SMXXPagedMQALogitsMetadataRuntime final: public LaunchRuntime { +public: + struct Args { + int aligned_batch_size; + int split_kv; + int num_sms; + bool is_varlen; + + int batch_size; + int next_n; + bool is_context_lens_2d; + int* context_lens; + int* indices; + int* schedule_metadata; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sched::smxx_paged_mqa_logits_metadata< + {}, {}, {}, {} + >); +}}; +)", args.aligned_batch_size, args.split_kv, args.num_sms, args.is_varlen ? "true" : "false"); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.batch_size, + args.next_n, + args.is_context_lens_2d, + args.context_lens, + args.indices, + args.schedule_metadata + )); + } +}; + +static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens, + const torch::Tensor& schedule_metadata, + const int& batch_size, const int& next_n, + const int& block_kv, const int& num_sms, + const bool& is_context_lens_2d, + const bool& is_varlen, const int* indices_ptr) { + constexpr int split_kv = 256; + constexpr int num_threads = 32; + const int aligned_batch_size = align(batch_size, 32); + DG_HOST_ASSERT(split_kv % block_kv == 0); + + // Shared memory: prefix_sum[kAlignedBatchSize] + varlen_atom_token_start/context_len[kAlignedBatchSize] + varlen_num_atoms + const int smem_size = (3 * aligned_batch_size + 1) * static_cast(sizeof(int)); + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + + // Launch + const SMXXPagedMQALogitsMetadataRuntime::Args& args = { + .aligned_batch_size = aligned_batch_size, + .split_kv = split_kv, + .num_sms = num_sms, + .is_varlen = is_varlen, + .batch_size = batch_size, + .next_n = next_n, + .is_context_lens_2d = is_context_lens_2d, + .context_lens = context_lens.data_ptr(), + .indices = const_cast(indices_ptr), + .schedule_metadata = schedule_metadata.data_ptr(), + .launch_args = LaunchArgs(1, num_threads, smem_size) + }; + const auto code = SMXXPagedMQALogitsMetadataRuntime::generate(args); + const auto runtime = compiler->build("smxx_paged_mqa_logits_metadata", code); + SMXXPagedMQALogitsMetadataRuntime::launch(runtime, args); +} + +class SMXXFP8PagedMQALogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int batch_size; + int next_n; + int num_heads; + int head_dim; + int block_kv; + bool is_context_lens_2d; + bool is_varlen; + int block_table_stride; + int logits_stride; + + int num_q_stages; + int num_kv_stages; + int split_kv; + + int* context_lens; + void* logits; + int* block_table; + int* indices; + int* schedule_meta; + + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_kv_scales; + CUtensorMap tensor_map_weights; + at::ScalarType logits_dtype; + + int num_specialized_threads; + int num_math_threads; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + // TODO: optimize performance by tuning args + // Block sizes are fixed in this kernel + DG_HOST_ASSERT(128 % args.num_heads == 0); + const auto arch = device_runtime->get_arch(true); + + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm{}_fp8_paged_mqa_logits< + {}, {}, + {}, {}, + {}, {}, + {}, {}, + {}, + {}, {}, + {} + >); +}}; +)", arch, arch, + args.next_n, args.num_heads, + args.head_dim, args.block_kv, + args.is_context_lens_2d, args.is_varlen ? "true" : "false", + args.num_q_stages, args.num_kv_stages, + args.split_kv, + args.num_specialized_threads, args.num_math_threads, + to_string(args.logits_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.batch_size, + args.logits_stride, args.block_table_stride, + args.context_lens, args.logits, + args.block_table, args.indices, args.schedule_meta, + args.tensor_map_q, args.tensor_map_kv, + args.tensor_map_kv_scales, args.tensor_map_weights + )); + } +}; + +static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, + const torch::Tensor& kv_cache, + const torch::Tensor& kv_cache_scales, + const torch::Tensor& weights, + const torch::Tensor& context_lens, + const torch::Tensor& logits, + const torch::Tensor& block_table, + const torch::Tensor& indices, + const torch::Tensor& schedule_meta, + const at::ScalarType& logits_dtype, + const int& batch_size, const int& next_n, + const int& num_heads, const int& head_dim, + const int& num_kv_blocks, const int& block_kv, + const bool& is_context_lens_2d, + const bool& is_varlen, + const int& logits_stride, + const int& block_table_stride, + const int& num_sms, + const int& split_kv) { + const int num_specialized_threads = 128; + const int mma_m = (device_runtime->get_arch_major() == 10 ? 128 : 64); + const int num_math_warp_groups = split_kv / mma_m; + const int num_math_threads = num_math_warp_groups * 128; + const int num_q_stages = 3, num_kv_stages = (device_runtime->get_arch_major() == 10 ? 4 : 3); + DG_HOST_ASSERT(split_kv % mma_m == 0 and logits_stride % split_kv == 0); + + // Construct TMAs + const int next_n_atom = (is_varlen or next_n >= 2) ? 2 : 1; + const auto tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads, + head_dim, next_n_atom * num_heads, + static_cast(q.stride(2)), + head_dim); + const auto tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks, + head_dim, block_kv, 1, + static_cast(kv_cache.stride(1)), + static_cast(kv_cache.stride(0)), + head_dim); + + const auto tensor_map_kv_scales = make_tma_2d_desc(kv_cache_scales, block_kv, num_kv_blocks, + block_kv, 1, + static_cast(kv_cache_scales.stride(0)), 0); + const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, batch_size * next_n, + num_heads, next_n_atom, + static_cast(weights.stride(0)), 0); + + // Calculate shared memory size + int smem_size = 0; + if (device_runtime->get_arch_major() == 9) { + const int swizzle_alignment = head_dim * 8; + + const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast(q.element_size()); + const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast(weights.element_size()), swizzle_alignment); + const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment); + + const int smem_kv_size_per_stage = block_kv * head_dim * static_cast(kv_cache.element_size()); + const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast(kv_cache_scales.element_size()), swizzle_alignment); + const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment); + + // Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90 + const int smem_umma_barriers = num_math_warp_groups * 2 * 8; + const int smem_tmem_ptr = 4; + + smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + DG_HOST_ASSERT(next_n == 1 or next_n == 2); + } else { + const int smem_q_size_per_stage = next_n_atom * num_heads * head_dim * static_cast(q.element_size()); + const int smem_kv_size_per_stage = split_kv * head_dim * static_cast(kv_cache.element_size()); + const int smem_kv_scale_size_per_stage = split_kv * static_cast(kv_cache_scales.element_size()); + const int smem_weight_size_per_stage = next_n_atom * num_heads * static_cast(weights.element_size()); + + const int smem_barriers = (num_q_stages + num_kv_stages) * 2 * 8; + const int smem_umma_barriers = num_math_warp_groups * 2 * 8; + const int smem_tmem_ptr = 4; + + smem_size = num_q_stages * (smem_q_size_per_stage + smem_weight_size_per_stage) + + num_kv_stages * (smem_kv_size_per_stage + smem_kv_scale_size_per_stage) + + smem_barriers + smem_umma_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + } + + // Launch + const SMXXFP8PagedMQALogitsRuntime::Args args = { + .batch_size = batch_size, + .next_n = next_n, + .num_heads = num_heads, + .head_dim = head_dim, + .block_kv = block_kv, + .is_context_lens_2d = is_context_lens_2d, + .is_varlen = is_varlen, + .block_table_stride = block_table_stride, + .logits_stride = logits_stride, + .num_q_stages = num_q_stages, + .num_kv_stages = num_kv_stages, + .split_kv = split_kv, + .context_lens = context_lens.data_ptr(), + .logits = logits.data_ptr(), + .block_table = block_table.data_ptr(), + .indices = is_varlen ? indices.data_ptr() : nullptr, + .schedule_meta = schedule_meta.data_ptr(), + .tensor_map_q = tensor_map_q, + .tensor_map_kv = tensor_map_kv, + .tensor_map_kv_scales = tensor_map_kv_scales, + .tensor_map_weights = tensor_map_weights, + .logits_dtype = logits_dtype, + .num_specialized_threads = num_specialized_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(num_sms, + num_specialized_threads + num_math_threads, + smem_size) + }; + const auto code = SMXXFP8PagedMQALogitsRuntime::generate(args); + const auto runtime = compiler->build("smxx_fp8_paged_mqa_logits", code); + SMXXFP8PagedMQALogitsRuntime::launch(runtime, args); +} + +class SM100FP4PagedMQALogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int batch_size; + int next_n; + int num_heads; + int head_dim; + int block_kv; + bool is_context_lens_2d; + bool is_varlen; + int block_table_stride; + int logits_stride; + + int num_q_stages; + int num_kv_stages; + int split_kv; + + int* context_lens; + void* logits; + int* block_table; + int* indices; + int* schedule_meta; + + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_sf_q; + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_sf_kv; + CUtensorMap tensor_map_weights; + at::ScalarType logits_dtype; + + int num_specialized_threads; + int num_math_threads; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp4_paged_mqa_logits< + {}, {}, + {}, {}, + {}, {}, + {}, {}, + {}, + {}, {}, + {} + >); +}}; +)", args.next_n, args.num_heads, + args.head_dim, args.block_kv, + args.is_context_lens_2d, args.is_varlen ? "true" : "false", + args.num_q_stages, args.num_kv_stages, + args.split_kv, + args.num_specialized_threads, args.num_math_threads, + to_string(args.logits_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.batch_size, + args.logits_stride, args.block_table_stride, + args.context_lens, args.logits, + args.block_table, args.indices, args.schedule_meta, + args.tensor_map_q, args.tensor_map_sf_q, + args.tensor_map_kv, args.tensor_map_sf_kv, + args.tensor_map_weights + )); + } +}; + +static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q, + const torch::Tensor& sf_q, + const torch::Tensor& kv_cache, + const torch::Tensor& kv_cache_sf, + const torch::Tensor& weights, + const torch::Tensor& context_lens, + const torch::Tensor& logits, + const torch::Tensor& block_table, + const torch::Tensor& indices, + const torch::Tensor& schedule_meta, + const at::ScalarType& logits_dtype, + const int& batch_size, const int& next_n, + const int& num_heads, const int& head_dim, + const int& num_kv_blocks, const int& block_kv, + const bool& is_context_lens_2d, + const bool& is_varlen, + const int& logits_stride, + const int& block_table_stride, + const int& num_sms, + const int& split_kv) { + const int num_specialized_threads = 128; + const int num_math_threads = 2 * 128; + DG_HOST_ASSERT(split_kv == 256 and logits_stride % split_kv == 0); + + // TODO: tuning num_stages + const int num_q_stages = 3, num_kv_stages = 10, num_tmem_stages = 3; + const int next_n_atom = (is_varlen or next_n >= 2) ? 2 : 1; + + // `head_dim` must be 128 for 64B swizzling + DG_HOST_ASSERT(head_dim == 128); + + // Using 2D TMA as tensor q is asserted contiguous + const auto tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads, + head_dim, next_n_atom * num_heads, + static_cast(q.stride(2)), + head_dim / 2, 0, false, false); + // NOTES: `sf_q` is a 3D tensor, while `weights` is a 2D tensor + const auto tensor_map_sf_q = make_tma_2d_desc(sf_q, num_heads, batch_size * next_n, + num_heads, next_n_atom, + static_cast(sf_q.stride(1)), 0); + const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, batch_size * next_n, + num_heads, next_n_atom, + static_cast(weights.stride(0)), 0); + + const auto tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks, + head_dim, block_kv, 1, + static_cast(kv_cache.stride(1)), + static_cast(kv_cache.stride(0)), + head_dim / 2, 0, false, false); + const auto tensor_map_sf_kv = make_tma_2d_desc(kv_cache_sf, block_kv, num_kv_blocks, + block_kv, 1, + static_cast(kv_cache_sf.stride(0)), 0); + + // Calculate shared memory size + const int smem_q_size_per_stage = next_n_atom * num_heads * head_dim / 2; + const int smem_sf_q_size_per_stage = align(next_n_atom * num_heads, 128) * sizeof(int); + const int smem_kv_size_per_stage = split_kv * head_dim / 2; + const int smem_sf_kv_size_per_stage = align(split_kv, 128) * sizeof(int); + const int smem_weight_size_per_stage = next_n_atom * num_heads * sizeof(float); + + const int smem_barriers = (num_q_stages + num_kv_stages + num_tmem_stages) * 2 * 8; + const int smem_tmem_ptr = 4; + const int smem_size = num_q_stages * (smem_q_size_per_stage + smem_sf_q_size_per_stage + smem_weight_size_per_stage) + + num_kv_stages * (smem_kv_size_per_stage + smem_sf_kv_size_per_stage) + + smem_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + + // Launch + const SM100FP4PagedMQALogitsRuntime::Args args = { + .batch_size = batch_size, + .next_n = next_n, + .num_heads = num_heads, + .head_dim = head_dim, + .block_kv = block_kv, + .is_context_lens_2d = is_context_lens_2d, + .is_varlen = is_varlen, + .block_table_stride = block_table_stride, + .logits_stride = logits_stride, + .num_q_stages = num_q_stages, + .num_kv_stages = num_kv_stages, + .split_kv = split_kv, + .context_lens = context_lens.data_ptr(), + .logits = logits.data_ptr(), + .block_table = block_table.data_ptr(), + .indices = is_varlen ? indices.data_ptr() : nullptr, + .schedule_meta = schedule_meta.data_ptr(), + .tensor_map_q = tensor_map_q, + .tensor_map_sf_q = tensor_map_sf_q, + .tensor_map_kv = tensor_map_kv, + .tensor_map_sf_kv = tensor_map_sf_kv, + .tensor_map_weights = tensor_map_weights, + .logits_dtype = logits_dtype, + .num_specialized_threads = num_specialized_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(num_sms, + num_specialized_threads + num_math_threads, + smem_size) + }; + const auto code = SM100FP4PagedMQALogitsRuntime::generate(args); + const auto runtime = compiler->build("sm100_fp4_paged_mqa_logits", code); + SM100FP4PagedMQALogitsRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/smxx_layout.hpp b/deep-gemm/csrc/jit_kernels/impls/smxx_layout.hpp index 0b9eebd7..49fa8833 100644 --- a/deep-gemm/csrc/jit_kernels/impls/smxx_layout.hpp +++ b/deep-gemm/csrc/jit_kernels/impls/smxx_layout.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include "../../utils/torch_compat.hpp" #include "../../jit/kernel_runtime.hpp" #include "../../jit/compiler.hpp" @@ -72,7 +72,7 @@ static void __instantiate_kernel() {{ class PackFP32IntoUE8M0Runtime final: public LaunchRuntime { public: struct Args { - int num_groups, mn, sf_k, packed_sf_k; + int num_groups, mn, sf_k, packed_sf_k, gran_k; int block_mn, block_packed_sf_k; void *sf, *out, *ks; @@ -95,32 +95,32 @@ static void __instantiate_kernel() {{ static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, - args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k)); + args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k, args.gran_k)); } }; static std::tuple preprocess_sf(const torch::Tensor& sf) { // NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA - const auto& dim = sf.dim(); + const auto dim = sf.dim(); DG_HOST_ASSERT(dim == 2 or dim == 3); DG_HOST_ASSERT(sf.scalar_type() == torch::kFloat); - const auto& batched_sf = dim == 2 ? sf.unsqueeze(0) : sf; + const auto batched_sf = dim == 2 ? sf.unsqueeze(0) : sf; - const auto& [num_groups, mn, sf_k] = get_shape<3>(batched_sf); - const auto& tma_aligned_mn = get_tma_aligned_size(mn, static_cast(sf.element_size())); + const auto [num_groups, mn, sf_k] = get_shape<3>(batched_sf); + const auto tma_aligned_mn = get_tma_aligned_size(mn, static_cast(sf.element_size())); return {dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf}; } static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) { - const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf); + const auto [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf); // The last kernel already gives a column-major TMA aligned layout if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn) return (dim == 2) ? batched_sf.squeeze(0) : batched_sf; - const auto& out = torch::empty_strided({num_groups, mn, sf_k}, - {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, - batched_sf.options()); + const auto out = torch::empty_strided({num_groups, mn, sf_k}, + {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, + batched_sf.options()); if (not batched_sf.is_contiguous()) { // Fallback to PyTorch's slow copy if not contiguous @@ -129,7 +129,7 @@ static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) { } else { constexpr int block_mn = 64; constexpr int num_threads = 512; - const auto& smem_size = block_mn * (sf_k + (1 - (sf_k % 2))) * static_cast(sizeof(float)); + const auto smem_size = block_mn * (sf_k + (1 - (sf_k % 2))) * static_cast(sizeof(float)); const TransposeFP32Runtime::Args& args = { .mn = mn, .sf_k = sf_k, @@ -139,25 +139,25 @@ static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) { .launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, smem_size) }; - const auto& code = TransposeFP32Runtime::generate(args); - const auto& runtime = compiler->build("transpose_fp32", code); + const auto code = TransposeFP32Runtime::generate(args); + const auto runtime = compiler->build("transpose_fp32", code); TransposeFP32Runtime::launch(runtime, args); } return (dim == 2) ? out.squeeze(0) : out; } static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const torch::Tensor& sf) { - const auto& sf_reshaped = (sf.dim() == 2) ? sf.unsqueeze(0) : sf; + const auto sf_reshaped = (sf.dim() == 2) ? sf.unsqueeze(0) : sf; // First, convert into UE8M0 `uint8_t` - const auto& ue8m0_tensor = sf_reshaped.view(torch::kInt32).bitwise_right_shift(23).to(torch::kUInt8); + const auto ue8m0_tensor = sf_reshaped.view(torch::kInt32).bitwise_right_shift(23).to(torch::kUInt8); // Second, make padded packed tensors - const auto& [num_groups, mn, k] = get_shape<3>(sf_reshaped); - const auto& aligned_mn = get_tma_aligned_size(mn, 4); - const auto& aligned_k = align(k, 4); + const auto [num_groups, mn, k] = get_shape<3>(sf_reshaped); + const auto aligned_mn = get_tma_aligned_size(mn, 4); + const auto aligned_k = align(k, 4); - const auto& options = torch::TensorOptions().device(sf.device()).dtype(torch::kUInt8); + const auto options = torch::TensorOptions().device(sf.device()).dtype(torch::kUInt8); auto padded = torch::zeros({num_groups, aligned_mn, aligned_k}, options); // ReSharper disable once CppExpressionWithoutSideEffects padded.slice(1, 0, mn).slice(2, 0, k).copy_(ue8m0_tensor); @@ -172,11 +172,11 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const to } static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf) { - const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf); - const auto& packed_sf_k = ceil_div(sf_k, 4); - const auto& out = torch::empty_strided({num_groups, mn, packed_sf_k}, - {packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn}, - at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt)); + const auto [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf); + const auto packed_sf_k = ceil_div(sf_k, 4); + const auto out = torch::empty_strided({num_groups, mn, packed_sf_k}, + {packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn}, + at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt)); // Launch the kernel if (batched_sf.is_contiguous()) { if ((mn * sf_k) % 4 != 0 and num_groups > 1) @@ -193,8 +193,8 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T .launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, block_mn * sf_k * 4) }; - const auto& code = TransposeAndPackFP32IntoUE8M0Runtime::generate(args); - const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code); + const auto code = TransposeAndPackFP32IntoUE8M0Runtime::generate(args); + const auto runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code); TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args); } else { if (mn % 4 != 0 or num_groups > 1) @@ -217,8 +217,8 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T .launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads) }; - const auto& code = PackFP32IntoUE8M0Runtime::generate(args); - const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code); + const auto code = PackFP32IntoUE8M0Runtime::generate(args); + const auto runtime = compiler->build("pack_fp32_into_ue8m0", code); PackFP32IntoUE8M0Runtime::launch(runtime, args); } return (dim == 2) ? out.squeeze(0) : out; @@ -226,18 +226,20 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf, const torch::Tensor& ks_tensor, - const std::vector& ks) { - const auto& [sf_k, mn] = get_shape<2>(sf); - const auto& num_groups = static_cast(ks.size()); + const std::vector& ks, + const int gran_k) { + DG_HOST_ASSERT(gran_k == 32 or gran_k == 128); + const auto [sf_k, mn] = get_shape<2>(sf); + const auto num_groups = static_cast(ks.size()); int ref_sf_k = 0, packed_sf_k = 0; - for (const auto& k: ks) - ref_sf_k += ceil_div(k, 128), packed_sf_k += ceil_div(k, 512); + for (const auto k: ks) + ref_sf_k += ceil_div(k, gran_k), packed_sf_k += ceil_div(k, gran_k * 4); DG_HOST_ASSERT(sf.is_contiguous()); DG_HOST_ASSERT(ref_sf_k == sf_k); DG_HOST_ASSERT(num_groups <= 128 and mn % 4 == 0); - const auto& out = torch::empty({packed_sf_k, mn}, at::TensorOptions().device(sf.device()).dtype(torch::kInt)); + const auto out = torch::empty({packed_sf_k, mn}, at::TensorOptions().device(sf.device()).dtype(torch::kInt)); constexpr int block_mn = 128; constexpr int block_packed_sf_k = 16; @@ -247,6 +249,7 @@ static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(cons .mn = mn, .sf_k = sf_k, .packed_sf_k = packed_sf_k, + .gran_k = gran_k, .block_mn = block_mn, .block_packed_sf_k = block_packed_sf_k, .sf = sf.data_ptr(), @@ -255,8 +258,8 @@ static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(cons .launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads) }; - const auto& code = PackFP32IntoUE8M0Runtime::generate(args); - const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code); + const auto code = PackFP32IntoUE8M0Runtime::generate(args); + const auto runtime = compiler->build("pack_fp32_into_ue8m0", code); PackFP32IntoUE8M0Runtime::launch(runtime, args); return out; } diff --git a/deep-gemm/csrc/python_api.cpp b/deep-gemm/csrc/python_api.cpp index 0354f1f8..a966afe1 100644 --- a/deep-gemm/csrc/python_api.cpp +++ b/deep-gemm/csrc/python_api.cpp @@ -6,6 +6,7 @@ #include "apis/hyperconnection.hpp" #include "apis/gemm.hpp" #include "apis/layout.hpp" +#include "apis/mega.hpp" #include "apis/runtime.hpp" #ifndef TORCH_EXTENSION_NAME @@ -22,5 +23,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { deep_gemm::hyperconnection::register_apis(m); deep_gemm::gemm::register_apis(m); deep_gemm::layout::register_apis(m); + deep_gemm::mega::register_apis(m); deep_gemm::runtime::register_apis(m); } diff --git a/deep-gemm/csrc/utils/exception.hpp b/deep-gemm/csrc/utils/exception.hpp index 2aa27066..417dd3b4 100644 --- a/deep-gemm/csrc/utils/exception.hpp +++ b/deep-gemm/csrc/utils/exception.hpp @@ -42,7 +42,7 @@ do { \ #ifndef DG_NVRTC_CHECK #define DG_NVRTC_CHECK(cmd) \ do { \ - const auto& e = (cmd); \ + const auto e = (cmd); \ if (e != NVRTC_SUCCESS) { \ throw DGException("NVRTC", __FILE__, __LINE__, nvrtcGetErrorString(e)); \ } \ @@ -52,7 +52,7 @@ do { \ #ifndef DG_CUDA_DRIVER_CHECK #define DG_CUDA_DRIVER_CHECK(cmd) \ do { \ - const auto& e = (cmd); \ + const auto e = (cmd); \ if (e != CUDA_SUCCESS) { \ std::stringstream ss; \ const char *name, *info; \ @@ -66,7 +66,7 @@ do { \ #ifndef DG_CUDA_RUNTIME_CHECK #define DG_CUDA_RUNTIME_CHECK(cmd) \ do { \ - const auto& e = (cmd); \ + const auto e = (cmd); \ if (e != cudaSuccess) { \ std::stringstream ss; \ ss << static_cast(e) << " (" << cudaGetErrorName(e) << ", " << cudaGetErrorString(e) << ")"; \ @@ -97,7 +97,7 @@ inline const char* cublasGetStatusString(cublasStatus_t status) { #define DG_CUBLASLT_CHECK(cmd) \ do { \ - const auto& e = (cmd); \ + const auto e = (cmd); \ if (e != CUBLAS_STATUS_SUCCESS) { \ std::ostringstream ss; \ ss << static_cast(e) << " (" << cublasGetStatusString(e) << ")"; \ diff --git a/deep-gemm/csrc/utils/format.hpp b/deep-gemm/csrc/utils/format.hpp index b89f9c83..c649fdaf 100644 --- a/deep-gemm/csrc/utils/format.hpp +++ b/deep-gemm/csrc/utils/format.hpp @@ -6,6 +6,7 @@ // Uses std::string concatenation instead of std::ostringstream to avoid // potential locale/ABI issues with ostringstream across different platforms. +#include #include #include #include @@ -41,6 +42,13 @@ inline std::string to_str(const T& v) { } } +template +inline std::string to_hex_float_str(const T& v) { + std::ostringstream os; + os << std::hexfloat << v; + return os.str(); +} + // Overload for C string literals (arrays) template inline std::string to_str(const char (&s)[N]) { @@ -80,6 +88,11 @@ std::string format_impl(std::string_view fmt, result += to_str(first); result += format_impl(fmt.substr(i + 2), rest...); return result; + } else if (i + 3 < fmt.size() && fmt[i + 1] == ':' && + fmt[i + 2] == 'a' && fmt[i + 3] == '}') { + result += to_hex_float_str(first); + result += format_impl(fmt.substr(i + 4), rest...); + return result; } else { result += fmt[i++]; } diff --git a/deep-gemm/csrc/utils/hash.hpp b/deep-gemm/csrc/utils/hash.hpp index ff36ef39..9efe6408 100644 --- a/deep-gemm/csrc/utils/hash.hpp +++ b/deep-gemm/csrc/utils/hash.hpp @@ -1,14 +1,12 @@ #pragma once -#include #include -#include namespace deep_gemm { static uint64_t fnv1a(const std::vector& data, const uint64_t& seed) { uint64_t h = seed; - const uint64_t& prime = 0x100000001b3ull; + const uint64_t prime = 0x100000001b3ull; for (const char& c: data) { h ^= static_cast(c); h *= prime; @@ -17,22 +15,21 @@ static uint64_t fnv1a(const std::vector& data, const uint64_t& seed) { } static std::string get_hex_digest(const std::vector& data) { - const auto& state_0 = fnv1a(data, 0xc6a4a7935bd1e995ull); - const auto& state_1 = fnv1a(data, 0x9e3779b97f4a7c15ull); + const auto state_0 = fnv1a(data, 0xc6a4a7935bd1e995ull); + const auto state_1 = fnv1a(data, 0x9e3779b97f4a7c15ull); // Split-mix 64 - const auto& split_mix = [](uint64_t z) { + const auto split_mix = [](uint64_t z) { z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ull; z = (z ^ (z >> 27)) * 0x94d049bb133111ebull; return z ^ (z >> 31); }; - // Use snprintf instead of ostringstream - char buf[64]; - snprintf(buf, sizeof(buf), "%016lx%016lx", - (unsigned long)split_mix(state_0), - (unsigned long)split_mix(state_1)); - return std::string(buf); + std::ostringstream oss; + oss << std::hex << std::setfill('0') + << std::setw(16) << split_mix(state_0) + << std::setw(16) << split_mix(state_1); + return oss.str(); } static std::string get_hex_digest(const std::string& data) { diff --git a/deep-gemm/csrc/utils/layout.hpp b/deep-gemm/csrc/utils/layout.hpp index c9ac9514..a003f5af 100644 --- a/deep-gemm/csrc/utils/layout.hpp +++ b/deep-gemm/csrc/utils/layout.hpp @@ -1,7 +1,7 @@ #pragma once #include -#include +#include "torch_compat.hpp" #include "math.hpp" #include "exception.hpp" @@ -116,9 +116,4 @@ static torch::Tensor check_sf_layout(const torch::Tensor& sf, return sf; } -// Value matrix layout -static int get_mk_alignment_for_contiguous_layout() { - return 128; -} - } // namespace deep_gemm diff --git a/deep-gemm/csrc/utils/math.hpp b/deep-gemm/csrc/utils/math.hpp index 2af48e83..6ce9adc6 100644 --- a/deep-gemm/csrc/utils/math.hpp +++ b/deep-gemm/csrc/utils/math.hpp @@ -1,13 +1,14 @@ +// TODO: merge this file with `math.cuh` (the device part) #pragma once -#include +#include "torch_compat.hpp" #include "exception.hpp" namespace deep_gemm { -// TODO: Use `torch::kFloat4_e2m1fn_x2` -constexpr auto kPackedFP4 = torch::kUInt8; +// TODO: use `torch::kFloat4_e2m1fn_x2` +constexpr auto kPackedFP4 = torch::kInt8; template static T ceil_div(const T& a, const T& b) { diff --git a/deep-gemm/csrc/utils/system.hpp b/deep-gemm/csrc/utils/system.hpp index 2c97066f..fda020be 100644 --- a/deep-gemm/csrc/utils/system.hpp +++ b/deep-gemm/csrc/utils/system.hpp @@ -16,7 +16,7 @@ namespace deep_gemm { // ReSharper disable once CppNotAllPathsReturnValue template static dtype_t get_env(const std::string& name, const dtype_t& default_value = dtype_t()) { - const auto& c_str = std::getenv(name.c_str()); + const auto c_str = std::getenv(name.c_str()); if (c_str == nullptr) return default_value; @@ -34,7 +34,7 @@ static dtype_t get_env(const std::string& name, const dtype_t& default_value = d static std::tuple call_external_command(std::string command) { command = command + " 2>&1"; - const auto& deleter = [](FILE* f) { if (f) pclose(f); }; + const auto deleter = [](FILE* f) { if (f) pclose(f); }; std::unique_ptr pipe(popen(command.c_str(), "r"), deleter); DG_HOST_ASSERT(pipe != nullptr); @@ -42,7 +42,10 @@ static std::tuple call_external_command(std::string command) { std::string output; while (fgets(buffer.data(), buffer.size(), pipe.get())) output += buffer.data(); - const auto& exit_code = WEXITSTATUS(pclose(pipe.release())); + const auto status = pclose(pipe.release()); + // NOTES: if the child was killed by a signal (e.g., SIGINT from Ctrl+C), + // WEXITSTATUS would incorrectly return 0. Treat signal death as failure. + const auto exit_code = WIFEXITED(status) ? WEXITSTATUS(status) : 128 + WTERMSIG(status); return {exit_code, output}; } @@ -68,13 +71,13 @@ static std::vector collect_files(const std::filesystem::p static std::filesystem::path make_dirs(const std::filesystem::path& path) { // OK if existed std::error_code capture; - const bool& created = std::filesystem::create_directories(path, capture); + const bool created = std::filesystem::create_directories(path, capture); if (not (created or capture.value() == 0)) { DG_HOST_UNREACHABLE(fmt::format("Failed to make directory: {}, created: {}, value: {}", path.c_str(), created, capture.value())); } if (created and get_env("DG_JIT_DEBUG")) - fprintf(stderr, "Create directory: %s\n", path.c_str()); + printf("Create directory: %s\n", path.c_str()); return path; } @@ -85,11 +88,41 @@ static std::string get_uuid() { }()); static std::uniform_int_distribution dist; - // Use snprintf instead of stringstream - char buf[64]; - std::snprintf(buf, sizeof(buf), "%d-%08x-%08x-%08x", - getpid(), dist(gen), dist(gen), dist(gen)); - return std::string(buf); + std::stringstream ss; + ss << getpid() << "-" + << std::hex << std::setfill('0') + << std::setw(8) << dist(gen) << "-" + << std::setw(8) << dist(gen) << "-" + << std::setw(8) << dist(gen); + return ss.str(); +} + +static void safe_remove_all(const std::filesystem::path& path) { + std::error_code ec; + if (not std::filesystem::exists(path, ec) or ec) + return; + + // A single file + if (not std::filesystem::is_directory(path, ec) or ec) { + std::filesystem::remove(path, ec); + return; + } + + // Remove directory + auto it = std::filesystem::directory_iterator(path, + std::filesystem::directory_options::skip_permission_denied, ec); + for (auto end = std::filesystem::directory_iterator(); it != end and not ec;) { + const auto entry_path = it->path(); + + // Increase firstly to avoid failures + it.increment(ec); + if (ec) + break; + + // Recursively clean + safe_remove_all(entry_path); + } + std::filesystem::remove(path, ec); } } // deep_gemm diff --git a/deep-gemm/csrc/utils/torch_compat.hpp b/deep-gemm/csrc/utils/torch_compat.hpp new file mode 100644 index 00000000..9b24d894 --- /dev/null +++ b/deep-gemm/csrc/utils/torch_compat.hpp @@ -0,0 +1,36 @@ +#pragma once + +#include +#include +#include +#include + +// DeepGEMM upstream uses the torch:: C++ API from torch/python.h. Kernel Hub +// builds use Python's limited ABI, so CUDA TUs must avoid Python/pybind headers. +namespace torch { +namespace indexing = at::indexing; + +using at::Tensor; +using at::TensorOptions; +using c10::ScalarType; + +using at::arange; +using at::empty; +using at::empty_like; +using at::empty_strided; +using at::from_blob; +using at::tensor; +using at::zeros; + +inline constexpr auto kBFloat16 = at::kBFloat16; +inline constexpr auto kByte = at::kByte; +inline constexpr auto kFloat = at::kFloat; +inline constexpr auto kFloat8_e4m3fn = at::kFloat8_e4m3fn; +inline constexpr auto kFloat32 = at::kFloat; +inline constexpr auto kInt = at::kInt; +inline constexpr auto kInt8 = at::kChar; +inline constexpr auto kInt32 = at::kInt; +inline constexpr auto kInt64 = at::kLong; +inline constexpr auto kUInt8 = at::kByte; +inline constexpr auto kCUDA = c10::kCUDA; +} // namespace torch diff --git a/deep-gemm/deep_gemm/__init__.py b/deep-gemm/deep_gemm/__init__.py index 1c07f5d9..a9542e2f 100644 --- a/deep-gemm/deep_gemm/__init__.py +++ b/deep-gemm/deep_gemm/__init__.py @@ -19,6 +19,10 @@ get_num_sms, set_tc_util, get_tc_util, + set_ignore_compile_dims, + set_block_size_multiple_of, + set_pdl, + get_pdl, ) # cuBLASLt Kernels @@ -56,14 +60,16 @@ einsum, fp8_einsum, # Attention kernels - fp8_mqa_logits, + fp8_fp4_mqa_logits, get_paged_mqa_logits_metadata, + fp8_fp4_paged_mqa_logits, + # Attention kernels (legacy) + fp8_mqa_logits, fp8_paged_mqa_logits, # Hyperconnection kernels tf32_hc_prenorm_gemm, # Layout kernels transform_sf_into_required_layout, - get_mk_alignment_for_contiguous_layout ) # Some alias for legacy supports @@ -74,6 +80,14 @@ # Expected behavior for CUDA runtime version before 12.1 pass +# Mega kernels +from .mega import ( + SymmBuffer, + get_symm_buffer_for_mega_moe, + transform_weights_for_mega_moe, + fp8_fp4_mega_moe, +) + # Some utils from . import testing from . import utils @@ -109,4 +123,4 @@ def _find_cuda_home() -> str: _find_cuda_home() # CUDA home ) -__version__ = '2.3.0' +__version__ = '2.5.0' diff --git a/deep-gemm/deep_gemm/include/deep_gemm/comm/barrier.cuh b/deep-gemm/deep_gemm/include/deep_gemm/comm/barrier.cuh new file mode 100644 index 00000000..eb9858d8 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/comm/barrier.cuh @@ -0,0 +1,83 @@ +#pragma once + +#include + +#include +#include +#include + +namespace deep_gemm::comm { + +CUTLASS_DEVICE void cluster_sync_with_relaxed_arrive() { + // Perform cluster_sync with `barrier.cluster.arrive.relaxed` + // This is slightly faster than `cute::cluster_sync` but has weaker memory ordering guarantee + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); +} + +template +CUTLASS_DEVICE void grid_sync(const layout::Workspace& workspace, + const uint32_t& sm_idx, const uint32_t& thread_idx, + const sync_scope_t& sync_scope) { + // NOTES: the implementation idea is from `cooperative_groups::this_grid().sync()` + static constexpr uint32_t kFinishSumTag = 0x80000000u; + sync_scope(); + if (thread_idx == 0) { + const auto count_ptr = workspace.get_grid_sync_count_ptr(); + const auto old_value = ptx::atomic_add_rel( + count_ptr, sm_idx == 0 ? (kFinishSumTag - (kNumSMs - 1)) : 1); + uint32_t new_value; + do { + new_value = ptx::ld_acq(count_ptr); + } while (((new_value ^ old_value) & kFinishSumTag) == 0); + } + sync_scope(); +} + +template +CUTLASS_DEVICE void nvlink_barrier(const layout::Workspace& workspace, + const layout::SymBuffer& sym_buffer, + const uint32_t& sm_idx, const uint32_t& thread_idx, + const sync_scope_t& sync_scope, + const bool& sync_prologue = true, + const bool& sync_epilogue = true) { + DG_STATIC_ASSERT(kNumRanks <= kNumThreads, "Insufficient threads"); + + // Grid sync before NVLink signaling + if (sync_prologue) + grid_sync(workspace, sm_idx, thread_idx, sync_scope); + + // NVLink cross-rank barrier, only SM 0 participates + if (sm_idx == 0) { + auto* counter_ptr = workspace.get_nvl_barrier_counter_ptr(); + const auto status = (*counter_ptr) & 3; + const auto signal_phase = status & 1, signal_sign = status >> 1; + auto* signal_ptr = workspace.get_nvl_barrier_signal_ptr(signal_phase); + + // Send signals to remote ranks + if (thread_idx < kNumRanks) + ptx::red_add_rel_sys(sym_buffer.map(signal_ptr, thread_idx), signal_sign ? -1 : 1); + sync_scope(); + + // Update status and wait arrival (with 30s timeout, at 2 GHz) + constexpr int64_t kNumTimeoutCycles = 30ll * 2000000000ll; + if (thread_idx == 0) { + ptx::red_add(counter_ptr, 1); + const int target = signal_sign ? 0 : static_cast(kNumRanks); + const auto start_clock = clock64(); + while (ptx::ld_acq_sys(signal_ptr) != target) { + if (clock64() - start_clock >= kNumTimeoutCycles) { + printf("DeepGEMM NVLink barrier timeout (30s): rank=%d, counter=%d, signal=%d, target=%d, phase=%d, sign=%d, tag=%d\n", + sym_buffer.rank_idx, *counter_ptr, ptx::ld_acq_sys(signal_ptr), target, signal_phase, signal_sign, kTag); + DG_DEVICE_ASSERT(false and "NVLink barrier timeout"); + } + } + } + } + + // Grid sync after NVLink completion + if (sync_epilogue) + grid_sync(workspace, sm_idx, thread_idx, sync_scope); +} + +} // namespace deep_gemm::comm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/compile.cuh b/deep-gemm/deep_gemm/include/deep_gemm/common/compile.cuh new file mode 100644 index 00000000..e93c43fb --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/compile.cuh @@ -0,0 +1,18 @@ +#pragma once + +#include + +#if defined(__NVCC__) or (defined(__clang__) and defined(__CUDA__)) or defined(__CUDACC_RTC__) or defined(__CLION_IDE__) +#define DG_IN_CUDA_COMPILATION +#endif + +#if defined(__NVCC__) || (defined(__clang__) and defined(__CUDA__)) +#define CUTLASS_HOST_DEVICE_NOINLINE __device__ __host__ +#define CUTLASS_DEVICE_NOINLINE __device__ +#elif defined(__CUDACC_RTC__) +#define CUTLASS_HOST_DEVICE_NOINLINE __device__ +#define CUTLASS_DEVICE_NOINLINE __device__ +#else +#define CUTLASS_HOST_DEVICE_NOINLINE +#define CUTLASS_DEVICE_NOINLINE +#endif diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/cute_tie.cuh b/deep-gemm/deep_gemm/include/deep_gemm/common/cute_tie.cuh index cd2aace7..a3a8b62a 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/common/cute_tie.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/cute_tie.cuh @@ -1,5 +1,7 @@ #pragma once +#include + namespace cute { struct ignore_t { diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/exception.cuh b/deep-gemm/deep_gemm/include/deep_gemm/common/exception.cuh new file mode 100644 index 00000000..78acf747 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/exception.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include +#include + +#ifdef __CLION_IDE__ + +CUTLASS_HOST_DEVICE void host_device_printf(const char* format, ...) { + asm volatile("trap;"); +} + +#define printf host_device_printf +#endif + +#ifndef DG_DEVICE_ASSERT +#define DG_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ +} while (0) +#endif + +#ifndef DG_TRAP_ONLY_DEVICE_ASSERT +#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) \ + asm("trap;"); \ +} while (0) +#endif + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#endif + +#ifndef DG_UNIFIED_ASSERT +#ifdef DG_IN_CUDA_COMPILATION +#define DG_UNIFIED_ASSERT(cond) DG_DEVICE_ASSERT(cond) +#else +#define DG_UNIFIED_ASSERT(cond) DG_HOST_ASSERT(cond) +#endif +#endif diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/math.cuh b/deep-gemm/deep_gemm/include/deep_gemm/common/math.cuh new file mode 100644 index 00000000..0f0d2504 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/math.cuh @@ -0,0 +1,149 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm::math { + +/// Pointer operations +template +CUTLASS_HOST_DEVICE dtype_t* advance_ptr(void* ptr, const uint64_t num_bytes) { + return reinterpret_cast(static_cast(ptr) + num_bytes); +} + +/// Math functions +template +CUTLASS_HOST_DEVICE T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +CUTLASS_HOST_DEVICE T align(T a, T b) { + return (kDoCeilAlignment ? ceil_div(a, b) : (a / b)) * b; +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_align(T a, T b) { + return constexpr_ceil_div(a, b) * b; +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_gcd(T a, T b) { + return b == 0 ? a : constexpr_gcd(b, a % b); +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; +} + +template +CUTLASS_DEVICE void swap(T& a, T& b) { + T temp = a; + a = b; + b = temp; +} + +#ifdef DG_IN_CUDA_COMPILATION +CUTLASS_DEVICE float2 fma2(const float2& a, const float2& b, const float2& c) { +#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000) + return __ffma2_rn(a, b, c); +#else + return make_float2( + __fmaf_rn(a.x, b.x, c.x), + __fmaf_rn(a.y, b.y, c.y) + ); +#endif +} + +CUTLASS_HOST_DEVICE float fast_rcp(const float& x) { + float ret; + asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(ret) : "f"(x)); + return ret; +} + +/// Casting +template +CUTLASS_DEVICE int cast_into_bf16_and_pack(old_t& x, old_t& y) { + auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); + return *reinterpret_cast(&bf16x2); +} + +CUTLASS_DEVICE float fast_pow2(const int& x) { + uint32_t bits_x = (x + 127) << 23; + return *reinterpret_cast(&bits_x); +} + +CUTLASS_DEVICE int fast_log2_ceil(float x) { + const auto bits = *reinterpret_cast(&x); + const auto exp = bits >> 23; + const auto man = bits & ((1 << 23) - 1); + return exp - 127 + (man != 0); +} + +template +CUTLASS_DEVICE void get_e4m3_sf_and_sf_inv(const float2& amax, float2& sf, float2& sf_inv) { + DG_STATIC_ASSERT(kUseUE8M0, "Must use UE8M0"); + const float2 finfo_factor = {1.0 / 448.0, 1.0 / 448.0}; + const auto scaled = __fmul2_rn(amax, finfo_factor); + const auto exp_x = fast_log2_ceil(scaled.x); + const auto exp_y = fast_log2_ceil(scaled.y); + sf.x = fast_pow2(exp_x), sf_inv.x = fast_pow2(-exp_x); + sf.y = fast_pow2(exp_y), sf_inv.y = fast_pow2(-exp_y); +} + +/// Reduction +CUTLASS_DEVICE uint32_t warp_inclusive_sum(uint32_t value, const uint32_t& lane_idx) { + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const uint32_t synced = __shfl_up_sync(0xffffffff, value, offset); + if (lane_idx >= offset) + value += synced; + } + return value; +} + +// Operation functors +template struct ReduceSum { CUTLASS_DEVICE T operator()(T a, T b) const { return a + b; } }; +template struct ReduceMax { CUTLASS_DEVICE T operator()(T a, T b) const { return a > b ? a : b; } }; +template struct ReduceMin { CUTLASS_DEVICE T operator()(T a, T b) const { return a < b ? a : b; } }; +template struct ReduceAnd { CUTLASS_DEVICE T operator()(T a, T b) const { return a & b; } }; +template struct ReduceOr { CUTLASS_DEVICE T operator()(T a, T b) const { return a | b; } }; + +// Unified reduction function +template +CUTLASS_DEVICE T warp_reduce(T value, Op op) { + DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or + kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1, + "Invalid number of lanes"); + constexpr uint32_t mask = 0xffffffff; + if constexpr (kIntergroupReduce) { + if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1)); + if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16)); + } else { + if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16)); + if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1)); + } + return value; +} + +// Convenience aliases +template +CUTLASS_DEVICE T warp_reduce_sum(T value) { + return warp_reduce(value, ReduceSum{}); +} +#endif + +} // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/tma_copy.cuh b/deep-gemm/deep_gemm/include/deep_gemm/common/tma_copy.cuh new file mode 100644 index 00000000..2c5bf708 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/tma_copy.cuh @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include + +#include + +namespace deep_gemm::tma { + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +CUTLASS_DEVICE void +copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr, + dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx, + const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) { + DG_STATIC_ASSERT(static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL) == + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint"); + constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size(); + + if constexpr (not kIs3DTMA) { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } + #endif + } + } else { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } + #endif + } + } +} + +} // namespace deep_gemm::tma diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/types.cuh b/deep-gemm/deep_gemm/include/deep_gemm/common/types.cuh new file mode 100644 index 00000000..e07df0af --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/types.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include + +namespace deep_gemm { + +enum class MmaKind { + BF16 = 0, + MXFP8FP4 = 1, +}; + +constexpr CUTLASS_HOST_DEVICE int get_element_size(const MmaKind& mma_kind) { + switch (mma_kind) { + case MmaKind::BF16: return 2; + case MmaKind::MXFP8FP4: return 1; + default: return 0; + } +} + +enum class GemmType { + Normal = 0, + MGroupedContiguous = 1, + MGroupedMasked = 2, + KGroupedContiguous = 3, + Batched = 4, + MGroupedContiguousWithPsumLayout = 5, +}; + +constexpr CUTLASS_HOST_DEVICE bool is_m_grouped_contiguous(const GemmType& gemm_type) { + switch (gemm_type) { + case GemmType::MGroupedContiguous: return true; + case GemmType::MGroupedContiguousWithPsumLayout: return true; + default: return false; + } +} + +enum class KernelType { + Kernel1D1D = 0, + Kernel1D2D = 1, + KernelNoSF = 2 +}; + +} // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/utils.cuh b/deep-gemm/deep_gemm/include/deep_gemm/common/utils.cuh index 8fb6c2fc..3a5f7ad6 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/common/utils.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/utils.cuh @@ -1,167 +1,24 @@ #pragma once -#include -#include #include -#include -#include -#include "cute_tie.cuh" +#include -#ifdef __CLION_IDE__ - -__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { - asm volatile("trap;"); -} - -#define printf host_device_printf -#endif - -#ifndef DG_DEVICE_ASSERT -#define DG_DEVICE_ASSERT(cond) \ -do { \ - if (not (cond)) { \ - printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ - asm("trap;"); \ - } \ -} while (0) -#endif - -#ifndef DG_TRAP_ONLY_DEVICE_ASSERT -#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \ -do { \ - if (not (cond)) \ - asm("trap;"); \ -} while (0) -#endif - -#ifndef DG_STATIC_ASSERT -#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) -#endif - -namespace deep_gemm { +namespace deep_gemm::utils { template struct PatternVisitor { FuncT func; - __device__ __host__ + CUTLASS_HOST_DEVICE explicit PatternVisitor(FuncT&& func): func(std::forward(func)) {} - __device__ __host__ - auto operator [](const uint32_t& i) { + CUTLASS_HOST_DEVICE + auto operator [](const uint32_t& i) const { return func(i); } }; -template -__device__ __host__ T ceil_div(T a, T b) { - return (a + b - 1) / b; -} - -template -__device__ __host__ constexpr T constexpr_ceil_div(T a, T b) { - return (a + b - 1) / b; -} - -template -__device__ __host__ T align(T a, T b) { - return ceil_div(a, b) * b; -} - -template -__device__ __host__ constexpr T constexpr_align(T a, T b) { - return constexpr_ceil_div(a, b) * b; -} - -template -__device__ __host__ constexpr T constexpr_gcd(T a, T b) { - return b == 0 ? a : constexpr_gcd(b, a % b); -} - -template -__forceinline__ __device__ void swap(T& a, T& b) { - T temp = a; - a = b; - b = temp; -} - -__forceinline__ __device__ uint32_t get_sm_idx() { - uint32_t sm_idx; - asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx)); - return sm_idx; -} - -__forceinline__ __device__ uint32_t get_lane_idx() { - uint32_t lane_id; - asm ("mov.u32 %0, %laneid;" : "=r"(lane_id)); - return lane_id; -} - -__device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) { - uint32_t ret; - asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ float2 ld_shared(const float2* ptr) { - float2 ret; - asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ float4 ld_shared(const float4* ptr) { - float4 ret; - asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ uint4 ld_shared(const uint4* ptr) { - uint4 ret; - asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ float ld_shared(const float* ptr) { - float ret; - asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ void st_shared(const float* ptr, float val) { - asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val)); -} - -__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) { - asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y)); -} - -__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { - asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val)); -} - -__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) { - asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y)); -} - -__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { - asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w)); -} - -__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) { - asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); -} - -template -__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) { - auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); - return *reinterpret_cast(&bf16x2); -} - -__device__ __forceinline__ void prefetch_l1(void *ptr) { - asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); -} - template struct Vectorized { static auto zeros() { @@ -180,4 +37,14 @@ struct Vectorized { using vec_t = decltype(zeros()); }; -} // namespace `deep_gemm` +template +CUTLASS_DEVICE constexpr uint32_t get_num_aligned_tmem_cols() { + DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns"); + if constexpr (kNumCols <= 32) return 32; + if constexpr (kNumCols <= 64) return 64; + if constexpr (kNumCols <= 128) return 128; + if constexpr (kNumCols <= 256) return 256; + return 512; +} + +} // namespace deep_gemm::utils diff --git a/deep-gemm/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd.cuh b/deep-gemm/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd.cuh new file mode 100644 index 00000000..bf0e460c --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd.cuh @@ -0,0 +1,137 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace deep_gemm::epilogue { + +template +CUTLASS_DEVICE void +sm100_store_cd(const utils::PatternVisitor& smem_cd, uint32_t& tma_stage_idx, + const uint32_t& tmem_base_addr, + const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx, + const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx, + const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier, + const cute::TmaDescriptor& tensor_map_cd) { + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Share store pipeline between blocks + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Iterate over M waves + constexpr auto kNumMWaves = BLOCK_M / STORE_BLOCK_M; + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + auto smem_base_ptr = reinterpret_cast(smem_cd[tma_stage_idx]); + + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // The pipeline stage + const auto m_idx = base_m_idx + w * STORE_BLOCK_M; + const auto n_idx = epilogue_type_t::apply_index_n(base_n_idx + s * STORE_BLOCK_N); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = tmem_base_addr + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = smem_base_ptr + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + ptx::st_shared( + smem_ptr, + math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]) + ); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barrier->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx, batch_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx); + } + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +} + +} // namespace deep_gemm::epilogue diff --git a/deep-gemm/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh b/deep-gemm/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh new file mode 100644 index 00000000..f3f5351e --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh @@ -0,0 +1,144 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace deep_gemm::epilogue { + +template +CUTLASS_DEVICE void +sm100_store_cd_swap_ab(const utils::PatternVisitor& smem_cd, uint32_t& tma_stage_idx, + const uint32_t& tmem_base_addr, + const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx, + const uint32_t& effective_m, + const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx, + const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier, + const cute::TmaDescriptor& tensor_map_cd) { + // NOTES: The epilogue requires a full warpgroup to read all 128 TMEM rows, + // implying STORE_BLOCK_N must be 128. + DG_STATIC_ASSERT(STORE_BLOCK_N == 128, "STORE_BLOCK_N must be 128 to match TMEM rows"); + + // TMA checks + constexpr uint32_t STORE_BLOCK_N_ATOM = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumSwizzleAtomRows = 8; + DG_STATIC_ASSERT(kSwizzleCDMode == 128, "TMA D must be 128B swizzled"); + DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(STORE_BLOCK_M % kNumSwizzleAtomRows == 0, "Invalid swizzling"); + DG_STATIC_ASSERT(STORE_BLOCK_N % STORE_BLOCK_N_ATOM == 0, "Invalid swizzling"); + + // Share store pipeline between blocks + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Iterate over M blocks + const auto num_stores = effective_m / STORE_BLOCK_M; + for (uint32_t s = 0; s < num_stores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_M / kNumSwizzleAtomRows; ++ i) { + uint32_t tmem_addr = tmem_base_addr + + s * STORE_BLOCK_M + // Store stage offset + i * kNumSwizzleAtomRows; // In-block offset + uint32_t values[kNumSwizzleAtomRows]; + + // Warps cooperatively write an atomic block to shared memory + DG_STATIC_ASSERT(STORE_BLOCK_N_ATOM % 32 == 0, "Invalid block sizes"); + constexpr uint32_t kNumWarpsPerAtom = STORE_BLOCK_N_ATOM / 32; + uint32_t outer_atom_offset = (epilogue_warp_idx / kNumWarpsPerAtom) * STORE_BLOCK_M * kSwizzleCDMode; + uint32_t inner_atom_offset = i * kNumSwizzleAtomRows * kSwizzleCDMode; + auto smem_base_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + outer_atom_offset + inner_atom_offset; + + if constexpr (cute::is_same_v) { + // NOTES: Swizzling is not required in this case, but used here for consistency with other cases + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + uint32_t col = lane_idx / 4; + + #pragma unroll + for (uint32_t row = 0; row < kNumSwizzleAtomRows; ++ row) { + auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes + + (lane_idx % 4) * sizeof(float); + ptx::st_shared(reinterpret_cast(smem_ptr), values[row]); + } + } else { + // Load from TMEM using `.16x256b` shape to satisfy STSM layout requirements + // Start from lane index 0 + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + // Start from lane index 16 + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Destination shared memory address + uint32_t row = lane_idx % 8; + uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8; + auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes; + + // Store matrix with transposition + ptx::SM90_U32x4_STSM_T::copy(math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]), + smem_ptr); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (s == num_stores - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barrier->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / STORE_BLOCK_N_ATOM; ++ i) { + auto smem_ptr = smem_cd[tma_stage_idx] + i * STORE_BLOCK_M * STORE_BLOCK_N_ATOM; + uint32_t m_idx = base_m_idx + s * STORE_BLOCK_M; + uint32_t n_idx = epilogue_type_t::apply_index_n(base_n_idx + i * STORE_BLOCK_N_ATOM); + + // Issue 2D or 3D TMA store + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx, batch_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx); + } + } + cute::tma_store_arrive(); + } + __syncwarp(); + } +} + +} // namespace deep_gemm::epilogue diff --git a/deep-gemm/deep_gemm/include/deep_gemm/epilogue/transform.cuh b/deep-gemm/deep_gemm/include/deep_gemm/epilogue/transform.cuh new file mode 100644 index 00000000..0266f4d4 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/epilogue/transform.cuh @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace deep_gemm::epilogue::transform { + +struct EpilogueIdentity { + template + CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) { + return n_idx; + } +}; + +template +struct EpilogueHeadSplits: EpilogueIdentity { + template + CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) { + DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0 and + kRight % STORE_BLOCK_N == 0, "Invalid head splits config"); + return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid; + } +}; + +} // namespace deep_gemm::epilogue::transform diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh index 0227b3e8..a60e2de8 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -4,14 +4,18 @@ #include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm100; - template -__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) sm100_bf16_gemm_impl(int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, @@ -48,41 +53,31 @@ sm100_bf16_gemm_impl(int* grouped_layout, if constexpr (kWithAccumulation) DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); - // Configs + // MMA Configs constexpr uint32_t LAYOUT_AD_M = 128; - constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); - constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; - constexpr uint32_t kNumTMAStoreStages = 2; - DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K"); - DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); - DG_STATIC_ASSERT(sizeof(cutlass::bfloat16_t) * LAYOUT_AD_M % kSwizzleAMode == 0, "Invalid swizzle A mode"); - - // Overwrite shape constants if the compiler gives - shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; - shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; - shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; - - // Utils - bool is_leader_cta = cute::block_rank_in_cluster() == 0; - const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); - - // Align to 1024 bytes for swizzle-128B - extern __shared__ __align__(1024) uint8_t smem_buffer[]; - - // 2-CTA MMA + constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast; + constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N; + constexpr uint32_t UMMA_K = 16; constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); - constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); - constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); - constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; - DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); - DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K"); DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or + (not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size"); + + // Epilogue configs + // Always enable pipeline for better performance + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + // NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N + // per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases + constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M; DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); // Share memory sizes - constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t); constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); @@ -91,41 +86,54 @@ sm100_bf16_gemm_impl(int* grouped_layout, DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); // NOTES: Make sure we have enough shared memory for UMMA padding - static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16); - DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); - - // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size - // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` - constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2; + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory out of bound for UMMA"); // Real tensor memory size and offsets - constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; - constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * UMMA_N; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Synchronize the cluster before 2-CTA TMEM allocation + kNumMulticast > 1 ? cute::cluster_sync() : void(); + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors at the very beginning - if (warp_idx == 0 and cute::elect_one_sync()) { + if (warp_idx == 0) { cute::prefetch_tma_descriptor(&tensor_map_a); cute::prefetch_tma_descriptor(&tensor_map_b); cute::prefetch_tma_descriptor(&tensor_map_cd); } + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + // D/A/B shared memory - auto smem_cd = PatternVisitor([&](const uint32_t& i) { + auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); }); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); - auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); - auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2; // Fill the tensor memory pointer @@ -159,9 +167,13 @@ sm100_bf16_gemm_impl(int* grouped_layout, } kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler( + shape_m, shape_n, shape_k, grouped_layout); // Pipeline and TMA phases uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0; @@ -178,16 +190,20 @@ sm100_bf16_gemm_impl(int* grouped_layout, // TMA load warp // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + // Use dynamic load block M, when swap-AB is enabled + const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M; + + // For k-grouped layout, the number of block K is variable + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { // Wait consumer release empty_barriers[stage_idx]->wait(phase ^ 1); // Compute offsets // NOTES: the group is always concatenated with the outer dimension - uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> ( + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> ( shape_m, BLOCK_M, m_block_idx); - uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> ( + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> ( shape_n, BLOCK_N, n_block_idx, m_block_idx); // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major @@ -195,14 +211,14 @@ sm100_bf16_gemm_impl(int* grouped_layout, DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or kMajorA == cute::UMMA::Major::K, "Invalid major"); uint32_t k_idx = k_block_idx * BLOCK_K; - uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); - uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); // Add 2 CTA offsets if constexpr (kNumMulticast > 1) { - m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0; n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); } @@ -210,16 +226,16 @@ sm100_bf16_gemm_impl(int* grouped_layout, constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); if constexpr (kMajorA == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx); if constexpr (kMajorA == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx); // Arrive at full barriers @@ -235,17 +251,16 @@ sm100_bf16_gemm_impl(int* grouped_layout, // MMA issue warp // NOTES: only the leader CTA will do this // Make instruction descriptor - // TODO: refactor `UMMA_M` calculation - constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); - constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); - constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); - auto instr_desc = cute::UMMA::make_instr_desc(); + auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc() + : cute::UMMA::make_instr_desc(); DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); // Merged stages only happens in NT normal GEMM cases constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; - auto a_desc = make_umma_desc(smem_a[0], 0, 0); - auto b_desc = make_umma_desc(smem_b[0], 0, 0); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; @@ -262,7 +277,7 @@ sm100_bf16_gemm_impl(int* grouped_layout, auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // UMMA and empty barrier arrival alias auto umma_arrive = [](const uint64_t* barrier) { @@ -279,36 +294,45 @@ sm100_bf16_gemm_impl(int* grouped_layout, // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting if (do_tmem_full_arrive) umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); }; + // Dynamic update of UMMA N based on effective M, when swap-AB is enabled + if constexpr (kSwapAB) { + uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx); + mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n); + } + // Launch MMAs - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { // Wait TMA arrival full_barriers[stage_idx]->wait(phase); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Issue UMMA in the leader CTA - using mma_t = cute::conditional_t; - const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); - const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); - const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + using mma_t = cute::conditional_t; + const auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); if (cute::elect_one_sync()) { #pragma unroll for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K; - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); - #pragma unroll - for (uint32_t w = 0; w < kNumMWaves; ++ w) { - DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); - a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K + w * WAVE_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); - mma_t::fma(a_desc, b_desc, - accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, - k_block_idx > 0 or k > 0, - runtime_instr_desc); + a_desc.lo = mma::sm100::advance_umma_desc_lo( + a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo( + b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + if (kSwapAB) { + mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc); + } else { + mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc); } } } + __syncwarp(); // Commit to the mbarrier object // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` @@ -319,15 +343,16 @@ sm100_bf16_gemm_impl(int* grouped_layout, if constexpr (kTensorCoreUtilControl < 100) { // For utilization control umma_arrive(reinterpret_cast(tensor_core_full_barrier)); + __syncwarp(); // Wait for last UMMA to be done tensor_core_full_barrier->wait(tensor_core_phase); tensor_core_phase ^= 1; // Sleep for certain cycles - constexpr static uint64_t kNumUMMACycles = (2ull * LAYOUT_AD_M * kNumMWaves * BLOCK_N * BLOCK_K) / 8192ull; + constexpr static uint64_t kNumUMMACycles = (2ull * UMMA_M * UMMA_N * BLOCK_K) / 8192ull; constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl; - const auto& start_clock = clock64(); + const auto start_clock = clock64(); if (cute::elect_one_sync()) while (clock64() - start_clock < kNumDummyCycles) {} __syncwarp(); @@ -336,9 +361,9 @@ sm100_bf16_gemm_impl(int* grouped_layout, } // To safely deconstruct barriers, we need another round of waits - const auto& iter_idx = scheduler.current_iter - 1; + const auto iter_idx = scheduler.current_iter - 1; if (kNumMulticast > 1 and iter_idx >= 0) { - const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); } } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { @@ -348,19 +373,10 @@ sm100_bf16_gemm_impl(int* grouped_layout, // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. // NOTES: we also forbid two CTAs to share the same SM and its tensor memory - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); - - // TMA checks - constexpr uint32_t kNumBankGroupBytes = 16; - constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); - DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); - DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // Share store pipeline between blocks uint32_t tma_stage_idx = 0; - auto advance_store_pipeline = [&]() { - tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; - }; // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { @@ -369,108 +385,47 @@ sm100_bf16_gemm_impl(int* grouped_layout, // Wait UMMA arrival tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Load from tensor memory into registers, and write shared memory with STSM - DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); - DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); - - // Iterate over M waves - #pragma unroll - for (uint32_t w = 0; w < kNumMWaves; ++ w) { - // Issue every swizzled atom and pipeline STSM and TMA store - constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; - #pragma unroll - for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { - // Wait shared memory to be released - if (epilogue_warp_idx == 0) - cute::tma_store_wait(); - cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); - - // The pipeline stage - const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; - const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; - - // Store into shared memory - #pragma unroll - for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { - // Calculate the index of the bank group to be written in the atom - auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); - - // Reshape the atom in another view and swizzle - // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` - // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` - // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern - constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; - auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); - auto col = kHasShortcut ? (i) : (bank_group_index % 8); - col ^= row % (kSwizzleCDMode / 16); - - // Source and destination memory address - uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset - w * BLOCK_N + // Wave offset - s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset - auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer - epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset - row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset - - // Load from tensor memory, store into shared memory - uint32_t values[kNumElemsPerBankGroup]; - if constexpr (cute::is_same_v) { - // For FP32 output, read and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); - cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, - values[0], values[1], values[2], values[3]); - cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, values[0], values[1], values[2], values[3]); - } else { - // For BF16 output, read, cast and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); - cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, - values[0], values[1], values[2], values[3], - values[4], values[5], values[6], values[7]); - cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, - cast_into_bf16_and_pack(values[0], values[1]), - cast_into_bf16_and_pack(values[2], values[3]), - cast_into_bf16_and_pack(values[4], values[5]), - cast_into_bf16_and_pack(values[6], values[7])); - } - } - - // Notify tensor memory empty (only at the leader CTA) arrival ASAP - // NOTES: only the last stage needs to do this - if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { - tcgen05_before_thread_sync(); - tmem_empty_barriers[accum_stage_idx]->arrive(0u); - } - __syncwarp(); - - // Synchronize all threads and issue TMA - cute::tma_store_fence(); - cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); - if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { - if constexpr (kGemmType == GemmType::Batched) { - using cute_tma_t = cute::conditional_t; - cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], - n_idx, m_idx, scheduler.current_group_idx); - } else { - using cute_tma_t = cute::conditional_t; - cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); - } - cute::tma_store_arrive(); - } - } + const auto tmem_base_addr = accum_stage_idx * UMMA_N; + const auto base_m_idx = scheduler.template get_global_idx< + (not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto base_n_idx = n_block_idx * BLOCK_N; + + if constexpr (kSwapAB) { + const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx); + epilogue::sm100_store_cd_swap_ab + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + effective_m, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } else { + epilogue::sm100_store_cd + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); } } - - // Deallocate tensor memory by the last UMMA store warp - // NOTES: warp 0 is waiting TMA store - if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) - Allocator().free(0, kNumTmemCols); } + + // TODO: Remove redundant synchronization + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Deallocate tensor memory + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + #else if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh index 86303347..13bb0872 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh @@ -5,18 +5,19 @@ #include #include -#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm100; - template -__global__ void __launch_bounds__(kNumThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumThreads, 1) sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -30,7 +31,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // Utils const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size"); DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode"); @@ -51,24 +52,24 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, } // Real tensor memory size and offsets - constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); // Fill D/A/B - auto smem_cd = PatternVisitor([&](const uint32_t& i) { + auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE)); }); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2); // Fill the tensor memory pointer @@ -93,14 +94,17 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, __syncthreads(); // Block indices - const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); - const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M); const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; const uint32_t n_block_idx = mn_block_idx % num_n_blocks; const uint32_t m_block_idx = mn_block_idx / num_n_blocks; const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (warp_idx == 0) { // TMA load warp for (uint32_t s = 0; s < num_total_stages; ++ s) { @@ -115,8 +119,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // Issue TMAs if (cute::elect_one_sync()) { - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); + tma::copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); + tma::copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); } // Arrive at full barriers @@ -134,8 +138,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, auto instr_desc = cute::UMMA::make_instr_desc(); DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); - auto a_desc = make_umma_desc(smem_a[0], 0, 0); - auto b_desc = make_umma_desc(smem_b[0], 0, 0); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; @@ -147,14 +151,14 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, "Invalid MMA instruction shape"); // Wait tensor memory empty barrier arrival - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Launch MMAs for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait TMA arrival const auto& stage_idx = s % kNumStages; full_barriers[stage_idx]->wait((s / kNumStages) & 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Issue UMMA in the leader CTA const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); @@ -163,9 +167,11 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, if (cute::elect_one_sync()) { #pragma unroll for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { - a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, 0, k * UMMA_K); - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); - SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc); + a_desc.lo = mma::sm100::advance_umma_desc_lo( + a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo( + b_desc_base_lo, 0, k * UMMA_K); + ptx::SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc); } } @@ -180,7 +186,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`. // NOTES: we also forbid two CTAs to share the same SM and its tensor memory if (warp_idx == 2) - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // TMA checks constexpr uint32_t kNumBankGroupBytes = 16; @@ -191,7 +197,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // Wait UMMA arrival tmem_full_barrier->wait(0); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Load from tensor memory into registers, and write shared memory with STSM DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); @@ -239,7 +245,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, values[0], values[1], values[2], values[3]); cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]); } // Synchronize all threads and issue TMA diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh new file mode 100644 index 00000000..b8a99fd0 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh @@ -0,0 +1,457 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, + const uint32_t logits_stride, + const uint32_t* cu_seq_len_k_start, + const uint32_t* cu_seq_len_k_end, + logits_dtype_t* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; + + // Prefetch TMA descriptors + if (warp_idx == kSpecWarpStart) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_sf_q); + cute::prefetch_tma_descriptor(&tensor_map_weights); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_sf_kv); + } + + // UMMA configs + static constexpr uint32_t kNumTmemStages = 3; + static constexpr uint32_t kNumUTCCPAlignedElems = 128; + static constexpr uint32_t UMMA_M = 128; + static constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; + static constexpr uint32_t UMMA_K = 64; + static constexpr uint32_t kNumSFQ = math::constexpr_align(BLOCK_Q * kNumHeads, kNumUTCCPAlignedElems); + static constexpr uint32_t kNumSFKV = math::constexpr_align(BLOCK_KV, kNumUTCCPAlignedElems); + static constexpr uint32_t kRealNumSFQ = BLOCK_Q * kNumHeads; + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(BLOCK_KV == kNumMathWarpGroups * UMMA_M and BLOCK_KV % kNumUTCCPAlignedElems == 0, "Invalid `BLOCK_KV`"); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2); + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQ * sizeof(int); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q and KV data on shared memory + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i; + }); + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i; + }); + const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages); + auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i); + }); + auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i); + }); + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; + auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; }); + auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(tmem_barrier_ptr + kNumTmemStages * 2); + + // Tensor memory configs + constexpr uint32_t kNumAccumTmemCols = BLOCK_Q * kNumHeads * kNumTmemStages; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQ / 32; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Initialize barriers + if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads + 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumTmemStages; ++i) { + full_tmem_barriers[i]->init(1); + empty_tmem_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + + // Allocate tensor memory + if (warp_idx == kSpecWarpStart + 2) + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + __syncthreads(); + + // Scheduler + const uint32_t num_q_blocks = math::ceil_div(seq_len, BLOCK_Q); + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; + auto load_schedule = [&](const uint32_t& q_idx) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto row_idx = cute::min(q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = cute::min(cu_seq_len_k_start[row_idx], seq_len_kv); + seq_k_end[i] = cute::min(cu_seq_len_k_end[row_idx], seq_len_kv); + start = cute::min(start, seq_k_start[i]); + end = cute::max(end, seq_k_end[i]); + } + // TMA alignment requirements for SF KV + start = start / 4 * 4; + return {start, math::ceil_div(end - start, BLOCK_KV)}; + }; + + // Make Q, KV and TMEM pipeline + auto make_pipeline = [](const uint32_t& num_stages) { + // Return current stage and phase, and advance pipeline by steps + return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple { + uint32_t current_idx = iter_idx; + iter_idx += step; + return {current_idx % num_stages, (current_idx / num_stages) & 1}; + }; + }; + auto advance_q_pipeline = make_pipeline(kNumQStages); + auto advance_kv_pipeline = make_pipeline(kNumKVStages); + auto advance_tmem_pipeline = make_pipeline(kNumTmemStages); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 56; + constexpr uint32_t kNumMathRegisters = 224; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + if (warp_idx == kSpecWarpStart) { + // TMA warp for loading Q + cutlass::arch::warpgroup_reg_dealloc(); + + // Enumerate Q blocks + if (cute::elect_one_sync()) { + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Wait Q consumer release + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast(full_q_barriers[q_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_q[q_stage_idx], 0, q_idx * BLOCK_Q * kNumHeads); + tma::copy(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_idx * BLOCK_Q); + tma::copy(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_idx * BLOCK_Q); + full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQ * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE); + } + } + __syncwarp(); + } else if (warp_idx == kSpecWarpStart + 1) { + // TMA warp for loading KV cache + cutlass::arch::warpgroup_reg_dealloc(); + + if (cute::elect_one_sync()) { + // Enumerate Q blocks + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Load KV block ranges + CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks); + + // Enumerate KV blocks + for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) { + // Wait KV consumer release + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + cute::SM90_TMA_LOAD_2D::copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_kv[kv_stage_idx], 0, kv_start + kv_idx * BLOCK_KV); + tma::copy(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx], + smem_sf_kv[kv_stage_idx], + kv_start + kv_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE); + } + } + } + } else if (warp_idx == kSpecWarpStart + 2) { + // UMMA warp + cutlass::arch::warpgroup_reg_dealloc(); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + // Enumerate Q blocks + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Load KV block ranges + CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks); + + // Wait TMA Q arrivals + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Transpose and copy SF Q + #pragma unroll + for (uint32_t i = 0; i < kNumSFQ / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + if (cute::elect_one_sync()) + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4); + __syncwarp(); + } + + // Enumerate KV blocks + for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) { + // Wait TMA KV arrivals + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Transpose + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + } + + // UMMA with SF + if (cute::elect_one_sync()) { + // Copy SF KV + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4); + } + + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + // Wait TMEM release + CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase); + uint32_t tmem_addr = tmem_stage_idx * UMMA_N; + + empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Issue UMMA with SF + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2); + // TODO: generalize umma desc + DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim"); + auto a_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + auto b_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_q[q_stage_idx] + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + ptx::SM100_MMA_MXF4_SS::fma( + a_desc, b_desc, tmem_addr, k, runtime_instr_desc, + kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ); + } + // TODO: move this into `deep_gemm/ptx/tcgen05.cuh` + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" + ::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx]))); + } + } + cutlass::arch::umma_arrive(reinterpret_cast(empty_kv_barriers[kv_stage_idx])); + } + + // UMMA warp must also arrive on empty_q to prevent running ahead + // of math warps in the Q pipeline. Without this, UMMA can consume + // kNumQStages Q blocks before math warps release any, causing a + // circular dependency: UMMA waits full_q -> TMA_Q waits empty_q + // -> Math waits full_tmem -> UMMA (already moved on). + empty_q_barriers[q_stage_idx]->arrive(); + } + } else if (warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + // Math warpgroups for reduce + cutlass::arch::warpgroup_reg_alloc(); + + const auto math_warpgroup_idx = warpgroup_idx; + const auto math_thread_idx = threadIdx.x; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr uint32_t N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; + + // Math warpgroups process TMEM stages alternately + // Advance pipeline to align with the assigned stage + advance_tmem_pipeline(math_warpgroup_idx); + + // Local register buffers + float accum[kNumHeads]; + float weights[BLOCK_Q][kNumHeads]; + + // Enumerate Q blocks + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Load KV block ranges + CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks); + + // Wait TMA Q arrivals + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + // TODO: optimize bank conflicts + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + + // Enumerate KV blocks + for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) { + // Calculate KV offset in advance + auto kv_offset = kv_start + kv_idx * BLOCK_KV + math_thread_idx; + + // Advance pipeline by `kNumMathWarpGroups` steps + // Wait UMMA arrival + CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase); + full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase); + ptx::tcgen05_after_thread_sync(); + + // Reduce over the head dim and store + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + // Load accumulator from TMEM + uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads; + tmem_load(cute::Int{}, tmem_addr, accum); + tmem_load(cute::Int{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2); + + // Release TMEM empty + if (i == BLOCK_Q - 1) { + ptx::tcgen05_before_thread_sync(); + empty_tmem_barriers[tmem_stage_idx]->arrive(); + } + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(sum.x + sum.y); + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + // TODO: optimize performance + const auto q_offset = (q_idx * BLOCK_Q + i) * static_cast(logits_stride); + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i]) + logits[q_offset + kv_offset - seq_k_start[i]] = result; + } else { + logits[q_offset + kv_offset] = result; + } + __syncwarp(); + } + } + + // Release last Q empty + empty_q_barriers[q_stage_idx]->arrive(); + } + + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } +} + +} // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh new file mode 100644 index 00000000..d9add534 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh @@ -0,0 +1,510 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, + const uint32_t logits_stride, const uint32_t block_table_stride, + const uint32_t* context_lens, logits_dtype_t* logits, + const uint32_t* block_table, const uint32_t* indices, + const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; + + // Prefetch TMA descriptors + if (warp_idx == kSpecWarpStart) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_sf_q); + cute::prefetch_tma_descriptor(&tensor_map_weights); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_sf_kv); + } + + // For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill. + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom); + + // UMMA configs + static constexpr uint32_t kNumTmemStages = 3; + static constexpr uint32_t kNumUTCCPAlignedElems = 128; + static constexpr uint32_t UMMA_M = 128; + static constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads; + static constexpr uint32_t UMMA_K = 64; + static constexpr uint32_t kNumSFQAtom = math::constexpr_align(kNextNAtom * kNumHeads, kNumUTCCPAlignedElems); + static constexpr uint32_t kNumSFKV = math::constexpr_align(SPLIT_KV, kNumUTCCPAlignedElems); + static constexpr uint32_t kRealNumSFQAtom = kNextNAtom * kNumHeads; + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(SPLIT_KV == kNumMathWarpGroups * UMMA_M and SPLIT_KV % kNumUTCCPAlignedElems == 0, "Invalid `SPLIT_KV`"); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2); + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQAtom * sizeof(int); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q and KV data on shared memory + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i; + }); + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i; + }); + const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages); + auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i); + }); + auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i); + }); + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; + auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; }); + auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(tmem_barrier_ptr + kNumTmemStages * 2); + + // Tensor memory configs + constexpr uint32_t kNumAccumTmemCols = kNextNAtom * kNumHeads * kNumTmemStages; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQAtom / 32; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Initialize barriers + if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads + 32); + } + cutlass::arch::fence_barrier_init(); + } + if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(1); + } + cutlass::arch::fence_barrier_init(); + } + if (warp_idx == kSpecWarpStart + 2) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumTmemStages; ++i) { + full_tmem_barriers[i]->init(1); + empty_tmem_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Scheduler + constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; + using Scheduler = sched::PagedMQALogitsScheduler; + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); + + // Make Q, KV and TMEM pipeline + auto make_pipeline = [](const uint32_t& num_stages) { + // Return current stage and phase, and advance pipeline by steps + return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple { + uint32_t current_idx = iter_idx; + iter_idx += step; + return {current_idx % num_stages, (current_idx / num_stages) & 1}; + }; + }; + auto advance_q_pipeline = make_pipeline(kNumQStages); + auto advance_kv_pipeline = make_pipeline(kNumKVStages); + auto advance_tmem_pipeline = make_pipeline(kNumTmemStages); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 56; + constexpr uint32_t kNumMathRegisters = 224; + + if (warp_idx == kSpecWarpStart) { + // TMA warp for loading Q + cutlass::arch::warpgroup_reg_dealloc(); + + if (cute::elect_one_sync()) { + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + + // Persistently schedule over blocks + // Initialize outside valid range to indicate no previous task + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, _, __; + while (scheduler.fetch_next_task(q_atom_idx, _, __)) { + // Issue TMA Q when (q_idx, atom_idx) changes + if (q_atom_idx != last_q_atom_idx) { + // Wait Q consumer release + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + const auto q_token_idx = Scheduler::atom_to_token_idx(q_atom_idx); + cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast(full_q_barriers[q_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_q[q_stage_idx], 0, q_token_idx * kNumHeads); + tma::copy(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_token_idx); + tma::copy(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_token_idx); + full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQAtom * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE); + } + last_q_atom_idx = q_atom_idx; + } + } + __syncwarp(); + } else if (warp_idx == kSpecWarpStart + 1) { + // TMA warp for loading KV cache + cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + + // Persistently schedule over blocks + uint32_t kv_block_idx_ptr = 32, kv_block_idx_storage; + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, kv_idx, num_kv; + while (scheduler.fetch_next_task(q_atom_idx, kv_idx, num_kv)) { + // Reset block table cache on kv restart + if (q_atom_idx != last_q_atom_idx) + kv_block_idx_ptr = 32; + last_q_atom_idx = q_atom_idx; + + // Coalesced load of block table + if (kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast(block_table_stride); + kv_block_idx_storage = (kv_idx + lane_idx < num_kv) + ? block_table[block_table_offset + kv_idx + lane_idx] : 0; + } + __syncwarp(); + + // Broadcast KV block indices + int kv_block_idx[kNumBlocksPerSplit]; + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) + kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); + kv_block_idx_ptr += kNumBlocksPerSplit; + DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `SPLIT_KV`"); + + // Wait KV consumer release + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + + // Issue TMA KV + if (cute::elect_one_sync()) { + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim / 2) * i, + 0, 0, kv_block_idx[i]); + tma::copy(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx], + smem_sf_kv[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); + } + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE); + } + } + } else if (warp_idx == kSpecWarpStart + 2) { + // UMMA warp + cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + // Persistently schedule over blocks + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, kv_idx, _; + while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) { + // Wait TMA Q arrivals + uint32_t q_stage_idx, q_phase; + if (q_atom_idx != last_q_atom_idx) { + CUTE_TIE(advance_q_pipeline(), q_stage_idx, q_phase); + + // Release previous Q empty (UMMA warp must participate to prevent + // running ahead of math warps in the Q pipeline) + if (last_q_atom_idx != batch_size * kNumNextNAtoms) + empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive(); + + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Transpose and copy SF Q + #pragma unroll + for (uint32_t i = 0; i < kNumSFQAtom / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + if (cute::elect_one_sync()) + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4); + __syncwarp(); + } + } + last_q_atom_idx = q_atom_idx; + + // Wait TMA KV arrivals + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Transpose + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + } + + // UMMA with SF + if (cute::elect_one_sync()) { + // Copy SF KV + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4); + } + + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + // Wait TMEM release + CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase); + uint32_t tmem_addr = tmem_stage_idx * UMMA_N; + + empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Issue UMMA with SF + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2); + // TODO: generalize UMMA desc + DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim"); + auto a_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + auto b_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_q[q_stage_idx] + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + ptx::SM100_MMA_MXF4_SS::fma(a_desc, b_desc, tmem_addr, k, runtime_instr_desc, + kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ); + } + // TODO: move this PTX into headers + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" + ::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx]))); + } + } + cutlass::arch::umma_arrive(reinterpret_cast(empty_kv_barriers[kv_stage_idx])); + } + } else if (warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + // Math warpgroups for reduce + cutlass::arch::warpgroup_reg_alloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + + const auto math_warpgroup_idx = warpgroup_idx; + const auto math_thread_idx = warp_idx * 32 + lane_idx; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr int N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; + + // Math warpgroups process TMEM stages alternately + // Advance pipeline to align with the assigned stage + advance_tmem_pipeline(math_warpgroup_idx); + + // Local register buffers + float accum[kNumHeads]; + float weights[kNextNAtom][kNumHeads]; + + // Persistently schedule over blocks + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, kv_idx, _; + bool is_paired_atom = false; + while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) { + if (q_atom_idx != last_q_atom_idx) { + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + + // Release last Q empty + if (last_q_atom_idx != batch_size * kNumNextNAtoms) + empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrivals + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextNAtom; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + float4 raw = ptx::ld_shared((float4*)(smem_weights[q_stage_idx] + i * kNumHeads + j)); + weights[i][j + 0] = raw.x; + weights[i][j + 1] = raw.y; + weights[i][j + 2] = raw.z; + weights[i][j + 3] = raw.w; + } + } + + // Check if this atom pairs two tokens from the same sequence + if constexpr (kIsVarlen) { + is_paired_atom = (scheduler.get_atom_advance(q_atom_idx, batch_size) == 2); + } + } + last_q_atom_idx = q_atom_idx; + + // Calculate KV offset in advance + auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx; + + // Advance pipeline by `kNumMathWarpGroups` steps + // Wait UMMA arrival + CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase); + full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase); + ptx::tcgen05_after_thread_sync(); + + // Reduce over the head dim and store + const auto reduce_and_store = [&](auto num_iters_c) { + constexpr uint32_t kNumIters = decltype(num_iters_c)::value; + + // Only loop over valid iterations + #pragma unroll + for (uint32_t i = 0; i < kNumIters; ++ i) { + // Load accumulator from TMEM + uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads; + tmem_load(cute::Int{}, tmem_addr, accum); + tmem_load(cute::Int{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2); + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(sum.x + sum.y); + + // Store into the global memory + logits[kv_offset + i * static_cast(logits_stride)] = result; + __syncwarp(); + } + + // Release TMEM empty + ptx::tcgen05_before_thread_sync(); + empty_tmem_barriers[tmem_stage_idx]->arrive(); + }; + + if constexpr (kIsVarlen) { + if (is_paired_atom) + reduce_and_store(cute::Int{}); + else + reduce_and_store(cute::Int<1>{}); + } else if constexpr (kPadOddN) { + if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1) + reduce_and_store(cute::Int<1>{}); + else + reduce_and_store(cute::Int{}); + } else { + reduce_and_store(cute::Int{}); + } + } + + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } +} + +} // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh new file mode 100644 index 00000000..0bc6a3fe --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh @@ -0,0 +1,514 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // MMA Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast; + constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N; + constexpr uint32_t UMMA_K = 32; + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or + (not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size"); + + // SF configs + constexpr uint32_t kNumUTCCPAlignedElems = 128; + constexpr uint32_t SF_BLOCK_M = math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems); + constexpr uint32_t SF_BLOCK_N = math::constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); + constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4; + constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4; + DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A"); + DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B"); + DG_STATIC_ASSERT((kGemmType != GemmType::KGroupedContiguous) or kGranKA == kGranKB, "K-grouped SF requires kGranKA == kGranKB"); + + // Epilogue configs + // Always enable pipeline for better performance + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + // NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N + // per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases + constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M; + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t); + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); + // NOTES: Make sure we have enough shared memory for UMMA padding + constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + + // Tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Synchronize the cluster before 2-CTA TMEM allocation + kNumMulticast > 1 ? cute::cluster_sync() : void(); + + // Utils + const bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = ptx::get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + const auto shape_sfa_k = math::ceil_div(shape_k, kGranKA * 4); + const auto shape_sfb_k = math::ceil_div(shape_k, kGranKB * 4); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // D/A/B shared memory + auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SFA/SFB shared memory + auto sf_start_ptr = reinterpret_cast(smem_b[kNumStages]); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Barriers and tensor memory pointer + auto barrier_start_ptr = reinterpret_cast(smem_sfb[kNumStages]);; + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto with_sf_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + // Arrive only at the leader CTA + with_sf_full_barriers[i]->init(kNumMulticast * 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = sched::Scheduler( + shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Dispatch warps into different roles + if (warp_idx == 0 and cute::elect_one_sync()) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Use dynamic load block M, when swap-AB is enabled + const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M; + + // For k-grouped layout, the number of block K is variable + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma::copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma::copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma::copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma::copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx); + auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2) + + SMEM_B_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2); + + // Issue SFA and SFB TMAs at certain stages + // No swizzling, so one TMA for one SF is enough + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + uint32_t sfa_m_idx = m_block_idx * BLOCK_M; + uint32_t sfa_k_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::SF_K>( + shape_sfa_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad)); + tma::copy(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx); + num_arrival_bytes += BLOCK_M * sizeof(uint32_t); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + uint32_t sfb_n_idx = n_block_idx * BLOCK_N; + uint32_t sfb_k_idx = scheduler.template get_global_idx( + shape_sfb_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx); + tma::copy(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx); + num_arrival_bytes += BLOCK_N * sizeof(uint32_t); + } + + // Arrive at full barriers + full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes); + } + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc_block_scaled() + : cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); + }; + + // Dynamic update of UMMA N based on effective M, when swap-AB is enabled + if constexpr (kSwapAB) { + uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx); + mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n); + } + + // Launch MMAs + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + #pragma unroll 4 + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA and SF-transpose arrival + with_sf_full_barriers[stage_idx]->wait(phase); + ptx::tcgen05_after_thread_sync(); + + const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx); + const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + // Do SF copy at certain stages + // TODO: process shared memory descriptor by addition + using cute_utccp_t = cute::conditional_t; + const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad; + if (sfa_stage_in_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + } + const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad; + if (sfb_stage_in_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + } + + // Issue UMMA + using mma_t = cute::conditional_t< + kNumMulticast == 1, ptx::SM100_MMA_MXF8F6F4_SS, ptx::SM100_MMA_MXF8F6F4_2x1SM_SS>; + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx); + const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx); + const auto runtime_instr_desc = kSwapAB ? + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfb_id, sfa_id): + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id); + + a_desc.lo = mma::sm100::advance_umma_desc_lo(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + if constexpr (kSwapAB) { + mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFB, kTmemStartColOfSFA); + } else { + mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFA, kTmemStartColOfSFB); + } + } + } + __syncwarp(); + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } else if (warp_idx == 2) { + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + + // Transpose for UTCCP at certain stages + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + + // Arrive + with_sf_full_barriers[stage_idx]->arrive(0u); + } + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { + // Epilogue warp groups + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + ptx::tcgen05_after_thread_sync(); + + const auto tmem_base_addr = accum_stage_idx * UMMA_N; + const auto base_m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto base_n_idx = n_block_idx * BLOCK_N; + + if constexpr (kSwapAB) { + const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx); + epilogue::sm100_store_cd_swap_ab< + BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N, + kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads, + kGemmType, kWithAccumulation, + cd_dtype_t, epilogue_type_t> + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + effective_m, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } else { + epilogue::sm100_store_cd< + BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N, + kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads, + kGemmType, kWithAccumulation, + cd_dtype_t, epilogue_type_t> + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } + } + } + + // TODO: Remove redundant synchronization + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Deallocate tensor memory + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh new file mode 100644 index 00000000..b2adc6c7 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh @@ -0,0 +1,1380 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template < + uint32_t kNumMaxTokensPerRank, + uint32_t kHidden, uint32_t kIntermediateHidden, + uint32_t kNumExperts, uint32_t kNumTopk, + uint32_t kNumExpertsPerWave, + uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, + uint32_t STORE_BLOCK_M, + uint32_t SF_BLOCK_M, uint32_t SF_BLOCK_N, + uint32_t kNumMaxPoolTokens, + uint32_t kNumPaddedSFPoolTokens, + uint32_t kNumStages, + uint32_t kNumDispatchThreads, uint32_t kNumNonEpilogueThreads, + uint32_t kNumEpilogueThreads, + uint32_t kNumSMs, uint32_t kNumRanks, + float kActivationClamp, + bool kFastMath, + uint32_t L1_SHAPE_N = kIntermediateHidden * 2, + uint32_t L1_SHAPE_K = kHidden, + uint32_t L2_SHAPE_N = kHidden, + uint32_t L2_SHAPE_K = kIntermediateHidden, + uint32_t kNumDispatchWarps = kNumDispatchThreads / 32, + uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32, + uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32, + uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4, + uint32_t kNumThreads = kNumDispatchThreads + kNumNonEpilogueThreads + kNumEpilogueThreads, + uint32_t kNumTokensPerWarp = 32 / kNumTopk, + uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks +> +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void +sm100_fp8_fp4_mega_moe_impl(void* y, + int* cumulative_local_expert_recv_stats, + const uint32_t num_tokens, + const __grid_constant__ layout::SymBuffer sym_buffer, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_output, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights_sf) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::TMEM::Allocator2Sm; + + // Template checks + DG_STATIC_ASSERT(kNumDispatchThreads % 128 == 0, "Invalid number of dispatch threads"); + DG_STATIC_ASSERT(kNumNonEpilogueThreads == 128, "Invalid number of MMA non-epilogue threads"); + DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Invalid number of MMA epilogue and combine threads"); + DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid number of experts or ranks"); + + // Thread indices + const bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const uint32_t sm_idx = blockIdx.x; + const uint32_t thread_idx = threadIdx.x; + const uint32_t warp_idx = cutlass::canonical_warp_idx_sync(); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0) { + cute::prefetch_tma_descriptor(&tensor_map_l1_acts); + cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_output); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights_sf); + } + + // Workspaces + const auto workspace = layout::Workspace( + sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk); + + // Token and buffer layouts + constexpr auto fp8_token_layout = layout::Data(kHidden); + constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); + constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); + constexpr auto fp8_sf_layout = layout::Data(kHidden / 32); + constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 32); + constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); + constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); + constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + // Registered inputs + const auto input_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxTokensPerRank, + workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumMaxTokensPerRank, + input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer( + input_topk_idx_layout, 1, kNumMaxTokensPerRank, + input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer( + input_topk_weights_layout, 1, kNumMaxTokensPerRank, + input_topk_idx_buffer.get_end_ptr()); + + // SF and its buffer configs + constexpr uint32_t kGranK = 32; + constexpr uint32_t kNumUTCCPAlignedElems = 128; + DG_STATIC_ASSERT(SF_BLOCK_M == math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems), "Invalid SF_BLOCK_M"); + DG_STATIC_ASSERT(SF_BLOCK_N == BLOCK_N, "No padding is needed for SFB"); + + // UTCCP 4x32 transpose index mapping within each 128-element group + const auto transform_sf_token_idx = [](const uint32_t& token_idx_in_expert) { + const uint32_t idx = token_idx_in_expert % BLOCK_M; + return token_idx_in_expert / BLOCK_M * SF_BLOCK_M + + (idx & ~127u) + (idx & 31u) * 4 + ((idx >> 5) & 3u); + }; + + // L1 inputs + const auto l1_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxPoolTokens, + input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumPaddedSFPoolTokens, + l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer( + l1_topk_weights_layout, 1, kNumMaxPoolTokens, + l1_sf_buffer.get_end_ptr()); + + // L2 inputs + const auto l2_token_buffer = layout::Buffer( + fp8_intermediate_token_layout, 1, kNumMaxPoolTokens, + l1_topk_weights_buffer.get_end_ptr() + ); + const auto l2_sf_buffer = layout::Buffer( + fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, + l2_token_buffer.get_end_ptr() + ); + + // Combine inputs + const auto combine_token_buffer = layout::Buffer( + bf16_token_layout, kNumTopk, kNumMaxTokensPerRank, + l2_sf_buffer.get_end_ptr() + ); + + // Data types + // NOTES: activations are FP8 (e4m3), weights are FP4 (e2m1) + using a_dtype_t = cutlass::float_e4m3_t; + using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; + + // MMA configs + // NOTES: always swap A/B, 2-CTA MMA, and matrices are K-major + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t UMMA_M = LAYOUT_AD_M * 2; + constexpr uint32_t UMMA_N = BLOCK_M; // Swap AB + constexpr uint32_t UMMA_K = 32; + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N; + DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_N == LAYOUT_AD_M, "Invalid block N"); + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + + // Swizzle configs + constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t kSwizzleCDMode = 128; + DG_STATIC_ASSERT(BLOCK_N % kSwizzleCDMode == 0, "Invalid block N"); + + // Epilogue configs + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + + // Shared memory + constexpr uint32_t kSharedMemoryAlignment = 1024; + extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; + + // Shared memory sizes + // NOTES: FP8 CD output for L1 (2 TMA stages, BLOCK_N/2 post-SwiGLU), BF16 output for L2 (no TMA, a single stage) + constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2; + constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = + math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); + constexpr uint32_t SMEM_SEND_BUFFER_SIZE = + math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + constexpr uint32_t SMEM_CD_L1_SIZE = + kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) * kNumTMAStoreStages; + constexpr uint32_t SMEM_CD_L2_SIZE = + kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE; + constexpr uint32_t SMEM_CD_L1_SIZE_PER_STAGE = SMEM_CD_L1_SIZE / kNumTMAStoreStages; + constexpr uint32_t SMEM_BEFORE_BARRIER_SIZE = + SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + DG_STATIC_ASSERT(SMEM_CD_SIZE % kSharedMemoryAlignment == 0 and + SMEM_A_SIZE_PER_STAGE % kSharedMemoryAlignment == 0 and + SMEM_B_SIZE_PER_STAGE % kSharedMemoryAlignment == 0, + "Shared memory of CD/A/B must be aligned to 1024 bytes"); + + // Tensor memory size + constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Assign shared memory for dispatch warps + const auto smem_expert_count = reinterpret_cast(smem_buffer); + const auto smem_send_buffers = layout::Buffer( + fp8_token_layout, kNumDispatchWarps, 1, + math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE)); + + // GEMM shared memory: C/D, A, B + // NOTES: GEMM shared memory starts after the dispatch region, aligned to 1024 bytes + auto smem_gemm_base = math::advance_ptr( + smem_buffer, SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + ); + + // D/A/B shared memory + auto smem_cd = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, i * SMEM_CD_L1_SIZE_PER_STAGE); + }); + auto smem_cd_l2 = smem_cd[0]; + auto smem_a = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SF shared memory: SFA and SFB per pipeline stage + auto sf_start_ptr = math::advance_ptr(smem_gemm_base, + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Epilogue amax reduction shared memory + auto smem_amax_reduction = reinterpret_cast(smem_sfb[kNumStages]); + + // Barriers and tensor memory pointer + auto barrier_start_ptr = reinterpret_cast(smem_amax_reduction + STORE_BLOCK_M * kNumEpilogueWarps / 2); + auto dispatch_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages + i); }); + auto combine_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + i); }); + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + kNumEpilogueWarps * 2); + + // A cluster sync is essential for 2CTA tensor memory allocation + comm::cluster_sync_with_relaxed_arrive(); + + // Initialization + if (warp_idx == 0) { + // Clean shared memory + if (cute::elect_one_sync()) + ptx::st_shared_bulk(smem_expert_count, kNumExperts * sizeof(uint32_t)); + } else if (warp_idx == 1) { + // Init m-barriers for dispatch + #pragma unroll + for (uint32_t i = lane_idx; i < kNumDispatchWarps; i += 32) + dispatch_barriers[i]->init(1); + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Init GEMM barriers + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(2 * 2); + empty_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(2 * kNumEpilogueThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueWarps * 2; ++ i) + combine_barriers[i]->init(1); + } + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 3) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + // NOTES: Using `.relaxed` is allowed here since `fence_barrier_init` is `.release.cluster`, + // and `barrier.cluster.wait.aligned` is by default `.acquire` + comm::cluster_sync_with_relaxed_arrive(); + + // Task scheduler + auto scheduler = sched::MegaMoEScheduler< + BLOCK_M, BLOCK_N, BLOCK_K, + L1_SHAPE_N, L1_SHAPE_K, + L2_SHAPE_N, L2_SHAPE_K, + kNumExpertsPerRank, + kNumExpertsPerWave, + kNumSMs, kNumRanks>(workspace); + + // MMA pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Intra-SM Barrier indices + constexpr uint32_t kDispatchBarrierIdx = 0; + constexpr uint32_t kDispatchWithEpilogueBarrierIdx = 1; + constexpr uint32_t kEpilogueFullBarrierIdx = 2; + constexpr uint32_t kEpilogueWGBarrierStartIdx = 3; + + // NVLink barrier tags + constexpr uint32_t kBeforeDispatchPullBarrierTag = 1; + constexpr uint32_t kBeforeCombineReduceBarrierTag = 2; + constexpr uint32_t kAfterWorkspaceCleanBarrierTag = 3; + + // Adjust registers + constexpr uint32_t kNumDispatchRegisters = 48; + constexpr uint32_t kNumNonEpilogueRegisters = 40; + constexpr uint32_t kNumEpilogueRegisters = 208; + DG_STATIC_ASSERT(kNumDispatchRegisters * kNumDispatchThreads + + kNumNonEpilogueRegisters * kNumNonEpilogueThreads + + kNumEpilogueRegisters * kNumEpilogueThreads <= 64512, + "Too many registers"); + + // Grid sync index assignments (dispatch and epilogue use separate counters to avoid conflicts) + constexpr uint32_t kDispatchGridSyncIndex = 0; + constexpr uint32_t kEpilogueGridSyncIndex = 1; + + // Different warp roles + if (warp_idx < kNumDispatchWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // Dispatch warps + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + constexpr uint32_t kNumActivateLanes = kNumTokensPerWarp * kNumTopk; + const auto read_topk_idx = [&](const auto& process) { + // TODO: figure out better unrolling + // Now, `unroll` is better than `unroll 8` + #pragma unroll + for (uint32_t i = (sm_idx * kNumDispatchWarps + warp_idx) * kNumTokensPerWarp; + i < num_tokens; + i += kNumSMs * kNumDispatchWarps * kNumTokensPerWarp) { + // Allocate slots for each token-topk + int expert_idx = -1; + if (i + (lane_idx / kNumTopk) < num_tokens and lane_idx < kNumActivateLanes) { + expert_idx = static_cast( + __ldg(input_topk_idx_buffer.get_base_ptr() + i * kNumTopk + lane_idx)); + if (expert_idx >= 0) + process(i * kNumTopk + lane_idx, expert_idx); + } + __syncwarp(); + } + }; + + // Count experts' tokens + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + atomicAdd_block(smem_expert_count + expert_idx, 1); + }); + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Get SM offset (~6.5 us) + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const uint64_t send_value = (1ull << 32) | static_cast(smem_expert_count[i]); + smem_expert_count[i] = static_cast( + ptx::atomic_add(workspace.get_expert_send_count_ptr(i), send_value)); + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Write source indices (~2 us with 512 tokens) + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + const auto dst_rank_idx = expert_idx / kNumExpertsPerRank; + const auto dst_slot_idx = atomicAdd_block(smem_expert_count + expert_idx, 1); + const auto dst_ptr = workspace.get_src_token_topk_idx_ptr( + expert_idx % kNumExpertsPerRank, sym_buffer.rank_idx, dst_slot_idx); + *sym_buffer.map(dst_ptr, dst_rank_idx) = token_topk_idx; + }); + + // Grid sync + comm::grid_sync( + workspace, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); } + ); + + // Write expert count + if (sm_idx == 0) { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const auto dst_rank_idx = i / kNumExpertsPerRank; + const auto dst_local_expert_idx = i % kNumExpertsPerRank; + const auto expert_status = *workspace.get_expert_send_count_ptr(i); + *sym_buffer.map( + workspace.get_expert_recv_count_ptr(sym_buffer.rank_idx, dst_local_expert_idx), + dst_rank_idx) = expert_status & 0xffffffff; + ptx::atomic_add_sys( + sym_buffer.map(workspace.get_expert_recv_count_sum_ptr(dst_local_expert_idx), dst_rank_idx), + expert_status); + } + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Barrier before pulling + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + /* After the grid sync above, there is no more writes by other SMs (except 0) */ false, + /* After the NVLink barrier, there is a grid sync */ true + ); + + // Ensure the epilogue barrier cannot run with the pull barrier + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Pull token data and SF from remote ranks into local L1 buffer + uint32_t pull_mbarrier_phase = 0; + const auto pull_buffer = smem_send_buffers.get_rank_buffer(warp_idx).get_data_buffer(0); + const auto pull_mbarrier = dispatch_barriers[warp_idx]; + + // Cache expert token counts in registers (same pattern as scheduler) + scheduler.fetch_expert_recv_count(); + + // Per-rank counts for current expert (re-loaded when expert changes) + constexpr uint32_t kNumRanksPerLane = math::constexpr_ceil_div(kNumRanks, 32u); + int current_expert_idx = -1; + uint32_t stored_rank_count[kNumRanksPerLane] = {}; + uint32_t expert_start_idx = 0, expert_end_idx = 0; + uint32_t expert_pool_block_offset = 0; + + constexpr uint32_t kNumGlobalWarps = kNumSMs * kNumDispatchWarps; + for (uint32_t token_idx = sm_idx * kNumDispatchWarps + warp_idx; ; token_idx += kNumGlobalWarps) { + // Advance expert until within the range + int old_expert_idx = current_expert_idx; + while (token_idx >= expert_end_idx) { + if (++ current_expert_idx >= kNumExpertsPerRank) + break; + + // Update pool block offset for the new expert + expert_pool_block_offset += math::ceil_div(expert_end_idx - expert_start_idx, BLOCK_M); + + // Move start and end to the next expert + expert_start_idx = expert_end_idx; + expert_end_idx += scheduler.get_num_tokens(current_expert_idx); + } + + // Finish all tokens + if (current_expert_idx >= kNumExpertsPerRank) + break; + + // Load per-rank counts when expert changes + if (old_expert_idx != current_expert_idx) { + old_expert_idx = current_expert_idx; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t j = i * 32 + lane_idx; + // TODO: this is not coalesced + stored_rank_count[i] = j < kNumRanks ? + static_cast(*workspace.get_expert_recv_count_ptr(j, current_expert_idx)) : 0; + } + } + + // Round-robin rank selection via iterative min-peeling + uint32_t current_rank_in_expert_idx; + uint32_t remaining[kNumRanksPerLane]; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] = stored_rank_count[i]; + uint32_t offset = 0; + uint32_t token_idx_in_expert = token_idx - expert_start_idx; + uint32_t slot_idx = token_idx_in_expert; + uint32_t token_idx_in_rank; + while (true) { + // Compute active count and min across all ranks + // NOTES: reduce within each lane first, then warp-reduce once + uint32_t num_actives_in_lane = 0; + uint32_t min_in_lane = 0xffffffff; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + num_actives_in_lane += remaining[i] > 0; + if (remaining[i] > 0) + min_in_lane = cute::min(min_in_lane, remaining[i]); + } + const uint32_t num_active_ranks = __reduce_add_sync(0xffffffff, num_actives_in_lane); + const uint32_t length = __reduce_min_sync(0xffffffff, min_in_lane); + + // Hit in the current round + const uint32_t num_round_tokens = length * num_active_ranks; + if (slot_idx < num_round_tokens) { + const uint32_t slot_idx_in_round = slot_idx % num_active_ranks; + uint32_t num_seen_ranks = 0; + current_rank_in_expert_idx = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t mask = __ballot_sync(0xffffffff, remaining[i] > 0); + const uint32_t num_active_lanes = __popc(mask); + if (slot_idx_in_round >= num_seen_ranks and slot_idx_in_round < num_seen_ranks + num_active_lanes) + current_rank_in_expert_idx = i * 32 + __fns(mask, 0, slot_idx_in_round - num_seen_ranks + 1); + num_seen_ranks += num_active_lanes; + } + token_idx_in_rank = offset + (slot_idx / num_active_ranks); + break; + } + + // Move into the next round + slot_idx -= num_round_tokens; + offset += length; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] -= cute::min(remaining[i], length); + } + + // Read source token-topk index (written by remote dispatch via NVLink) + const uint32_t src_token_topk_idx = *workspace.get_src_token_topk_idx_ptr( + current_expert_idx, current_rank_in_expert_idx, token_idx_in_rank); + const uint32_t src_token_idx = src_token_topk_idx / kNumTopk; + const uint32_t src_topk_idx = src_token_topk_idx % kNumTopk; + + // TMA load token from remote rank into shared memory + if (cute::elect_one_sync()) { + ptx::tma_load_1d( + pull_buffer.get_base_ptr(), + sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx), + pull_mbarrier, kHidden); + } + __syncwarp(); + + // Load and store SF (overlaps with TMA token load) + constexpr uint32_t kNumSFUint32 = kHidden / 128; + DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 128 == 0, "Invalid SF"); + const auto remote_sf_ptr = sym_buffer.map( + input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx); + const auto local_sf_ptr = l1_sf_buffer.get_base_ptr(); + const auto sf_pool_token_idx = expert_pool_block_offset * SF_BLOCK_M + + transform_sf_token_idx(token_idx_in_expert); + #pragma unroll + for (uint32_t i = 0; i < math::constexpr_ceil_div(kNumSFUint32, 32u); ++ i) { + const uint32_t j = i * 32 + lane_idx; + if (j < kNumSFUint32) + local_sf_ptr[j * kNumPaddedSFPoolTokens + sf_pool_token_idx] = remote_sf_ptr[j]; + } + __syncwarp(); + + // Store weights and token data + const uint32_t pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert; + if (cute::elect_one_sync()) { + // Load weights + const auto weight = *sym_buffer.map( + input_topk_weights_buffer.get_base_ptr() + src_token_topk_idx, + current_rank_in_expert_idx); + *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; + + // Wait for TMA token load to complete + ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); + ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase); + + // Store token to local L1 buffer via TMA + ptx::tma_store_1d( + l1_token_buffer.get_data_buffer(pool_token_idx).get_base_ptr(), + pull_buffer.get_base_ptr(), pull_buffer.get_num_bytes()); + + // Write source metadata for combine write-back + *workspace.get_token_src_metadata_ptr(pool_token_idx) = + {current_rank_in_expert_idx, src_token_idx, src_topk_idx}; + + // Wait for token TMA store to complete + cute::tma_store_arrive(); + ptx::tma_store_wait<0>(); + ptx::red_add_rel( + workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + token_idx_in_expert / BLOCK_M), 1); + } + __syncwarp(); + } + + // Clean workspace for the next usage, and also do cumulative stats + // NOTES: it is overlapped with combine reduction epilogue + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + DG_STATIC_ASSERT(kNumSMs > 1, "Invalid SM count"); + if (sm_idx == 0) { + // SM 0: clear expert send count + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) + *workspace.get_expert_send_count_ptr(i) = 0; + } else { + // Other SMs: clean blocks + for (uint32_t i = sm_idx - 1; i < kNumExpertsPerRank; i += kNumSMs - 1) { + // Read expert token count before clearing + const auto num_recv_tokens = static_cast( + *workspace.get_expert_recv_count_sum_ptr(i)); + const auto num_recv_m_blocks = math::ceil_div(num_recv_tokens, BLOCK_M); + + // Compute expert pool block offset + expert_pool_block_offset = scheduler.get_pool_block_offset(i); + + // Wait read count ready + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Clean expert token count, and add cumulative results + DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps"); + if (warp_idx == 0) { + *workspace.get_expert_recv_count_sum_ptr(i) = 0; + } else if (warp_idx == 1) { + if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr) + ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); + __syncwarp(); + } + + // Clean per-rank token count + for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) + *workspace.get_expert_recv_count_ptr(j, i) = 0; + __syncwarp(); + + // Clean L1 and L2 arrival stuffs + for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { + *workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0; + *workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0; + } + __syncwarp(); + } + } + + // Wait for all ranks to finish cleaning + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + /* Before the NVLink barrier, there is a grid sync */ true, + /* At the end of kernel does not need to sync */ false + ); + } else if (warp_idx == kNumDispatchWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM TMA load warp for tokens with SFA + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_a_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts : &tensor_map_l1_acts; + const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf; + + const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K; + const auto shape_sfa_k = math::ceil_div(shape_k, kGranK * 4u); + + // Compute pool block offset for this expert + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + + // Wait the entire token arrival for linear 1 + if (block_phase == sched::BlockPhase::Linear1) { + const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); + const auto expected = scheduler.template get_valid_m(); + while (ptx::ld_acq(ptr) != expected); + } else { + // The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2x L1 blocks' arrival + // NOTES: Originally we wait blocks on-demand to overlap L1 calculation + // with L2, but this optimization is negative when `num_experts_per_wave` + // guarantees L1's completion when L2 starts. So we remove it. + // In the future, if `num_experts_per_wave` is not large enough + // due to small `num_experts_per_rank`, we may need to add it back or add a switch + DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes"); + const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); + // NOTES: Equivalent to `(1ull << (2 * num_k_blocks)) - 1`, but split into two shifts + // to avoid undefined behavior when `num_k_blocks == 32` + const uint64_t expected = ((1ull << num_k_blocks) << num_k_blocks) - 1; + while (ptx::ld_acq_gpu(ptr) != expected); + } + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute token offset from pool block index + uint32_t m_idx = pool_block_idx * BLOCK_M; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfa_m_idx = pool_block_idx * SF_BLOCK_M; + uint32_t sfa_k_idx = k_block_idx; + + // Add 2 CTA offsets for non-leader CTA + if (not is_leader_cta) + m_idx += scheduler.template get_valid_m() / 2; + + // TMA copy tokens and SFA, then arrive at full barrier + if (cute::elect_one_sync()) { + tma::copy( + tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx, 2); + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE * 2 + SF_BLOCK_M * sizeof(uint32_t) * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == kNumDispatchWarps + 1) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM TMA load warp for weights with SF + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_b_ptr = + block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights : &tensor_map_l1_weights; + const auto tensor_map_sfb_ptr = + block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights_sf : &tensor_map_l1_weights_sf; + + const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K; + const auto shape_n = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_N : L1_SHAPE_N; + const auto shape_sfb_k = math::ceil_div(shape_k, kGranK * 4u); + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute weight offset + uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfb_n_idx = n_block_idx * BLOCK_N; + uint32_t sfb_k_idx = local_expert_idx * shape_sfb_k + k_block_idx; + + // TMA copy weights with SF + if (cute::elect_one_sync()) { + tma::copy( + tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx, 2); + tma::copy( + tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE + BLOCK_N * sizeof(uint32_t) * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == kNumDispatchWarps + 2) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM MMA issue warp (only the leader CTA will run) + if (is_leader_cta) { + // Make instruction descriptor with block scaling + // NOTES: always swap A/B + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled< + b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t, + UMMA_M, UMMA_N, + cute::UMMA::Major::K, cute::UMMA::Major::K + >(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + // Dynamic update of UMMA N based on effective M + mma::sm100::update_instr_desc_with_umma_n(instr_desc, scheduler.template get_valid_m()); + + // Wait tensor memory empty barrier arrival + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + constexpr uint16_t kCTAMask = (1 << 2) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); + }; + + // Launch MMAs + #pragma unroll 2 + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA load completion + full_barriers[stage_idx]->wait(phase); + ptx::tcgen05_after_thread_sync(); + + const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx); + const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + // UTCCP copy SFA and SFB to TMEM + using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta; + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + + // Issue UMMA + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const auto runtime_instr_desc = + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k); + a_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K); + ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma( + b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFB, kTmemStartColOfSFA); + } + } + __syncwarp(); + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_k_blocks - 1); + } + }); + + // To safely deconstruct barriers, we need another round of waits + if (current_iter_idx > 0) { + const auto accum_phase_idx = ((current_iter_idx - 1) / kNumEpilogueStages) & 1; + tmem_empty_barriers[(current_iter_idx - 1) % kNumEpilogueStages]->wait(accum_phase_idx); + } + } + } else if (warp_idx == kNumDispatchWarps + 3) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + } else if (warp_idx >= kNumDispatchWarps + kNumMMANonEpilogueWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // GEMM epilogue warps + const auto epilogue_warp_idx = warp_idx - (kNumDispatchWarps + kNumMMANonEpilogueWarps); + const auto epilogue_wg_idx = epilogue_warp_idx / 4; + const auto epilogue_thread_idx = epilogue_warp_idx * 32 + lane_idx; + const auto warp_idx_in_wg = epilogue_warp_idx % 4; + DG_STATIC_ASSERT((kNumDispatchWarps + kNumMMANonEpilogueWarps) % 4 == 0 and + kNumEpilogueWarps % 4 == 0, "Invalid epilogue warps"); + + // TODO: support effective block M + // NOTES: + // - 2 warpgroups divide the whole BM into BM / 2 + // - 4 warps divide the whole BN into BN / 4 + // - BM / 2 is further divided into stored blocks, i.e. with `STORE_BLOCK_M` size + // - `STORE_BLOCK_M` in further divided into `ATOM_M` + constexpr uint32_t WG_BLOCK_M = BLOCK_M / kNumEpilogueWarpgroups; + constexpr uint32_t ATOM_M = 8; + constexpr uint32_t kNumBankGroupBytes = 16u; + constexpr uint32_t kNumAtomsPerStore = STORE_BLOCK_M / ATOM_M; + DG_STATIC_ASSERT(BLOCK_M % kNumEpilogueWarpgroups == 0, "Invalid block M"); + DG_STATIC_ASSERT(WG_BLOCK_M % STORE_BLOCK_M == 0, "Invalid warpgroup block M"); + DG_STATIC_ASSERT(STORE_BLOCK_M % ATOM_M == 0, "Invalid store block M"); + DG_STATIC_ASSERT(BLOCK_N == 128, "Invalid block N"); + + // Ensure the epilogue barrier cannot run with the pull barrier + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Persistently schedule over blocks + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + // Wait UMMA arrival + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1; + tmem_full_barriers[accum_stage_idx]->wait(accum_phase); + ptx::tcgen05_after_thread_sync(); + + // Compute offsets + // NOTES: use shuffle here to let NVCC know warp divergence won't happen + const uint32_t valid_m = ptx::exchange(scheduler.template get_valid_m(), 0); + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + uint32_t m_idx = pool_block_idx * BLOCK_M; + uint32_t n_idx = n_block_idx * BLOCK_N; + + if (block_phase == sched::BlockPhase::Linear1) { + // Unified L1 epilogue: SwiGLU in-place using granularity 8 interleaved weights + // With `SM100_TMEM_LOAD_16dp256b1x`, gate/up pairs are: + // (values[0], values[2]), (values[1], values[3]), + // (values[4], values[6]), (values[5], values[7]) + float stored_cached_weight = 0; + + #pragma unroll + for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) { + // Early break if the entire store block is beyond the valid token range + if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + break; + } + + // Iterate all atoms in the store block + float2 swiglu_values[kNumAtomsPerStore * 2]; + float2 amax_values[kNumAtomsPerStore]; + #pragma unroll + for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { + const uint32_t j = s * kNumAtomsPerStore + i; + + // Load weights from global into register cache per 32 tokens + DG_STATIC_ASSERT(32 % ATOM_M == 0, "Invalid block size"); + if ((j * ATOM_M) % 32 == 0 and (WG_BLOCK_M % 32 == 0 or j * ATOM_M + lane_idx < WG_BLOCK_M)) { + stored_cached_weight = *l1_topk_weights_buffer + .get_data_buffer(m_idx + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M + lane_idx) + .get_base_ptr(); + } + + // Load weights from register cache + const float2 weights = { + ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 0), + ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 1) + }; + + // Load from TMEM + uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M; + uint32_t values[ATOM_M]; + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Signal tensor memory consumed on the last atom + if (j == WG_BLOCK_M / ATOM_M - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Apply SwiGLU: silu(gate) * up + // Gate/up pairs: (0, 2), (1, 3), (4, 6), (5, 7) + auto fp32_values = reinterpret_cast(values); + #pragma unroll + for (uint32_t k = 0; k < 2; ++ k) { + auto bf16_gate = __float22bfloat162_rn(make_float2(fp32_values[k * 4], fp32_values[k * 4 + 1])); + auto bf16_up = __float22bfloat162_rn(make_float2(fp32_values[k * 4 + 2], fp32_values[k * 4 + 3])); + + // Clamp + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) { + bf16_gate = __hmin2(bf16_gate, {kActivationClamp, kActivationClamp}); + bf16_up = __hmax2(bf16_up, {-kActivationClamp, -kActivationClamp}); + bf16_up = __hmin2(bf16_up, {kActivationClamp, kActivationClamp}); + } + + // SwiGLU + auto gate = __bfloat1622float2(bf16_gate); + auto neg_gate_exp = make_float2( + kFastMath ? __expf(-gate.x) : expf(-gate.x), + kFastMath ? __expf(-gate.y) : expf(-gate.y)); + const auto denom = __fadd2_rn({1.0f, 1.0f}, neg_gate_exp); + if constexpr (kFastMath) { + gate = __fmul2_rn(gate, {math::fast_rcp(denom.x), math::fast_rcp(denom.y)}); + } else { + gate = {gate.x / denom.x, gate.y / denom.y}; + } + const auto up = __bfloat1622float2(bf16_up); + swiglu_values[i * 2 + k] = __fmul2_rn(__fmul2_rn(gate, up), weights); + } + + // Amax reduction + amax_values[i].x = math::warp_reduce<4, true>( + cute::max(cute::abs(swiglu_values[i * 2 + 0].x), cute::abs(swiglu_values[i * 2 + 1].x)), + math::ReduceMax()); + amax_values[i].y = math::warp_reduce<4, true>( + cute::max(cute::abs(swiglu_values[i * 2 + 0].y), cute::abs(swiglu_values[i * 2 + 1].y)), + math::ReduceMax()); + if (lane_idx < 4) + smem_amax_reduction[epilogue_warp_idx * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx] = amax_values[i]; + __syncwarp(); + } + + // Wait shared memory release from previous TMA store + // And fence `smem_amax_reduction` + const uint32_t tma_stage_idx = s % kNumTMAStoreStages; + ptx::tma_store_wait(); + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Cast to FP8 E4M3 and store into shared memory + #pragma unroll + for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { + // Reduce amax + const float2 wp_amax = + smem_amax_reduction[(epilogue_warp_idx ^ 1) * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx % 4]; + amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x); + amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y); + + // Calculate SF + float2 sf, sf_inv; + math::get_e4m3_sf_and_sf_inv(amax_values[i], sf, sf_inv); + + // Cast + const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv); + const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv); + const auto fp8x4_values = __nv_fp8x4_e4m3(make_float4(upper.x, upper.y, lower.x, lower.y)); + + // STSM + uint32_t row = lane_idx; + uint32_t col = warp_idx_in_wg; + const auto smem_ptr = smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N + + i * ATOM_M * L1_OUT_BLOCK_N + + row * L1_OUT_BLOCK_N + + (col ^ (row / 2)) * kNumBankGroupBytes; + ptx::SM100_U8x4_STSM_T<__nv_fp8x4_e4m3>::copy(fp8x4_values, smem_ptr); + + // Store SF to `l2_sf_buffer` as UE8M0 (MN-major layout) + // Only one warp per pair writes (both hold the same SF after cross-warp reduce) + // Each lane < 4 holds SF for 2 rows (sf.x and sf.y) + if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) { + const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2; + const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4; + const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t); + const auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); + // NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4 + // NOTES: originally there was: + // - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2 + // - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)` + // We find out that + // 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside + // 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside + // This reduce the number of computation instructions. + const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; + __builtin_assume(token_base_idx < BLOCK_M); + const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M + + m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4; + const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast(sizeof(uint32_t)) + byte_idx; + sf_base_ptr[sf_addr] = + (*reinterpret_cast(&sf.x) >> 23); + sf_base_ptr[sf_addr + 4 * static_cast(sizeof(uint32_t))] = + (*reinterpret_cast(&sf.y) >> 23); + } + __syncwarp(); + } + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Issue TMA store after all atoms in this store block + if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { + uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N, + out_n_idx, + m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M); + cute::tma_store_arrive(); + } + __syncwarp(); + } + + // Notify L2 + // TODO: less epilogue sync scope + ptx::tma_store_wait<0>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + DG_STATIC_ASSERT(L2_SHAPE_K <= 64 * L1_OUT_BLOCK_N, "L2 shape K is too large"); + ptx::red_or_rel_gpu( + workspace.get_l2_arrival_mask_ptr(pool_block_idx), + 1ull << n_block_idx + ); + } + __syncwarp(); + } else { + DG_STATIC_ASSERT(STORE_BLOCK_M % 8 == 0, "Invalid store M"); + constexpr uint32_t kNumRowsPerWarp = STORE_BLOCK_M / 8; + + // L2 BF16 epilogue: write GEMM output to remote combine buffer via NVLink + #pragma unroll + for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) { + // Early break if the entire store block is beyond the valid token range + // TODO: check performance + if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + break; + } + + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_M / ATOM_M; ++ i) { + // Load from TMEM using .16x256b shape to satisfy STSM layout requirements + // Start from lane index 0 and 16 + uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; + uint32_t values[ATOM_M]; + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Wait shared memory release from previous NVLink store + // NOTES: skip for the first store block since the prior full barrier already ensures completion + if (i == 0 and s > 0) + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Signal tensor memory consumed + if (s == WG_BLOCK_M / STORE_BLOCK_M - 1 and i == STORE_BLOCK_M / ATOM_M - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Store into shared memory + // NOTES: only use first 16 lanes for address + // NOTES: 2 warps share a BF16 swizzle atom + uint32_t row = lane_idx % 8; + uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8; + const auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast(sizeof(nv_bfloat16)) + + (warp_idx_in_wg / 2) * STORE_BLOCK_M * kSwizzleCDMode + + i * ATOM_M * kSwizzleCDMode + + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes; + ptx::SM90_U32x4_STSM_T::copy( + math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]), + smem_ptr + ); + } + + // Wait shared memory ready + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Write into remote buffers + // One warp per row, now the layout is different from shared memory storing + const uint32_t row_in_atom = (warp_idx_in_wg * 2 + lane_idx / 16) % ATOM_M; + const uint32_t bank_group_idx = lane_idx % 8; + + #pragma unroll + for (uint32_t j = 0; j < kNumRowsPerWarp; ++ j) { + const uint32_t row_in_store = j * 8 + warp_idx_in_wg * 2 + lane_idx / 16; + const uint32_t m_idx_in_block = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + row_in_store; + + // Skip padding rows beyond the actual token count for this expert + if (m_idx_in_block >= valid_m) + break; + + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block); + const uint32_t dst_rank_idx = src_metadata.rank_idx; + const uint32_t dst_token_idx = src_metadata.token_idx; + const uint32_t dst_topk_idx = src_metadata.topk_idx; + + // Read from shared memory + const auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast(sizeof(nv_bfloat16)) + + (lane_idx % 16 / 8) * STORE_BLOCK_M * kSwizzleCDMode + + row_in_store * kSwizzleCDMode + + (bank_group_idx ^ row_in_atom) * kNumBankGroupBytes; + const auto packed = ptx::ld_shared(reinterpret_cast(smem_ptr)); + + // Write into remote + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + const auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * static_cast(sizeof(nv_bfloat16)) + (lane_idx % 16) * static_cast(sizeof(float4))); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } + } + + // Ensure the next epilogue safe to use shared memory + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + } + }); + + // Deallocate tensor memory + // NOTES: must be called by the same logical warp ID on both CTAs + if (epilogue_warp_idx == 0) + Allocator().free(0, kNumTmemCols); + + // NVLink barrier (grid sync + cross-rank signal + grid sync): ~4 us + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, epilogue_thread_idx, + [&]() { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); } + ); + + // Barrier with dispatch warps, so that they can do clean workspace + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Combine: reduce top-k results and write back + // NOTES: reuse shared memory from start up to the barriers + // 1 token, 1 topk latency: ~3 us + constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16); + constexpr uint32_t kNumElemsPerUint4 = sizeof(uint4) / sizeof(nv_bfloat162); + + // 3 slots of chunk is needed: 2 load stages and 1 store + constexpr uint32_t kNumChunkSlots = 3; + constexpr uint32_t kNumMaxRegistersForBuffer = 128; + + // NOTES: either 1 or 2 chunks for simplicity + // NOTES: Restrict on both smem and register + constexpr uint32_t kNumChunks = + kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE and kHidden <= 32 * kNumMaxRegistersForBuffer ? 1 : 2; + constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks; + constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4); + constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32; + DG_STATIC_ASSERT(kHidden % kNumChunks == 0, "Hidden must be divisible by number of chunks"); + DG_STATIC_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / kNumChunks <= SMEM_BEFORE_BARRIER_SIZE, "Hidden is too large"); + DG_STATIC_ASSERT(kNumChunkBytes % 16 == 0, "Combine chunk must be TMA-aligned (16 bytes)"); + DG_STATIC_ASSERT(kNumChunkBytes % sizeof(uint4) == 0, "Combine chunk must be divisible by 16 bytes"); + DG_STATIC_ASSERT(kNumChunkUint4 % 32 == 0, "Combine chunk must be a multiple of 32 16-byte elements (one per lane)"); + DG_STATIC_ASSERT(kNumTopk <= 32, "Top-k must fit in a single warp"); + + // Verify combined shared memory budget at runtime + DG_DEVICE_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumChunkBytes <= static_cast( + reinterpret_cast(barrier_start_ptr) - smem_buffer)); + + // Per-warp buffer: 2 stage load buffers + 1 store buffer + const auto combine_load_buffer = utils::PatternVisitor([&](const uint32_t& i) { + return math::advance_ptr(smem_buffer, (epilogue_warp_idx + i * kNumEpilogueWarps) * kNumChunkBytes); + }); + const auto combine_store_buffer = math::advance_ptr(smem_buffer, (epilogue_warp_idx + kNumEpilogueWarps * 2) * kNumChunkBytes); + + // Per-warp barriers + auto combine_load_barriers = utils::PatternVisitor([&](const uint32_t& i) { + return combine_barriers[i + epilogue_warp_idx * 2]; + }); + + // Iterate over all tokens + uint32_t combine_phase = 0; + uint32_t load_stage_idx = 0; + for (uint32_t token_idx = sm_idx * kNumEpilogueWarps + epilogue_warp_idx; + token_idx < num_tokens; + token_idx += kNumSMs * kNumEpilogueWarps) { + // Read top-k slot indices: each lane reads one slot, then broadcast via exchange + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + const int stored_topk_slot_idx = lane_idx < kNumTopk ? + static_cast(__ldg(input_topk_idx_buffer.get_base_ptr() + token_idx * kNumTopk + lane_idx)) : -1; + const uint32_t total_mask = __ballot_sync(0xffffffff, stored_topk_slot_idx >= 0); + + // Iterate all chunks + for (uint32_t chunk = 0; chunk < kNumChunks; ++ chunk) { + const uint32_t chunk_byte_offset = chunk * kNumChunkBytes; + + // Move mask and load + uint32_t mask = total_mask; + const auto move_mask_and_load = [&](const uint32_t& i) { + if (mask) { + // Move + const uint32_t slot_idx = __ffs(mask) - 1; + mask ^= 1 << slot_idx; + + // Load + if (cute::elect_one_sync()) { + const auto src_ptr = math::advance_ptr( + combine_token_buffer.get_rank_buffer(slot_idx) + .get_data_buffer(token_idx).get_base_ptr(), + chunk_byte_offset); + ptx::tma_load_1d(combine_load_buffer[i], src_ptr, combine_load_barriers[i], kNumChunkBytes); + ptx::mbarrier_arrive_and_set_tx(combine_load_barriers[i], kNumChunkBytes); + } + __syncwarp(); + return true; + } + return false; + }; + + // Load the first selection + bool do_reduce = move_mask_and_load(load_stage_idx); + + // Accumulate all top-k contributions for this chunk in float registers + float2 reduced[kNumUint4PerLane * kNumElemsPerUint4] = {}; + while (do_reduce) { + // Prefetch next top-k into the buffer while current is being accumulated + do_reduce = move_mask_and_load(load_stage_idx ^ 1); + + // Accumulate + combine_load_barriers[load_stage_idx]->wait(combine_phase); + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + const auto uint4_values = combine_load_buffer[load_stage_idx][j * 32 + lane_idx]; + const auto bf16_values = reinterpret_cast(&uint4_values); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + ptx::accumulate(reduced[j * kNumElemsPerUint4 + l], bf16_values[l]); + } + combine_phase ^= load_stage_idx; + load_stage_idx ^= 1; + } + + // Cast + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + uint4 casted; + auto casted_bf16 = reinterpret_cast(&casted); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + casted_bf16[l] = __float22bfloat162_rn(reduced[j * kNumElemsPerUint4 + l]); + + // Wait share memory release and write + if (j == 0) { + ptx::tma_store_wait<0>(); + __syncwarp(); + } + ptx::st_shared(combine_store_buffer + j * 32 + lane_idx, + casted.x, casted.y, casted.z, casted.w); + } + __syncwarp(); + + // TMA store the token chunk + if (cute::elect_one_sync()) { + cute::tma_store_fence(); + ptx::tma_store_1d( + math::advance_ptr(y, static_cast(token_idx) * kNumHiddenBytes + chunk_byte_offset), + combine_store_buffer, kNumChunkBytes); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 45a603ad..7ce008e5 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -155,6 +155,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + if (kNumMulticast > 1) + cute::cluster_sync(); + // Initialize barriers if (warp_idx == 1 and cute::elect_one_sync()) { #pragma unroll @@ -546,12 +549,13 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, } } } - - // Deallocate tensor memory by the last UMMA store warp - // NOTES: warp 0 is waiting TMA store - if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) - Allocator().free(0, kNumTmemCols); } + + // Deallocate tensor memory + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + #else if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh index 180a308b..e6744f59 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -6,27 +6,31 @@ #include #include +#include +#include +#include #include -#include -#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; -using namespace deep_gemm::sm100; - template -__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, - const uint32_t max_seqlen_k, const uint64_t stride_logits, + const uint32_t max_seqlen_k, const uint32_t stride_logits, uint32_t* cu_seq_len_k_start, uint32_t* cu_seq_len_k_end, - float* logits, + logits_dtype_t* logits, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, @@ -35,26 +39,26 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64` // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` // Q should be load only at once for a block - const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q); // Types using Barrier = cutlass::arch::ClusterTransactionBarrier; - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const auto& warp_in_group_idx = warp_idx % 4; - const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; // Prefetch TMA descriptors DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); - if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart) { cute::prefetch_tma_descriptor(&tensor_map_q); cute::prefetch_tma_descriptor(&tensor_map_kv); cute::prefetch_tma_descriptor(&tensor_map_kv_scales); cute::prefetch_tma_descriptor(&tensor_map_weights); } - __syncwarp(); // Shared memory configs // NOTES: weight may be unaligned @@ -62,7 +66,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u); // Align to 512 bytes for swizzle-64B extern __shared__ __align__(512) uint8_t smem_buffer[]; @@ -75,19 +79,19 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); // Data on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); }); - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); }); - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); @@ -95,76 +99,77 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // TMA barriers auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); - auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); }); - auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); }); + auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); }); // Tensor memory allocation auto tmem_ptr_in_smem = reinterpret_cast(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2); // Initialize barriers DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads"); - const bool& is_tma_load_warp = (warp_idx == (kNumMathThreads / 32)); - const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 1)); - if (is_tma_load_warp and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumQStages; ++ i) { full_q_barriers[i]->init(1); - empty_q_barriers[i]->init(kNumMathThreads); + empty_q_barriers[i]->init(kNumMathThreads + 32); } #pragma unroll for (uint32_t i = 0; i < kNumKVStages; ++ i) { full_kv_barriers[i]->init(1); empty_kv_barriers[i]->init(kNumMathThreads); } - #pragma unroll - for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { - full_umma_barriers[i]->init(1); - empty_umma_barriers[i]->init(128); - } - - // Make initialized barrier visible in async proxy cutlass::arch::fence_barrier_init(); - } else if (is_umma_warp) { + } + if (warp_idx == kSpecWarpStart + 1) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } // Allocate tensor memory cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); } __syncthreads(); // Register reconfigurations - constexpr uint32_t kNumSpecializedRegisters = 24; - constexpr uint32_t kNumMathRegisters = 240; + constexpr uint32_t kNumSpecializedRegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; // Block scheduler - uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; - const auto& get_next_block_q_idx = [&]() -> cute::tuple { - return {block_q_idx + gridDim.x, q_iter_idx + 1}; + uint32_t block_q_idx = sm_idx, q_iter_idx = 0; + const auto get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + kNumSMs, q_iter_idx + 1}; }; uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; - const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { uint32_t start = cute::numeric_limits::max(); uint32_t end = cute::numeric_limits::min(); #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); - seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); - seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = cu_seq_len_k_start[q_idx]; + seq_k_end[i] = cu_seq_len_k_end[q_idx]; start = min(start, min(seq_k_start[i], seq_len_kv)); end = max(end, min(seq_k_end[i], seq_len_kv)); } + // TMA alignment requirements for SF KV start = start / 4 * 4; return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase - start, ceil_div(end - start, BLOCK_KV)}; // Task info + start, math::ceil_div(end - start, BLOCK_KV)}; // Task info }; // KV pipeline uint32_t num_total_kv_blocks = 0; - const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { return { (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase @@ -177,13 +182,16 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; - if (is_tma_load_warp) { + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + if (warp_idx == kSpecWarpStart) { cutlass::arch::warpgroup_reg_dealloc(); // Prefetch - const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + const auto issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); }; if (cute::elect_one_sync() and block_q_idx < num_q_blocks) @@ -209,10 +217,10 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); // Issue TMA KV - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], - smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], - smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } num_total_kv_blocks += num_kv_blocks; @@ -221,11 +229,11 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); } } - } else if (is_umma_warp) { + } else if (warp_idx == kSpecWarpStart + 1) { cutlass::arch::warpgroup_reg_dealloc(); // Require full allocation - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // Make UMMA desc auto instr_desc = cute::UMMA::make_instr_descwait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { - auto a_desc = make_umma_desc( + auto a_desc = mma::sm100::make_umma_desc( smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); - auto b_desc = make_umma_desc( + auto b_desc = mma::sm100::make_umma_desc( smem_q[q_stage_idx], 0, k * UMMA_K); cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); } @@ -266,23 +274,37 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, } num_total_kv_blocks += num_kv_blocks; + // UMMA warp must also arrive on empty_q to prevent running ahead + // of math warps in the Q pipeline + empty_q_barriers[q_stage_idx]->arrive(); + // Jump to the next block CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); } - } else if (warp_idx >= kNumMathThreads / 32) { + } else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) { cutlass::arch::warpgroup_reg_dealloc(); - } else if (warp_idx < kNumMathThreads / 32) { + } else if (warp_idx < kSpecWarpStart) { cutlass::arch::warpgroup_reg_alloc(); // Offsets - const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); - const auto& warp_offset = warp_idx * 32; - const auto& v_offset = lane_idx; + const auto tmem_start = warpgroup_idx * UMMA_N; + const auto math_thread_idx = warp_idx * 32 + lane_idx; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr int N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; - // Preload weights - constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads); - float weights[BLOCK_Q][kNumWeightsInReg]; - DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + // Local register buffers + float weights[BLOCK_Q][kNumHeads]; while (block_q_idx < num_q_blocks) { CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); @@ -293,9 +315,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Read weights #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) { - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); - } + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); } // Compute over KV blocks @@ -307,82 +329,59 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Read per-KV scales - float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_offset); + float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx); // Wait UMMA arrival full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Release KV empty empty_kv_barriers[kv_stage_idx]->arrive(); // Reduce over the head dim and store - const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; - static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + const auto kv_offset = kv_start + kv_block_idx * BLOCK_KV + math_thread_idx; DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); - constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q; - DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems"); - uint32_t shifted_accum[kNumLDTMElems]; - auto tmem_load = [&](auto... Is) { - if constexpr (kNumLDTMElems == 32) { - cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 64) { - cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 128) { - cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); - } - }; - [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); - cutlass::arch::fence_view_async_tmem_load(); - - tcgen05_before_thread_sync(); - empty_umma_barriers[warpgroup_idx]->arrive(); - #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); + // Load accumulator from TMEM + float accum[kNumHeads]; + tmem_load(cute::Int{}, tmem_start + i * kNumHeads, accum); + + // Release TMEM empty + if (i == BLOCK_Q - 1) { + ptx::tcgen05_before_thread_sync(); + empty_umma_barriers[warpgroup_idx]->arrive(); + } + // Accumulate weighted ReLU in parallel auto sum_0 = make_float2(0, 0); auto sum_1 = make_float2(0, 0); - const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { + const auto transform = [&](const uint32_t& j, const float2& sum) { auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); auto b = make_float2(weights[i][j], weights[i][j + 1]); return __ffma2_rn(a, b, sum); }; #pragma unroll - for (int j = 0; j < kNumWeightsInReg; j += 4) { - sum_0 = transform_reg(j, sum_0); - sum_1 = transform_reg(j + 2, sum_1); - } - - const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { - auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); - auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), - ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); - return __ffma2_rn(a, b, sum); - }; - - #pragma unroll - for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { - sum_0 = transform_smem(j, sum_0); - sum_1 = transform_smem(j + 2, sum_1); + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); } auto sum = __fadd2_rn(sum_0, sum_1); - float result = scale_kv * (sum.x + sum.y); + auto result = static_cast(scale_kv * (sum.x + sum.y)); // Store into the global memory - // NOTES: we have redundant writes here, consider more carefully - const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast(stride_logits); if constexpr (kIsCompressedLogits) { - if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i]) - logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result; + if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i]) + logits[q_offset + kv_offset - seq_k_start[i]] = result; } else { - logits[q_idx * stride_logits + kv_offset + v_offset] = result; + logits[q_offset + kv_offset] = result; } + __syncwarp(); } } num_total_kv_blocks += num_kv_blocks; @@ -393,12 +392,12 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Jump to the next block CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); } - } - // Free tensor memory - __syncthreads(); - if (is_tma_load_warp) - cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } } } // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh index 7058c40f..9a5bddbf 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh @@ -6,56 +6,65 @@ #include #include +#include +#include +#include #include -#include -#include - -#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; -using namespace deep_gemm::sm100; - template -__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, - const uint64_t logits_stride, const uint64_t block_table_stride, - const uint32_t* context_lens, float* logits, - const uint32_t* block_table, const uint32_t* schedule_meta, + const uint32_t logits_stride, const uint32_t block_table_stride, + const uint32_t* context_lens, logits_dtype_t* logits, + const uint32_t* block_table, const uint32_t* indices, + const uint32_t* schedule_meta, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { using Barrier = cutlass::arch::ClusterTransactionBarrier; - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; // Prefetch TMA descriptors DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); - if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart) { cute::prefetch_tma_descriptor(&tensor_map_q); cute::prefetch_tma_descriptor(&tensor_map_kv); cute::prefetch_tma_descriptor(&tensor_map_kv_scales); cute::prefetch_tma_descriptor(&tensor_map_weights); } - __syncwarp(); + + // For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill. + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom); // Shared memory configs static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; - static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float); - static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float); // Align to swizzling alignment bytes extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; @@ -63,43 +72,40 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); // Q and KV data on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i); }); constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages; - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); }); // Barriers and TMEM pointer on shared memory const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; - auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); - auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); + auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); + auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); auto tmem_ptr_in_smem = reinterpret_cast(umma_barrier_ptr + kNumMathWarpGroups * 2); - constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups; + constexpr uint32_t kNumTmemCols = kNextNAtom * kNumHeads * kNumMathWarpGroups; DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); - const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4); - const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4); - const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1); // Initialize barriers - if (is_tma_load_warp and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumQStages; ++ i) { full_q_barriers[i]->init(1); - empty_q_barriers[i]->init(kNumMathThreads); + empty_q_barriers[i]->init(kNumMathThreads + 32); } #pragma unroll for (uint32_t i = 0; i < kNumKVStages; ++ i) { @@ -108,7 +114,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, } cutlass::arch::fence_barrier_init(); } - if (is_umma_warp) { + if (warp_idx == kSpecWarpStart + 1) { if (cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) { @@ -123,79 +129,92 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, __syncthreads(); // Register reconfigurations - constexpr uint32_t kNumSpecializedRegisters = 40; - constexpr uint32_t kNumMathRegisters = 232; + constexpr uint32_t kNumSpecializedRegisters = 56; + constexpr uint32_t kNumMathRegisters = 224; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); // Scheduler constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; - auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + using Scheduler = sched::PagedMQALogitsScheduler; DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); // Q and KV pipeline - const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase }; - const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase }; - uint32_t q_iter_idx = 0, kv_iter_idx = 0; // UMMA settings // Construct instruction with layout D constexpr uint32_t UMMA_M = 128; constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); - constexpr uint32_t UMMA_N = kNextN * kNumHeads; + constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads; DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); - if (is_tma_load_warp) { - // TMA warp-group for loading data + if (warp_idx == kSpecWarpStart) { + // TMA warp for loading data cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + uint32_t q_iter_idx = 0, kv_iter_idx = 0; - const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& tma_q_atom_idx) { if (cute::elect_one_sync()) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + const auto q_token_idx = Scheduler::atom_to_token_idx(tma_q_atom_idx); + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_token_idx * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_token_idx); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); } }; - // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none - uint32_t q_idx = batch_size, kv_idx, num_kv; - uint32_t next_q_idx, next_kv_idx, next_num_kv; + // Initialize outside valid range to indicate no previous task + uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx, num_kv; + uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; bool fetched_next_task; // Prefetch the first Q - if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv))) - issue_tma_q(0, next_q_idx), q_iter_idx = 1; + if ((fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_atom_idx), q_iter_idx = 1; - int kv_block_idx_ptr = 32; + uint32_t kv_block_idx_ptr = 32; uint32_t kv_block_idx_storage; while (fetched_next_task) { - // Prefetch next Q when current Q changes - bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); - q_idx = next_q_idx; + // Prefetch next Q when (q, atom) changes + const auto next_advance = scheduler.get_atom_advance(next_q_atom_idx, batch_size); + bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + next_advance); + + if (q_atom_idx != next_q_atom_idx) + kv_block_idx_ptr = 32; + + q_atom_idx = next_q_atom_idx; kv_idx = next_kv_idx; num_kv = next_num_kv; // Read KV block index - // TODO: deal with `-1`? - if (kv_idx == 0 or kv_block_idx_ptr == 32) { + // TODO(xuzhean): consider -1 + if (kv_block_idx_ptr == 32) { kv_block_idx_ptr = 0; - kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0); + const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast(block_table_stride); + kv_block_idx_storage = (kv_idx + lane_idx < num_kv) + ? block_table[block_table_offset + kv_idx + lane_idx] : 0; } + __syncwarp(); DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`"); // Wait Q consumer release and issue TMA Q if (prefetch_q) { CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); - issue_tma_q(q_stage_idx, q_idx + 1); + issue_tma_q(q_stage_idx, q_atom_idx + next_advance); } - int kv_block_idx[kNumBlocksPerSplit]; + uint32_t kv_block_idx[kNumBlocksPerSplit]; #pragma unroll - for (int i = 0; i < kNumBlocksPerSplit; ++ i) + for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i) kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); kv_block_idx_ptr += kNumBlocksPerSplit; @@ -205,45 +224,53 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, if (cute::elect_one_sync()) { #pragma unroll - for (int i = 0; i < kNumBlocksPerSplit; ++ i) { - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], - smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i, - 0, 0, 1, kv_block_idx[i]); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], - smem_kv_scales[kv_stage_idx] + BLOCK_KV * i, - 0, kv_block_idx[i]); + for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i) { + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i, + 0, 0, 1, kv_block_idx[i]); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); } full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } // Fetch next task - fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv); + fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv); } - } else if (is_umma_warp) { + } else if (warp_idx == kSpecWarpStart + 1) { cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + uint32_t q_iter_idx = 0, kv_iter_idx = 0; // Require full allocation - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // Make UMMA desc auto instr_desc = cute::UMMA::make_instr_desc(); auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); - uint32_t q_idx = batch_size, kv_idx; - uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx; + uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; uint32_t q_stage_idx, q_phase; uint32_t umma_phase = 1; - while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { - if (q_idx != next_q_idx) { + while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) { + if (q_atom_idx != next_q_atom_idx) { + // Release previous Q empty (UMMA warp must participate to prevent + // running ahead of math warps in the Q pipeline) + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); full_q_barriers[q_stage_idx]->wait(q_phase); } - q_idx = next_q_idx; + q_atom_idx = next_q_atom_idx; kv_idx = next_kv_idx; + // Wait KV arrival CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); full_kv_barriers[kv_stage_idx]->wait(kv_phase); @@ -251,12 +278,12 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, #pragma unroll for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { empty_umma_barriers[i]->wait(umma_phase); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { - auto a_desc = make_umma_desc( + auto a_desc = mma::sm100::make_umma_desc( smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); - auto b_desc = make_umma_desc( + auto b_desc = mma::sm100::make_umma_desc( smem_q[q_stage_idx], 0, k * UMMA_K); cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); } @@ -264,29 +291,46 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, } umma_phase ^= 1; } - } else if (is_math_warp) { - // Math warp-groups for WGMMA + } else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + // Math warpgroups for reduce cutlass::arch::warpgroup_reg_alloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + uint32_t q_iter_idx = 0, kv_iter_idx = 0; // Offsets - const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); - const uint32_t thread_idx = threadIdx.x; + const auto math_warpgroup_idx = warpgroup_idx; + const auto tmem_start = math_warpgroup_idx * UMMA_N; + const auto math_thread_idx = warp_idx * 32 + lane_idx; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr int N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; - // Weights - constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads)); - float weights[kNextN][kNumWeightsInReg]; - DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + // Local register buffers + float weights[kNextNAtom][kNumHeads]; - // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none - uint32_t q_idx = batch_size, kv_idx; - uint32_t next_q_idx, next_kv_idx, next_num_kv; + // Initialize outside valid range to indicate no previous task + uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx; + uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; uint32_t q_stage_idx, q_phase; uint32_t umma_phase = 0; + bool is_paired_atom = false; - while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { - // Current Q changes - if (q_idx != next_q_idx) { - // Release Last Q empty + while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) { + // Q or atom changes + if (q_atom_idx != next_q_atom_idx) { + // Release last Q empty if (q_iter_idx > 0) empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); @@ -296,30 +340,34 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Read weights #pragma unroll - for (uint32_t i = 0; i < kNextN; ++ i) { - for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + for (uint32_t i = 0; i < kNextNAtom; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + + if constexpr (kIsVarlen) { + is_paired_atom = (scheduler.get_atom_advance(next_q_atom_idx, batch_size) == 2); } } - // Get current Q and KV index - q_idx = next_q_idx; + // Get current task indices + q_atom_idx = next_q_atom_idx; kv_idx = next_kv_idx; // Calculate KV offset in advance - auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV; + auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast(logits_stride) + kv_idx * BLOCK_KV; - // Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]` // Wait TMA KV arrival CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Read per-KV scales - float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx); + float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx); // Wait UMMA arrival - full_umma_barriers[warpgroup_idx]->wait(umma_phase); - tcgen05_after_thread_sync(); + full_umma_barriers[math_warpgroup_idx]->wait(umma_phase); + ptx::tcgen05_after_thread_sync(); umma_phase ^= 1; // Release KV empty @@ -327,72 +375,65 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Reduce over the head dim and store DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); - constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN; - uint32_t shifted_accum[kNumLDTMElems]; - DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM"); - auto tmem_load = [&](auto... Is) { - if constexpr (kNumLDTMElems == 32) { - cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 64) { - cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 128) { - cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); - } - }; - [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); - cutlass::arch::fence_view_async_tmem_load(); - - tcgen05_before_thread_sync(); - empty_umma_barriers[warpgroup_idx]->arrive(); - - #pragma unroll - for (uint32_t i = 0; i < kNextN; ++ i) { - auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); - - auto sum_0 = make_float2(0, 0); - auto sum_1 = make_float2(0, 0); - const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { - auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); - auto b = make_float2(weights[i][j], weights[i][j + 1]); - return __ffma2_rn(a, b, sum); - }; + const auto reduce_and_store = [&](auto num_iters_c) { + constexpr uint32_t kNumIters = decltype(num_iters_c)::value; + float accum[kNumHeads]; #pragma unroll - for (int j = 0; j < kNumWeightsInReg; j += 4) { - sum_0 = transform_reg(j, sum_0); - sum_1 = transform_reg(j + 2, sum_1); + for (uint32_t i = 0; i < kNumIters; ++ i) { + // Load accumulator from TMEM + tmem_load(cute::Int{}, tmem_start + i * kNumHeads, accum); + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(scale_kv * (sum.x + sum.y)); + + // Store into the global memory + logits[kv_offset + i * static_cast(logits_stride) + math_thread_idx] = result; + __syncwarp(); } - const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { - auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); - auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), - ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); - return __ffma2_rn(a, b, sum); - }; - - #pragma unroll - for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { - sum_0 = transform_smem(j, sum_0); - sum_1 = transform_smem(j + 2, sum_1); - } - - auto sum = __fadd2_rn(sum_0, sum_1); - float result = scale_kv * (sum.x + sum.y); + // Release TMEM empty + ptx::tcgen05_before_thread_sync(); + empty_umma_barriers[math_warpgroup_idx]->arrive(); + }; - // Store into the global memory - // NOTES: we have redundant writes here, consider more carefully - logits[kv_offset + i * logits_stride + thread_idx] = result; + if constexpr (kIsVarlen) { + if (is_paired_atom) + reduce_and_store(cute::Int{}); + else + reduce_and_store(cute::Int<1>{}); + } else if constexpr (kPadOddN) { + if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1) + reduce_and_store(cute::Int<1>{}); + else + reduce_and_store(cute::Int{}); + } else { + reduce_and_store(cute::Int{}); } } - } else { - cutlass::arch::warpgroup_reg_dealloc(); - } - // Free tensor memory - __syncthreads(); - if (is_umma_warp) - cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } } } // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh index 4e4ff21d..aaf7fd9a 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh @@ -4,20 +4,22 @@ #include -#include +#include +#include +#include #include -#include -#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm100; - template -__device__ __forceinline__ +CUTLASS_DEVICE uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) { // Calculate the index of the bank group to be written in the atom - const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase); + const auto bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase); // Reshape the atom in another view and swizzle // - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)` @@ -37,7 +39,7 @@ template -__global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1) sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -58,7 +60,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Utils const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); // Align to 1024 bytes for swizzle-128B extern __shared__ __align__(1024) uint8_t smem_buffer[]; @@ -70,7 +72,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); // Real tensor memory size and offsets - constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); // Prefetch TMA descriptors at the very beginning if (warp_idx == 0 and cute::elect_one_sync()) { @@ -82,20 +84,20 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Data on shared memory (layout as ordered below) // Fill D/A/B pointers auto smem_cd = reinterpret_cast(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); - auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto empty_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4; // Fill the tensor memory pointer @@ -121,7 +123,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, } __syncthreads(); - constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K); constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); @@ -131,6 +133,9 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const uint32_t m_offset = shape_m * k_split_idx; const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Dispatch warps into different roles if (warp_idx < kNumMMAThreads / 32) { // TMA load warp @@ -145,8 +150,8 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, uint32_t k_idx = k_offset + s * BLOCK_K; // Issue TMAs - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + tma::copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma::copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); // Arrive at full barriers constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; @@ -168,7 +173,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); - auto b_desc = make_umma_desc(smem_b[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; // Checks for MMA instructions @@ -185,7 +190,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const auto& stage_idx = s % kNumStages; const auto& cast_stage_idx = s % kNumCastStages; full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Issue UMMA const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); @@ -194,7 +199,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK; const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK; const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK; - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, offset, in_atom_idx); + b_desc.lo = mma::sm100::advance_umma_desc_lo(b_desc_base_lo, offset, in_atom_idx); umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc); } @@ -218,7 +223,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Wait UMMA arrival tmem_full_barrier->wait(0); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Load from tensor memory into registers, and write shared memory with STSM DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough"); @@ -239,7 +244,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, values[0], values[1], values[2], values[3]); cutlass::arch::fence_view_async_tmem_load(); if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16)) - st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]); if constexpr (BLOCK_M == 64) __syncwarp(); } @@ -290,9 +295,9 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, #pragma unroll for (uint32_t i = 0; i < kNumLoads; i += 2) { auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset(i + lane_idx / 16, lane_idx % 16); - sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0], - uint32_values[0][i + 1], uint32_values[1][i + 1], - smem_ptr); + ptx::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0], + uint32_values[0][i + 1], uint32_values[1][i + 1], + smem_ptr); } // Wait tensor memory empty @@ -321,15 +326,15 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, cutlass::arch::fence_view_async_tmem_store(); // Arrive for issuing MMAs - tcgen05_before_thread_sync(); + ptx::tcgen05_before_thread_sync(); full_cast_barriers[cast_stage_idx]->arrive(); } // Intra-warp reduction and write back #pragma unroll for (uint32_t u = 0; u < 2; ++ u) { - const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y); - const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8; + const auto reduced_sum = math::warp_reduce_sum<4>(sum[u].x + sum[u].y); + const auto m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8; if (lane_idx % 4 == 0 and m_idx < shape_m) sqr_sum[m_offset + m_idx] = reduced_sum; } diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh index 7a77e4e8..84a149eb 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -11,14 +11,19 @@ #include #include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_bf16_gemm_impl(int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, @@ -51,7 +56,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; // Types - using WGMMA = typename BF16MMASelector::type; + using WGMMA = typename mma::sm90::BF16MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); @@ -61,7 +66,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; // Shared memory - static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(cd_dtype_t)), 1024u); + static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(cd_dtype_t)), 1024u); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); @@ -71,7 +76,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Configs const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_idx(); + const uint32_t lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors at the very beginning if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { @@ -88,17 +93,17 @@ sm90_bf16_gemm_impl(int* grouped_layout, // D/A/B shared memory auto smem_d = reinterpret_cast(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); // Initialize barriers if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { @@ -119,9 +124,12 @@ sm90_bf16_gemm_impl(int* grouped_layout, constexpr uint32_t kNumTMARegisters = 48; constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); // Pipeline and TMA phases uint32_t stage_idx = 0, phase = 0; @@ -151,7 +159,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { // Wait consumer release empty_barriers[stage_idx]->wait(phase ^ 1); @@ -159,31 +167,30 @@ sm90_bf16_gemm_impl(int* grouped_layout, constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; auto& full_barrier = *full_barriers[stage_idx]; - const auto m_idx = scheduler.template get_global_idx(shape_m, BLOCK_M, m_block_idx); - const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx); + const auto m_idx = scheduler.template get_global_idx(shape_m, BLOCK_M, m_block_idx); + const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx); DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); - uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); - uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); // Issue TMAs constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); if constexpr (kMajorA == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx); if constexpr (kMajorA == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx); - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); } } @@ -203,8 +210,8 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Merged stages only happens in NT normal GEMM cases constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; - auto a_desc = make_gmma_desc(smem_a[0], math_wg_idx * WGMMA::M, 0); - auto b_desc = make_gmma_desc(smem_b[0], 0, 0); + auto a_desc = mma::sm90::make_gmma_desc(smem_a[0], math_wg_idx * WGMMA::M, 0); + auto b_desc = mma::sm90::make_gmma_desc(smem_b[0], 0, 0); const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); @@ -229,10 +236,10 @@ sm90_bf16_gemm_impl(int* grouped_layout, }; // TODO: remove some useless computation for unaligned Ms - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { - const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); - const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); // Wait TMA arrivals full_barriers[stage_idx]->wait(phase); @@ -240,26 +247,26 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - const uint32_t& atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K; - a_desc.reg32_[0] = advance_gmma_desc_lo( + const uint32_t atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K; + a_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo( a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K); - b_desc.reg32_[0] = advance_gmma_desc_lo( + b_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo( b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K); WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1); } } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival empty_barrier_arrive(stage_idx); @@ -324,7 +331,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, } // NOTES: only 16 lanes' addresses are used - SM90_U32x2_STSM_N::copy( + ptx::SM90_U32x2_STSM_N::copy( __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), smem_ptr @@ -341,8 +348,8 @@ sm90_bf16_gemm_impl(int* grouped_layout, auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1])); - st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3])); + ptx::st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1])); + ptx::st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3])); } } } @@ -350,7 +357,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); // Use TMA store to write back to global memory - const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh index 191a4fe2..7c344296 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh @@ -4,26 +4,32 @@ #include #include +#include #include -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, float *d) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Types - using WGMMA = typename BF16MMASelector::type; + using WGMMA = typename mma::sm90::BF16MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); @@ -33,7 +39,7 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, // Configs const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_idx(); + const uint32_t lane_idx = ptx::get_lane_idx(); DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M"); DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads"); DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads"); @@ -48,17 +54,17 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, // Align to 1024 bytes for swizzle-128B // Fill shared memory pointers extern __shared__ __align__(1024) uint8_t smem_buffer[]; - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); // Initialize barriers if (warp_idx == 1 and cute::elect_one_sync()) { @@ -80,14 +86,17 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, constexpr uint32_t kNumMathRegisters = 232; // Block indices - const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); - const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M); const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; const uint32_t n_block_idx = mn_block_idx % num_n_blocks; const uint32_t m_block_idx = mn_block_idx / num_n_blocks; const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (warp_idx >= kNumMathThreads / 32) { // TMA warp-group for loading data cutlass::arch::warpgroup_reg_dealloc(); @@ -98,18 +107,18 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, #pragma unroll for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait consumer release - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1); auto& full_barrier = *full_barriers[stage_idx]; - const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; - const uint32_t& k_idx = sk_idx % SHAPE_K; - const uint32_t& s_idx = sk_idx / SHAPE_K; + const uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + const uint32_t k_idx = sk_idx % SHAPE_K; + const uint32_t s_idx = sk_idx / SHAPE_K; constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16); - tma_copy( + tma::copy( &tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1); - tma_copy( + tma::copy( &tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); } @@ -125,32 +134,32 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, // Launch MMAs for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait TMA arrivals - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; full_barriers[stage_idx]->wait((s / kNumStages) & 1); // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); WGMMA::wgmma(desc_a, desc_b, accum, 1); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival at the last warpgroup wave empty_barriers[stage_idx]->arrive(); } - const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4; - const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2; + const auto row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4; + const auto col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2; #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { if (col + i * 8 >= SHAPE_N) diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh index cdd28fcb..195d431f 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -6,18 +6,26 @@ #include #include +#include #include #include #include +#include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, int* grouped_layout, cute::TmaDescriptor* tensor_map_buffer, @@ -45,7 +53,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type"); // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); @@ -55,13 +63,13 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; // Shared memory - static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0); + static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 2 : 0); static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u); + static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u); DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment"); // Configs @@ -83,47 +91,41 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); // Tensor maps on shared and global memory - auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * i); - }); - auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * (2 + i)); - }); - auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; }); - auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; }); + auto smem_tensor_map_a = reinterpret_cast(smem_buffer); + auto smem_tensor_map_b = smem_tensor_map_a + 1; + auto gmem_tensor_map_a = tensor_map_buffer + blockIdx.x * 2; + auto gmem_tensor_map_b = gmem_tensor_map_a + 1; // Data on shared memory auto smem_d = reinterpret_cast(smem_buffer + SMEM_TENSOR_MAP_SIZE); - auto smem_a = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); - auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE)); }); - auto smem_sfb = PatternVisitor([&](const uint32_t& i) { + auto smem_sfb = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE)); }); // Barriers on shared memory constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE); - auto full_barriers = PatternVisitor([&](const uint32_t& i) { + auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast(sizeof(Barrier)))); }); - auto empty_barriers = PatternVisitor([&](const uint32_t& i) { + auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast(sizeof(Barrier)))); }); if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { // Load tensormap A/B to shared memory if constexpr (kGemmType == GemmType::KGroupedContiguous) { - *smem_tensor_map_a[0] = tensor_map_a_base; - *smem_tensor_map_a[1] = tensor_map_a_base; - *smem_tensor_map_b[0] = tensor_map_b_base; - *smem_tensor_map_b[1] = tensor_map_b_base; + *smem_tensor_map_a = tensor_map_a_base; + *smem_tensor_map_b = tensor_map_b_base; } // Initialize barriers @@ -149,12 +151,15 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24); constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); // TMA and MMA pipeline - const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple { + const auto get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple { return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase }; uint32_t iter_idx = 0; @@ -165,9 +170,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // NOTES: only one thread (or warp) will be used if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { - const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base; - const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base; - uint32_t last_group_idx = kNumGroups, sum_k = 0; + uint32_t last_group_idx = kNumGroups; // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { @@ -177,35 +180,27 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - - const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); - const uint32_t& m_idx = m_block_idx * BLOCK_M; - const uint32_t& n_idx = n_block_idx * BLOCK_N; - - if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) { - const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1; - const uint32_t& next_stage_idx = stage_idx ^ 1; + + const uint32_t num_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + const uint32_t m_idx = m_block_idx * BLOCK_M; + const uint32_t n_idx = n_block_idx * BLOCK_N; + + if (kGemmType == GemmType::KGroupedContiguous && last_group_idx != scheduler.current_group_idx) { last_group_idx = scheduler.current_group_idx; - // Prepare next tensor map - sum_k += scheduler.current_shape_k; - if (scheduler.next_group_idx < kNumGroups) { - tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast(sum_k) * shape_m); - tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast(sum_k) * shape_n); - tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); - tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); - *(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]); - *(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]); - tensor_map_release_cta(); - } - - // Get current tensor map - if (scheduler.current_num_valid_groups > 0) { - tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]); - tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]); - current_tensor_map_a = gmem_tensor_map_a[stage_idx]; - current_tensor_map_b = gmem_tensor_map_b[stage_idx]; - } + // Directly update current tensor map + const uint64_t current_k_offset = scheduler.current_k_cumsum; + ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_a, gmem_a_ptr + current_k_offset * shape_m); + ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_b, gmem_b_ptr + current_k_offset * shape_n); + ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a, scheduler.current_shape_k, scheduler.current_shape_k); + ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b, scheduler.current_shape_k, scheduler.current_shape_k); + *(gmem_tensor_map_a) = *(smem_tensor_map_a); + *(gmem_tensor_map_b) = *(smem_tensor_map_b); + ptx::tensor_map_release_gpu(); + + // Immediately acquire current tensor map + ptx::tensor_map_acquire_gpu(gmem_tensor_map_a); + ptx::tensor_map_acquire_gpu(gmem_tensor_map_b); } #pragma unroll kNumPipelineUnrolls @@ -216,12 +211,14 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // Issue TMA auto& full_barrier = *full_barriers[stage_idx]; - const uint32_t& k_idx = k_block_idx * BLOCK_K; - const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; - tma_copy(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); - tma_copy(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); - tma_copy(current_tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); - tma_copy(current_tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); + const uint32_t k_idx = k_block_idx * BLOCK_K; + const uint32_t sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; + const auto tensor_map_a_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_a : &tensor_map_a_base); + const auto tensor_map_b_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_b : &tensor_map_b_base); + tma::copy(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); + tma::copy(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); + tma::copy(tensor_map_a_ptr, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); + tma::copy(tensor_map_b_ptr, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE); } } @@ -248,9 +245,9 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, while (scheduler.get_next_block(m_block_idx, n_block_idx)) { // Accumulation for WGMMA or CUDA promotion DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes"); - const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); - const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0); - const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K); + const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); + const uint32_t current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0); + const uint32_t num_k_blocks = math::ceil_div(current_shape_k, BLOCK_K); float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; float2 scales_b[WGMMA::kNumAccum / 4]; @@ -272,30 +269,30 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // Read A scales // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0); - auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1); + auto scale_a_0 = ptx::ld_shared(smem_sfa[stage_idx] + r_0); + auto scale_a_1 = ptx::ld_shared(smem_sfa[stage_idx] + r_1); // Read B scales #pragma unroll for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) - scales_b[i] = ld_shared(reinterpret_cast(smem_sfb[stage_idx] + i * 8 + col_idx * 2)); + scales_b[i] = ptx::ld_shared(reinterpret_cast(smem_sfb[stage_idx] + i * 8 + col_idx * 2)); // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); WGMMA::wgmma(desc_a, desc_b, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival empty_barrier_arrive(stage_idx); @@ -318,12 +315,12 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, cutlass::arch::NamedBarrier::sync(128, math_wg_idx); // Store to D shared memory - const auto& smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); - const auto& smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); + const auto smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); + const auto smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); #pragma unroll for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); - st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); + ptx::st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); + ptx::st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); } cute::tma_store_fence(); cutlass::arch::NamedBarrier::sync(128, math_wg_idx); diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 9247304c..aa412484 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -10,17 +10,21 @@ #include #include -#include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) { +CUTLASS_DEVICE void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) { if (num_former_iters == kNumFormerIters) { func(cute::Int{}); return; @@ -35,12 +39,12 @@ template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, @@ -50,10 +54,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + DG_STATIC_ASSERT( + math::constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or + (math::constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); @@ -64,23 +70,23 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Shared memory static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); - static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(__nv_bfloat16)), 1024u); + static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(__nv_bfloat16)), 1024u); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u); - const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); - const uint32_t& shape_n_sfb = ceil_div(shape_n, BLOCK_K); - const uint32_t& smem_sfb_size = align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); + static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u); + const uint32_t shape_k_scales = math::ceil_div(shape_k, BLOCK_K); + const uint32_t shape_n_sfb = math::ceil_div(shape_n, BLOCK_K); + const uint32_t smem_sfb_size = math::align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); // NOTES: Make sure we have enough shared memory for WGMMA padding static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); // Configs - const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K); + const uint32_t num_total_k_blocks = math::ceil_div(shape_k, BLOCK_K); const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_idx(); + const uint32_t lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors at the very beginning if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { @@ -97,22 +103,22 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Data on shared memory auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); }); constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); - auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); }); auto smem_sfb = reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); // Fill barriers auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); - auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); - auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); + auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); + auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); // Initialize barriers DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); @@ -136,9 +142,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, constexpr uint32_t kNumTMARegisters = 40; constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); // Pipeline and TMA phases uint32_t stage_idx = 0, phase = 0; @@ -177,15 +186,15 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; auto& full_barrier = *full_barriers[stage_idx]; const uint32_t k_idx = k_block_idx * BLOCK_K; - tma_copy(&tensor_map_a, &full_barrier, + tma::copy(&tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), num_tma_multicast_a, batch_idx); - tma_copy(&tensor_map_sfa, &full_barrier, - smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx(shape_k_scales, 1, k_block_idx), + tma::copy(&tensor_map_sfa, &full_barrier, + smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx(shape_k_scales, 1, k_block_idx), num_tma_multicast_a); // Issue TMA B - tma_copy(&tensor_map_b, &full_barrier, + tma::copy(&tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), num_tma_multicast_b, batch_idx); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); @@ -206,8 +215,8 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; - auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1); - auto b_desc = make_smem_desc(smem_b[0], 1); + auto a_desc = mma::sm90::make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1); + auto b_desc = mma::sm90::make_smem_desc(smem_b[0], 1); const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); @@ -225,14 +234,14 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Load B scales with math warp-groups // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks if (threadIdx.x >= 32) { - auto previous_group_offset = scheduler.template get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); + auto previous_group_offset = scheduler.template get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales; const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1; auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb; #pragma unroll for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) - st_shared(smem_sfb + i, __ldg(i < shape_k_scales ? local_sfb + i * stride_k_sfb : local_sfb + (i - shape_k_scales) * stride_k_sfb + stride_n_sfb)); + ptx::st_shared(smem_sfb + i, i < shape_k_scales ? local_sfb[i * stride_k_sfb] : local_sfb[(i - shape_k_scales) * stride_k_sfb + stride_n_sfb]); } cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); @@ -259,22 +268,22 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Skip useless computations if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) { // The compiler must know the dynamic variable `num_former_iters`'s real value - constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; - constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; + constexpr bool kShouldOptimize = BLOCK_K / math::constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr uint32_t kGap = math::constexpr_gcd(BLOCK_K, BLOCK_N) / 8; constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; // Dispatch `num_former_iters` and launch MMAs dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) { #pragma unroll 8 for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { - const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); - const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); // Read B scales - float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1; + float scale_b_0 = ptx::ld_shared(smem_sfb + k_block_idx), scale_b_1; // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales); + scale_b_1 = ptx::ld_shared(smem_sfb + k_block_idx + shape_k_scales); // Wait TMA arrivals full_barriers[stage_idx]->wait(phase); @@ -286,25 +295,25 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Read A scales // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0; - auto scale_a_1 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0; + auto scale_a_0 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0; + auto scale_a_1 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0; // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16; b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16; WGMMA::wgmma(a_desc, b_desc, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival at the last warpgroup wave if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) @@ -325,7 +334,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant - const bool& predicate = kMustUseUniformedScaleB or i < num_former_iters; + const bool predicate = kMustUseUniformedScaleB or i < num_former_iters; shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; @@ -399,7 +408,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, } // NOTES: only 16 lanes' addresses are used - SM90_U32x2_STSM_N::copy( + ptx::SM90_U32x2_STSM_N::copy( __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), smem_ptr diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh index d58c7162..225af441 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh @@ -7,36 +7,31 @@ #include #include +#include +#include #include -#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - -// ReSharper disable once CppNotAllPathsReturnValue -template -static constexpr int to_swizzle_cute_type() { - DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling"); - if constexpr (kHeadDim == 32) - return static_cast(cute::SM90::GMMA::LayoutType::B32); - if constexpr (kHeadDim == 64) - return static_cast(cute::SM90::GMMA::LayoutType::B64); - if constexpr (kHeadDim == 128) - return static_cast(cute::SM90::GMMA::LayoutType::B128); -} - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) + uint32_t kNumSMs, + uint32_t kNumTMAThreads, uint32_t kNumMathThreads, + typename logits_dtype_t> +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, - const uint32_t max_seqlen_k, const uint64_t stride_logits, + const uint32_t max_seqlen_k, const uint32_t stride_logits, uint32_t* cu_seq_len_k_start, uint32_t* cu_seq_len_k_end, - float* logits, + logits_dtype_t* logits, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, @@ -44,10 +39,10 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // TODO: consider TMA multicast // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` // Q should be load only at once for a block - const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q); // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; // Prefetch TMA descriptors @@ -74,19 +69,19 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); // Data on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); }); - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i); @@ -94,13 +89,13 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // TMA barriers auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); // Initialize barriers - const bool& is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32; + const bool is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32; if (is_tma_load_warp and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumQStages; ++ i) { @@ -123,38 +118,43 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, constexpr uint32_t kNumMathRegisters = 112; // Block scheduler - uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; - const auto& get_next_block_q_idx = [&]() -> cute::tuple { - return {block_q_idx + gridDim.x, q_iter_idx + 1}; + const auto sm_idx = blockIdx.x; + uint32_t block_q_idx = sm_idx, q_iter_idx = 0; + const auto get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + kNumSMs, q_iter_idx + 1}; }; uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; - const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { uint32_t start = cute::numeric_limits::max(); uint32_t end = cute::numeric_limits::min(); #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); - seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); - seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = cu_seq_len_k_start[q_idx]; + seq_k_end[i] = cu_seq_len_k_end[q_idx]; start = min(start, min(seq_k_start[i], seq_len_kv)); end = max(end, min(seq_k_end[i], seq_len_kv)); } + // TMA alignment requirements for SF KV start = start / 4 * 4; return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase - start, ceil_div(end - start, BLOCK_KV)}; // Task info + start, math::ceil_div(end - start, BLOCK_KV)}; // Task info }; // KV pipeline uint32_t num_total_kv_blocks = 0; - const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { return { (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase }; }; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (threadIdx.x >= kNumMathThreads) { // TMA warp-group for loading data cutlass::arch::warpgroup_reg_dealloc(); @@ -165,8 +165,8 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Prefetch const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); }; if (cute::elect_one_sync() and block_q_idx < num_q_blocks) @@ -192,9 +192,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); // Issue TMA KV - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } @@ -212,7 +212,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const auto& thread_idx = threadIdx.x % kNumMathThreads; const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0); const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + const auto& lane_idx = ptx::get_lane_idx(); float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4]; const auto& warp_offset = warp_idx * 16; @@ -230,7 +230,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, for (uint32_t i = 0; i < BLOCK_Q; ++ i) { #pragma unroll for (uint32_t j = 0; j < kNumHeads / 4; ++ j) - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); } // Compute over KV blocks @@ -242,29 +242,31 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Read per-KV scales - float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset); - float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset); + float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset); + float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset); // Issue WGMMA DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size"); DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K, - to_swizzle_cute_type(), 0, kHeadDim * 8); - auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, - to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_a = mma::sm90::make_smem_desc( + smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = mma::sm90::make_smem_desc( + smem_q[q_stage_idx] + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); WGMMA::wgmma(desc_a, desc_b, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Release KV empty empty_kv_barriers[kv_stage_idx]->arrive(); @@ -278,7 +280,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { auto shifted_accum = accum + i * kNumAccumPerReduce; - const auto& transform = [&](const uint32_t& j) { + const auto transform = [&](const uint32_t& j) { return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; }; @@ -302,16 +304,15 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, } // Store into the global memory - // NOTES: we have redundant writes here, consider more carefully - const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast(stride_logits); if constexpr (kIsCompressedLogits) { if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i]) - logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0; + logits[q_offset + kv_offset + v_0_offset - seq_k_start[i]] = static_cast(v_0); if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i]) - logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1; + logits[q_offset + kv_offset + v_1_offset - seq_k_start[i]] = static_cast(v_1); } else { - logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0; - logits[q_idx * stride_logits + kv_offset + v_1_offset] = v_1; + logits[q_offset + kv_offset + v_0_offset] = static_cast(v_0); + logits[q_offset + kv_offset + v_1_offset] = static_cast(v_1); } } } diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh index 482a85a8..cc2592bb 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh @@ -6,133 +6,46 @@ #include #include +#include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -template -__global__ __launch_bounds__(32, 1) -void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d, - const uint32_t* context_lens, uint32_t* schedule_metadata) { - DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); - const uint32_t lane_idx = get_lane_idx(); - - uint32_t num_segs[kAlignedBatchSize / 32]; - #pragma unroll - for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { - const uint32_t q_idx = k * 32 + lane_idx; - const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); - const uint32_t& context_len = (q_idx < batch_size ? __ldg(context_lens + lens_idx) : 0); - num_segs[k] = ceil_div(context_len, SPLIT_KV); - } - - __shared__ uint32_t prefix_sum[kAlignedBatchSize]; - uint32_t sum = 0; - #pragma unroll - for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { - uint32_t x = num_segs[k]; - #pragma unroll - for (uint32_t offset = 1; offset < 32; offset <<= 1) { - const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset); - x += (lane_idx >= offset ? y : 0); - } - x += sum; - prefix_sum[k * 32 + lane_idx] = x; - sum = __shfl_sync(0xffffffff, x, 31); - } - - const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs; - for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { - uint32_t seg_starts = sm_idx * q + min(sm_idx, r); - uint32_t q_idx = 0; - while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts) - ++ q_idx; - const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]); - __syncwarp(); - - schedule_metadata[sm_idx * 2] = q_idx; - schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; - } -} - -template -struct PagedMQALogitsScheduler { - uint32_t batch_size; - const uint32_t* context_lens; - - uint32_t current_q_idx, current_kv_idx; - uint32_t end_q_idx, end_kv_idx; - uint32_t current_num_kv; - - __device__ __forceinline__ uint32_t get_num_kv(const uint32_t& q_idx) { - const auto& lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); - return q_idx < batch_size ? ceil_div(__ldg(context_lens + lens_idx), BLOCK_KV) : 0; - } - - __device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx, - const uint32_t* context_lens, const uint32_t* schedule_meta) { - this->batch_size = batch_size; - this->context_lens = context_lens; - - const auto& current_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx); - const auto& end_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx + 1); - current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit; - end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit; - - current_num_kv = get_num_kv(current_q_idx); - } - - __device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) { - q_idx = current_q_idx; - kv_idx = current_kv_idx; - num_kv = current_num_kv; - - if (q_idx == end_q_idx and kv_idx == end_kv_idx) - return false; - - current_kv_idx += kNumBlocksPerSplit; - if (current_kv_idx >= current_num_kv) { - ++ current_q_idx; - current_kv_idx = 0; - current_num_kv = get_num_kv(current_q_idx); - } - - return true; - } - - __device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const { - return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx; - } -}; - -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) + uint32_t kNumTMAThreads, uint32_t kNumMathThreads, + typename logits_dtype_t> +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, - const uint64_t logits_stride, const uint64_t block_table_stride, - const uint32_t* context_lens, float* logits, - const uint32_t* block_table, const uint32_t* schedule_meta, + const uint32_t logits_stride, const uint32_t block_table_stride, + const uint32_t* context_lens, logits_dtype_t* logits, + const uint32_t* block_table, const uint32_t* indices, + const uint32_t* schedule_meta, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + DG_STATIC_ASSERT(not kIsVarlen, "Varlen is not supported for SM90 paged MQA logits"); + // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + const auto warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128; @@ -150,15 +63,15 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = math::constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) + - constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); + math::constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) + - constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); + math::constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); // Align to swizzling alignment bytes extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; @@ -166,31 +79,31 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); // Q data and barriers on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i); }); auto q_barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); // Separate math warpgroups and tma load warps into KV groups // Each math warpgroup corresponds to a tma load warp - const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); + const auto kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); // Per group KV data and barriers on shared memory - const auto& smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + const auto smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i); }); - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); }); auto kv_barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); // Initialize barriers if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) { @@ -218,15 +131,19 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, constexpr uint32_t kNumTMARegisters = 64; constexpr uint32_t kNumMathRegisters = 104; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Scheduler - auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + auto scheduler = sched::PagedMQALogitsScheduler( + blockIdx.x, batch_size, context_lens, schedule_meta, indices); DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); // Q and KV pipeline - const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase }; - const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase }; uint32_t q_iter_idx = 0, kv_iter_idx = 0; @@ -237,10 +154,10 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, if (kv_group_idx >= kNumMathWarpGroups) return; - const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { if (kv_group_idx == 0 and cute::elect_one_sync()) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx * kNextN); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); } }; @@ -259,7 +176,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, while (fetched_next_task) { // Prefetch next Q when current Q changes - bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); + bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_atom_idx(next_q_idx + 1)); q_idx = next_q_idx; kv_idx = next_kv_idx; num_kv = next_num_kv; @@ -276,9 +193,9 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, if (kv_idx == 0 or kv_block_idx_ptr == 32) { kv_block_idx_ptr = 0; kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ? - __ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0); + block_table[q_idx * static_cast(block_table_stride) + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)] : 0); } - const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); + const auto kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); // Wait KV consumer release CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); @@ -286,10 +203,10 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, // Issue TMA KV if (cute::elect_one_sync()) { - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], - smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], - smem_kv_scales[kv_stage_idx], 0, kv_block_idx); + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], 0, kv_block_idx); full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } @@ -301,9 +218,9 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, cutlass::arch::warpgroup_reg_alloc(); float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4]; - const auto& sub_warp_offset = (warp_idx % 4) * 16; - const auto& v_0_offset = lane_idx / 4 + 0; - const auto& v_1_offset = lane_idx / 4 + 8; + const auto sub_warp_offset = (warp_idx % 4) * 16; + const auto v_0_offset = lane_idx / 4 + 0; + const auto v_1_offset = lane_idx / 4 + 8; // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none uint32_t q_idx = batch_size, kv_idx; @@ -326,7 +243,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, for (uint32_t i = 0; i < kNextN; ++ i) { #pragma unroll for (uint32_t j = 0; j < kNumHeads / 4; ++ j) - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); } } @@ -335,7 +252,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, kv_idx = next_kv_idx; // Calculate KV offset in advance - auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); + auto kv_offset = q_idx * kNextN * static_cast(logits_stride) + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]` // Wait TMA KV arrival @@ -347,25 +264,29 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); - auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_a = mma::sm90::make_smem_desc( + smem_kv[kv_stage_idx] + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = mma::sm90::make_smem_desc( + smem_q[q_stage_idx] + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); WGMMA::wgmma(desc_a, desc_b, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum[i]); // Read per-KV scales - float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset); - float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset); + float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset); + float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset); // Wait WGMMA - warpgroup_wait<0>(); + ptx::warpgroup_wait<0>(); // Release KV empty empty_kv_barriers[kv_stage_idx]->arrive(); @@ -378,7 +299,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, #pragma unroll for (uint32_t i = 0; i < kNextN; ++ i) { auto shifted_accum = accum + i * kNumAccumPerReduce; - const auto& transform = [&](const uint32_t& j) { + const auto transform = [&](const uint32_t& j) { return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; }; @@ -396,15 +317,15 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, // Inter-thread reduction #pragma unroll for (uint32_t j = 0; j < 2; ++ j) { - const auto& offset = static_cast(1u << j); + const auto offset = static_cast(1u << j); v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); } // Store into the global memory // NOTES: we have redundant writes here, consider more carefully - logits[kv_offset + i * logits_stride + v_0_offset] = v_0; - logits[kv_offset + i * logits_stride + v_1_offset] = v_1; + logits[kv_offset + i * static_cast(logits_stride) + v_0_offset] = static_cast(v_0); + logits[kv_offset + i * static_cast(logits_stride) + v_1_offset] = static_cast(v_1); } } } diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh index e3bf9847..93b14100 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh @@ -5,20 +5,23 @@ #include #include -#include +#include #include -#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__device__ __forceinline__ +CUTLASS_DEVICE uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) { constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase; - const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange; + const auto bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange; constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups; @@ -35,7 +38,7 @@ template -__global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1) sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -56,7 +59,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Utils const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); // Align to 1024 bytes for swizzle-128B extern __shared__ __align__(1024) uint8_t smem_buffer[]; @@ -76,17 +79,17 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Data on shared memory (layout as ordered below) // Fill D/A/B pointers auto smem_cd = reinterpret_cast(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); // Initialize barriers if (warp_idx == 1 and cute::elect_one_sync()) { @@ -101,7 +104,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, } __syncthreads(); - constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K); constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); @@ -113,12 +116,15 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, constexpr uint32_t kNumTMARegisters = 40; constexpr uint32_t kNumMathRegisters = 256; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // TMA load warp if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { cutlass::arch::warpgroup_reg_dealloc(); for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait consumer release - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); // Compute offsets @@ -126,8 +132,8 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, uint32_t k_idx = k_offset + s * BLOCK_K; // Issue TMAs - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + tma::copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma::copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); // Arrive at full barriers constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; @@ -135,7 +141,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, } for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) { - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); } } else if (warp_idx < kNumMathThreads / 32) { @@ -148,7 +154,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, constexpr uint32_t WGMMA_N = BLOCK_N; constexpr uint32_t WGMMA_K = 8; - using WGMMA = typename TF32MMASelector::type; + using WGMMA = typename mma::sm90::TF32MMASelector::type; float accum[WGMMA::kNumAccum] = {0}; constexpr uint32_t kNumBankGroupBytes = 16; @@ -196,14 +202,14 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y; } - warpgroup_wait<0>(); + ptx::warpgroup_wait<0>(); if (s > 0) empty_barriers[(s - 1) % kNumStages]->arrive(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float); constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K; @@ -213,18 +219,19 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) { #pragma unroll for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) { - auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1); + auto b_desc = mma::sm90::make_smem_desc( + smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1); WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1); } } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum[i]); } - const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0); - const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1); + const auto& reduced_sum_0 = math::warp_reduce_sum<4>(sqr_sum_acc_0); + const auto& reduced_sum_1 = math::warp_reduce_sum<4>(sqr_sum_acc_1); const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4); if (lane_idx % 4 == 0) { @@ -233,7 +240,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, if (m_idx + 8 < shape_m) sqr_sum[m_offset + m_idx + 8] = reduced_sum_1; } - warpgroup_wait<0>(); + ptx::warpgroup_wait<0>(); empty_barriers[(num_total_stages-1) % kNumStages]->arrive(); // Write accum to shared memory @@ -260,8 +267,8 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // 0/1 write to the same row, 2/3 write to another row auto values = reinterpret_cast(accum + i * 2); - st_shared(smem_ptr, values[0], values[1]); - st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]); + ptx::st_shared(smem_ptr, values[0], values[1]); + ptx::st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]); } cute::tma_store_fence(); cutlass::arch::NamedBarrier::sync(128, 1); diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh index cc9e5e6b..2f66b980 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh @@ -3,21 +3,24 @@ #include #include -#include +#include +#include namespace deep_gemm { -template -__global__ __launch_bounds__(kNumWarps * 32, 1) +template +CUTLASS_GLOBAL __launch_bounds__(kNumWarps * 32, 1) void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits, - const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) { - const uint32_t& num_sms = gridDim.x; - const uint32_t& sm_idx = blockIdx.x; - const uint32_t& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - constexpr float neg_inf = -cute::numeric_limits::infinity(); + const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, logits_dtype_t* logits) { + const uint32_t num_sms = gridDim.x; + const uint32_t sm_idx = blockIdx.x; + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + constexpr uint32_t kAlignment = 16 / sizeof(logits_dtype_t); + const logits_dtype_t neg_inf = -cute::numeric_limits::infinity(); // Allocate filled `-inf` shared memory - extern __shared__ __align__(1024) float smem_buffer[]; + extern __shared__ __align__(1024) logits_dtype_t smem_buffer[]; #pragma unroll for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32) smem_buffer[i] = neg_inf; @@ -25,38 +28,42 @@ void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const __syncthreads(); // Assign sequence to each warp - const auto& assign_task = [&](const uint32_t& num, const uint32_t& idx, - const uint32_t& start, const uint32_t& total) -> cute::tuple { - const auto& per = total / num, rem = total % num; - return {start + idx * per + min(idx, rem), per + (idx < rem)}; + const auto assign_task = [&](const uint32_t& num, const uint32_t& idx, + const uint32_t& start, const uint32_t& total) -> cute::tuple { + const auto per = total / num, rem = total % num; + return {start + idx * per + cute::min(idx, rem), per + (idx < rem)}; }; CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len); CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (cute::elect_one_sync()) { for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { - const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); - const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; - const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN]; + const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1; + const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment; for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) { - const auto& right = min(left + BLOCK_KV, static_cast(stride_logits)); + const auto right = cute::min(left + BLOCK_KV, static_cast(stride_logits)); if (right <= ks or ke <= left) { - cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(float)); + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(logits_dtype_t)); } else { if (left < aligned_ks) - cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(float)); + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(logits_dtype_t)); if (aligned_ke < right) - cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(float)); + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(logits_dtype_t)); } } } } + __syncwarp(); for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { - const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); - const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; - const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN]; + const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1; + const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment; for (uint32_t j = aligned_ks; j < ks; ++ j) logits[i * stride_logits + j] = neg_inf; for (uint32_t j = ke; j < aligned_ke; ++ j) diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh index bea70002..a977c554 100644 --- a/deep-gemm/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh @@ -1,13 +1,16 @@ #pragma once +#include #include +#include +#include namespace deep_gemm { template -__global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { - typedef typename Vectorized::vec_t in_vec_t; +CUTLASS_GLOBAL void transpose_fp32(const float* sf, float* out, const uint32_t mn) { + typedef typename utils::Vectorized::vec_t in_vec_t; constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float); constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec; @@ -15,16 +18,19 @@ __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { extern __shared__ float smem_buffer[]; constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(float)); const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); - const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + const auto tma_aligned_mn = math::align(mn, kNumTMAAlignedElems); // Shift into the block sf = sf + static_cast(blockIdx.y) * mn * SF_K; out = out + static_cast(blockIdx.y) * tma_aligned_mn * SF_K; const auto& local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Load for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) { - auto in_vec = __ldg(local_sf + i); + auto in_vec = local_sf[i]; const auto& in_values = reinterpret_cast(&in_vec); const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec; @@ -39,26 +45,29 @@ __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) { const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn; const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; - out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx); + out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ptx::ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx); } } // NOTES: the two kernels below always pack the K dimension template -__global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) { +CUTLASS_GLOBAL void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) { extern __shared__ uint32_t smem_buffer[]; // Shapes and strides - constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u); + constexpr auto kNumPackedSFK = math::constexpr_ceil_div(SF_K, 4u); constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(int)); const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); - const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + const auto tma_aligned_mn = math::align(mn, kNumTMAAlignedElems); // Shift into the group sf = sf + static_cast(blockIdx.y) * mn * SF_K; out = out + static_cast(blockIdx.y) * tma_aligned_mn * kNumPackedSFK; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Load FP32 SFs DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size"); const auto local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); @@ -66,13 +75,13 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con const auto num_uint4 = num_values / 4; #pragma unroll for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) { - const auto& [x, y, z, w] = __ldg(reinterpret_cast(local_sf) + i); - st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); + const auto& [x, y, z, w] = reinterpret_cast(local_sf)[i]; + ptx::st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); } // Fill unaligned values as well if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values) - st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx)); + ptx::st_shared(smem_buffer + unaligned_idx, local_sf[unaligned_idx]); __syncthreads(); // Pack into UE8M0 and store @@ -85,7 +94,7 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con #pragma unroll for (uint32_t j = 0; j < 4; ++ j) { const auto sf_k_idx = sf_k_pack_idx * 4 + j; - values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0; + values[j] = sf_k_idx < SF_K ? ptx::ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0; } // Pack and store @@ -101,8 +110,9 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con template -__global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, - const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) { +CUTLASS_GLOBAL void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, + const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k, + const uint32_t gran_k) { // Always packing the K dimension // NOTES: should also assert `mn % 4 == 0` at launch DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)"); @@ -120,11 +130,14 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, // Each warp is responsible for a packed row const auto warp_idx = threadIdx.x / 32; - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); const auto packed_sf_k_idx = static_cast(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx; if (warp_idx >= in_block_packed_sf_k) return; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Make an offset on the input uint32_t input_offset = 0; if constexpr (kNumGroups > 1) { @@ -134,18 +147,18 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, #pragma unroll for (uint32_t i = 0; i < 4; ++ i) { const auto group_idx = lane_idx * 4 + i; - group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0; + group_ks[i] = group_idx < kNumGroups ? ks[group_idx] : 0; } __syncwarp(); // Make the offset sf_k = 0; - auto sum_packed_sf_k = 0; + uint32_t sum_packed_sf_k = 0; #pragma unroll for (uint32_t i = 0; i < kNumGroups; ++ i) { - const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4); + const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / gran_k, i / 4); sf_k += sf_k_in_group; - sum_packed_sf_k += ceil_div(sf_k_in_group, 4u); + sum_packed_sf_k += math::ceil_div(sf_k_in_group, 4u); if (packed_sf_k_idx < sum_packed_sf_k) break; if (const auto remainder = sf_k_in_group % 4; remainder > 0) @@ -153,14 +166,14 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, } } - for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) { + for (uint32_t mn_idx = ptx::get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) { // Load uint4 values[4]; #pragma unroll for (uint32_t j = 0; j < 4; ++ j) { values[j] = make_uint4(0, 0, 0, 0); if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k) - values[j] = __ldg(reinterpret_cast(sf + sf_k_idx * mn) + mn_idx); + values[j] = reinterpret_cast(sf + sf_k_idx * mn)[mn_idx]; } // Pack and store diff --git a/deep-gemm/deep_gemm/include/deep_gemm/layout/mega_moe.cuh b/deep-gemm/deep_gemm/include/deep_gemm/layout/mega_moe.cuh new file mode 100644 index 00000000..13520c60 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/layout/mega_moe.cuh @@ -0,0 +1,260 @@ +#pragma once + +#include + +#include +#include + +namespace deep_gemm::layout { + +static constexpr int kNumCandidateBlockMs = 7; +static constexpr int kCandidateBlockM[kNumCandidateBlockMs] = {8, 16, 32, 64, 96, 128, 192}; +static constexpr int kMaxCandidateBlockM = 192; +static constexpr int kMinCandidateBlockM = 8; +static constexpr int kLCMCandidateBlockM = 384; + +// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding, among all possible BLOCK_M +template +CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk, + T num_experts_per_rank) { + const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank; + const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank); + return math::constexpr_align( + num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (static_cast(kMaxCandidateBlockM) - 1), + static_cast(kLCMCandidateBlockM)); +} + +// SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M +template +CUTLASS_HOST_DEVICE constexpr T get_num_padded_sf_pool_tokens(T num_max_pool_tokens, T block_m) { + return (num_max_pool_tokens / block_m) * math::constexpr_align(block_m, static_cast(128)); +} + +// Per-token source metadata for combine write-back +struct TokenSrcMetadata { + uint32_t rank_idx; + uint32_t token_idx; + uint32_t topk_idx; +}; + +struct Workspace { + void* base; + uint32_t num_ranks, num_experts; + uint32_t num_experts_per_rank; + uint32_t num_max_tokens_per_rank; + uint32_t num_max_recv_tokens_per_expert; + + // Pool capacity: all local experts share a contiguous token pool + uint32_t num_max_pool_tokens; + uint32_t num_max_pool_blocks; + + // For both grid barrier and NVLink barrier + static constexpr uint64_t kNumBarrierSignalBytes = 32; + + CUTLASS_HOST_DEVICE + Workspace(void* base, + const uint32_t& num_ranks, + const uint32_t& num_experts, + const uint32_t& num_max_tokens_per_rank, + const uint32_t& num_topk): + base(base), + num_ranks(num_ranks), num_experts(num_experts), + num_max_tokens_per_rank(num_max_tokens_per_rank) { + num_experts_per_rank = num_experts / num_ranks; + num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank; + num_max_pool_tokens = get_num_max_pool_tokens(num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank); + num_max_pool_blocks = num_max_pool_tokens / kMinCandidateBlockM; + } + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes() const { + uint64_t num_bytes = 0; + + // Barrier + num_bytes += kNumBarrierSignalBytes; + + // Expert send/recv count + num_bytes += num_experts * sizeof(uint64_t) * 2; + + // Expert recv count sum + num_bytes += num_experts_per_rank * sizeof(uint64_t); + + // L1 arrival count (padded to even entry count for `uint64_t` alignment of L2 mask) + num_bytes += math::align(num_max_pool_blocks, 2u) * sizeof(uint32_t); + + // L2 block arrival mask + num_bytes += num_max_pool_blocks * sizeof(uint64_t); + + // Dispatch pulling source token-topk + num_bytes += num_experts_per_rank * num_ranks * num_max_recv_tokens_per_expert * sizeof(int); + + // Combine push source indices + num_bytes += num_max_pool_tokens * sizeof(TokenSrcMetadata); + + // Align to TMA descriptor requirements + num_bytes = math::align(num_bytes, 16); + return num_bytes; + } + + CUTLASS_HOST_DEVICE + void* get_end_ptr() const { + return math::advance_ptr(base, get_num_bytes()); + } + + // Grid sync counters: `kNumBarrierSignalBytes` layout + // [ 0..15]: 4 x `uint32_t` grid sync counters + // [16..20]: `uint32_t` NVLink barrier counter + // [20..27]: 2 x `int` NVLink barrier signals (phase 0 and 1) + static constexpr uint32_t kNumMaxGridSyncCounters = 4; + + template + CUTLASS_DEVICE + uint32_t* get_grid_sync_count_ptr() const { + DG_STATIC_ASSERT(kIndex < kNumMaxGridSyncCounters, "Grid sync index out of bounds"); + return static_cast(base) + kIndex; + } + + CUTLASS_DEVICE + uint32_t* get_nvl_barrier_counter_ptr() const { + return static_cast(base) + kNumMaxGridSyncCounters; + } + + CUTLASS_DEVICE + int* get_nvl_barrier_signal_ptr(const uint32_t& phase) const { + // NOTES: the signal is signed, as we may minus + return math::advance_ptr(base, (kNumMaxGridSyncCounters + 1) * sizeof(uint32_t) + phase * sizeof(int)); + } + + CUTLASS_DEVICE + uint64_t* get_expert_send_count_ptr(const uint32_t& expert_idx = 0) const { + return math::advance_ptr(base, kNumBarrierSignalBytes) + expert_idx; + } + + CUTLASS_DEVICE + uint64_t* get_expert_recv_count_ptr( + const uint32_t& rank_idx = 0, const uint32_t& expert_idx = 0) const { + return get_expert_send_count_ptr(num_experts) + rank_idx * num_experts_per_rank + expert_idx; + } + + CUTLASS_DEVICE + uint64_t* get_expert_recv_count_sum_ptr(const uint32_t& expert_idx = 0) const { + return get_expert_send_count_ptr(num_experts * 2) + expert_idx; + } + + CUTLASS_DEVICE + uint32_t* get_l1_arrival_count_ptr(const uint32_t& pool_block_idx = 0) const { + const auto base = get_expert_recv_count_sum_ptr(num_experts_per_rank); + return reinterpret_cast(base) + pool_block_idx; + } + + CUTLASS_DEVICE + uint64_t* get_l2_arrival_mask_ptr(const uint32_t& pool_block_idx = 0) const { + // Pad L1 entry count to even so that the `l2_arrival_mask` is 8-byte aligned + const auto base = get_l1_arrival_count_ptr(math::align(num_max_pool_blocks, 2u)); + return reinterpret_cast(base) + pool_block_idx; + } + + // For dispatch pulling + CUTLASS_DEVICE + uint32_t* get_src_token_topk_idx_ptr( + const uint32_t& expert_idx = 0, const uint32_t& rank_idx = 0, const uint32_t& token_idx = 0) const { + const auto base = get_l2_arrival_mask_ptr(num_max_pool_blocks); + return reinterpret_cast(base) + + expert_idx * (num_ranks * num_max_recv_tokens_per_expert) + + rank_idx * num_max_recv_tokens_per_expert + token_idx; + } + + // For combine usages + CUTLASS_DEVICE + TokenSrcMetadata* get_token_src_metadata_ptr(const uint32_t& pool_token_idx = 0) const { + const auto base = reinterpret_cast(get_src_token_topk_idx_ptr(num_experts_per_rank)); + return base + pool_token_idx; + } +}; + +struct Data { + uint32_t num_bytes; + bool require_tma_alignment; + void* base; + + CUTLASS_HOST_DEVICE + constexpr explicit Data( + const uint32_t& num_bytes, + const bool& require_tma_alignment = true, + void* base = nullptr) : + num_bytes(num_bytes), require_tma_alignment(require_tma_alignment), base(base) { + DG_UNIFIED_ASSERT(num_bytes % 16 == 0 or not require_tma_alignment); + } + + template + CUTLASS_HOST_DEVICE constexpr dtype_t get_num_bytes() const { + return static_cast(num_bytes); + } + + template + CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const { + return static_cast(base); + } + + CUTLASS_HOST_DEVICE void set_base_ptr(void* ptr) { + base = ptr; + } +}; + +struct Buffer { + Data data_layout; + uint32_t num_ranks; + uint32_t num_max_tokens_per_rank; + + void* base; + + CUTLASS_HOST_DEVICE + Buffer(const Data& data_layout, + const uint32_t& num_ranks, + const uint32_t& max_num_tokens_per_rank, + void* base = nullptr) : + data_layout(data_layout), + num_ranks(num_ranks), num_max_tokens_per_rank(max_num_tokens_per_rank), + base(base) {} + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes_per_rank() const { + return num_max_tokens_per_rank * data_layout.get_num_bytes(); + } + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes() const { + return get_num_bytes_per_rank() * num_ranks; + } + + template + CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const { + return static_cast(base); + } + + CUTLASS_HOST_DEVICE + void* get_end_ptr() const { + return math::advance_ptr(base, get_num_bytes()); + } + + CUTLASS_HOST_DEVICE + Buffer get_rank_buffer(const uint32_t& rank_idx) const { + return { + data_layout, + 1, num_max_tokens_per_rank, + math::advance_ptr(base, get_num_bytes_per_rank() * rank_idx) + }; + } + + CUTLASS_HOST_DEVICE + Data get_data_buffer(const uint32_t& token_idx, const bool& global = false) const { + DG_DEVICE_ASSERT(num_ranks == 1 or global); + return Data( + data_layout.num_bytes, + data_layout.require_tma_alignment, + math::advance_ptr(base, data_layout.get_num_bytes() * token_idx) + ); + } +}; + +} // namespace deep_gemm::layout diff --git a/deep-gemm/deep_gemm/include/deep_gemm/layout/sym_buffer.cuh b/deep-gemm/deep_gemm/include/deep_gemm/layout/sym_buffer.cuh new file mode 100644 index 00000000..7f11aabc --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/layout/sym_buffer.cuh @@ -0,0 +1,41 @@ +#pragma once + +#include + +namespace deep_gemm::layout { + +constexpr static uint32_t kNumMaxRanks = 72; + +template +struct SymBuffer { + int64_t base; + int64_t offsets[kNumMaxRanks]; + uint32_t rank_idx; + + DG_STATIC_ASSERT(kNumRanks <= kNumMaxRanks, "Too many ranks"); + + SymBuffer() = default; + + template + explicit SymBuffer(const Container& c, const uint32_t& rank_idx): rank_idx(rank_idx) { + const auto size = static_cast(c.size()); + base = c[rank_idx]; + for (uint32_t i = 0; i < kNumMaxRanks; ++ i) + offsets[i] = i < size ? (c[i] - base) : 0; + } + +#if defined(__CUDA_ARCH__) or defined(__CLION_IDE__) + template + CUTLASS_DEVICE ptr_t get_base_ptr() const { + return reinterpret_cast(base); + } + + template + CUTLASS_DEVICE ptr_t map(const ptr_t& ptr, const uint32_t& dst_rank_idx) const { + int64_t mapped_ptr = offsets[dst_rank_idx] + reinterpret_cast(ptr); + return *reinterpret_cast(&mapped_ptr); + } +#endif +}; + +} // namespace deep_gemm::layout diff --git a/deep-gemm/deep_gemm/include/deep_gemm/mma/sm100.cuh b/deep-gemm/deep_gemm/include/deep_gemm/mma/sm100.cuh new file mode 100644 index 00000000..0c554f4c --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/mma/sm100.cuh @@ -0,0 +1,151 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace deep_gemm::mma::sm100 { + +/// Shared memory descriptor +CUTLASS_DEVICE +cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr, + const uint32_t& stride_byte_offset, const uint32_t& leading_byte_offset) { + cute::UMMA::SmemDescriptor desc; + + // Set the version for SM100 + desc.version_ = 1; + + // Legacy mode + desc.lbo_mode_ = 0; + + // Layout + desc.layout_type_ = static_cast(layout); + + // Start address + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); + + // Base offset + desc.base_offset_ = 0; + + // SBO and LBO + desc.stride_byte_offset_ = stride_byte_offset >> 4; + desc.leading_byte_offset_ = leading_byte_offset >> 4; + + return desc; +} + +CUTLASS_DEVICE +cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) { + // NOTES: the UTCCP layout is K-major by default + // Atom size: 8 x 128 bits + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero + return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0); +} + +CUTLASS_DEVICE +void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) { + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); +} + +CUTLASS_DEVICE +static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) { + return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16; +} + +/// UMMA descriptors +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::UMMA::LayoutType to_umma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + // A special case + if constexpr ((cute::is_same_v and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) { + DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base"); + return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B; + } + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B; + if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B; + if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B; +} + +template +CUTLASS_DEVICE +constexpr uint32_t get_umma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : tma::get_inner_block_atom_size(); +} + +template +CUTLASS_DEVICE +uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) { + return base + (((offset + k_idx * get_umma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +CUTLASS_DEVICE +cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_umma_desc_stride_k(); + const auto layout_type = to_umma_layout_type(); + const auto num_non_contiguous = 128 / get_atom_base(layout_type); + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)` + // also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = tma::get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + math::swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } +} + +CUTLASS_DEVICE uint64_t make_runtime_instr_desc_with_sf_id( + cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) { + desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id; + return static_cast(static_cast(desc)) << 32; +} + +CUTLASS_DEVICE void update_instr_desc_with_umma_n( + cute::UMMA::InstrDescriptorBlockScaled& desc, const uint32_t& umma_n) { + desc.n_dim_ = umma_n >> 3; +} + +CUTLASS_DEVICE void update_instr_desc_with_umma_n( + cute::UMMA::InstrDescriptor& desc, const uint32_t& umma_n) { + desc.n_dim_ = umma_n >> 3; +} + +} // namespace deep_gemm::mma::sm100 diff --git a/deep-gemm/deep_gemm/include/deep_gemm/mma/sm90.cuh b/deep-gemm/deep_gemm/include/deep_gemm/mma/sm90.cuh new file mode 100644 index 00000000..2c061940 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/mma/sm90.cuh @@ -0,0 +1,293 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace deep_gemm::mma::sm90 { + +/// MMA +template +struct FP8MMA { + template + CUTLASS_DEVICE static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + CUTLASS_DEVICE static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct FP8MMASelector { + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN(); + } + + static constexpr auto select_type() { + return FP8MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct BF16MMA { + template + CUTLASS_DEVICE static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + CUTLASS_DEVICE static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 16; + static constexpr int kNumAccum = M * N / 128; +}; + +template +constexpr cute::SM90::GMMA::Major to_sm90_major() { + DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness"); + return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN; +} + +template +struct BF16MMASelector { + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + constexpr auto kGMMAMajorA = to_sm90_major(); + constexpr auto kGMMAMajorB = to_sm90_major(); + if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS(); + if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS(); + if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS(); + if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS(); + if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS(); + if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS(); + if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS(); + if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS(); + if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS(); + if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS(); + if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS(); + if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS(); + if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS(); + if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS(); + if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS(); + if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS(); + if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS(); + if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS(); + if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS(); + if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS(); + if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS(); + if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS(); + if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS(); + if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS(); + if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS(); + if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS(); + if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS(); + if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS(); + if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS(); + if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS(); + if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS(); + if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS(); + } + + static constexpr auto select_type() { + return BF16MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct TF32MMARS { + template + CUTLASS_DEVICE static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + CUTLASS_DEVICE static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(reinterpret_cast(a), desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 8; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct TF32MMASelector { + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (kUseRS) { + if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN(); + if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN(); + if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN(); + if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN(); + if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN(); + if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN(); + DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N"); + } + } + + static constexpr auto select_type() { + if constexpr (kUseRS) { + return TF32MMARS(); + } else { + DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now"); + } + } + + using type = decltype(select_type()); +}; + +/// Shared memory descriptor +template +CUTLASS_DEVICE cute::GmmaDescriptor +make_smem_desc(PointerType smem_ptr, const int& layout_type, + const uint32_t& leading_byte_offset = 0, + const uint32_t& stride_byte_offset = 1024) { + // NOTES: the default LBO and SBO are for K-major types + cute::GmmaDescriptor desc; + const auto uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + desc.bitfield.start_address_ = uint_ptr >> 4; + desc.bitfield.layout_type_ = layout_type; + desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; + desc.bitfield.base_offset_ = 0; + return desc; +} + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +CUTLASS_DEVICE +constexpr uint32_t get_gmma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32; + if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64; + if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128; +} + +template +CUTLASS_DEVICE +uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) { + return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +CUTLASS_DEVICE +cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_gmma_desc_stride_k(); + const auto layout_type = to_gmma_layout_type(); + constexpr uint32_t num_non_contiguous = 128 / 16; + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + math::swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +static constexpr int to_swizzle_cute_type() { + DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling"); + if constexpr (kHeadDim == 32) + return static_cast(cute::SM90::GMMA::LayoutType::B32); + if constexpr (kHeadDim == 64) + return static_cast(cute::SM90::GMMA::LayoutType::B64); + if constexpr (kHeadDim == 128) + return static_cast(cute::SM90::GMMA::LayoutType::B128); +} + +} // namespace deep_gemm::mma::sm90 diff --git a/deep-gemm/deep_gemm/include/deep_gemm/ptx/ld_st.cuh b/deep-gemm/deep_gemm/include/deep_gemm/ptx/ld_st.cuh new file mode 100644 index 00000000..c3e03bec --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/ptx/ld_st.cuh @@ -0,0 +1,251 @@ +#pragma once + +#include +#include + +namespace deep_gemm::ptx { + +// Compatibility: 256 bits LD/ST instructions +#if defined(CUDART_VERSION) and CUDART_VERSION >= 13000 +using longlong4_t = longlong4_32a; +#define make_longlong4_t make_longlong4_32a +#else +struct alignas(32) longlong4_t { long long x, y, z, w; }; +CUTLASS_HOST_DEVICE longlong4_t make_longlong4_t( + const long long& x, const long long& y, const long long& z, const long long& w) { + return {x, y, z, w}; +} +#endif + +/// LD/ST matrix +// TODO: remove `struct` +struct SM90_U32x2_LDSM_N { + CUTLASS_DEVICE static void + copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst_0), "=r"(dst_1) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +struct SM90_U32x4_LDSM_N { + CUTLASS_DEVICE static void + copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +template +struct SM90_U32x2_STSM_N { + CUTLASS_DEVICE static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); + } +}; + +template +struct SM90_U32x4_STSM_T { + CUTLASS_DEVICE static void + copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src[4] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1), + *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; + asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16.trans [%0], {%1, %2, %3, %4};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), + "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); + } +}; + +template +struct SM100_U8x4_STSM_T { + __device__ __forceinline__ static void + copy(dtype_t src_0, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src = *reinterpret_cast(&src_0); + asm volatile("stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 [%0], {%1};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src)); + } +}; + +template +struct SM100_U8x8_STSM_T { + __device__ __forceinline__ static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 [%0], {%1, %2};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); + } +}; + +/// Shared memory +CUTLASS_DEVICE uint32_t ld_shared(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE float2 ld_shared(const float2* ptr) { + float2 ret; + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE float4 ld_shared(const float4* ptr) { + float4 ret; + asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE uint4 ld_shared(const uint4* ptr) { + uint4 ret; + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE float ld_shared(const float* ptr) { + float ret; + asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE void st_shared(const float* ptr, float val) { + asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val)); +} + +CUTLASS_DEVICE void st_shared(const float2* ptr, float2 val) { + asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y)); +} + +CUTLASS_DEVICE void st_shared(const uint32_t* ptr, uint32_t val) { + asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val)); +} + +CUTLASS_DEVICE void st_shared(const void* ptr, uint32_t x, uint32_t y) { + asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y)); +} + +CUTLASS_DEVICE void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w)); +} + +CUTLASS_DEVICE void st_shared(const __int128_t* ptr, __int128_t val) { + asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); +} + +CUTLASS_DEVICE void st_shared_bulk(void* smem_ptr, const uint32_t& num_bytes) { + // `size` must be 64-bit before PTX ISA 9.0 + asm volatile("st.bulk.weak.shared::cta [%0], %1, 0;" :: + "l"(__cvta_generic_to_shared(smem_ptr)), "l"(static_cast(num_bytes))); +} + +/// Global memory +CUTLASS_DEVICE uint64_t ld_volatile(const uint64_t* ptr) { + uint64_t ret; + asm volatile("ld.volatile.global.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint32_t ld_acq(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.acquire.gpu.global.b32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint64_t ld_acq_sys(const uint64_t* ptr) { + uint64_t ret; + asm volatile("ld.acquire.sys.global.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE void st_relaxed_sys(const uint64_t* ptr, const uint64_t& value) { + asm volatile("st.L1::no_allocate.relaxed.sys.global.u64 [%0], %1;" :: "l"(ptr), "l"(value)); +} + +/// Atomics +CUTLASS_DEVICE uint64_t atomic_add(const uint64_t* ptr, const uint64_t& value) { + uint64_t ret; + asm volatile("atom.global.add.u64 %0, [%1], %2;" : "=l"(ret) : "l"(ptr), "l"(value)); + return ret; +} + +CUTLASS_DEVICE uint64_t atomic_add_sys(const uint64_t* ptr, const uint64_t& value) { + uint64_t ret; + asm volatile("atom.sys.global.add.u64 %0, [%1], %2;" : "=l"(ret) : "l"(ptr), "l"(value)); + return ret; +} + +CUTLASS_DEVICE uint32_t atomic_add_rel(const uint32_t* ptr, const uint32_t& value) { + uint32_t ret; + asm volatile("atom.release.gpu.global.add.u32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); + return ret; +} + +CUTLASS_DEVICE void red_add(const int* ptr, const int& value) { + asm volatile("red.gpu.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_add(const uint32_t* ptr, const uint32_t& value) { + asm volatile("red.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_or_rel_sys(const uint64_t* ptr, const uint64_t& value) { + asm volatile("red.release.sys.global.or.b64 [%0], %1;" :: "l"(ptr), "l"(value)); +} + +CUTLASS_DEVICE void red_or_rel_gpu(uint64_t* ptr, const uint64_t& value) { + asm volatile("red.release.gpu.global.or.b64 [%0], %1;" :: "l"(ptr), "l"(value)); +} + +CUTLASS_DEVICE void red_add_rel(const uint32_t* ptr, const uint32_t& value) { + asm volatile("red.release.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_add_rel_sys(const int* ptr, const int& value) { + asm volatile("red.release.sys.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE int ld_acq_sys(const int* ptr) { + int ret; + asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint32_t ld_acq_sys(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint64_t ld_acq_gpu(const uint64_t* ptr) { + uint64_t ret; + asm volatile("ld.acquire.gpu.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +/// Predicated loads +CUTLASS_DEVICE longlong4_t ld_gez_pred(const longlong4_t* ptr, const int& pred) { + longlong4_t ret = make_longlong4_t(0, 0, 0, 0); + asm volatile( + "{\n\t" + " .reg .pred p;\n\t" + " setp.ge.s32 p, %5, 0;\n\t" + " @p ld.global.L2::256B.v4.s64 {%0, %1, %2, %3}, [%4];\n\t" + "}" + : "+l"(ret.x), "+l"(ret.y), "+l"(ret.z), "+l"(ret.w) + : "l"(ptr), "r"(pred) + : "memory"); + return ret; +} + +/// Prefetch +CUTLASS_DEVICE void prefetch_l1(void *ptr) { + asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); +} + +} // namespace deep_gemm::ptx diff --git a/deep-gemm/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh b/deep-gemm/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh new file mode 100644 index 00000000..528b3dd1 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh @@ -0,0 +1,168 @@ +#pragma once + +namespace deep_gemm::ptx { + +/// UMMA versions with relaxed assertions +struct SM100_MMA_F16BF16_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F16BF16_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF8F6F4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_MXF8F6F4_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_F8F6F4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F8F6F4_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], %1, %2, %3, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_F16BF16_WS_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +/// Tensor memory operations +CUTLASS_DEVICE void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +CUTLASS_DEVICE void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +} // namespace deep_gemm::ptx diff --git a/deep-gemm/deep_gemm/include/deep_gemm/ptx/tma.cuh b/deep-gemm/deep_gemm/include/deep_gemm/ptx/tma.cuh new file mode 100644 index 00000000..1530a3ed --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/ptx/tma.cuh @@ -0,0 +1,112 @@ +#pragma once + +#include +#include + +namespace deep_gemm::ptx { + +// Tensor-map instructions +CUTLASS_DEVICE void tensor_map_release_gpu() { + asm volatile ("fence.proxy.tensormap::generic.release.gpu;" ::: "memory"); +} + +CUTLASS_DEVICE void tensor_map_acquire_gpu(const cute::TmaDescriptor* gmem_desc_ptr) { + auto gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + asm volatile ("fence.proxy.tensormap::generic.acquire.gpu [%0], 128;" :: "l"(gmem_int_desc) : "memory"); +} + +CUTLASS_DEVICE void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) { + auto smem_int_desc = static_cast(__cvta_generic_to_shared(smem_desc)); + const auto new_int64_addr = reinterpret_cast(new_addr); + asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr)); +} + +CUTLASS_DEVICE void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) { + auto smem_int_desc = __cvta_generic_to_shared(smem_desc); + asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim)); +#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3))) + asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride)); +#else + DG_STATIC_ASSERT(false, "Invalid CUDA version"); +#endif +} + +/// TMA instructions +CUTLASS_DEVICE void mbarrier_arrive( + cutlass::arch::ClusterTransactionBarrier* ptr) { + asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0]; \n\t" :: + "r"(static_cast(__cvta_generic_to_shared(ptr)))); +} + +CUTLASS_DEVICE void mbarrier_arrive_and_set_tx( + cutlass::arch::ClusterTransactionBarrier* ptr, const uint32_t& num_bytes) { + asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: + "r"(num_bytes), "r"(static_cast(__cvta_generic_to_shared(ptr)))); +} + +CUTLASS_DEVICE void mbarrier_wait_and_flip_phase( + cutlass::arch::ClusterTransactionBarrier* ptr, uint32_t& phase) { + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + "}" :: + "r"(static_cast(__cvta_generic_to_shared(ptr))), + "r"(phase), "r"(0x989680)); + phase ^= 1; +} + +CUTLASS_DEVICE void tma_load_1d( + const void* dst_ptr, const void* src_ptr, + cutlass::arch::ClusterTransactionBarrier* mbarrier_ptr, + const uint32_t& num_bytes, + const cute::TMA::CacheHintSm90& hint = cute::TMA::CacheHintSm90::EVICT_FIRST) { + // NOTES: normally, the loaded part will be evicted soon + asm volatile( + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\n" :: + "r"(static_cast(__cvta_generic_to_shared(dst_ptr))), + "l"(src_ptr), + "r"(num_bytes), + "r"(static_cast(__cvta_generic_to_shared(mbarrier_ptr))), + "l"(hint) + : "memory"); +} + +CUTLASS_DEVICE void tma_store_1d( + const void* dst_ptr, const void* src_ptr, const uint32_t& num_bytes, + const cute::TMA::CacheHintSm90& hint = cute::TMA::CacheHintSm90::EVICT_NORMAL) { + // NOTES: normally, the stored part will be used soon + asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\n" :: + "l"(dst_ptr), + "r"(static_cast(__cvta_generic_to_shared(src_ptr))), + "r"(num_bytes), + "l"(hint) + : "memory"); +} + +template +__forceinline__ __device__ void tma_store_wait() { + // NOTES: this function does not have `.read` + asm volatile("cp.async.bulk.wait_group %0;" ::"n"(kNumRemainingWaits) : "memory"); +} + +CUTLASS_DEVICE +void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier& mbarrier, + void* smem_ptr, const uint32_t& col_idx, const int4& row_idxs, const uint64_t& cache_hint) { + const auto smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + const auto mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbarrier_addr), "l"(cache_hint) + : "memory" + ); +} + +} // namespace deep_gemm::ptx diff --git a/deep-gemm/deep_gemm/include/deep_gemm/ptx/utils.cuh b/deep-gemm/deep_gemm/include/deep_gemm/ptx/utils.cuh new file mode 100644 index 00000000..5c27166b --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/ptx/utils.cuh @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +#include + +namespace deep_gemm::ptx { + +CUTLASS_DEVICE uint32_t get_sm_idx() { + uint32_t sm_idx; + asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx)); + return sm_idx; +} + +CUTLASS_DEVICE uint32_t get_lane_idx() { + uint32_t lane_id; + asm ("mov.u32 %0, %%laneid;" : "=r"(lane_id)); + return lane_id; +} + +CUTLASS_DEVICE void sync_aligned(const uint32_t& num_threads, const uint32_t& barrier_idx) { + asm volatile("bar.sync %0, %1;" : : "r"(barrier_idx), "r"(num_threads)); +} + +CUTLASS_DEVICE void sync_unaligned(const uint32_t& num_threads, const uint32_t& barrier_idx) { + asm volatile("barrier.sync %0, %1;" : : "r"(barrier_idx), "r"(num_threads)); +} + +template +CUTLASS_DEVICE dtype_t exchange(dtype_t ptr, const uint32_t& src_lane_idx) { + DG_STATIC_ASSERT(sizeof(dtype_t) % sizeof(uint32_t) == 0, ""); + const auto send_int_values = reinterpret_cast(&ptr); + dtype_t recv_dtype; + auto recv_int_values = reinterpret_cast(&recv_dtype); + #pragma unroll + for (uint32_t i = 0; i < sizeof(dtype_t) / sizeof(uint32_t); ++ i) + recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], static_cast(src_lane_idx)); + return recv_dtype; +} + +CUTLASS_DEVICE void accumulate(float2& a, nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000) + // Use `add.rn.f32.bf16` instruction to perform fused (cast + add) operation on SM100 + asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.x) : "h"(*reinterpret_cast(&b.x))); + asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.y) : "h"(*reinterpret_cast(&b.y))); +#else + const auto [x, y] = __bfloat1622float2(b); + a.x += x, a.y += y; +#endif +} + +} // namespace deep_gemm::ptx diff --git a/deep-gemm/deep_gemm/include/deep_gemm/ptx/wgmma.cuh b/deep-gemm/deep_gemm/include/deep_gemm/ptx/wgmma.cuh new file mode 100644 index 00000000..8912a157 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/ptx/wgmma.cuh @@ -0,0 +1,25 @@ +#pragma once + +#include + +namespace deep_gemm::ptx { + +CUTLASS_DEVICE void warpgroup_arrive() { + asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); +} + +CUTLASS_DEVICE void warpgroup_commit_batch() { + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +CUTLASS_DEVICE void warpgroup_fence_operand(float& reg) { + asm volatile("" : "+f"(reg) :: "memory"); +} + +template +CUTLASS_DEVICE void warpgroup_wait() { + DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +} + +} // namespace deep_gemm::ptx diff --git a/deep-gemm/deep_gemm/include/deep_gemm/scheduler/gemm.cuh b/deep-gemm/deep_gemm/include/deep_gemm/scheduler/gemm.cuh new file mode 100644 index 00000000..5cd50c66 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/scheduler/gemm.cuh @@ -0,0 +1,300 @@ +#pragma once + +#include +#include + +namespace deep_gemm::sched { + +enum class IndexType { + MN, + K, + SF_K, +}; + +template +static constexpr uint32_t get_num_1d_blocks_per_group() { + // Select the best from candidates + uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits::max(); + for (const auto candidate: {8u, 16u}) { + const auto usage = kIsMulticastOnA ? + candidate * BLOCK_N + math::constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + math::constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + if (usage < min_usage) + min_usage = usage, num_best_blocks = candidate; + } + return num_best_blocks; +} + +#pragma clang diagnostic push +#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" +template ()> +struct Scheduler { + int current_iter = -1; + + // Block configs + uint32_t num_blocks; + uint32_t num_m_blocks; + uint32_t num_n_blocks; + + // For SM90 multicast checks + uint32_t num_blocks_in_group; + bool is_peer_cta_alive = true; + + // For grouped GEMM + int* grouped_layout; + uint32_t current_group_idx = 0; + // Only used for masked layout + uint32_t current_m_cumsum = 0; + // Only used for contiguous psum layout + uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0; + // Only used for k-grouped layout + uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0; + uint32_t next_group_idx, next_shape_k; + + // Only used for k-grouped gemm + CUTLASS_DEVICE void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const { + for (; group_idx < kNumGroups; ++ group_idx) { + shape_k = grouped_layout[group_idx]; + if (shape_k > 0) + break; + } + } + + // ReSharper disable once CppPossiblyUninitializedMember + CUTLASS_DEVICE explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, + const uint32_t& shape_k, int* grouped_layout = nullptr) { + num_m_blocks = math::ceil_div(shape_m, BLOCK_M); + num_n_blocks = math::ceil_div(shape_n, BLOCK_N); + current_shape_k = shape_k; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + num_blocks = num_m_blocks * num_n_blocks; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + this->grouped_layout = grouped_layout; + current_psum_m = grouped_layout[0]; + num_m_blocks = math::ceil_div(current_psum_m, BLOCK_M); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + get_next_k_group(current_group_idx, current_shape_k); + next_group_idx = current_group_idx + 1; + get_next_k_group(next_group_idx, next_shape_k); + } + } + + CUTLASS_DEVICE void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + const auto primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks; + const auto secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks; + const auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; + const auto group_idx = block_idx / num_blocks_per_group; + auto first_block_idx = group_idx * kNum1DBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); + + // Fix unaligned TMA multicast + // NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast + // while SM100 uses 2-CTA, which can not be dynamically disabled +#if __CUDA_ARCH__ < 1000 + if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) { + if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { + num_blocks_in_group = num_blocks_in_group ^ 1; + } else { + in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; + first_block_idx += num_blocks_in_group ^ 1; + num_blocks_in_group = 1; + } + } +#endif + + // Convert to final M/N block indices + // `kIsMulticastOnA == true` leads to groups on N + if constexpr (kIsMulticastOnA) { + m_block_idx = in_group_idx / num_blocks_in_group; + n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + } else { + m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + n_block_idx = in_group_idx / num_blocks_in_group; + } + } + + template + CUTLASS_DEVICE uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx = 0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + const auto offset = kWithGroupOffset ? cute::max(0, grouped_layout[m_block_idx * BLOCK_M]) : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + const auto offset = kWithGroupOffset ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + auto offset = 0; + if constexpr (kWithGroupOffset) { + if constexpr (kIndexType == IndexType::MN) + offset = current_group_idx * shape_dim; + else if constexpr (kIndexType == IndexType::K) + offset = current_k_cumsum; + else if constexpr (kIndexType == IndexType::SF_K) + offset = current_sf_k_cumsum; + } + return offset + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::Batched) { + // Ignore kWithGroupOffset, and apply offset for IndexType::SF_K + const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } + } + + // For swap A/B and psum layout only + CUTLASS_DEVICE uint32_t get_aligned_effective_m_in_block(const uint32_t& m_block_idx) const { + constexpr uint32_t UMMA_STEP_N = 16; + DG_STATIC_ASSERT(BLOCK_M % UMMA_STEP_N == 0, "Invalid alignment"); + if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) + return math::align(m_block_idx == last_psum_m / BLOCK_M + num_m_blocks - 1 ? current_psum_m - m_block_idx * BLOCK_M : BLOCK_M, UMMA_STEP_N); + return BLOCK_M; + } + + CUTLASS_DEVICE bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x; + + if constexpr (kGemmType == GemmType::MGroupedMasked) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + num_m_blocks = math::ceil_div(static_cast(grouped_layout[current_group_idx]), BLOCK_M); + const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * num_n_blocks) + break; + + // Move to check the next group + current_group_idx ++, current_m_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + while (true) { + // Within current group + if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks) + break; + + // Move to check the next group + if (++ current_group_idx == kNumGroups) + return false; + + // NOTES: `num_m_blocks` varies with the increase of the group index + last_psum_m = math::align(current_psum_m, BLOCK_M); + current_psum_m = grouped_layout[current_group_idx]; + current_m_block_cumsum += num_m_blocks; + num_m_blocks = math::ceil_div(current_psum_m - last_psum_m, BLOCK_M); + } + + get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx); + + // NOTES: `last_psum_m` is aligned with block M + m_block_idx += last_psum_m / BLOCK_M; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + if (next_block_idx < (current_num_valid_groups + 1) * num_blocks) + break; + + // Move to check the next group + current_k_cumsum += current_shape_k; + current_sf_k_cumsum += math::ceil_div(current_shape_k, SF_K_ALIGNMENT); + current_num_valid_groups ++; + + current_group_idx = next_group_idx ++; + current_shape_k = next_shape_k; + get_next_k_group(next_group_idx, next_shape_k); + } + + get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::Batched) { + if (next_block_idx >= num_blocks * kNumGroups) + return false; + + current_group_idx = next_block_idx / num_blocks; + const auto block_idx = next_block_idx - current_group_idx * num_blocks; + if constexpr (kIsMulticastOnA) { + m_block_idx = block_idx / num_n_blocks; + n_block_idx = block_idx % num_n_blocks; + } else { + m_block_idx = block_idx % num_m_blocks; + n_block_idx = block_idx / num_m_blocks; + } + } else { + if (next_block_idx >= num_blocks) + return false; + + // For SM90 only + // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned + is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass) + num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass) + (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound + get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx); + } + return true; + } + + // For SM90 only + CUTLASS_DEVICE bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { + if (num_blocks_in_group == 1) + return false; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or + kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + return true; + } else { + DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type"); + if constexpr (kIsMulticastOnA) { + return true; + } else { + const auto group_idx = grouped_layout[m_block_idx * BLOCK_M]; + const auto peer_group_idx = grouped_layout[(m_block_idx ^ 1) * BLOCK_M]; + return group_idx == peer_group_idx; + } + } + } + + // For SM90 only + // ReSharper disable once CppNotAllPathsReturnValue + CUTLASS_DEVICE bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + return true; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + return grouped_layout[m_offset + m_block_idx * BLOCK_M] >= 0; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + return m_offset + m_block_idx * BLOCK_M < grouped_layout[current_group_idx]; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + return m_offset + m_block_idx * BLOCK_M < current_psum_m; + } else { + // Unreachable + DG_TRAP_ONLY_DEVICE_ASSERT(false); + } + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm::sched diff --git a/deep-gemm/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh b/deep-gemm/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh new file mode 100644 index 00000000..cdbecccd --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh @@ -0,0 +1,221 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace deep_gemm::sched { + +// Computation phase for the current block +enum class BlockPhase { + None = 0, + Linear1 = 1, + Linear2 = 2 +}; + +template +struct MegaMoEScheduler { + DG_STATIC_ASSERT(L1_SHAPE_N % BLOCK_N == 0, "Invalid shape"); + DG_STATIC_ASSERT(L2_SHAPE_N % BLOCK_N == 0, "Invalid shape"); + DG_STATIC_ASSERT(L1_SHAPE_K % BLOCK_K == 0, "Invalid shape"); + DG_STATIC_ASSERT(L2_SHAPE_K % BLOCK_K == 0, "Invalid shape"); + DG_STATIC_ASSERT(kNumExpertsPerRank % kNumExpertsPerWave == 0, "Invalid wave config"); + + // NOTES: N block counts must be even so that 2 adjacent CTAs in a cluster + // always land on the same m_block_idx with n_block_idx differing by 1 + DG_STATIC_ASSERT(kNumSMs % 2 == 0, "Number of SMs must be even for 2-CTA cluster"); + DG_STATIC_ASSERT(kNumL1BlockNs % 2 == 0, "L1 N block count must be even for 2-CTA cluster"); + DG_STATIC_ASSERT(kNumL2BlockNs % 2 == 0, "L2 N block count must be even for 2-CTA cluster"); + + // Arrival counts + const layout::Workspace& workspace; + + // Scheduler state + BlockPhase next_phase = BlockPhase::Linear1; + + // Current expert and block indices + uint32_t current_local_expert_idx = 0; + uint32_t current_num_tokens = 0; + uint32_t current_pool_block_offset = 0; + uint32_t block_idx = 0; + uint32_t m_block_idx = 0; + uint32_t n_block_idx = 0; + + // Pre-cached per-expert token counts (filled during `for_each_block` init) + // Layout: `stored_num_tokens_per_expert[i]` holds expert (i * 32 + lane_idx)'s count + uint32_t stored_num_tokens_per_expert[kNumExpertsPerLane] = {}; + + CUTLASS_DEVICE explicit MegaMoEScheduler(const layout::Workspace& workspace): workspace(workspace) { + block_idx = blockIdx.x; + } + + CUTLASS_DEVICE uint32_t get_wave_expert_end_idx() const { + return math::align(current_local_expert_idx + 1, kNumExpertsPerWave); + } + + CUTLASS_DEVICE uint32_t get_num_tokens(const uint32_t& expert_idx) const { + uint32_t valid_value; + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + valid_value = (expert_idx == i * 32 + ptx::get_lane_idx()) ? + stored_num_tokens_per_expert[i] : valid_value; + } + return ptx::exchange(valid_value, expert_idx % 32); + } + + // Get pool block offset for a given expert index from a per-lane token count array + CUTLASS_DEVICE uint32_t get_pool_block_offset(const uint32_t& expert_idx) { + uint32_t num_blocks = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + if (i * 32 + ptx::get_lane_idx() < expert_idx) + num_blocks += math::ceil_div(stored_num_tokens_per_expert[i], BLOCK_M); + } + return __reduce_add_sync(0xffffffff, num_blocks); + } + + CUTLASS_DEVICE void advance_expert_idx() { + current_pool_block_offset += get_current_num_m_blocks(); + current_local_expert_idx += 1; + current_num_tokens = get_num_tokens(current_local_expert_idx); + } + + CUTLASS_DEVICE void set_expert_idx(const uint32_t& expert_idx) { + current_local_expert_idx = expert_idx; + current_num_tokens = get_num_tokens(expert_idx); + current_pool_block_offset = get_pool_block_offset(expert_idx); + } + + CUTLASS_DEVICE uint32_t get_current_pool_block_offset() const { + return current_pool_block_offset; + } + + CUTLASS_DEVICE uint32_t get_current_num_m_blocks() const { + return math::ceil_div(current_num_tokens, BLOCK_M); + } + + template + CUTLASS_DEVICE uint32_t get_valid_m() const { + const auto m = cute::min(current_num_tokens - m_block_idx * BLOCK_M, BLOCK_M); + return kDoUMMAAligned ? math::align(m, 16u) : m; + } + + CUTLASS_DEVICE bool fetch_next_l1_block() { + const auto wave_end_expert_idx = get_wave_expert_end_idx(); + while (current_local_expert_idx < wave_end_expert_idx) { + const auto num_m_blocks = get_current_num_m_blocks(); + m_block_idx = block_idx / kNumL1BlockNs; + if (m_block_idx < num_m_blocks) + return true; + + // Current expert is fully assigned, move to the next + block_idx -= num_m_blocks * kNumL1BlockNs; + advance_expert_idx(); + } + return false; + } + + CUTLASS_DEVICE bool fetch_next_l2_block() { + const auto wave_end_expert_idx = get_wave_expert_end_idx(); + while (current_local_expert_idx < wave_end_expert_idx) { + const auto num_m_blocks = get_current_num_m_blocks(); + if (block_idx < num_m_blocks * kNumL2BlockNs) { + m_block_idx = block_idx / kNumL2BlockNs; + return true; + } + + // Current expert is fully assigned, move to the next + block_idx -= num_m_blocks * kNumL2BlockNs; + advance_expert_idx(); + } + return false; + } + + // Core state machine: assigns the next block + CUTLASS_DEVICE cute::tuple get_next_block() { + while (true) { + if (current_local_expert_idx >= kNumExpertsPerRank) + break; + + if (next_phase == BlockPhase::Linear1) { + if (fetch_next_l1_block()) { + // Found a new L1 block + n_block_idx = block_idx - m_block_idx * kNumL1BlockNs; + // Jump to next block + block_idx += kNumSMs; + return {BlockPhase::Linear1, current_local_expert_idx, m_block_idx, n_block_idx}; + } else { + // L1 for the current wave is complete, transition to L2 + next_phase = BlockPhase::Linear2; + set_expert_idx(math::align(current_local_expert_idx - 1, kNumExpertsPerWave)); + } + } else { + if (fetch_next_l2_block()) { + // Found a new L2 block + n_block_idx = block_idx - m_block_idx * kNumL2BlockNs; + // Jump to next block + block_idx += kNumSMs; + return {BlockPhase::Linear2, current_local_expert_idx, m_block_idx, n_block_idx}; + } else { + // Move to L1 of the next wave + next_phase = BlockPhase::Linear1; + } + } + } + + // All waves and experts are fully processed + return {BlockPhase::None, 0, 0, 0}; + } + + CUTLASS_DEVICE void fetch_expert_recv_count() { + // NOTES: each lane caches experts at indices (i * 32 + lane_idx) + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + const auto expert_idx = i * 32 + ptx::get_lane_idx(); + uint64_t value = 0; + if (expert_idx < kNumExpertsPerRank) { + do { + value = ptx::ld_volatile(workspace.get_expert_recv_count_sum_ptr(expert_idx)); + } while (static_cast(value >> 32) != kNumSMs * kNumRanks); + } + stored_num_tokens_per_expert[i] = static_cast(value); + } + __syncwarp(); + } + + template + CUTLASS_DEVICE void for_each_block(Func&& func) { + // Wait for all expert counters to be finalized + fetch_expert_recv_count(); + + // Initialize current expert with 0 + set_expert_idx(0); + + // Iterate over all blocks + // TODO: add swizzle within expert waves for better L2 cache utilization + while (true) { + CUTE_TIE_DECL(get_next_block(), block_phase, current_local_expert_idx, m_block_idx, n_block_idx); + if (block_phase == BlockPhase::None) + break; + + func(block_phase, current_local_expert_idx, + block_phase == BlockPhase::Linear2 ? kNumL2BlockKs : kNumL1BlockKs, + m_block_idx, n_block_idx); + } + } +}; + +} // namespace deep_gemm::sched diff --git a/deep-gemm/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh b/deep-gemm/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh new file mode 100644 index 00000000..548bbbc6 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh @@ -0,0 +1,239 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm::sched { + +template +CUTLASS_GLOBAL __launch_bounds__(32, 1) +void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d, + const uint32_t* context_lens, const uint32_t* indices, uint32_t* schedule_metadata) { + DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + __shared__ uint32_t varlen_atom_token_start[kAlignedBatchSize]; + __shared__ uint32_t varlen_atom_context_len[kAlignedBatchSize]; + __shared__ uint32_t varlen_num_atoms_shared; + uint32_t num_items; + + if constexpr (kIsVarlen) { + if (lane_idx == 0) { + uint32_t t = 0, atom_count = 0; + while (t < batch_size) { + varlen_atom_token_start[atom_count] = t; + const bool is_paired = (t + 1 < batch_size and indices[t] == indices[t + 1]); + varlen_atom_context_len[atom_count] = is_paired ? context_lens[t + 1] : context_lens[t]; + t += is_paired ? 2 : 1; + ++ atom_count; + } + varlen_num_atoms_shared = atom_count; + } + __syncwarp(); + num_items = varlen_num_atoms_shared; + } else { + num_items = batch_size; + } + + // Compute num_segs and prefix sum + uint32_t num_segs[kAlignedBatchSize / 32]; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + const uint32_t q_idx = k * 32 + lane_idx; + uint32_t context_len; + if constexpr (kIsVarlen) { + context_len = (q_idx < num_items ? varlen_atom_context_len[q_idx] : 0); + } else { + const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); + context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0); + } + num_segs[k] = math::ceil_div(context_len, SPLIT_KV); + } + + __shared__ uint32_t prefix_sum[kAlignedBatchSize]; + uint32_t sum = 0; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + uint32_t x = num_segs[k]; + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const uint32_t y = __shfl_up_sync(0xffffffff, x, offset); + x += (lane_idx >= offset ? y : 0); + } + x += sum; + prefix_sum[k * 32 + lane_idx] = x; + sum = __shfl_sync(0xffffffff, x, 31); + } + + // SM work distribution + if constexpr (kIsVarlen) { + const uint32_t total = sum; + const uint32_t q = total / kNumSMs, r = total % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t lo = 0, hi = num_items; + while (lo < hi) { + const uint32_t mid = (lo + hi) / 2; + const bool pred = prefix_sum[mid] <= seg_starts; + lo = pred ? mid + 1 : lo; + hi = pred ? hi : mid; + } + const uint32_t atom_idx = lo; + const uint32_t kv_split_idx = (atom_idx == 0 ? seg_starts : seg_starts - prefix_sum[atom_idx - 1]); + const uint32_t q_atom_idx = (atom_idx < num_items ? varlen_atom_token_start[atom_idx] : batch_size); + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_atom_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } + } else { + const uint32_t next_n_atom = (next_n >= 2) ? 2 : 1; + const uint32_t num_next_n_atoms = math::ceil_div(next_n, next_n_atom); + const uint32_t total = sum * num_next_n_atoms; + const uint32_t q = total / kNumSMs, r = total % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t lo = 0, hi = batch_size; + while (lo < hi) { + const uint32_t mid = (lo + hi) / 2; + const bool pred = prefix_sum[mid] * num_next_n_atoms <= seg_starts; + lo = pred ? mid + 1 : lo; + hi = pred ? hi : mid; + } + const uint32_t q_idx = lo; + const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms); + const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]); + const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0; + const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0; + const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx; + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_atom_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } + } +} + +// Conditional storage for varlen indices pointer (EBO: zero cost when unused) +template +struct IndicesStorage { + const uint32_t* indices; +}; + +template <> +struct IndicesStorage {}; + +template +struct PagedMQALogitsScheduler : IndicesStorage { + const uint32_t* context_lens; + uint32_t batch_size; + + uint32_t current_q_atom_idx, current_kv_idx; + uint32_t end_q_atom_idx, end_kv_idx; + uint32_t current_num_kv; + + CUTLASS_DEVICE static uint32_t atom_to_token_idx(const uint32_t& q_atom_idx) { + if constexpr (kIsVarlen) { + return q_atom_idx; + } else { + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + if constexpr (kPadOddN) { + return q_atom_idx / kNumNextNAtoms * kNextN + q_atom_idx % kNumNextNAtoms * kNextNAtom; + } else { + return q_atom_idx * kNextNAtom; + } + } + } + + CUTLASS_DEVICE static uint32_t atom_to_block_table_row(const uint32_t& q_atom_idx) { + if constexpr (kIsVarlen) { + return q_atom_idx; + } else { + return q_atom_idx / kNumNextNAtoms; + } + } + + CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const { + if constexpr (kIsVarlen) { + const bool is_paired = (q_atom_idx + 1 < batch_size and + this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]); + const uint32_t ctx_len = is_paired ? context_lens[q_atom_idx + 1] : context_lens[q_atom_idx]; + return math::ceil_div(ctx_len, BLOCK_KV); + } else { + const uint32_t q_idx = q_atom_idx / kNumNextNAtoms; + const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); + return math::ceil_div(context_lens[lens_idx], BLOCK_KV); + } + } + + CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t& batch_size, + const uint32_t* context_lens, + const uint32_t* schedule_meta, const uint32_t* indices) { + this->context_lens = context_lens; + this->batch_size = batch_size; + if constexpr (kIsVarlen) { + this->indices = indices; + } + + const auto current_pack = reinterpret_cast(schedule_meta)[sm_idx]; + const auto end_pack = reinterpret_cast(schedule_meta)[sm_idx + 1]; + current_q_atom_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit; + end_q_atom_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit; + + current_num_kv = get_num_kv(current_q_atom_idx); + } + + // Advance step in q_atom_idx space when moving to the next atom. + // Varlen: 1 or 2 depending on whether consecutive tokens share the same sequence. + // Non-varlen: always 1 (one atom unit). + CUTLASS_DEVICE uint32_t get_atom_advance(const uint32_t& q_atom_idx, const uint32_t& bound) const { + if constexpr (kIsVarlen) { + return (q_atom_idx + 1 < bound and this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]) ? 2 : 1; + } else { + return 1; + } + } + + // Whether num_kv should be refreshed after advancing to q_atom_idx. + // Varlen: always refresh (each atom may have a different context_len). + // Non-varlen: only at atom-group boundaries (atoms within a group share context_len). + CUTLASS_DEVICE bool should_refresh_num_kv(const uint32_t& q_atom_idx) const { + if constexpr (kIsVarlen) { + return true; + } else { + return q_atom_idx % kNumNextNAtoms == 0; + } + } + + CUTLASS_DEVICE bool fetch_next_task(uint32_t &q_atom_idx, uint32_t &kv_idx, uint32_t &num_kv) { + q_atom_idx = current_q_atom_idx; + kv_idx = current_kv_idx; + num_kv = current_num_kv; + + if (current_q_atom_idx == end_q_atom_idx and current_kv_idx == end_kv_idx) + return false; + + current_kv_idx += kNumBlocksPerSplit; + if (current_kv_idx >= current_num_kv) { + current_kv_idx = 0; + current_q_atom_idx += get_atom_advance(current_q_atom_idx, end_q_atom_idx); + if (should_refresh_num_kv(current_q_atom_idx) and exist_q_atom_idx(current_q_atom_idx)) { + current_num_kv = get_num_kv(current_q_atom_idx); + } + } + return true; + } + + CUTLASS_DEVICE bool exist_q_atom_idx(const uint32_t& q_atom_idx) const { + return q_atom_idx < end_q_atom_idx or (q_atom_idx == end_q_atom_idx and 0 < end_kv_idx); + } +}; + +} // namespace deep_gemm::sched diff --git a/deep-gemm/deep_gemm/legacy/a_fused_m_grouped_gemm.py b/deep-gemm/deep_gemm/legacy/a_fused_m_grouped_gemm.py index 3f1f5294..41b35d53 100644 --- a/deep-gemm/deep_gemm/legacy/a_fused_m_grouped_gemm.py +++ b/deep-gemm/deep_gemm/legacy/a_fused_m_grouped_gemm.py @@ -47,7 +47,7 @@ def a_fused_m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, # Compute acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, K, BLOCK_SIZE_K): - k_range = (k + tl.arange(0, BLOCK_SIZE_K)).to(tl.int64) + k_range = k.to(tl.int64) + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) k_mask = k_range < K a_ptrs = a_ptr + rows[:, None] * K + k_range[None, :] b_ptrs = b_ptr + batch_id * K * N + k_range[:, None] * (1 if IS_B_K_MAJOR else N) + n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1) diff --git a/deep-gemm/deep_gemm/legacy/b_fused_k_grouped_gemm.py b/deep-gemm/deep_gemm/legacy/b_fused_k_grouped_gemm.py index a642204b..7df8741f 100644 --- a/deep-gemm/deep_gemm/legacy/b_fused_k_grouped_gemm.py +++ b/deep-gemm/deep_gemm/legacy/b_fused_k_grouped_gemm.py @@ -50,7 +50,7 @@ def b_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, # Compute acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(k_start, k_end, BLOCK_SIZE_K): - k_range = (k + tl.arange(0, BLOCK_SIZE_K)).to(tl.int64) + k_range = k.to(tl.int64) + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) rows = tl.load(k_indices_ptr + k_range).to(tl.int64) a_ptrs = a_ptr + m_range[:, None] + k_range[None, :] * M b_ptrs = b_ptr + rows[:, None] * N + n_range[None, :] diff --git a/deep-gemm/deep_gemm/mega/__init__.py b/deep-gemm/deep_gemm/mega/__init__.py new file mode 100644 index 00000000..670b409d --- /dev/null +++ b/deep-gemm/deep_gemm/mega/__init__.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import torch +from typing import Tuple, Optional +from ..utils.math import align + +# noinspection PyBroadException +try: + # noinspection PyProtectedMember + import torch.distributed._symmetric_memory as symm_mem + import torch.distributed as dist +except Exception as exception: + print(f'Failed to load mega kernels, please check your PyTorch version: {exception}') + +from .. import _C + + +class SymmBuffer: + def __init__(self, group: dist.ProcessGroup, + # MoE arguments + num_experts: int, + num_max_tokens_per_rank: int, num_topk: int, + hidden: int, intermediate_hidden: int, + use_fp8_dispatch: bool = True, + activation: str = 'swiglu'): + self.group = group + self.num_experts = num_experts + self.num_max_tokens_per_rank = num_max_tokens_per_rank + self.num_topk = num_topk + self.hidden = hidden + self.intermediate_hidden = intermediate_hidden + + # Allocate a symmetric buffer + num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe( + group.size(), num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + use_fp8_dispatch, activation + ) + self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda') + self.handle = symm_mem.rendezvous(self.buffer, group=group) + self.buffer.zero_() + self.group.barrier() + torch.cuda.synchronize() + + # Create input buffer views + (self.x, self.x_sf, + self.topk_idx, self.topk_weights, + self.l1_acts, self.l1_acts_sf, + self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer) + + def destroy(self): + self.handle = None + self.buffer = None + self.group = None + self.x = None + self.x_sf = None + + +def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup, + num_experts: int, + num_max_tokens_per_rank: int, num_topk: int, + hidden: int, intermediate_hidden: int, + use_fp8_dispatch: bool = True, + activation: str = 'swiglu') -> SymmBuffer: + # Token count must be aligned to block sizes + num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe()) + + return SymmBuffer( + group, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + use_fp8_dispatch, activation + ) + + +def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + # [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up] + def interleave(t, gran: int = 8) -> torch.Tensor: + g, n, *rest = t.shape + half = n // 2 + gate = t[:, :half].reshape(g, half // gran, gran, *rest) + up = t[:, half:].reshape(g, half // gran, gran, *rest) + return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest)) + + return interleave(l1_weights[0]), interleave(l1_weights[1]) + + +def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: + num_groups, mn, packed_sf_k = sf.shape + assert sf.dtype == torch.int and mn % 128 == 0 + result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k) + .transpose(2, 3) + .reshape(num_groups, mn, packed_sf_k)) + return torch.empty_like(sf).copy_(result) + + +def transform_weights_for_mega_moe( + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor] +) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + # L1: interleave gate/up, then transpose SF for UTCCP + l1_interleaved = _interleave_l1_weights(l1_weights) + l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1])) + # L2: only transpose SF for UTCCP + l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1])) + return l1_weights, l2_weights + + +def fp8_fp4_mega_moe(y: torch.Tensor, + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor], + sym_buffer: SymmBuffer, + cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, + recipe: Tuple[int, int, int] = (1, 1, 32), + activation: str = 'swiglu', + activation_clamp: Optional[float] = None, + fast_math: bool = True): + _C.fp8_fp4_mega_moe( + y, + l1_weights, l2_weights, + cumulative_local_expert_recv_stats, + sym_buffer.buffer, + sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), + sym_buffer.num_max_tokens_per_rank, + sym_buffer.num_experts, sym_buffer.num_topk, + recipe, + activation, activation_clamp, + fast_math + ) diff --git a/deep-gemm/deep_gemm/testing/bench.py b/deep-gemm/deep_gemm/testing/bench.py index 2c752da2..552b9aa1 100644 --- a/deep-gemm/deep_gemm/testing/bench.py +++ b/deep-gemm/deep_gemm/testing/bench.py @@ -1,6 +1,7 @@ import os import sys import torch +from typing import Callable, Optional def bench(fn, num_warmups: int = 5, num_tests: int = 10, @@ -78,7 +79,8 @@ def __exit__(self, *_): def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, trace_path: str = None, flush_l2: bool = True, - with_multiple_kernels: bool = False): + with_multiple_kernels: bool = False, + barrier: Optional[Callable] = None): assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) is_tuple = isinstance(kernel_names, tuple) @@ -96,14 +98,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, # Profile suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress with suppress(): - schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) - profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) + schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) + profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule, acc_events=True) with profiler: for i in range(2): for _ in range(num_tests): if flush_l2: torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() + if barrier is not None: + # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead + # noinspection PyProtectedMember + torch.cuda._sleep(int(2e7)) # ~10ms + barrier() fn() + torch.cuda.synchronize() profiler.step() # Parse the profiling table @@ -111,7 +120,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names if not with_multiple_kernels: for name in kernel_names: - assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table' + assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table {prof_lines}' # Save chrome traces if trace_path is not None: diff --git a/deep-gemm/deep_gemm/utils/__init__.py b/deep-gemm/deep_gemm/utils/__init__.py index e8f859a2..a0dc6f78 100644 --- a/deep-gemm/deep_gemm/utils/__init__.py +++ b/deep-gemm/deep_gemm/utils/__init__.py @@ -1,3 +1,4 @@ from . import math, layout from .layout import * from .math import * +from .dist import init_dist, uneven_all_gather diff --git a/deep-gemm/deep_gemm/utils/dist.py b/deep-gemm/deep_gemm/utils/dist.py new file mode 100644 index 00000000..426c3967 --- /dev/null +++ b/deep-gemm/deep_gemm/utils/dist.py @@ -0,0 +1,74 @@ +import inspect +import os +import torch +import torch.distributed as dist +from typing import Tuple + +_local_rank = None + + +def init_dist(local_rank: int, num_local_ranks: int) -> Tuple[int, int, dist.ProcessGroup]: + # NOTES: you may rewrite this function with your own cluster settings + ip = os.getenv('MASTER_ADDR', '127.0.0.1') + port = int(os.getenv('MASTER_PORT', '8361')) + num_nodes = int(os.getenv('WORLD_SIZE', 1)) + node_rank = int(os.getenv('RANK', 0)) + + # Set local rank + global _local_rank + _local_rank = local_rank + + sig = inspect.signature(dist.init_process_group) + params = { + 'backend': 'nccl', + 'init_method': f'tcp://{ip}:{port}', + 'world_size': num_nodes * num_local_ranks, + 'rank': node_rank * num_local_ranks + local_rank, + } + if 'device_id' in sig.parameters: + # noinspection PyTypeChecker + params['device_id'] = torch.device(f'cuda:{local_rank}') + dist.init_process_group(**params) + torch.set_default_device('cuda') + torch.cuda.set_device(local_rank) + + return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes))) + + +def uneven_all_gather(tensor: torch.Tensor, dim: int = 0, group: dist.ProcessGroup = None) -> torch.Tensor: + world_size = dist.get_world_size(group) + + # Exchange sizes + local_dim_size = torch.tensor([tensor.shape[dim]], device=tensor.device, dtype=torch.long) + all_dim_sizes = [torch.zeros_like(local_dim_size) for _ in range(world_size)] + dist.all_gather(all_dim_sizes, local_dim_size, group=group) + all_dim_sizes = [s.item() for s in all_dim_sizes] + max_dim_size = max(all_dim_sizes) + + # Pad + if tensor.shape[dim] < max_dim_size: + pad_shape = list(tensor.shape) + pad_shape[dim] = max_dim_size - tensor.shape[dim] + padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device) + tensor_padded = torch.cat([tensor, padding], dim=dim) + else: + tensor_padded = tensor.contiguous() + + # All-gather + gathered = [torch.zeros_like(tensor_padded) for _ in range(world_size)] + dist.all_gather(gathered, tensor_padded, group=group) + + # Remove padding + trimmed = [ + torch.narrow(gathered[i], dim, 0, all_dim_sizes[i]) + for i in range(world_size) + ] + return torch.cat(trimmed, dim=dim) + + +def dist_print(s: str = '', once_in_node: bool = False) -> None: + global _local_rank + assert _local_rank is not None + if not once_in_node or _local_rank == 0: + print(s, flush=True) + dist.barrier() diff --git a/deep-gemm/deep_gemm/utils/layout.py b/deep-gemm/deep_gemm/utils/layout.py index 790e0d66..6512c5ab 100644 --- a/deep-gemm/deep_gemm/utils/layout.py +++ b/deep-gemm/deep_gemm/utils/layout.py @@ -10,7 +10,11 @@ pass # Valid for all CUDA versions -from .._C import get_mk_alignment_for_contiguous_layout +from .._C import ( + set_mk_alignment_for_contiguous_layout, + get_mk_alignment_for_contiguous_layout, + get_theoretical_mk_alignment_for_contiguous_layout, +) # Some alias get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout diff --git a/deep-gemm/deep_gemm/utils/math.py b/deep-gemm/deep_gemm/utils/math.py index c65026e5..f1582ed5 100644 --- a/deep-gemm/deep_gemm/utils/math.py +++ b/deep-gemm/deep_gemm/utils/math.py @@ -11,21 +11,30 @@ def align(x: int, y: int) -> int: def ceil_to_ue8m0(x: torch.Tensor): - assert x.view(-1).amax().item() > 0 - return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + bits = x.abs().float().view(torch.int) + exp = ((bits >> 23) & 0xFF) + (bits & 0x7FFFFF).bool().int() + return (exp.clamp(1, 254) << 23).view(torch.float) -def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: +def pack_ue8m0_to_int(x: torch.Tensor): + assert x.dtype == torch.float and x.size(-1) % 4 == 0 + assert (x.view(torch.int) & ((1 << 23) - 1) == 0).all() + return (x.view(torch.int) >> 23).to(torch.uint8).view(torch.int) + + +def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128, + use_packed_ue8m0: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape padded_n = align(n, gran_k) x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0) x_padded[:, :n] = x - x_view = x_padded.view(m, -1, gran_k) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + x_view = x_padded.view(m, padded_n // gran_k, gran_k) + x_amax = x_view.abs().float().amax(dim=2).view(m, padded_n // gran_k).clamp(1e-4) sf = x_amax / 448.0 sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf - return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf + x_fp8 = (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous() + return x_fp8, pack_ue8m0_to_int(sf) if use_packed_ue8m0 else sf def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: @@ -70,13 +79,14 @@ def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: code = idx.to(torch.uint8) sign = (x < 0) & (idx != 0) code = code | (sign.to(torch.uint8) << 3) - return code # uint8, 0..15 + return code.view(torch.int8) -def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 +def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128, + use_packed_ue8m0: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: m, n = x.shape assert n % 2 == 0 + assert not use_packed_ue8m0 or use_ue8m0 padded_n = align(n, gran_k) x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device) x_padded[:, :n] = x @@ -85,23 +95,49 @@ def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) - sf = x_amax / 6.0 sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf x_scaled = x_view * (1.0 / sf.unsqueeze(2)) - codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n) + codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # int8, (m, padded_n) codes2 = codes.view(m, padded_n // 2, 2) - packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8 - return packed[:, :n // 2].contiguous(), sf + packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # int8 + return packed[:, :n // 2].contiguous(), pack_ue8m0_to_int(sf) if use_packed_ue8m0 else sf def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor: - assert a.dtype == torch.uint8 + assert a.dtype == torch.int8 assert a.dim() == 2 m, n2 = a.shape n = n2 * 2 assert (m % 2) == 0 lo = a & 0x0F hi = (a >> 4) & 0x0F - codes = torch.empty((m, n), device=a.device, dtype=torch.uint8) + codes = torch.empty((m, n), device=a.device, dtype=torch.int8) codes[:, 0::2], codes[:, 1::2] = lo, hi codes_t = codes.transpose(0, 1).contiguous() codes2 = codes_t.view(n, m // 2, 2) out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) - return out.contiguous() \ No newline at end of file + return out.contiguous() + + +def _dequantize_from_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: + fp4_values = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device=x.device, dtype=torch.float) + sign, value_idx = (x & 0x08) != 0, (x & 0x07).to(torch.int) + value = fp4_values[value_idx] + return torch.where(sign & (value_idx != 0), -value, value) + + +def unpack_ue8m0_from_int(packed_sf: torch.Tensor) -> torch.Tensor: + return (packed_sf.view(torch.uint8).to(torch.int) << 23).view(torch.float) + + +def cast_back_from_fp4(packed: torch.Tensor, sf: torch.Tensor, gran_k: int = 128, + use_packed_ue8m0: bool = False) -> torch.Tensor: + m, n2 = packed.shape + n = n2 * 2 + if use_packed_ue8m0: + sf = unpack_ue8m0_from_int(sf) + unpacked = torch.zeros((m, n), dtype=torch.int8, device=packed.device) + unpacked[:, ::2] = packed & 0x0F + unpacked[:, 1::2] = (packed >> 4) & 0x0F + x_dequantized = _dequantize_from_fp4_e2m1(unpacked) + group_idx = torch.arange(n, device=packed.device) // gran_k + x_restored = x_dequantized * sf[:, group_idx] + return x_restored \ No newline at end of file diff --git a/deep-gemm/scripts/quick_plot_pm.py b/deep-gemm/scripts/quick_plot_pm.py new file mode 100644 index 00000000..3aee8b86 --- /dev/null +++ b/deep-gemm/scripts/quick_plot_pm.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python3 +"""Plot a curated set of NCU PM metrics from an .ncu-rep report. + +Usage: + python scripts/quick_plot_pm.py [report.ncu-rep] + +By default the script saves a PNG next to the report. +With --interactive, it opens a Qt window instead. +""" + +import argparse +import csv +import io +import subprocess +from dataclasses import dataclass + +import matplotlib +import numpy as np + + +@dataclass(frozen=True) +class MetricSpec: + name: str + metric: str + kind: str + category: str + aliases: tuple[str, ...] = () + + +@dataclass(frozen=True) +class ResolvedMetricSpec: + name: str + metric: str + kind: str + category: str + + +@dataclass(frozen=True) +class MetricSeries: + name: str + metric: str + category: str + unit: str + values: tuple[float, ...] + + +CATEGORY_ORDER = [ + "Overview", + "SM", + "L1", + "L2", + "DRAM", + "Interconnect", +] + + +KIND_SUFFIXES = { + "pct_peak": [".avg.pct_of_peak_sustained_elapsed"], + "pct": [".pct", ".avg.pct_of_peak_sustained_elapsed"], + "avg": [".avg"], + "sum": [".sum"], + "avg_per_second": [".avg.per_second"], + "sum_per_second": [".sum.per_second"], + "avg_per_cycle_active": [".avg.per_cycle_active"], + "avg_per_cycle_elapsed": [".avg.per_cycle_elapsed"], + "sum_per_cycle_elapsed": [".sum.per_cycle_elapsed"], +} + + +# Curated from scripts/ncu-metrics.txt, with a few corrections against +# `ncu --query-metrics --chip gb100`: +# - Blocks launched uses `gr__ctas_launched_realtime` +# - SM active cycles uses `sm__cycles_active` +# - L2 throughput for GCC requests uses `lts__t_sector_throughput_srcunit_gcc` +# - C2C throughput uses `ctc__throughput` +# - NVLink RX metrics use the `NVLRX` domain +CURATED_METRICS = [ + MetricSpec("Blocks Launched", "FE_B.TriageCompute.gr__ctas_launched_realtime", "sum_per_cycle_elapsed", "Overview"), + MetricSpec("Average Blocks Active", "TPC.TriageCompute.tpc__ctas_active_realtime", "avg_per_cycle_elapsed", "Overview"), + MetricSpec("Total Blocks Active", "TPC.TriageCompute.tpc__ctas_active_realtime", "sum_per_cycle_elapsed", "Overview"), + MetricSpec("Average CGAs Active", "GPC_B.TriageCompute.gpc__cgas_active_realtime", "avg_per_cycle_elapsed", "Overview"), + MetricSpec("Total CGAs Active", "GPC_B.TriageCompute.gpc__cgas_active_realtime", "sum_per_cycle_elapsed", "Overview"), + MetricSpec("SM Active Cycles", "TPC.TriageCompute.sm__cycles_active", "avg", "SM"), + MetricSpec("Executed IPC Active", "TPC.TriageCompute.sm__inst_executed_realtime", "avg_per_cycle_active", "SM"), + MetricSpec("Executed IPC Elapsed", "TPC.TriageCompute.sm__inst_executed_realtime", "avg_per_cycle_elapsed", "SM"), + MetricSpec("SM Throughput", "TPC.TriageCompute.sm__inst_executed_realtime", "pct_peak", "SM"), + MetricSpec("SM ALU Pipe Throughput", "TPC.TriageCompute.sm__inst_executed_pipe_alu_realtime", "pct_peak", "SM"), + MetricSpec("SM FMA Pipe Throughput", "TPC.TriageCompute.sm__pipe_fma_cycles_active_realtime", "pct_peak", "SM"), + MetricSpec("SM FMA Heavy Pipe Throughput", "TPC.TriageCompute.sm__pipe_fmaheavy_cycles_active_realtime", "pct_peak", "SM"), + MetricSpec("SM FMA Light Pipe Throughput", "TPC.TriageCompute.sm__pipe_fmalite_cycles_active_realtime", "pct_peak", "SM"), + MetricSpec("SM Tensor Pipe Throughput", "TPC.TriageCompute.sm__pipe_tensor_cycles_active_realtime", "pct_peak", "SM"), + MetricSpec("SM TMEM Pipe Throughput", "SM_A.TriageCompute.sm__mem_tensor_cycles_active_realtime", "pct_peak", "SM"), + MetricSpec("SM Uniform Pipe Throughput", "SM_A.TriageCompute.sm__inst_executed_pipe_uniform_realtime", "pct_peak", "SM"), + MetricSpec("SM XU Pipe Throughput", "SM_A.TriageCompute.sm__inst_executed_pipe_xu_realtime", "pct_peak", "SM"), + MetricSpec("L1 Throughput", "SM_A.TriageCompute.l1tex__throughput", "pct_peak", "L1"), + MetricSpec("L1 Sectors", "SM_B.TriageCompute.l1tex__t_sectors", "sum", "L1"), + MetricSpec("L1 Hit Rate", "SM_B.TriageCompute.l1tex__t_sector_hit_rate", "pct", "L1"), + MetricSpec("L1 Lookup Hit", "SM_B.TriageCompute.l1tex__t_sectors_lookup_hit", "sum", "L1"), + MetricSpec("L1 Lookup Miss", "SM_B.TriageCompute.l1tex__t_sectors_lookup_miss", "sum", "L1"), + MetricSpec("L1 Wavefronts (Data)", "SM_A.TriageCompute.l1tex__data_pipe_lsu_wavefronts", "avg", "L1"), + MetricSpec("L1 Wavefronts (LGDS)", "SM_A.TriageCompute.l1tex__data_pipe_lsu_wavefronts_mem_lgds", "avg", "L1"), + MetricSpec("L1 Wavefronts (Shared)", "SM_A.TriageCompute.l1tex__data_pipe_lsu_wavefronts_mem_shared", "avg", "L1"), + MetricSpec("L2 Throughput", "LTS.TriageCompute.lts__throughput", "pct_peak", "L2"), + MetricSpec("L2 Throughput for L1 Requests", "LTS.TriageCompute.lts__t_sector_throughput_srcunit_tex", "pct_peak", "L2"), + MetricSpec("L2 Throughput for GCC Requests", "LTS.TriageCompute.lts__t_sector_throughput_srcunit_gcc", "pct_peak", "L2"), + MetricSpec("L2 Throughput to DRAM", "LTS.TriageCompute.lts__t_sector_throughput_srcnode_fbp", "pct_peak", "L2"), + MetricSpec("SysL2 Throughput to Peer Memory", "SYSLTS.TriageCompute.syslts__t_sector_throughput_aperture_peer", "pct_peak", "L2"), + MetricSpec("SysL2 Throughput to System Memory", "SYSLTS.TriageCompute.syslts__t_sector_throughput_aperture_sysmem", "pct_peak", "L2"), + MetricSpec("L2 Hit Rate", "LTS.TriageCompute.lts__average_t_sector_hit_rate_realtime", "pct", "L2"), + MetricSpec("L2 Hit Rate From L1", "LTS.TriageCompute.lts__average_t_sector_hit_rate_srcunit_tex_realtime", "pct", "L2"), + MetricSpec("DRAM Frequency", "FBSP.TriageCompute.dram__cycles_elapsed", "avg_per_second", "DRAM"), + MetricSpec("DRAM Throughput", "FBSP.TriageCompute.dram__throughput", "pct_peak", "DRAM"), + MetricSpec("DRAM Read Throughput", "FBSP.TriageCompute.dram__read_throughput", "pct_peak", "DRAM"), + MetricSpec("DRAM Write Throughput", "FBSP.TriageCompute.dram__write_throughput", "pct_peak", "DRAM"), + MetricSpec("C2C Throughput", "TriageCompute.ctc__throughput", "pct_peak", "Interconnect", aliases=("TriageCompute.ctx__throughput",)), + MetricSpec("NVLink Transmitted Throughput", "NVLTX.TriageCompute.nvltx__bytes", "pct_peak", "Interconnect"), + MetricSpec("NVLink Received Throughput", "NVLRX.TriageCompute.nvlrx__bytes", "pct_peak", "Interconnect"), + MetricSpec("NVLink Transmitted Bandwidth", "NVLTX.TriageCompute.nvltx__bytes", "sum_per_second", "Interconnect"), + MetricSpec("NVLink Received Bandwidth", "NVLRX.TriageCompute.nvlrx__bytes", "sum_per_second", "Interconnect"), + MetricSpec("PCIe Throughput", "PCI.TriageCompute.pcie__throughput", "pct_peak", "Interconnect"), + MetricSpec("PCIe Read Bandwidth", "PCI.TriageCompute.pcie__read_bytes", "sum_per_second", "Interconnect"), + MetricSpec("PCIe Write Bandwidth", "PCI.TriageCompute.pcie__write_bytes", "sum_per_second", "Interconnect"), +] + + +def _run_csv_command(command, timeout): + result = subprocess.run(command, capture_output=True, text=True, timeout=timeout) + if result.returncode != 0 and not result.stdout: + return None + reader = csv.reader(io.StringIO(result.stdout)) + return list(reader) + + +def _query_available_metrics(chip): + result = subprocess.run( + ["ncu", "--query-metrics", "--chip", chip], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: + raise RuntimeError(result.stderr.strip() or f"failed to query metrics for chip {chip}") + + metrics = set() + for line in result.stdout.splitlines(): + parts = line.split() + if not parts: + continue + token = parts[0] + if "__" not in token: + continue + metrics.add(token) + return metrics + + +def _metric_candidates(metric): + candidates = [metric] + marker = ".TriageCompute." + if marker in metric: + candidates.append(metric.split(marker, 1)[1]) + return candidates + + +def resolve_metric_specs(chip): + available = _query_available_metrics(chip) + resolved = [] + missing = [] + + for spec in CURATED_METRICS: + candidates = [] + for metric in (spec.metric, *spec.aliases): + candidates.extend(_metric_candidates(metric)) + + actual_metric = next((metric for metric in candidates if metric in available), None) + if actual_metric is None: + missing.append(spec) + continue + + resolved.append(ResolvedMetricSpec(spec.name, actual_metric, spec.kind, spec.category)) + + return resolved, missing + + +def _parse_metric_values(raw): + if not raw or raw == "no data": + return () + + try: + if raw.startswith("(") and raw.endswith(")"): + rest = raw[1:-1] + return tuple(float(v.strip().replace(",", "")) for v in rest.split(";") if v.strip()) + if " (" in raw: + _agg, rest = raw.split(" (", 1) + rest = rest.rstrip(")") + return tuple(float(v.strip().replace(",", "")) for v in rest.split(";") if v.strip()) + return (float(raw.replace(",", "")),) + except ValueError: + return () + + +def _probe_metric_series(report, metric_name): + rows = _run_csv_command( + [ + "ncu", + "--import", + report, + "--page", + "raw", + "--csv", + "--metrics", + metric_name, + "--print-metric-instances", + "values", + ], + timeout=60, + ) + if not rows or len(rows) < 3 or len(rows[0]) <= 11: + return None + + header, units, row = rows[0], rows[1], rows[2] + unit = units[11] if len(units) > 11 else "" + raw = row[11] if len(row) > 11 else "" + values = _parse_metric_values(raw) + return header[11], unit, values + + +def collect_metric_series(report, resolved_specs): + collected = [] + skipped = [] + + for spec in resolved_specs: + series = None + for suffix in KIND_SUFFIXES[spec.kind]: + probe = _probe_metric_series(report, f"{spec.metric}{suffix}") + if probe is None: + continue + full_metric, unit, values = probe + if len(values) > 1: + series = MetricSeries(spec.name, full_metric, spec.category, unit, values) + break + + if series is None: + skipped.append(spec) + continue + + collected.append(series) + + return collected, skipped + + +def _format_value(value): + if value == 0: + return "0" + abs_value = abs(value) + if abs_value >= 1e12: + return f"{value / 1e12:.2f} T" + if abs_value >= 1e9: + return f"{value / 1e9:.2f} G" + if abs_value >= 1e6: + return f"{value / 1e6:.2f} M" + if abs_value >= 1e3: + return f"{value / 1e3:.2f} K" + if abs_value >= 1: + return f"{value:.1f}" + return f"{value:.2f}" + + +def _format_with_unit(value, unit): + if not unit: + return _format_value(value) + return f"{_format_value(value)} {unit}" + + +def plot_pm(report, metrics, save=False): + """Plot curated PM metrics as shared-x subplots in a light theme.""" + import matplotlib.pyplot as plt + from matplotlib.gridspec import GridSpec + + if not metrics: + print("No curated metrics had time-series data in the report.") + return + + bg_fig = "#ffffff" + bg_row = "#f6f8fb" + text_primary = "#1f2937" + text_secondary = "#6b7280" + text_header = "#111827" + grid_color = "#d7deea" + border = "#c7d0dd" + + wave_colors = { + "Overview": "#7c8aa5", + "SM": "#4f87c2", + "L1": "#2f9d8f", + "L2": "#dd8452", + "DRAM": "#c95d63", + "Interconnect": "#8c6bb1", + } + + category_rank = {category: index for index, category in enumerate(CATEGORY_ORDER)} + metrics = sorted(metrics, key=lambda item: (category_rank.get(item.category, 99), item.name)) + + row_h = 0.55 + label_w = 3.6 + plot_w = 14.0 + fig_w = label_w + plot_w + fig_h = row_h * len(metrics) + 0.6 + + fig = plt.figure(figsize=(fig_w, fig_h), facecolor=bg_fig) + gs = GridSpec( + len(metrics), + 1, + figure=fig, + left=label_w / fig_w, + right=0.97, + top=1 - 0.45 / fig_h, + bottom=0.35 / fig_h, + hspace=0.18, + ) + axes = [fig.add_subplot(gs[i, 0]) for i in range(len(metrics))] + + prev_category = None + for idx, metric in enumerate(metrics): + ax = axes[idx] + values = np.array(metric.values) + x = np.arange(len(values)) + wave_color = wave_colors.get(metric.category, "#5b9bd5") + + ax.set_facecolor(bg_row) + ax.fill_between(x, values, alpha=0.35, color=wave_color, linewidth=0) + ax.plot(x, values, linewidth=0.8, color=wave_color) + + ax.set_xlim(0, len(values) - 1) + if metric.unit == "%": + ax.set_ylim(0, 100) + else: + ymax = np.max(values) + ax.set_ylim(0, ymax * 1.15 if ymax > 0 else 1) + + ax.grid(True, axis="both", color=grid_color, linewidth=0.5, alpha=0.85) + ax.tick_params(axis="both", colors=text_secondary, labelsize=6, length=0) + + if idx < len(metrics) - 1: + ax.tick_params(axis="x", labelbottom=False) + else: + ax.set_xlabel("Sample Index", fontsize=8, color=text_secondary) + + ymin_v, ymax_v = ax.get_ylim() + ax.set_yticks([ymin_v, ymax_v]) + ax.set_yticklabels([_format_value(ymin_v), _format_value(ymax_v)], fontsize=6, color=text_secondary) + + peak = np.max(values) + ax.text( + 1.005, + 0.5, + _format_with_unit(peak, metric.unit), + transform=ax.transAxes, + fontsize=7, + color=text_secondary, + va="center", + ha="left", + family="monospace", + ) + + for spine in ax.spines.values(): + spine.set_color(border) + spine.set_linewidth(0.5) + + if metric.category != prev_category: + cat_y = ax.get_position().y1 + 0.008 + fig.text( + 0.005, + cat_y, + f" {metric.category}", + fontsize=8.5, + fontweight="bold", + color=text_header, + va="bottom", + family="sans-serif", + transform=fig.transFigure, + bbox=dict(boxstyle="square,pad=0.15", facecolor="#e9eef5", edgecolor="none"), + ) + prev_category = metric.category + + label_y = (ax.get_position().y0 + ax.get_position().y1) / 2 + fig.text( + label_w / fig_w - 0.012, + label_y, + metric.name, + fontsize=7.5, + color=text_primary, + va="center", + ha="right", + family="sans-serif", + transform=fig.transFigure, + ) + + fig.text( + 0.5, + 1 - 0.15 / fig_h, + f"PM Sampling - {report}", + fontsize=11, + fontweight="bold", + color=text_header, + ha="center", + va="top", + family="sans-serif", + transform=fig.transFigure, + ) + + if save: + out_path = report.replace(".ncu-rep", ".pm_sampling.png") + fig.savefig(out_path, dpi=150, facecolor=bg_fig, bbox_inches="tight", pad_inches=0.2) + print(f"Saved: {out_path}") + plt.close(fig) + else: + plt.show() + + +def main(): + parser = argparse.ArgumentParser(description="NCU PM Sampling plotter") + parser.add_argument("report", nargs="?", default="mega-moe-kk.3.ncu-rep", help="Path to .ncu-rep file") + parser.add_argument("--chip", default="gb100", help="Chip name used for `ncu --query-metrics`") + parser.add_argument("--interactive", action="store_true", help="Open an interactive Qt window instead of saving a PNG") + args = parser.parse_args() + + if args.interactive: + matplotlib.use("QtAgg") + else: + matplotlib.use("Agg") + + resolved_specs, missing_specs = resolve_metric_specs(args.chip) + if missing_specs: + print(f"Skipped {len(missing_specs)} curated metrics not available on {args.chip}.") + for spec in missing_specs: + print(f" missing: {spec.name} -> {spec.metric}") + + metric_series, skipped_specs = collect_metric_series(args.report, resolved_specs) + if skipped_specs: + print(f"Skipped {len(skipped_specs)} curated metrics with no time-series data in {args.report}.") + for spec in skipped_specs: + print(f" no series: {spec.name} -> {spec.metric}") + + plot_pm(args.report, metric_series, save=not args.interactive) + + +if __name__ == "__main__": + main() diff --git a/deep-gemm/scripts/readme_example.py b/deep-gemm/scripts/readme_example.py deleted file mode 100644 index c3915b03..00000000 --- a/deep-gemm/scripts/readme_example.py +++ /dev/null @@ -1,49 +0,0 @@ -# /// script -# dependencies = [ -# "numpy", -# "torch", -# "kernels" -# ] -# /// - - -# CUDA_HOME=/usr/local/cuda-12.9 uv run scripts/readme_example.py -import torch -from kernels import get_local_kernel, get_kernel -from pathlib import Path - -# deep_gemm = get_local_kernel(Path("build"), "deep_gemm") -deep_gemm = get_kernel("drbh/deep-gemm", version=1) - -m, n, k = 256, 1024, 512 -device = "cuda" - -a = torch.randn((m, k), device=device, dtype=torch.bfloat16) -b = torch.randn((n, k), device=device, dtype=torch.bfloat16) -ref = a @ b.T - - -def compare(name, result, ref): - cos = torch.nn.functional.cosine_similarity( - result.float().flatten(), ref.float().flatten(), dim=0 - ) - diff = (result.float() - ref.float()).abs().max().item() - print(f"[{name}] shape: {m}x{n}x{k}, cosine_sim: {cos.item():.6f}, max_diff: {diff:.4f}") - - -# --- cuBLASLt GEMM (works on any GPU) --- -d = torch.empty((m, n), device=device, dtype=torch.bfloat16) -deep_gemm.cublaslt_gemm_nt(a, b, d) -compare("cuBLASLt BF16", d, ref) - -# --- FP8 GEMM (requires SM90+ / Hopper+) --- -arch = torch.cuda.get_device_capability()[0] -if arch >= 9: - # SFA: per-row (1, 128), SFB: per-block (128, 128) — SM90 recipe - a_fp8 = deep_gemm.utils.per_token_cast_to_fp8(a, use_ue8m0=False) - b_fp8 = deep_gemm.utils.per_block_cast_to_fp8(b, use_ue8m0=False) - d_fp8 = torch.empty((m, n), device=device, dtype=torch.bfloat16) - deep_gemm.fp8_gemm_nt(a_fp8, b_fp8, d_fp8) - compare("FP8 1D2D", d_fp8, ref) -else: - print(f"[FP8 GEMM] Skipped: requires SM90+ (Hopper), detected SM{arch}x") diff --git a/deep-gemm/scripts/run_ncu_mega_moe.sh b/deep-gemm/scripts/run_ncu_mega_moe.sh new file mode 100755 index 00000000..4324575c --- /dev/null +++ b/deep-gemm/scripts/run_ncu_mega_moe.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +set -e + +# parse num-processes, output_dir and separate python args +num_processes=8 +output_dir=work +python_args=() +for ((arg_idx = 1; arg_idx <= $#; ++arg_idx)); do + arg="${!arg_idx}" + case "$arg" in + --num-processes) + python_args+=("$arg") + if ((arg_idx < $#)); then + ((arg_idx++)) + num_processes="${!arg_idx}" + python_args+=("$num_processes") + fi + ;; + -h|--help) + echo "Usage: $0 [--num-processes N] [--output DIR] [python args...]" + exit 0 + ;; + --num-processes=*) + num_processes="${arg#*=}" + python_args+=("$arg") + ;; + -o|--output) + if ((arg_idx < $#)); then + ((arg_idx++)) + output_dir="${!arg_idx}" + fi + ;; + --output=*) + output_dir="${arg#*=}" + ;; + *) + python_args+=("$arg") + ;; + esac +done + +echo "Python Args: ${python_args[*]}" +echo "Num Processes: $num_processes" +echo "Output Dir: $output_dir" +mkdir -p $output_dir + +export DG_JIT_WITH_LINEINFO=1 # for source counters + +echo "Warm up JIT cache" +python tests/test_mega_moe.py --ncu-profile-only "${python_args[@]}" + +sleep 2 + +ncu_args=( + --config-file off + --force-overwrite + --kernel-name sm100_fp8_fp4_mega_moe_impl + --import-source yes + --replay-mode application + --section PmSampling + --section SourceCounters + --rule LocalMemoryUsage + --launch-skip 0 + --launch-count 1 + --lockstep-kernel-launch + --communicator tcp + --clock-control none + --pm-sampling-interval 1000 + --pm-sampling-max-passes 1 + --disable-pm-warp-sampling + --communicator-tcp-num-peers "$num_processes" + --kill yes + --app-replay-buffer memory +) + +echo "Run Job" + +for ((i = 0; i < num_processes; ++i)); do + ncu ${ncu_args[@]} -o "${output_dir%/}/mega-moe.$i" \ + python tests/test_mega_moe.py \ + --local-rank-idx=$i \ + --ncu-profile-only \ + "${python_args[@]}" & +done + +echo "Waiting" +wait +echo "Done" diff --git a/deep-gemm/setup.py b/deep-gemm/setup.py index 6199d7c3..c4d74ae9 100644 --- a/deep-gemm/setup.py +++ b/deep-gemm/setup.py @@ -68,7 +68,7 @@ def get_package_version(): cmd = ['git', 'rev-parse', '--short', 'HEAD'] revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() - except (subprocess.CalledProcessError, FileNotFoundError, OSError): + except Exception: revision = '+local' return f'{public_version}{revision}' @@ -172,6 +172,7 @@ def run(self): wheel_url, wheel_filename = get_wheel_url() print(f'Try to download wheel from URL: {wheel_url}') + # noinspection PyBroadException try: with urllib.request.urlopen(wheel_url, timeout=1) as response: with open(wheel_filename, 'wb') as out_file: diff --git a/deep-gemm/tests/generators.py b/deep-gemm/tests/generators.py index ee22e515..989e984e 100644 --- a/deep-gemm/tests/generators.py +++ b/deep-gemm/tests/generators.py @@ -8,7 +8,8 @@ align, ceil_div, per_token_cast_to_fp8, per_channel_cast_to_fp8, per_block_cast_to_fp8, per_token_cast_to_fp4, transpose_packed_fp4, - get_mk_alignment_for_contiguous_layout + get_mk_alignment_for_contiguous_layout, + set_mk_alignment_for_contiguous_layout ) @@ -107,7 +108,7 @@ def get_major_ab(allow_a_mn_major: bool, allow_b_mn_major: bool) -> Generator: def get_psum_layout_usage() -> tuple: - return (False, True) if get_arch_major() == 10 else (False, ) + return True, False def enumerate_normal(dtype: torch.dtype) -> Generator: @@ -168,7 +169,7 @@ def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator: def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator: quant_config_list = QuantConfig.get_list_from_dtype(dtype) max_m = 4096 - m_group_list = [(6, 1024), (32, 192), (32, 50)] + m_group_list = [(32, 192), (6, 1024), (32, 20), (6, 20)] n_k_list = [(6144, 7168), (7168, 3072), (4096, 4096), (4096, 2048)] for kernel_type in get_kernel_types(dtype): for quant_config in quant_config_list: @@ -182,6 +183,7 @@ def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator: def enumerate_k_grouped_contiguous(dtype: torch.dtype): + gran_k_list = (128, ) if get_arch_major() == 9 else (32, 128) # Only K-major is supported for SM90 FP8 major_a, major_b = (MajorTypeAB.KMajor, MajorTypeAB.KMajor) if get_arch_major() == 9 and dtype == torch.float8_e4m3fn \ else (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) @@ -189,26 +191,36 @@ def enumerate_k_grouped_contiguous(dtype: torch.dtype): for num_groups, m, n, expected_k_per_group in (( 4, 4096, 7168, 8192), ( 4, 7168, 2048, 8192), # EP64 ( 8, 4096, 7168, 4096), ( 8, 7168, 2048, 4096), # EP32 (16, 4096, 7168, 2048), (16, 7168, 2048, 2048)): # EP16 - ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)] - yield num_groups, m, n, major_a, major_b, ks, expected_k_per_group + if dtype == torch.bfloat16: + ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)] + yield num_groups, m, n, major_a, major_b, ks, expected_k_per_group + else: + for gran_k in gran_k_list: + set_mk_alignment_for_contiguous_layout(gran_k) + ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), gran_k) for _ in range(num_groups)] + yield num_groups, m, n, major_a, major_b, ks, expected_k_per_group, gran_k def enumerate_sf_layout(): + gran_k_list = (128, ) if get_arch_major() == 9 else (32, 128) for use_ue8m0 in (False, True): for with_transpose in (True, False): for mn in (4096, 4097, 8192): for k in (128, 7168, 7296): for num_groups in (1, 2, 4): - yield mn, k, with_transpose, use_ue8m0, num_groups + for gran_k in gran_k_list: + set_mk_alignment_for_contiguous_layout(gran_k) + yield mn, k, with_transpose, use_ue8m0, num_groups, gran_k def enumerate_k_grouped_sf_layout(): - alignment = get_mk_alignment_for_contiguous_layout() - assert alignment % 128 == 0 + gran_k_list = (128, ) if get_arch_major() == 9 else (32, 128) for mn in (4096, 7168): for num_groups, avg_k in ((16, 2048), (8, 4096), (72, 384), (128, 256)): - ks = [align(int(random.uniform(0.7, 1.3) * avg_k), alignment) for _ in range(num_groups)] - yield mn, ks, num_groups + for gran_k in gran_k_list: + set_mk_alignment_for_contiguous_layout(gran_k) + ks = [align(int(random.uniform(0.7, 1.3) * avg_k), gran_k) for _ in range(num_groups)] + yield mn, ks, num_groups, gran_k def enumerate_transpose(): @@ -222,25 +234,24 @@ def cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is use_ue8m0: bool, use_block_cast_for_fp8: bool = False): if is_fp4: x_fp4 = per_token_cast_to_fp4(x, use_ue8m0=use_ue8m0, gran_k=gran_k) - x = x_fp4 if major.is_k_major() else (transpose_packed_fp4(x_fp4[0]).T, x_fp4[1]) + return x_fp4 if major.is_k_major() else (transpose_packed_fp4(x_fp4[0]).T, x_fp4[1]) else: x_fp8 = per_block_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \ else per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) - x = x_fp8 if major.is_k_major() else (x_fp8[0].T.contiguous().T, x_fp8[1]) - return x + return x_fp8 if major.is_k_major() else (x_fp8[0].T.contiguous().T, x_fp8[1]) def grouped_cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool, use_ue8m0: bool, use_block_cast_for_fp8: bool = False): num_groups, mn, k = x.size() if is_fp4: - x_fp4 = (torch.empty((num_groups, mn, k // 2), device='cuda', dtype=torch.uint8) if major.is_k_major() else \ - torch.empty((num_groups, k, mn // 2), device='cuda', dtype=torch.uint8), + x_fp4 = (torch.empty((num_groups, mn, k // 2), device='cuda', dtype=torch.int8) if major.is_k_major() else \ + torch.empty((num_groups, k, mn // 2), device='cuda', dtype=torch.int8), torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float)) for i in range(num_groups): x_i_fp4 = per_token_cast_to_fp4(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) x_fp4[0][i], x_fp4[1][i] = x_i_fp4 if major.is_k_major() else (transpose_packed_fp4(x_i_fp4[0]), x_i_fp4[1]) - x = x_fp4 if major.is_k_major() else (x_fp4[0].mT, x_fp4[1]) + return x_fp4 if major.is_k_major() else (x_fp4[0].mT, x_fp4[1]) else: x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(mn, gran_k), ceil_div(k, gran_k)), device='cuda', dtype=torch.float) if use_block_cast_for_fp8 \ @@ -248,8 +259,7 @@ def grouped_cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: for i in range(num_groups): x_fp8[0][i], x_fp8[1][i] = per_block_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \ else per_token_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) - x = x_fp8 if major.is_k_major() else (x_fp8[0].mT.contiguous().mT, x_fp8[1]) - return x + return x_fp8 if major.is_k_major() else (x_fp8[0].mT.contiguous().mT, x_fp8[1]) def generate_normal(m: int, n: int, k: int, @@ -325,7 +335,7 @@ def layout_masked_to_psum(x: torch.Tensor, psum_m: torch.Tensor): last_psum_m = 0 for i in range(num_groups): x_psum[last_psum_m: psum_m[i]] = x[i, :psum_m[i] - last_psum_m] - last_psum_m = align(psum_m[i], 128) + last_psum_m = align(psum_m[i], get_mk_alignment_for_contiguous_layout()) return x_psum @@ -342,7 +352,7 @@ def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: psum_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) for j in range(num_groups): masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) - psum_m[j] = (0 if j == 0 else align(psum_m[j - 1], 128)) + masked_m[j] + psum_m[j] = (0 if j == 0 else align(psum_m[j - 1], get_mk_alignment_for_contiguous_layout())) + masked_m[j] assert masked_m.amax().item() <= max_m if use_bf16: @@ -356,8 +366,8 @@ def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int], - use_ue8m0: bool = False, use_bf16: bool = False): - assert get_mk_alignment_for_contiguous_layout() % 128 == 0 + use_ue8m0: bool = False, use_bf16: bool = False, gran_k = 128): + assert get_mk_alignment_for_contiguous_layout() % gran_k == 0 k = sum(ks) a = torch.randn((k, m), device='cuda', dtype=torch.bfloat16) @@ -376,8 +386,8 @@ def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: Majo assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) return k, a, b, c, d, ref_d - a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0) - b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0) + a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0, gran_k=gran_k) + b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0, gran_k=gran_k) # Transpose for K Major A/B if (major_a, major_b) == (MajorTypeAB.KMajor, MajorTypeAB.KMajor): diff --git a/deep-gemm/tests/test_attention.py b/deep-gemm/tests/test_attention.py index b26cf673..479da5b5 100644 --- a/deep-gemm/tests/test_attention.py +++ b/deep-gemm/tests/test_attention.py @@ -10,9 +10,9 @@ ignore_env, get_arch_major, test_filter ) -from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8 +from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8, per_token_cast_to_fp4, cast_back_from_fp4 -from generators import generate_normal, get_ue8m0_usage, get_kernel_types, MajorTypeAB +from generators import get_arch_major, generate_normal, get_ue8m0_usage, get_kernel_types, reset_seed, MajorTypeAB def apply_skip_head_mid(d: torch.Tensor, head_splits: Tuple[int, int, int]): @@ -53,40 +53,14 @@ def test_gemm_skip_head_mid() -> None: assert diff < 0.001, f'{m=}, {n=}, {k=}, {kernel_opt}, {diff:.5f}' t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast), - 'fp8_gemm', suppress_kineto_output=True) + 'gemm_', suppress_kineto_output=True) print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}): ' - f'{t * 1e6:4.0f} us | ' - f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' - f'{(count_bytes(a, b, d)) / 1e9 / t:4.0f} GB/s') + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d)) / 1e9 / t:4.0f} GB/s') print() -def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: - num_blocks, block_size, num_heads, head_dim = x.shape - assert num_heads == 1 - x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) - sf = x_amax / 448.0 - x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) - x_fp8 = torch.empty((num_blocks, block_size * (head_dim + 4)), device=x.device, dtype=torch.uint8) - x_fp8[ :, : block_size * head_dim] = x_scaled.view(num_blocks, block_size * head_dim).view(dtype=torch.uint8) - x_fp8[ :, block_size * head_dim :] = sf.view(num_blocks, block_size).view(dtype=torch.uint8) - return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) - - -def generate_cp_test_data(seq_len, seq_len_kv): - assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0 - chunk_size = seq_len // 2 - cp_size = seq_len_kv // seq_len - # Select an arbitrary CP rank - cp_id = cp_size // 3 - ks = torch.zeros(seq_len, dtype=torch.int, device='cuda') - ke = torch.zeros(seq_len, dtype=torch.int, device='cuda') - for i in range(chunk_size): - ke[i] = cp_id * chunk_size + i - ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i - return ks, ke - - def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, cost_only: bool = False): seq_len_kv = kv.shape[0] @@ -113,92 +87,137 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, return logits, cost -@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 10) def test_mqa_logits(): + + # Helper functions + def generate_ks_ke_tests(seq_len: int, seq_len_kv: int, disable_cp: bool): + if disable_cp: + ks = torch.zeros(seq_len, dtype=torch.int, device='cuda') + ke = torch.arange(seq_len, dtype=torch.int, device='cuda') + (seq_len_kv - seq_len) + return ks, ke + assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0 + chunk_size = seq_len // 2 + cp_size = seq_len_kv // seq_len + # Select an arbitrary CP rank + cp_id = cp_size // 3 + ks = torch.zeros(seq_len, dtype=torch.int, device='cuda') + ke = torch.zeros(seq_len, dtype=torch.int, device='cuda') + for i in range(chunk_size): + ke[i] = cp_id * chunk_size + i + ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i + return ks, ke + + def enumerate_mqa_logits(): + for is_fp4 in ((True, False) if get_arch_major() == 10 else (False, )): + for logits_dtype in (torch.float, torch.bfloat16): + for compressed_logits, clean_logits in [(False, True), (True, False)]: + for seq_len in (2048, 4096): + for seq_len_kv in (4096, 8192): + for num_heads, head_dim in [(64, 128)]: + for disable_cp in (False, True): + yield is_fp4, logits_dtype, compressed_logits, clean_logits, seq_len, seq_len_kv, num_heads, head_dim, disable_cp + print('Testing FP8 MQA Logits:') - num_heads, head_dim = 64, 128 - for seq_len in (2048, 4096): - for compressed_logits in (False, True): - for seq_len_kv in (4096, 8192): - for disable_cp in (False, True): - q = torch.randn(seq_len, num_heads, head_dim, device='cuda', dtype=torch.bfloat16) - kv = torch.randn(seq_len_kv, head_dim, device='cuda', dtype=torch.bfloat16) - weights = torch.randn(seq_len, num_heads, device='cuda', dtype=torch.float32) - - if disable_cp: - ks = torch.zeros(seq_len, dtype=torch.int, device='cuda') - ke = torch.arange(seq_len, dtype=torch.int, device='cuda') + (seq_len_kv - seq_len) - else: - ks, ke = generate_cp_test_data(seq_len, seq_len_kv) - - q_fp8 = q.to(torch.float8_e4m3fn) - kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False) - - if compressed_logits: - max_seqlen_k = (ke - ks).max().item() - logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke, max_seqlen_k=max_seqlen_k, clean_logits=False) - assert logits.size() == (seq_len, max_seqlen_k) - tmp = torch.full((seq_len, seq_len_kv), float('-inf'), device='cuda') - for i in range(seq_len): - tmp[i, ks[i] : ke[i]] = logits[i, : ke[i] - ks[i]] - logits = tmp - else: - logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) - - do_check = (seq_len_kv < 32768) - if do_check: - ref_logits, ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) - - ref_neginf_mask = (ref_logits == float('-inf')) - neginf_mask = (logits == float('-inf')) - assert torch.equal(neginf_mask, ref_neginf_mask) - - ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) - logits = logits.masked_fill(neginf_mask, 0) - diff = calc_diff(logits, ref_logits) - assert diff < 1e-3, f'{diff=}' - else: - ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke, cost_only=True) - - tflops = 2 * ref_cost * num_heads * head_dim / 1e12 - if compressed_logits: - t = bench_kineto(lambda: deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke, max_seqlen_k=max_seqlen_k, clean_logits=False), 'fp8_mqa_logits') - else: - t, clean_t = bench_kineto(lambda: deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke), ('fp8_mqa_logits', 'clean_logits')) - clean_bytes = (seq_len * seq_len_kv - ref_cost) * 4 + count_bytes(ks, ke) - print(f' > S={seq_len:4}, SKV={seq_len_kv:6}, H={num_heads:3}, D={head_dim:3}, CP={0 if disable_cp else 1}: ' - f'{tflops / t:4.0f} TFLOPS, {t * 1e6:4.0f} us, ' - f'{(count_bytes(q_fp8, kv_fp8, weights, ks, ke) + ref_cost * 4) / t / 1e9:4.0f} GB/s', end='') - # noinspection PyUnboundLocalVariable - print(f' | clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s' if not compressed_logits else '') + for is_fp4, logits_dtype, compressed_logits, clean_logits, seq_len, seq_len_kv, num_heads, head_dim, disable_cp in enumerate_mqa_logits(): + # Generate random inputs + q = torch.randn(seq_len, num_heads, head_dim, device='cuda', dtype=torch.bfloat16) + kv = torch.randn(seq_len_kv, head_dim, device='cuda', dtype=torch.bfloat16) + weights = torch.randn(seq_len, num_heads, device='cuda', dtype=torch.float32) + ks, ke = generate_ks_ke_tests(seq_len, seq_len_kv, disable_cp) + + # Calculate reference logits + ref_logits, ref_cost = ref_fp8_mqa_logits(q, kv, weights, ks, ke) + + # Quantize Q and KV to FP4 / FP8 + if is_fp4: + q_fp4 = per_token_cast_to_fp4(q.view(-1, head_dim), use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) + q_in = (q_fp4[0].view(seq_len, num_heads, head_dim // 2), q_fp4[1].view(seq_len, num_heads)) + q_simulated = cast_back_from_fp4(q_fp4[0], q_fp4[1], gran_k=32, use_packed_ue8m0=True).view(seq_len, num_heads, head_dim).to(torch.bfloat16) + + kv_fp4 = per_token_cast_to_fp4(kv.view(-1, head_dim), use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) + kv_in = (kv_fp4[0].view(seq_len_kv, head_dim // 2), kv_fp4[1].view(seq_len_kv)) + kv_simulated = cast_back_from_fp4(kv_fp4[0], kv_fp4[1], gran_k=32, use_packed_ue8m0=True).view(seq_len_kv, head_dim).to(torch.bfloat16) + else: + q_in = q.to(torch.float8_e4m3fn), None + q_simulated = q_in[0].to(torch.bfloat16) + kv_in = per_custom_dims_cast_to_fp8(kv, (0, ), False) + kv_simulated = (kv_in[0].float() * kv_in[1].unsqueeze(1)).to(torch.bfloat16) + + # Calculate reference logits + simulated_logits, _ = ref_fp8_mqa_logits(q_simulated, kv_simulated, weights, ks, ke) + + # Prepare kwargs + kernel_kwargs = dict( + q=q_in, kv=kv_in, weights=weights, + cu_seq_len_k_start=ks, cu_seq_len_k_end=ke, + clean_logits=clean_logits, max_seqlen_k=0, + logits_dtype=logits_dtype + ) + if compressed_logits: + max_seqlen_k = (ke - ks).max().item() + kernel_kwargs['max_seqlen_k'] = max_seqlen_k + + # Run kernel + logits = deep_gemm.fp8_fp4_mqa_logits(**kernel_kwargs) + + # Post process for compressed logits + if compressed_logits: + assert logits.size() == (seq_len, max_seqlen_k) + tmp = torch.full((seq_len, seq_len_kv), float('-inf'), device='cuda') + for i in range(seq_len): + tmp[i, ks[i] : ke[i]] = logits[i, : ke[i] - ks[i]] + logits = tmp + + # Validation + ref_neginf_mask = (ref_logits == float('-inf')) + neginf_mask = (logits == float('-inf')) + assert torch.equal(neginf_mask, ref_neginf_mask) + + ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) + simulated_logits = simulated_logits.masked_fill(ref_neginf_mask, 0) + logits = logits.masked_fill(ref_neginf_mask, 0) + diff = calc_diff(logits, ref_logits) + simulated_diff = calc_diff(logits, simulated_logits) + assert diff < 0.02 if is_fp4 else 1e-3, f"Diff: {diff}" + assert simulated_diff < 5e-6, f"Simulated Diff: {simulated_diff}" + + # Profiling + tflops = 2 * ref_cost * num_heads * head_dim / 1e12 + t, clean_t = bench_kineto(lambda: deep_gemm.fp8_fp4_mqa_logits(**kernel_kwargs), ('mqa_logits', 'clean_logits')) + clean_bytes = (seq_len * seq_len_kv - ref_cost) * 4 + count_bytes(ks, ke) + + print(f' > FP4={is_fp4}, BF16={logits_dtype == torch.bfloat16}, S={seq_len:4}, SKV={seq_len_kv:6}, H={num_heads:3}, D={head_dim:3}, CP={0 if disable_cp else 1}: ' + f'{tflops / t:4.0f} TFLOPS, {t * 1e6:4.0f} us, ' + f'{(count_bytes(q_in, kv_in, weights, ks, ke) + ref_cost * 4) / t / 1e9:4.0f} GB/s', end='') + print(f' | clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s' if clean_logits else '') print() -def ref_fp8_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor, - weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, - max_model_len: int, is_context_lens_2d: bool): - batch_size, next_n, heads, dim = q.size() +def ref_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor, + weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, + max_model_len: int, use_2d_context_lens: bool): + batch_size, next_n, num_heads, dim = q.size() num_block, block_size, _, dim = kv_cache.size() logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32) context_lens = context_lens.tolist() for i in range(batch_size): context_len = context_lens[i] - q_offsets = torch.full((next_n, ), context_len, device='cuda', dtype=torch.int32) if is_context_lens_2d \ - else torch.arange(context_len - next_n, context_len, device='cuda') + q_offsets = torch.full((next_n, ), context_len, device='cuda', dtype=torch.int32) if use_2d_context_lens \ + else torch.arange(context_len - next_n, context_len, device='cuda') weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous() num_blocks = (context_len + block_size - 1) // block_size block_idxs = block_tables[i][:num_blocks] kv_slice = kv_cache[block_idxs] # [num_blocks, block_size, kv_heads, dim] kx = kv_slice.permute(2, 3, 0, 1).reshape(kv_slice.size(2), dim, -1) # [kv_heads, dim, total_tokens] - qx = q[i].transpose(0, 1) # q[i]: [next_n, heads, dim] -> [heads, next_n, dim] - s = torch.matmul(qx, kx).to(logits.dtype) # [heads, next_n, dim] @ [1, dim, total_tokens] -> [heads, next_n, total_tokens] + qx = q[i].transpose(0, 1) # q[i]: [next_n, num_heads, dim] -> [num_heads, next_n, dim] + s = torch.matmul(qx, kx).to(logits.dtype) # [num_heads, next_n, dim] @ [1, dim, total_tokens] -> [num_heads, next_n, total_tokens] total_len = num_blocks * block_size k_offsets = torch.arange(0, total_len, device=q.device) mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None]) s = torch.where(mask[None, :, :], s, float('-inf')) # mask shape: [1, next_n, total_tokens] - s = torch.relu(s) * weight_slice[..., None] # weight_slice: [heads, next_n] -> [heads, next_n, 1] + s = torch.relu(s) * weight_slice[..., None] # weight_slice: [num_heads, next_n] -> [num_heads, next_n, 1] s = s.sum(dim=0) # [next_n, total_tokens] logits[i * next_n:(i + 1) * next_n, :total_len] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf')) @@ -206,70 +225,164 @@ def ref_fp8_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor, def test_paged_mqa_logits(): - print('Testing FP8 Paged MQA Logits:') - max_model_len = 111 * 1000 - for is_context_lens_2d in (False, True): - for batch_size, next_n in [(64, 1), (64, 2), (128, 1)]: - for heads, index_dim in [(64, 128)]: - for avg_kv in (8192, 32768): - num_blocks, blocksize = max_model_len * 3, 64 - - q = torch.randn((batch_size, next_n, heads, index_dim), device='cuda', dtype=torch.bfloat16) - kv_cache = torch.randn((num_blocks, blocksize, 1, index_dim), device='cuda', dtype=torch.bfloat16) - weights = torch.randn((batch_size * next_n, heads), device='cuda', dtype=torch.float32) - q_fp8 = q.to(torch.float8_e4m3fn) - kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) - - context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (batch_size, )).cuda().to(torch.int32) - context_lens_list = context_lens.tolist() - max_block_len = (max(context_lens_list) + blocksize - 1) // blocksize * blocksize - block_tables = torch.zeros((batch_size, max_block_len), device='cuda', dtype=torch.int32) - - counter, block_idx_pool = 0, torch.randperm(num_blocks, device='cuda', dtype=torch.int32) - for i in range(batch_size): - num_blocks = ceil_div(context_lens_list[i], blocksize) - block_tables[i][:num_blocks] = block_idx_pool[counter: counter+num_blocks] - counter += num_blocks - - ref_logits = ref_fp8_paged_mqa_logits(q, kv_cache, weights, context_lens, block_tables, max_model_len, is_context_lens_2d) - positions = torch.arange(max_model_len, device='cuda').unsqueeze(0).expand(batch_size * next_n, -1) - - if is_context_lens_2d: - context_lens_2d = ((context_lens.unsqueeze(1) + 1) * torch.rand(batch_size, next_n, device='cuda')).int() - context_lens_2d[:, next_n-1] = context_lens - schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens_2d, blocksize, deep_gemm.get_num_sms()) - logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens_2d, block_tables, schedule_metadata, max_model_len, clean_logits=False) - ref_neginf_mask = ~(positions < context_lens_2d.view(-1).unsqueeze(1)) - else: - schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, deep_gemm.get_num_sms()) - logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True) - row_indices = torch.arange(batch_size * next_n, device='cuda') // next_n - next_n_offset = torch.arange(batch_size * next_n, device='cuda') % next_n - ref_neginf_mask = ~(positions <= (context_lens[row_indices] - next_n + next_n_offset).unsqueeze(1)) - neginf_mask = (logits == float('-inf')) - assert torch.equal(neginf_mask, ref_neginf_mask) - - logits = logits.masked_fill(ref_neginf_mask, 0) - ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) - diff = calc_diff(logits, ref_logits) - assert diff < 1e-3, f"{diff=}" - - sum_lens = sum(context_lens.to(torch.int64)) - tflops = 2 * sum_lens * next_n * heads * index_dim / 1e12 - input_bytes = count_bytes(q_fp8, weights, context_lens) + sum_lens * (index_dim + 4) + (sum_lens / blocksize) * 4 - output_bytes = sum_lens * next_n * 4 - if is_context_lens_2d: - t = bench_kineto(lambda: deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens_2d, block_tables, schedule_metadata, max_model_len, clean_logits=False), - 'fp8_paged_mqa_logits') - else: - t, clean_t = bench_kineto(lambda: deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True), - ('fp8_paged_mqa_logits', 'clean_logits')) - clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens) - print(f' > BSZ={batch_size:3}, NextN={next_n:1}, H={heads:2}, D={index_dim:2}, L={avg_kv:6}: ' - f'{tflops / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, ' - f'{(input_bytes + output_bytes) / t / 1e9:4.0f} GB/s', end='') - # noinspection PyUnboundLocalVariable - print(f' | clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s' if not is_context_lens_2d else '') + + # Helper functions + def kv_cache_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + x_cast_back = x_scaled.float() * sf + + x_fp8 = torch.empty((num_blocks, block_size * (head_dim + 4)), device=x.device, dtype=torch.uint8) + x_fp8[ :, : block_size * head_dim] = x_scaled.view(num_blocks, block_size * head_dim).view(torch.uint8) + x_fp8[ :, block_size * head_dim :] = sf.view(num_blocks, block_size).view(torch.uint8) + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4), x_cast_back.to(x.dtype) + + def kv_cache_cast_to_fp4(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 and head_dim == 128 + x_scaled, sf = per_token_cast_to_fp4(x.view(-1, head_dim), use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) + x_cast_back = cast_back_from_fp4(x_scaled, sf, gran_k=32, use_packed_ue8m0=True).view(num_blocks, block_size, 1, head_dim) + + x_fp4 = torch.empty((num_blocks, block_size * (head_dim // 2 + 4)), device=x.device, dtype=torch.uint8) + x_fp4[ :, : block_size * head_dim // 2] = x_scaled.view(num_blocks, block_size * head_dim // 2).view(torch.uint8) + x_fp4[ :, block_size * head_dim // 2 :] = sf.view(num_blocks, block_size).view(torch.uint8) + return x_fp4.view(num_blocks, block_size, num_heads, head_dim // 2 + 4), x_cast_back.to(x.dtype) + + def enumerate_paged_mqa_logits(): + arch_major = get_arch_major() + for is_varlen in ((True, False) if arch_major == 10 else (False, )): + for is_fp4 in ((True, False) if arch_major == 10 else (False, )): + for logits_dtype in (torch.float, torch.bfloat16): + for block_kv in ((32, 64) if arch_major == 10 else (64, )): + for use_2d_context_lens, clean_logits in [(True, False)]: + for batch_size in (256, ): + for next_n in ((1, ) if is_varlen else ((1, 2, 4, 5, 6) if arch_major == 10 else (1, 2))): + for max_tokens_per_batch in ((1, 4, 10) if is_varlen else (1, )): + for num_heads, head_dim in [(64, 128)]: + for avg_kv in (8192, 32768): + yield is_varlen, is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, max_tokens_per_batch, num_heads, head_dim, avg_kv + + + print('Testing FP8/FP4 Paged MQA Logits:') + max_model_len = 111 * 1024 + num_total_blocks = max_model_len * 5 + + for is_varlen, is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, max_tokens_per_batch, num_heads, head_dim, avg_kv in enumerate_paged_mqa_logits(): + # Varlen: flatten raw_batch_size sequences with variable tokens into (batch_size, 1, ...) + raw_batch_size, raw_next_n = batch_size, next_n + if is_varlen: + tokens_per_seq = torch.randint(1, max_tokens_per_batch + 1, (raw_batch_size,), device='cuda', dtype=torch.int) + indices = torch.arange(raw_batch_size, device='cuda', dtype=torch.int).repeat_interleave(tokens_per_seq) + batch_size, next_n = tokens_per_seq.sum().item(), 1 + else: + tokens_per_seq, indices = None, None + + # Generate random inputs + q = torch.randn((batch_size, next_n, num_heads, head_dim), device='cuda', dtype=torch.bfloat16) + kv_cache = torch.randn((num_total_blocks, block_kv, 1, head_dim), device='cuda', dtype=torch.bfloat16) + weights = torch.randn((batch_size * next_n, num_heads), device='cuda', dtype=torch.float) + context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (raw_batch_size,), device='cuda', dtype=torch.int) + + if is_varlen: + max_ctx_len_per_seq = context_lens + (tokens_per_seq - 1) + else: + max_ctx_len_per_seq = context_lens + + # Assign block tables (per-sequence, sized by the largest ctx_len within the sequence) + seq_sum_lens = context_lens.sum().item() + num_blocks_per_query = ceil_div(max_ctx_len_per_seq, block_kv) + block_table = torch.empty((raw_batch_size, num_blocks_per_query.max().item()), device='cuda', dtype=torch.int) + block_idx_pool = torch.randperm(num_total_blocks, device='cuda', dtype=torch.int) + offset = 0 + for i, num_blocks in enumerate(num_blocks_per_query.tolist()): + block_table[i, :num_blocks] = block_idx_pool[offset : offset + num_blocks] + offset += num_blocks + if is_varlen: + context_lens = context_lens.repeat_interleave(tokens_per_seq) + offsets_within_seq = torch.cat([ + torch.arange(n.item(), device='cuda', dtype=torch.int) + for n in tokens_per_seq + ]) + context_lens = context_lens + offsets_within_seq + block_table = block_table.repeat_interleave(tokens_per_seq, dim=0) + + # Calculate reference logits + ref_logits = ref_paged_mqa_logits(q, kv_cache, weights, context_lens, block_table, max_model_len, use_2d_context_lens) + + # Quantize Q and KV cache to FP4 / FP8 + if is_fp4: + q_fp4 = per_token_cast_to_fp4(q.view(-1, head_dim), use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) + q_in = (q_fp4[0].view(batch_size, next_n, num_heads, head_dim // 2), q_fp4[1].view(batch_size, next_n, num_heads)) + q_simulated = cast_back_from_fp4(q_fp4[0], q_fp4[1], gran_k=32, use_packed_ue8m0=True).view(batch_size, next_n, num_heads, head_dim).to(torch.bfloat16) + kv_in, kv_simulated = kv_cache_cast_to_fp4(kv_cache) + else: + q_in = q.to(torch.float8_e4m3fn), None + q_simulated = q_in[0].to(torch.bfloat16) + kv_in, kv_simulated = kv_cache_cast_to_fp8(kv_cache) + + # Calculate simulated reference logits + simulated_logits = ref_paged_mqa_logits(q_simulated, kv_simulated, weights, context_lens, block_table, max_model_len, use_2d_context_lens) + + # Prepare masks and context lengths with NextN + positions = torch.arange(max_model_len, device='cuda').unsqueeze(0).expand(batch_size * next_n, -1) + if use_2d_context_lens: + if is_varlen: + # Varlen: context_lens is already per-token (shape [total_tokens]); + # just reshape to (total_tokens, 1) so each token keeps its own ctx_len. + context_lens_nextn = context_lens.view(-1, 1) + else: + context_lens_nextn = ((context_lens.unsqueeze(1) + 1) * torch.rand(batch_size, next_n, device='cuda')).int() + # Ensure last token matches actual length + context_lens_nextn[:, -1] = context_lens + ref_neginf_mask = ~(positions < context_lens_nextn.view(-1, 1)) + else: + context_lens_nextn = context_lens + offsets = torch.arange(batch_size * next_n, device='cuda') + limits = (context_lens[offsets // next_n] - next_n + offsets % next_n).unsqueeze(1) + ref_neginf_mask = ~(positions <= limits) + + # Run Kernel + kernel_kwargs = dict( + q=q_in, kv_cache=kv_in, weights=weights, + context_lens=context_lens_nextn, block_table=block_table, + schedule_meta=deep_gemm.get_paged_mqa_logits_metadata(context_lens_nextn, block_kv, deep_gemm.get_num_sms(), indices=indices), + max_context_len=max_model_len, clean_logits=clean_logits, logits_dtype=logits_dtype, + indices=indices, + ) + logits = deep_gemm.fp8_fp4_paged_mqa_logits(**kernel_kwargs) + + # Validation + assert logits.dtype == logits_dtype + logits = logits.to(torch.float) + + if clean_logits: + assert torch.equal(logits == float('-inf'), ref_neginf_mask), "Mask mismatch" + + logits_masked = logits.masked_fill(ref_neginf_mask, 0) + ref_masked = ref_logits.masked_fill(ref_neginf_mask, 0) + simulated_masked = simulated_logits.masked_fill(ref_neginf_mask, 0) + diff = calc_diff(logits_masked, ref_masked) + simulated_diff = calc_diff(logits_masked, simulated_masked) + assert diff < 0.02 if is_fp4 else 1e-3, f"Diff: {diff}" + assert simulated_diff < 5e-6, f"Simulated Diff: {simulated_diff}" + + # Profiling + sum_lens = context_lens.sum().item() + tflops_calc = 2 * sum_lens * next_n * num_heads * head_dim / 1e12 + kv_bytes_per_token = head_dim / (2 if is_fp4 else 1) + 4 + # KV is read once per sequence; for varlen sum_lens overcounts (per-token), so use seq_sum_lens + kv_sum_lens = seq_sum_lens if is_varlen else sum_lens + total_bytes = count_bytes(q, weights) + kv_sum_lens * kv_bytes_per_token + (sum_lens * next_n * logits_dtype.itemsize) + + t, clean_t = bench_kineto(lambda: deep_gemm.fp8_fp4_paged_mqa_logits(**kernel_kwargs), ('paged_mqa_logits', 'clean_logits')) + print(f' > FP4={is_fp4}, BF16={logits_dtype == torch.bfloat16}, BLOCK_KV={block_kv}, BSZ={raw_batch_size:3}, NextN={raw_next_n:1}, H={num_heads:2}, D={head_dim:2}, L={avg_kv:6}: ' + f'{tflops_calc / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, {total_bytes / t / 1e9:4.0f} GB/s', end='') + if is_varlen: + print(f' | Varlen, MaxTPB={max_tokens_per_batch}, NumTokens={batch_size}', end='') + print(f' | clean: {clean_t*1e6:3.0f} us' if clean_logits else '') print() @@ -280,6 +393,5 @@ def test_paged_mqa_logits(): random.seed(0) test_gemm_skip_head_mid() - test_mqa_logits() test_paged_mqa_logits() diff --git a/deep-gemm/tests/test_bf16.py b/deep-gemm/tests/test_bf16.py index 1a3b0467..4e754477 100644 --- a/deep-gemm/tests/test_bf16.py +++ b/deep-gemm/tests/test_bf16.py @@ -11,7 +11,8 @@ from generators import ( get_arch_major, layout_masked_to_psum, align, enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, - generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous + generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous, + get_mk_alignment_for_contiguous_layout ) @@ -56,6 +57,10 @@ def test_m_grouped_gemm_contiguous() -> None: major_opt = 'N' if major_a.is_k_major() else 'T' major_opt += 'T' if major_b.is_k_major() else 'N' + # Select best alignment + alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout() + deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) + for test_alias in (False, True): m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True, use_psum_layout=use_psum_layout) @@ -65,8 +70,15 @@ def test_m_grouped_gemm_contiguous() -> None: b = b if major_b.is_k_major() else b.mT assert a[0].is_contiguous() and b[0].is_contiguous() getattr(deep_gemm, func_name)(a, b, d, grouped_layout, use_psum_layout=use_psum_layout) - diff = calc_diff(d, ref_d) - assert diff < 1e-5, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}' + if use_psum_layout: + for j in range(num_groups): + start = 0 if j == 0 else align(grouped_layout[j - 1], get_mk_alignment_for_contiguous_layout()) + end = grouped_layout[j] + diff = calc_diff(d[start : end], ref_d[start : end]) + assert diff < 1e-5, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}' + else: + diff = calc_diff(d, ref_d) + assert diff < 1e-5, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}' m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True, use_psum_layout=use_psum_layout) @@ -91,6 +103,10 @@ def test_m_grouped_gemm_masked() -> None: sum_t, max_t = 0, 0 sum_ops, sum_bytes = 0, 0 + # Select best alignment + alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout(int(expected_m_per_group * 1.2)) + deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) + for i in range(num_tests): a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True, use_psum_layout=use_psum_layout) @@ -111,7 +127,7 @@ def test_func(): if masked_m[j].item() == 0: continue if use_psum_layout: - d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], 128): psum_m[j]] + d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], get_mk_alignment_for_contiguous_layout()): psum_m[j]] else: d_slice = d[j, :masked_m[j].item()] diff = calc_diff(d_slice, ref_d[j, :masked_m[j].item()]) @@ -138,6 +154,9 @@ def test_func(): def test_k_grouped_gemm_contiguous() -> None: print('Testing k-grouped contiguous GEMM:') + # TODO: Support arbitrary alignment + deep_gemm.set_mk_alignment_for_contiguous_layout(128) + for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.bfloat16): for test_empty_groups in (False, True): new_ks = copy.deepcopy(ks) diff --git a/deep-gemm/tests/test_cublaslt.py b/deep-gemm/tests/test_cublaslt.py deleted file mode 100644 index afe8a175..00000000 --- a/deep-gemm/tests/test_cublaslt.py +++ /dev/null @@ -1,20 +0,0 @@ -import pytest -import torch - -import deep_gemm - - -@pytest.mark.kernels_ci -def test_cublaslt_gemm_nt(): - m, n, k = 256, 1024, 512 - a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) - b = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) - d = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) - - deep_gemm.cublaslt_gemm_nt(a, b, d) - - ref = a @ b.T - cos = torch.nn.functional.cosine_similarity( - d.float().flatten(), ref.float().flatten(), dim=0 - ) - assert cos.item() > 0.99, f"cosine similarity too low: {cos.item()}" diff --git a/deep-gemm/tests/test_einsum.py b/deep-gemm/tests/test_einsum.py index b7979989..57f54592 100644 --- a/deep-gemm/tests/test_einsum.py +++ b/deep-gemm/tests/test_einsum.py @@ -99,7 +99,7 @@ def test_fp8_bhr_hdr_bhd(use_ue8m0: bool = True): deep_gemm.fp8_einsum('bhr,hdr->bhd', x_fp8, y_fp8, z) assert calc_diff(z, ref_z) < 1e-3 - t = bench_kineto(lambda: deep_gemm.fp8_einsum('bhr,hdr->bhd', x_fp8, y_fp8, z), 'fp8_gemm', suppress_kineto_output=True) + t = bench_kineto(lambda: deep_gemm.fp8_einsum('bhr,hdr->bhd', x_fp8, y_fp8, z), 'gemm_', suppress_kineto_output=True) t_cublaslt = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z, use_cublaslt=True), 'nvjet', suppress_kineto_output=True) print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', f'{t * 1e6:4.0f} us | ' @@ -129,7 +129,7 @@ def test_fp8_bhd_hdr_bhr(use_ue8m0: bool = True): deep_gemm.fp8_einsum('bhd,hdr->bhr', x_fp8, y_fp8, z) assert calc_diff(z, ref_z) < 1e-3 - t = bench_kineto(lambda: deep_gemm.fp8_einsum('bhd,hdr->bhr', x_fp8, y_fp8, z), 'fp8_gemm', suppress_kineto_output=True) + t = bench_kineto(lambda: deep_gemm.fp8_einsum('bhd,hdr->bhr', x_fp8, y_fp8, z), 'gemm_', suppress_kineto_output=True) t_cublaslt = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z, use_cublaslt=True), 'nvjet', suppress_kineto_output=True) print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', f'{t * 1e6:4.0f} us | ' @@ -157,7 +157,7 @@ def test_fp8_bhd_bhr_hdr(use_ue8m0: bool = True): deep_gemm.fp8_einsum('bhd,bhr->hdr', x_fp8, y_fp8, z, z, recipe=(1, 1, 128)) assert calc_diff(z, ref_z) < 1e-3 - t = bench_kineto(lambda: deep_gemm.fp8_einsum('bhd,bhr->hdr', x_fp8, y_fp8, z, z, recipe=(1, 1, 128)), 'fp8_gemm', suppress_kineto_output=True) + t = bench_kineto(lambda: deep_gemm.fp8_einsum('bhd,bhr->hdr', x_fp8, y_fp8, z, z, recipe=(1, 1, 128)), 'gemm_', suppress_kineto_output=True) print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', f'{t * 1e6:4.0f} us | ' f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' diff --git a/deep-gemm/tests/test_fp8_fp4.py b/deep-gemm/tests/test_fp8_fp4.py index f7e3e1c4..4e9f54f7 100644 --- a/deep-gemm/tests/test_fp8_fp4.py +++ b/deep-gemm/tests/test_fp8_fp4.py @@ -13,11 +13,11 @@ from generators import ( KernelType, get_ue8m0_usage, layout_masked_to_psum, align, enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, - generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous + generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous, + get_mk_alignment_for_contiguous_layout ) -@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 9) def test_gemm() -> None: print('Testing GEMM:') scores = [] @@ -45,7 +45,7 @@ def test_gemm() -> None: a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0, quant_config=quant_config) t = bench_kineto(lambda: deep_gemm.fp8_fp4_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b), - 'fp8_gemm', suppress_kineto_output=True) + 'gemm_', suppress_kineto_output=True) cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) \ if not quant_config.is_fp4_a and not quant_config.is_fp4_b else (0, 0) print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): ' @@ -68,6 +68,10 @@ def test_m_grouped_gemm_contiguous() -> None: disable_ue8m0_cast = not use_ue8m0 recipe, recipe_a, recipe_b = quant_config.get_recipes() + # Select best alignment + alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout() + deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) + for test_alias in (False, True): m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, @@ -79,8 +83,15 @@ def test_m_grouped_gemm_contiguous() -> None: assert a[0].is_contiguous() and b[0].is_contiguous() getattr(deep_gemm, func_name)(a, b, d, grouped_layout, disable_ue8m0_cast=disable_ue8m0_cast, use_psum_layout=use_psum_layout, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) - diff = calc_diff(d, ref_d) - assert diff < quant_config.max_diff(), f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}' + if use_psum_layout: + for j in range(num_groups): + start = 0 if j == 0 else align(grouped_layout[j - 1], get_mk_alignment_for_contiguous_layout()) + end = grouped_layout[j] + diff = calc_diff(d[start : end], ref_d[start : end]) + assert diff < quant_config.max_diff(), f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}' + else: + diff = calc_diff(d, ref_d) + assert diff < quant_config.max_diff(), f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}' m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, quant_config=quant_config) @@ -90,7 +101,7 @@ def test_func(): deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout, disable_ue8m0_cast=disable_ue8m0_cast, use_psum_layout=use_psum_layout, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + t = bench_kineto(test_func, 'gemm_', suppress_kineto_output=True) print(f' > Perf ({num_groups=}, m={m:5}, n={n:6}, k={k:5}, {kernel_opt}, layout={major_opt}, psum={use_psum_layout}): ' f'{t * 1e6:4.0f} us | ' f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' @@ -112,6 +123,10 @@ def test_m_grouped_gemm_masked() -> None: sum_t, max_t = 0, 0 sum_ops, sum_bytes = 0, 0 + # Select best alignment + alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout(int(expected_m_per_group * 1.2)) + deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) + for i in range(num_tests): a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, @@ -124,10 +139,10 @@ def test_m_grouped_gemm_masked() -> None: def test_func(): if use_psum_layout: deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(a_psum, b, d_psum, psum_m, disable_ue8m0_cast=disable_ue8m0_cast, - use_psum_layout=True, expected_m_for_psum_layout=expected_m_per_group, + use_psum_layout=True, expected_m_for_psum_layout=int(expected_m_per_group * 1.2), recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) else: - deep_gemm.m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, + deep_gemm.m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, int(expected_m_per_group * 1.2), disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) test_func() @@ -135,7 +150,7 @@ def test_func(): if masked_m[j].item() == 0: continue if use_psum_layout: - d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], 128): psum_m[j]] + d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], get_mk_alignment_for_contiguous_layout()): psum_m[j]] else: d_slice = d[j, :masked_m[j].item()] diff = calc_diff(d_slice, ref_d[j, :masked_m[j].item()]) @@ -143,7 +158,7 @@ def test_func(): # Test performance with fixed shapes valid_m = masked_m.sum().item() - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + t = bench_kineto(test_func, 'gemm_', suppress_kineto_output=True) sum_t += t max_t = max(max_t, t) @@ -158,36 +173,36 @@ def test_func(): print() -@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 9) def test_k_grouped_gemm_contiguous() -> None: print('Testing k-grouped contiguous GEMM:') k_grouped_fp8_gemm_contiguous = deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 \ else deep_gemm.k_grouped_fp8_gemm_tn_contiguous - for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.float8_e4m3fn): + for num_groups, m, n, major_a, major_b, ks, expected_k_per_group, gran_k in enumerate_k_grouped_contiguous(torch.float8_e4m3fn): + recipe = (1, 1, gran_k) use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D) for test_empty_groups in (False, True): new_ks = copy.deepcopy(ks) if test_empty_groups and len(ks) > 1: new_ks[random.randint(0, num_groups - 1)] = 0 - k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0) + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0, gran_k=gran_k) new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda') - k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c) + k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c, recipe=recipe) diff = calc_diff(d, ref_d) assert diff < 0.001, f'{m=}, {n=}, {k=}, {ks=}, {diff:.5f}' # Test performance - k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0) + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0, gran_k=gran_k) ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') # noinspection PyShadowingNames def test_func(): - k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c) + k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c, recipe=recipe) - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): ' + t = bench_kineto(test_func, 'gemm_', suppress_kineto_output=True) + print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}, gran_k={gran_k:3}): ' f'{t * 1e6:4.0f} us | ' f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s') diff --git a/deep-gemm/tests/test_layout.py b/deep-gemm/tests/test_layout.py index 7875733a..a0d4a02e 100644 --- a/deep-gemm/tests/test_layout.py +++ b/deep-gemm/tests/test_layout.py @@ -1,6 +1,6 @@ import torch import random -from deep_gemm.testing import bench_kineto, count_bytes +from deep_gemm.testing import bench_kineto, count_bytes, get_arch_major from deep_gemm.utils import ( align, ceil_div, per_token_cast_to_fp8, per_channel_cast_to_fp8, @@ -43,9 +43,9 @@ def get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(x: torch.Tensor) -> def test_sf_layout_kernels() -> None: print('Testing SF layout kernels:') - for mn, k, with_transpose, use_ue8m0, num_groups in enumerate_sf_layout(): + for mn, k, with_transpose, use_ue8m0, num_groups, gran_k in enumerate_sf_layout(): x = torch.randn((num_groups * mn, k), dtype=torch.bfloat16, device='cuda') - x, fp32_sf = per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0) + x, fp32_sf = per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) fp32_sf = fp32_sf if num_groups == 1 else fp32_sf.view(num_groups, mn, -1) fp32_sf = fp32_sf if with_transpose else fp32_sf.transpose(-1, -2).contiguous().transpose(-1, -2) @@ -60,7 +60,7 @@ def test_sf_layout_kernels() -> None: else: impl, name = get_mn_major_tma_aligned_tensor, 'transpose' transposed_sf = get_mn_major_tma_aligned_tensor(fp32_sf) - tma_aligned_mn, sf_k = get_tma_aligned_size(mn, fp32_sf.element_size()), ceil_div(k, 128) + tma_aligned_mn, sf_k = get_tma_aligned_size(mn, fp32_sf.element_size()), ceil_div(k, gran_k) if num_groups > 1: assert transposed_sf.size(0) == num_groups assert transposed_sf.stride(0) == tma_aligned_mn * sf_k @@ -74,22 +74,22 @@ def test_sf_layout_kernels() -> None: except AssertionError as e: # Some cases may fallback to PyTorch impl t = 0 - print(f' > Perf ({num_groups=:2}, {mn=:5}, {k=:5}, transpose={int(with_transpose)}, use_ue8m0={int(use_ue8m0)}): ' + print(f' > Perf ({num_groups=:2}, {mn=:5}, {k=:5}, transpose={int(with_transpose)}, use_ue8m0={int(use_ue8m0)}, gran_k={gran_k:3}): ' f'{t * 1e6:4.0f} us | {count_bytes(fp32_sf, impl(fp32_sf)) / 1e9 / t if t else 0:4.0f} GB/s') print() def test_k_grouped_sf_layout_kernels() -> None: print('Testing k-grouped SF layout kernels:') - for mn, ks, num_groups in enumerate_k_grouped_sf_layout(): - sf_ks = [k // 128 for k in ks] - packed_sf_ks = [ceil_div(k, 512) for k in ks] + for mn, ks, num_groups, gran_k in enumerate_k_grouped_sf_layout(): + sf_ks = [k // gran_k for k in ks] + packed_sf_ks = [ceil_div(k, gran_k * 4) for k in ks] ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') x = torch.randn((sum(ks), mn), dtype=torch.bfloat16, device='cuda') - x, fp32_sf = per_channel_cast_to_fp8(x, use_ue8m0=True) + x, fp32_sf = per_channel_cast_to_fp8(x, use_ue8m0=True, gran_k=gran_k) # Correctness - packed_sf = get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf, ks_tensor, ks) + packed_sf = get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf, ks_tensor, ks, gran_k) split_packed_sf = packed_sf.split(packed_sf_ks) split_fp32_sf = fp32_sf.split(sf_ks) for i in range(num_groups): @@ -97,8 +97,8 @@ def test_k_grouped_sf_layout_kernels() -> None: assert torch.equal(split_packed_sf[i], ref_packed_sf), f'{i=}' # Performance - t = bench_kineto(lambda: get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf, ks_tensor, ks), 'pack_fp32_into_ue8m0') - print(f' > Perf ({num_groups=:3}, {mn=:5}, sum_k={sum(ks):5}):' + t = bench_kineto(lambda: get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf, ks_tensor, ks, gran_k), 'pack_fp32_into_ue8m0') + print(f' > Perf ({num_groups=:3}, {mn=:5}, sum_k={sum(ks):5}, gran_k={gran_k:3}):' f'{t * 1e6:4.0f} us | ' f'{count_bytes(fp32_sf, packed_sf, ks_tensor) / 1e9 / t:4.0f} GB/s') print() diff --git a/deep-gemm/tests/test_lazy_init.py b/deep-gemm/tests/test_lazy_init.py index 5363b6db..17a3a121 100644 --- a/deep-gemm/tests/test_lazy_init.py +++ b/deep-gemm/tests/test_lazy_init.py @@ -1,3 +1,4 @@ +import argparse import torch import torch.multiprocessing as mp import deep_gemm @@ -8,7 +9,11 @@ def main(local_rank: int): if __name__ == '__main__': - procs = [mp.Process(target=main, args=(i, ), ) for i in range(8)] + parser = argparse.ArgumentParser(description='Test lazy initialization') + parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') + args = parser.parse_args() + + procs = [mp.Process(target=main, args=(i, ), ) for i in range(args.num_processes)] for p in procs: p.start() for p in procs: diff --git a/deep-gemm/tests/test_mega_moe.py b/deep-gemm/tests/test_mega_moe.py new file mode 100644 index 00000000..e74b65e5 --- /dev/null +++ b/deep-gemm/tests/test_mega_moe.py @@ -0,0 +1,295 @@ +import argparse +import os +import random +import sys +import torch +import torch.distributed as dist +from typing import Tuple + +import deep_gemm +from deep_gemm.utils import per_token_cast_to_fp4, per_token_cast_to_fp8 +from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather +from deep_gemm.testing import bench_kineto + + +def import_baseline(): + # Load legacy implements from third-party + deep_ep, tilelang_ops, do_bench, is_legacy_loaded = None, None, None, False + # noinspection PyBroadException + try: + import deep_ep + import importlib.util + from tilelang.profiler.bench import do_bench + spec = importlib.util.spec_from_file_location( + 'tilelang_ops', + os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'third-party', 'tilelang_ops', '__init__.py')) + tilelang_ops = importlib.util.module_from_spec(spec) + sys.modules['tilelang_ops'] = tilelang_ops + spec.loader.exec_module(tilelang_ops) + is_legacy_loaded = True + except Exception as ex: + dist_print(f'Failed to load legacy code: {ex}, skip baseline benchmarking', once_in_node=True) + dist_print(once_in_node=True) + return deep_ep, tilelang_ops, do_bench, is_legacy_loaded + + +# TODO: skip the test for SM90 +# noinspection PyUnboundLocalVariable,PyShadowingNames +def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) + torch.manual_seed(rank_idx) + random.seed(rank_idx) + + # Settings + num_max_tokens_per_rank = args.num_max_tokens_per_rank + num_tokens = max(0, args.num_max_tokens_per_rank - random.randint(0, args.num_max_removed_tokens)) \ + if args.num_tokens == 0 else args.num_tokens + hidden, intermediate_hidden = args.hidden, args.intermediate_hidden + num_experts, num_topk = args.num_experts, args.num_topk + num_experts_per_rank = num_experts // num_ranks + assert num_tokens <= num_max_tokens_per_rank + + # Allocate symmetric memory + buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden + ) + + # Create inputs + # noinspection PyGlobalUndefined + def create_inputs(): + global x, topk_idx, topk_weights, l1_weights, l2_weights, transformed_l1_weights, transformed_l2_weights + global cumulative_local_expert_recv_stats_fused + global cumulative_local_expert_recv_stats_baseline + x = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + l1_weights = torch.randn( + (num_experts_per_rank, intermediate_hidden * 2, hidden), dtype=torch.bfloat16, device='cuda') + l2_weights = torch.randn( + (num_experts_per_rank, hidden, intermediate_hidden), dtype=torch.bfloat16, device='cuda') + scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda') + topk_weights, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) + cumulative_local_expert_recv_stats_fused = torch.randint( + 0, 100, (num_experts_per_rank, ), dtype=torch.int, device='cuda') + cumulative_local_expert_recv_stats_baseline = cumulative_local_expert_recv_stats_fused.clone() + if args.masked_ratio > 0: + rand_mask = torch.rand_like(topk_idx, dtype=torch.float) + topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1) + topk_weights.masked_fill_(topk_idx < 0, 0) + + # Check SF requirements + assert hidden % 128 == 0 + assert intermediate_hidden % 128 == 0 + assert l1_weights.shape[2] % 128 == 0 and l2_weights.shape[2] % 128 == 0 + + # Cast inputs to FP8 with per-32 UE8M0 SF + x = per_token_cast_to_fp8(x, use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) + + # Cast grouped BF16 weights to FP4 with MN-major SF + # TODO: merge with `cast_fp8_fp4_with_major` + def cast_grouped_weights_to_fp4(bf16_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + num_groups, n, k = bf16_weights.shape + w = torch.empty((num_groups, n, k // 2), device='cuda', dtype=torch.int8) + w_sf = torch.empty((num_groups, n, k // 32), device='cuda', dtype=torch.float) + for i in range(num_groups): + w[i], w_sf[i] = per_token_cast_to_fp4(bf16_weights[i], use_ue8m0=True, gran_k=32) + w_sf = deep_gemm.transform_sf_into_required_layout(w_sf, n, k, (1, 32), num_groups) + return w, w_sf + + l1_weights = cast_grouped_weights_to_fp4(l1_weights) + l2_weights = cast_grouped_weights_to_fp4(l2_weights) + transformed_l1_weights, transformed_l2_weights = deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights) + + # Run fused mega MoE + # NOTES: copy x into buffer before each call because debug mode zeros the entire buffer + def run_fused(): + buffer.x[:num_tokens].copy_(x[0]) + buffer.x_sf[:num_tokens].copy_(x[1]) + buffer.topk_idx[:num_tokens].copy_(topk_idx) + buffer.topk_weights[:num_tokens].copy_(topk_weights) + + y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + # noinspection PyTypeChecker + deep_gemm.fp8_fp4_mega_moe( + y, + transformed_l1_weights, transformed_l2_weights, + buffer, + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_fused, + activation_clamp=args.activation_clamp, + fast_math=bool(args.fast_math) + ) + return y, cumulative_local_expert_recv_stats_fused + + dist_print('Config:', once_in_node=True) + dist_print(f' > Tokens: {num_tokens}/{num_max_tokens_per_rank}', once_in_node=True) + dist_print(f' > Hidden: {hidden}', once_in_node=True) + dist_print(f' > Intermediate: {intermediate_hidden}', once_in_node=True) + dist_print(f' > Experts: {num_topk}/{num_experts}', once_in_node=True) + dist_print(f' > Buffer: {buffer.buffer.nbytes / 2 ** 30:.3f} GiB', once_in_node=True) + dist_print(once_in_node=True) + + # Only do NCU profiling + if args.ncu_profile_only: + create_inputs() + dist_print(f'Run fused kernel:', once_in_node=True) + run_fused() + dist_print(f' > Done, exiting', once_in_node=True) + + # Destroy and exit + dist.barrier() + buffer.destroy() + dist.destroy_process_group() + return + + # Non-overlapped baseline: EP dispatch + GEMM + EP combine + deep_ep, tilelang_ops, tilelang_bench, is_legacy_loaded = import_baseline() + alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout() + deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) + ep_buffer = deep_ep.ElasticBuffer( + group, + num_max_tokens_per_rank=num_max_tokens_per_rank, hidden=hidden, + num_topk=num_topk, use_fp8_dispatch=True, + explicitly_destroy=True, + allow_multiple_reduction=False, + gpu_timeout_secs=10, cpu_timeout_secs=30 + ) if is_legacy_loaded else None + + def run_baseline(): + recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch( + x, topk_idx=topk_idx, topk_weights=topk_weights, + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_baseline, + num_experts=num_experts, expert_alignment=alignment, + do_cpu_sync=False, do_handle_copy=False, + do_expand=True, use_tma_aligned_col_major_sf=True, + ) + n = recv_x[0].size(0) + l1_y = torch.empty((n, intermediate_hidden * 2), dtype=torch.bfloat16, device='cuda') + deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous( + recv_x, l1_weights, l1_y, handle.psum_num_recv_tokens_per_expert, + use_psum_layout=True, recipe=(1, 1, 32)) + # noinspection PyCallingNonCallable + l1_y = tilelang_ops.swiglu_apply_weight_to_fp8( + x=l1_y, + topk_weights=recv_topk_weights, + avail_tokens=handle.psum_num_recv_tokens_per_expert[-1], + num_per_channels=32, + use_col_major_scales=True, + round_scale=True, + ue8m0_scale=True, + output_bf16=False, + clamp_value=args.activation_clamp, + fast_math=bool(args.fast_math) + ) + l2_y = torch.empty((n, hidden), dtype=torch.bfloat16, device='cuda') + deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous( + l1_y, l2_weights, l2_y, handle.psum_num_recv_tokens_per_expert, + use_psum_layout=True, recipe=(1, 1, 32)) + return ep_buffer.combine(l2_y, handle=handle)[0], cumulative_local_expert_recv_stats_baseline + + # Check correctness (must be bitwise identical) + num_correctness_tests = 1 if args.num_correctness_tests is None else args.num_correctness_tests + # noinspection PyBroadException + if is_legacy_loaded and num_correctness_tests > 0: + dist_print('Running correctness tests:', once_in_node=True) + for i in range(num_correctness_tests): + create_inputs() + for fused_result, baseline_result in zip(run_fused(), run_baseline()): + assert torch.equal(fused_result, baseline_result) + if (i + 1) % 100 == 0 or i == num_correctness_tests - 1: + dist_print(f' > Correctness test #{i + 1}/{num_correctness_tests} passed', once_in_node=True) + dist_print(once_in_node=True) + else: + create_inputs() + + # Count local received tokens + gathered_topk_idx = uneven_all_gather(topk_idx, group=group) + gathered_topk_idx[(gathered_topk_idx < rank_idx * num_experts_per_rank) | \ + (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)] = -1 + num_recv_tokens = (gathered_topk_idx != -1).sum().item() + + # Benchmark + t_fused = bench_kineto( + run_fused, 'mega_moe', + barrier=lambda: ep_buffer.barrier(use_comm_stream=False) if ep_buffer else dist.barrier(), + trace_path=None if not args.dump_profile_traces else f'{args.dump_profile_traces}/mega_moe_rank{rank_idx}.json') + t_baseline = tilelang_bench(run_baseline, _n_warmup=5, _n_repeat=1, backend='cudagraph', return_mode='median') / 1e3 if is_legacy_loaded else 0 + + # TFLOPS: 3 matmuls (L1 left, L1 right, L2), each 2 * M * N * K + safe_div = lambda a, b: float('nan') if b == 0 else a / b + tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused) + + # HBM bytes: weights (FP4 packed = 0.5 bytes) + activations (FP8 = 1 byte) + output (BF16 = 2 bytes) + num_touched_experts = torch.unique(gathered_topk_idx.flatten()).numel() - 1 # NOTES minus 1 to exclude "-1" + num_hbm_bytes = ( + num_touched_experts * intermediate_hidden * 2 * hidden // 2 + # L1 weights (FP4) + num_touched_experts * hidden * intermediate_hidden // 2 + # L2 weights (FP4) + num_recv_tokens * hidden + # L1 acts read (FP8) + num_recv_tokens * intermediate_hidden + # L1 output write (FP8) + num_recv_tokens * intermediate_hidden + # L2 acts read (FP8) + num_recv_tokens * hidden * 2 # L2 output write (BF16) + ) + hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused) + + # NVLink bytes: dispatch pull + combine write-back + num_nvlink_bytes = num_recv_tokens * hidden * 3 + nvlink_gbs = safe_div(num_nvlink_bytes / 1e9, t_fused) + + # Combine reduction (serial) time approximation + t_reduction = num_tokens * hidden * 2 * (1 + num_topk) / 6.5e12 + + # Summary + approx_factor = t_fused / (t_fused - t_reduction) + dist_print('Performance:', once_in_node=True) + dist_print(f' > EP: {rank_idx:2}/{num_ranks} | ' + f'{tflops:4.0f} TFLOPS | ' + f'overlap: ' + f'{tflops * approx_factor:4.0f} TFLOPS, ' + f'HBM {hbm_gbs * approx_factor:4.0f} GB/s, ' + f'NVL {nvlink_gbs * approx_factor:3.0f} GB/s | ' + f'{t_fused * 1e6:4.0f} us, ' + f'reduction: {t_reduction * 1e6:4.1f} us | ' + f'{safe_div(t_baseline, t_fused):.2f}x legacy') + + # Exit + dist.barrier() + buffer.destroy() + ep_buffer.destroy() if is_legacy_loaded else None + dist.destroy_process_group() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Test PyTorch symmetric memory') + + # Resource settings + parser.add_argument('--ncu-profile-only', action='store_true', help='Only run profiling without correctness test') + parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') + + # Model settings + parser.add_argument('--num-max-tokens-per-rank', type=int, default=8192, help='Number of maximum tokens per rank') + parser.add_argument('--num-tokens', type=int, default=0, help='Number of tokens per rank (follow max minus removed if 0)') + parser.add_argument('--num-max-removed-tokens', type=int, default=0, help='Maximum number of tokens to remove') + parser.add_argument('--hidden', type=int, default=7168, help='Hidden size') + parser.add_argument('--intermediate-hidden', type=int, default=3072, help='Intermediate hidden size') + parser.add_argument('--activation-clamp', type=float, default=10, help='Clamp value for activation') + parser.add_argument('--num-experts', type=int, default=384, help='Number of experts') + parser.add_argument('--num-topk', type=int, default=6, help='Number of expert selections') + parser.add_argument('--masked-ratio', type=float, default=0.0, help='Mask some expert selections') + parser.add_argument('--fast-math', type=int, default=1, help='Enable fast math (0 or 1, default: 1)') + + # Test settings + parser.add_argument('--num-correctness-tests', type=int, default=None, help='Pressure test') + parser.add_argument('--dump-profile-traces', type=str, default='', help='Dump profiling trace JSONs') + parser.add_argument('--local-rank-idx', type=int, default=None, help='Run as single process with this local rank (e.g. for NCU prof)') + args = parser.parse_args() + + # Create dump trace directories + if args.dump_profile_traces: + os.makedirs(args.dump_profile_traces, exist_ok=True) + + if args.local_rank_idx is not None: + # Single-process mode: each process is launched separately (e.g. by NCU) + test(args.local_rank_idx, args.num_processes, args) + else: + # Launch tests + num_processes = args.num_processes + torch.multiprocessing.spawn(test, args=(num_processes, args), nprocs=num_processes) diff --git a/deep-gemm/tests/test_sanitizer.py b/deep-gemm/tests/test_sanitizer.py index b063e6c4..75ab10e6 100644 --- a/deep-gemm/tests/test_sanitizer.py +++ b/deep-gemm/tests/test_sanitizer.py @@ -21,7 +21,7 @@ torch.manual_seed(0) random.seed(0) -from tests.{module_name} import {func_name} +from {module_name} import {func_name} {func_name}() """ @@ -40,7 +40,7 @@ else: # Get all test functions except those related to cuBLAS files = [f for f in os.listdir(script_dir) if f.endswith('.py')] - exclude_files = ['test_sanitizer.py', 'generators.py'] + exclude_files = ['test_sanitizer.py', 'generators.py', 'test_mega_moe.py'] funcs = [ (module_name, name) for module_name in [os.path.splitext(f)[0] for f in files if f not in exclude_files] @@ -53,6 +53,7 @@ env['CUDA_LAUNCH_BLOCKING'] = '1' env['DG_JIT_PTXAS_CHECK'] = '1' env['DG_USE_NVIDIA_TOOLS'] = '1' + env['DG_USE_TEMP_CUBLASLT_WORKSPACE'] = '1' # Avoid holding CUDA tensor that crashes during shutdown env['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1' env['TORCH_SHOW_CPP_STACKTRACES'] = '1' diff --git a/deep-gemm/torch-ext/deep_gemm/_C.py b/deep-gemm/torch-ext/deep_gemm/_C.py new file mode 100644 index 00000000..8f2fd6df --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/_C.py @@ -0,0 +1,194 @@ +import torch + +from ._ops import ops + + +def set_num_sms(num_sms: int): + ops.set_num_sms(num_sms) + + +def get_num_sms() -> int: + return ops.get_num_sms() + + +def set_tc_util(tc_util: int): + ops.set_tc_util(tc_util) + + +def get_tc_util() -> int: + return ops.get_tc_util() + + +def set_ignore_compile_dims(value: bool): + ops.set_ignore_compile_dims(value) + + +def set_block_size_multiple_of(value): + if isinstance(value, tuple): + block_m, block_n = value + else: + block_m = block_n = value + ops.set_block_size_multiple_of(block_m, block_n) + + +def set_pdl(enable_pdl: bool): + ops.set_pdl(enable_pdl) + + +def get_pdl() -> bool: + return ops.get_pdl() + + +def set_mk_alignment_for_contiguous_layout(value: int): + ops.set_mk_alignment_for_contiguous_layout(value) + + +def get_mk_alignment_for_contiguous_layout() -> int: + return ops.get_mk_alignment_for_contiguous_layout() + + +def get_theoretical_mk_alignment_for_contiguous_layout(expected_m=None) -> int: + return ops.get_theoretical_mk_alignment_for_contiguous_layout( + 0 if expected_m is None else expected_m, + expected_m is not None, + ) + + +def get_tma_aligned_size(mn: int, element_size: int) -> int: + return ops.get_tma_aligned_size(mn, element_size).item() + + +def get_mn_major_tma_aligned_tensor(sf): + return ops.get_mn_major_tma_aligned_tensor(sf) + + +def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): + return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + + +def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( + sf, ks_tensor, ks, gran_k +): + ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu") + return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( + sf, ks_tensor, ks_int, gran_k + ) + + +def transform_sf_into_required_layout( + sf, + mn, + k, + recipe, + num_groups=None, + is_sfa=None, + disable_ue8m0_cast=False, +): + if len(recipe) == 3: + r0, r1, r2 = recipe + recipe_len = 3 + elif len(recipe) == 2: + r0, r1 = recipe + r2 = 0 + recipe_len = 2 + else: + raise ValueError("recipe must have length 2 or 3") + + return ops.transform_sf_into_required_layout( + sf, + mn, + k, + r0, + r1, + r2, + recipe_len, + 0 if num_groups is None else num_groups, + num_groups is not None, + False if is_sfa is None else is_sfa, + is_sfa is not None, + disable_ue8m0_cast, + ) + + +def get_token_alignment_for_mega_moe() -> int: + return ops.get_token_alignment_for_mega_moe() + + +def get_symm_buffer_size_for_mega_moe( + num_ranks, + num_experts, + num_max_tokens_per_rank, + num_topk, + hidden, + intermediate_hidden, + use_fp8_dispatch=True, + activation="swiglu", +): + num_bytes = ops.get_symm_buffer_size_for_mega_moe( + num_ranks, + num_experts, + num_max_tokens_per_rank, + num_topk, + hidden, + intermediate_hidden, + use_fp8_dispatch, + activation, + ) + + def slice_input_buffers(buffer): + return tuple( + ops.get_symm_buffer_views_for_mega_moe( + buffer, + num_ranks, + num_experts, + num_max_tokens_per_rank, + num_topk, + hidden, + intermediate_hidden, + use_fp8_dispatch, + activation, + ) + ) + + return num_bytes, slice_input_buffers + + +def fp8_fp4_mega_moe( + y, + l1_weights, + l2_weights, + cumulative_local_expert_recv_stats, + sym_buffer, + sym_buffer_ptrs, + rank_idx, + num_max_tokens_per_rank, + num_experts, + num_topk, + recipe, + activation, + activation_clamp, + fast_math, +): + l1_weights_data, l1_weights_sf = l1_weights + l2_weights_data, l2_weights_sf = l2_weights + r0, r1, r2 = recipe + ops.fp8_fp4_mega_moe( + y, + l1_weights_data, + l1_weights_sf, + l2_weights_data, + l2_weights_sf, + cumulative_local_expert_recv_stats, + sym_buffer, + sym_buffer_ptrs, + rank_idx, + num_max_tokens_per_rank, + num_experts, + num_topk, + r0, + r1, + r2, + activation, + activation_clamp, + fast_math, + ) diff --git a/deep-gemm/torch-ext/deep_gemm/__init__.py b/deep-gemm/torch-ext/deep_gemm/__init__.py index 8f0a7f80..d3acc4db 100644 --- a/deep-gemm/torch-ext/deep_gemm/__init__.py +++ b/deep-gemm/torch-ext/deep_gemm/__init__.py @@ -3,10 +3,10 @@ import torch # Import the compiled extension -from ._ops import ops, add_op_namespace_prefix +from ._ops import ops as _ops, add_op_namespace_prefix from . import utils -__version__ = "2.3.0" +__version__ = "2.5.0" # ── Register fake tensor implementations for torch.compile ────────────────── @@ -32,6 +32,7 @@ "m_grouped_bf16_gemm_nn_contiguous", "m_grouped_bf16_gemm_nt_masked", "fp8_gemm_nt_skip_head_mid", + "fp8_fp4_mega_moe", ]: @torch.library.register_fake(add_op_namespace_prefix(_op)) @@ -58,10 +59,41 @@ def get_tc_util() -> int: return ops.get_tc_util() +def set_ignore_compile_dims(value: bool): + ops.set_ignore_compile_dims(value) + + +def set_block_size_multiple_of(value): + if isinstance(value, tuple): + block_m, block_n = value + else: + block_m = block_n = value + ops.set_block_size_multiple_of(block_m, block_n) + + +def set_pdl(enable_pdl: bool): + ops.set_pdl(enable_pdl) + + +def get_pdl() -> bool: + return ops.get_pdl() + + +def set_mk_alignment_for_contiguous_layout(alignment: int): + ops.set_mk_alignment_for_contiguous_layout(alignment) + + def get_mk_alignment_for_contiguous_layout() -> int: return ops.get_mk_alignment_for_contiguous_layout() +def get_theoretical_mk_alignment_for_contiguous_layout(expected_m=None) -> int: + return ops.get_theoretical_mk_alignment_for_contiguous_layout( + 0 if expected_m is None else expected_m, + expected_m is not None, + ) + + # Layout utilities @@ -77,10 +109,12 @@ def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) -def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks): +def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( + sf, ks_tensor, ks, gran_k +): ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu") return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( - sf, ks_tensor, ks_int + sf, ks_tensor, ks_int, gran_k ) @@ -88,16 +122,20 @@ def transform_sf_into_required_layout( sf, mn, k, - recipe=None, - recipe_ab=None, + recipe, num_groups=None, - is_sfa=False, + is_sfa=None, disable_ue8m0_cast=False, ): - has_recipe = recipe is not None - r0, r1, r2 = recipe if has_recipe else (0, 0, 0) - has_recipe_ab = recipe_ab is not None - rab0, rab1 = recipe_ab if has_recipe_ab else (0, 0) + if len(recipe) == 3: + r0, r1, r2 = recipe + recipe_len = 3 + elif len(recipe) == 2: + r0, r1 = recipe + r2 = 0 + recipe_len = 2 + else: + raise ValueError("recipe must have length 2 or 3") has_ng = num_groups is not None ng = num_groups if has_ng else 0 return ops.transform_sf_into_required_layout( @@ -107,13 +145,11 @@ def transform_sf_into_required_layout( r0, r1, r2, - has_recipe, - rab0, - rab1, - has_recipe_ab, + recipe_len, ng, has_ng, - is_sfa, + False if is_sfa is None else is_sfa, + is_sfa is not None, disable_ue8m0_cast, ) @@ -593,8 +629,37 @@ def fp8_mqa_logits( ) -def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms): - return ops.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms) +def fp8_fp4_mqa_logits( + q, + kv, + weights, + cu_seq_len_k_start, + cu_seq_len_k_end, + clean_logits=True, + max_seqlen_k=0, + logits_dtype=torch.float32, +): + if isinstance(q, tuple): + q_data, q_sf = q + else: + q_data, q_sf = q, None + kv_data, kv_sf = kv + return ops.fp8_fp4_mqa_logits( + q_data, + q_sf, + kv_data, + kv_sf, + weights, + cu_seq_len_k_start, + cu_seq_len_k_end, + clean_logits, + max_seqlen_k, + logits_dtype, + ) + + +def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms, indices=None): + return ops.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms, indices) def fp8_paged_mqa_logits( @@ -606,6 +671,7 @@ def fp8_paged_mqa_logits( schedule_meta, max_context_len, clean_logits=False, + indices=None, ): return ops.fp8_paged_mqa_logits( q, @@ -616,6 +682,38 @@ def fp8_paged_mqa_logits( schedule_meta, max_context_len, clean_logits, + indices, + ) + + +def fp8_fp4_paged_mqa_logits( + q, + kv_cache, + weights, + context_lens, + block_table, + schedule_meta, + max_context_len, + clean_logits=False, + logits_dtype=torch.float32, + indices=None, +): + if isinstance(q, tuple): + q_data, q_sf = q + else: + q_data, q_sf = q, None + return ops.fp8_fp4_paged_mqa_logits( + q_data, + q_sf, + kv_cache, + weights, + context_lens, + block_table, + schedule_meta, + max_context_len, + clean_logits, + logits_dtype, + indices, ) @@ -642,6 +740,14 @@ def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None): ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns, has_ns) +from .mega import ( + SymmBuffer, + get_symm_buffer_for_mega_moe, + transform_weights_for_mega_moe, + fp8_fp4_mega_moe, +) + + # Initialize the C++ runtime @@ -703,8 +809,21 @@ def _ensure_initialized(): global _initialized if _initialized: return + _ops.init(_lib_root, _find_cuda_home()) _initialized = True - ops.init(_lib_root, _find_cuda_home()) + + +class _InitializedOps: + def __init__(self, raw_ops): + self._raw_ops = raw_ops + + def __getattr__(self, name): + if name != "init": + _ensure_initialized() + return getattr(self._raw_ops, name) + + +ops = _InitializedOps(_ops) # Try to initialize eagerly, but don't fail if CUDA is not found diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/comm/barrier.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/comm/barrier.cuh new file mode 100644 index 00000000..eb9858d8 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/comm/barrier.cuh @@ -0,0 +1,83 @@ +#pragma once + +#include + +#include +#include +#include + +namespace deep_gemm::comm { + +CUTLASS_DEVICE void cluster_sync_with_relaxed_arrive() { + // Perform cluster_sync with `barrier.cluster.arrive.relaxed` + // This is slightly faster than `cute::cluster_sync` but has weaker memory ordering guarantee + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); +} + +template +CUTLASS_DEVICE void grid_sync(const layout::Workspace& workspace, + const uint32_t& sm_idx, const uint32_t& thread_idx, + const sync_scope_t& sync_scope) { + // NOTES: the implementation idea is from `cooperative_groups::this_grid().sync()` + static constexpr uint32_t kFinishSumTag = 0x80000000u; + sync_scope(); + if (thread_idx == 0) { + const auto count_ptr = workspace.get_grid_sync_count_ptr(); + const auto old_value = ptx::atomic_add_rel( + count_ptr, sm_idx == 0 ? (kFinishSumTag - (kNumSMs - 1)) : 1); + uint32_t new_value; + do { + new_value = ptx::ld_acq(count_ptr); + } while (((new_value ^ old_value) & kFinishSumTag) == 0); + } + sync_scope(); +} + +template +CUTLASS_DEVICE void nvlink_barrier(const layout::Workspace& workspace, + const layout::SymBuffer& sym_buffer, + const uint32_t& sm_idx, const uint32_t& thread_idx, + const sync_scope_t& sync_scope, + const bool& sync_prologue = true, + const bool& sync_epilogue = true) { + DG_STATIC_ASSERT(kNumRanks <= kNumThreads, "Insufficient threads"); + + // Grid sync before NVLink signaling + if (sync_prologue) + grid_sync(workspace, sm_idx, thread_idx, sync_scope); + + // NVLink cross-rank barrier, only SM 0 participates + if (sm_idx == 0) { + auto* counter_ptr = workspace.get_nvl_barrier_counter_ptr(); + const auto status = (*counter_ptr) & 3; + const auto signal_phase = status & 1, signal_sign = status >> 1; + auto* signal_ptr = workspace.get_nvl_barrier_signal_ptr(signal_phase); + + // Send signals to remote ranks + if (thread_idx < kNumRanks) + ptx::red_add_rel_sys(sym_buffer.map(signal_ptr, thread_idx), signal_sign ? -1 : 1); + sync_scope(); + + // Update status and wait arrival (with 30s timeout, at 2 GHz) + constexpr int64_t kNumTimeoutCycles = 30ll * 2000000000ll; + if (thread_idx == 0) { + ptx::red_add(counter_ptr, 1); + const int target = signal_sign ? 0 : static_cast(kNumRanks); + const auto start_clock = clock64(); + while (ptx::ld_acq_sys(signal_ptr) != target) { + if (clock64() - start_clock >= kNumTimeoutCycles) { + printf("DeepGEMM NVLink barrier timeout (30s): rank=%d, counter=%d, signal=%d, target=%d, phase=%d, sign=%d, tag=%d\n", + sym_buffer.rank_idx, *counter_ptr, ptx::ld_acq_sys(signal_ptr), target, signal_phase, signal_sign, kTag); + DG_DEVICE_ASSERT(false and "NVLink barrier timeout"); + } + } + } + } + + // Grid sync after NVLink completion + if (sync_epilogue) + grid_sync(workspace, sm_idx, thread_idx, sync_scope); +} + +} // namespace deep_gemm::comm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/compile.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/compile.cuh new file mode 100644 index 00000000..e93c43fb --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/compile.cuh @@ -0,0 +1,18 @@ +#pragma once + +#include + +#if defined(__NVCC__) or (defined(__clang__) and defined(__CUDA__)) or defined(__CUDACC_RTC__) or defined(__CLION_IDE__) +#define DG_IN_CUDA_COMPILATION +#endif + +#if defined(__NVCC__) || (defined(__clang__) and defined(__CUDA__)) +#define CUTLASS_HOST_DEVICE_NOINLINE __device__ __host__ +#define CUTLASS_DEVICE_NOINLINE __device__ +#elif defined(__CUDACC_RTC__) +#define CUTLASS_HOST_DEVICE_NOINLINE __device__ +#define CUTLASS_DEVICE_NOINLINE __device__ +#else +#define CUTLASS_HOST_DEVICE_NOINLINE +#define CUTLASS_DEVICE_NOINLINE +#endif diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/cute_tie.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/cute_tie.cuh index cd2aace7..a3a8b62a 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/cute_tie.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/cute_tie.cuh @@ -1,5 +1,7 @@ #pragma once +#include + namespace cute { struct ignore_t { diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/exception.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/exception.cuh new file mode 100644 index 00000000..78acf747 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/exception.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include +#include + +#ifdef __CLION_IDE__ + +CUTLASS_HOST_DEVICE void host_device_printf(const char* format, ...) { + asm volatile("trap;"); +} + +#define printf host_device_printf +#endif + +#ifndef DG_DEVICE_ASSERT +#define DG_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ +} while (0) +#endif + +#ifndef DG_TRAP_ONLY_DEVICE_ASSERT +#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) \ + asm("trap;"); \ +} while (0) +#endif + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#endif + +#ifndef DG_UNIFIED_ASSERT +#ifdef DG_IN_CUDA_COMPILATION +#define DG_UNIFIED_ASSERT(cond) DG_DEVICE_ASSERT(cond) +#else +#define DG_UNIFIED_ASSERT(cond) DG_HOST_ASSERT(cond) +#endif +#endif diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/math.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/math.cuh new file mode 100644 index 00000000..0f0d2504 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/math.cuh @@ -0,0 +1,149 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm::math { + +/// Pointer operations +template +CUTLASS_HOST_DEVICE dtype_t* advance_ptr(void* ptr, const uint64_t num_bytes) { + return reinterpret_cast(static_cast(ptr) + num_bytes); +} + +/// Math functions +template +CUTLASS_HOST_DEVICE T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +CUTLASS_HOST_DEVICE T align(T a, T b) { + return (kDoCeilAlignment ? ceil_div(a, b) : (a / b)) * b; +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_align(T a, T b) { + return constexpr_ceil_div(a, b) * b; +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_gcd(T a, T b) { + return b == 0 ? a : constexpr_gcd(b, a % b); +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; +} + +template +CUTLASS_DEVICE void swap(T& a, T& b) { + T temp = a; + a = b; + b = temp; +} + +#ifdef DG_IN_CUDA_COMPILATION +CUTLASS_DEVICE float2 fma2(const float2& a, const float2& b, const float2& c) { +#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000) + return __ffma2_rn(a, b, c); +#else + return make_float2( + __fmaf_rn(a.x, b.x, c.x), + __fmaf_rn(a.y, b.y, c.y) + ); +#endif +} + +CUTLASS_HOST_DEVICE float fast_rcp(const float& x) { + float ret; + asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(ret) : "f"(x)); + return ret; +} + +/// Casting +template +CUTLASS_DEVICE int cast_into_bf16_and_pack(old_t& x, old_t& y) { + auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); + return *reinterpret_cast(&bf16x2); +} + +CUTLASS_DEVICE float fast_pow2(const int& x) { + uint32_t bits_x = (x + 127) << 23; + return *reinterpret_cast(&bits_x); +} + +CUTLASS_DEVICE int fast_log2_ceil(float x) { + const auto bits = *reinterpret_cast(&x); + const auto exp = bits >> 23; + const auto man = bits & ((1 << 23) - 1); + return exp - 127 + (man != 0); +} + +template +CUTLASS_DEVICE void get_e4m3_sf_and_sf_inv(const float2& amax, float2& sf, float2& sf_inv) { + DG_STATIC_ASSERT(kUseUE8M0, "Must use UE8M0"); + const float2 finfo_factor = {1.0 / 448.0, 1.0 / 448.0}; + const auto scaled = __fmul2_rn(amax, finfo_factor); + const auto exp_x = fast_log2_ceil(scaled.x); + const auto exp_y = fast_log2_ceil(scaled.y); + sf.x = fast_pow2(exp_x), sf_inv.x = fast_pow2(-exp_x); + sf.y = fast_pow2(exp_y), sf_inv.y = fast_pow2(-exp_y); +} + +/// Reduction +CUTLASS_DEVICE uint32_t warp_inclusive_sum(uint32_t value, const uint32_t& lane_idx) { + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const uint32_t synced = __shfl_up_sync(0xffffffff, value, offset); + if (lane_idx >= offset) + value += synced; + } + return value; +} + +// Operation functors +template struct ReduceSum { CUTLASS_DEVICE T operator()(T a, T b) const { return a + b; } }; +template struct ReduceMax { CUTLASS_DEVICE T operator()(T a, T b) const { return a > b ? a : b; } }; +template struct ReduceMin { CUTLASS_DEVICE T operator()(T a, T b) const { return a < b ? a : b; } }; +template struct ReduceAnd { CUTLASS_DEVICE T operator()(T a, T b) const { return a & b; } }; +template struct ReduceOr { CUTLASS_DEVICE T operator()(T a, T b) const { return a | b; } }; + +// Unified reduction function +template +CUTLASS_DEVICE T warp_reduce(T value, Op op) { + DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or + kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1, + "Invalid number of lanes"); + constexpr uint32_t mask = 0xffffffff; + if constexpr (kIntergroupReduce) { + if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1)); + if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16)); + } else { + if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16)); + if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1)); + } + return value; +} + +// Convenience aliases +template +CUTLASS_DEVICE T warp_reduce_sum(T value) { + return warp_reduce(value, ReduceSum{}); +} +#endif + +} // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/tma_copy.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/tma_copy.cuh new file mode 100644 index 00000000..2c5bf708 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/tma_copy.cuh @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include + +#include + +namespace deep_gemm::tma { + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +CUTLASS_DEVICE void +copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr, + dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx, + const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) { + DG_STATIC_ASSERT(static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL) == + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint"); + constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size(); + + if constexpr (not kIs3DTMA) { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } + #endif + } + } else { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } + #endif + } + } +} + +} // namespace deep_gemm::tma diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/types.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/types.cuh new file mode 100644 index 00000000..e07df0af --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/types.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include + +namespace deep_gemm { + +enum class MmaKind { + BF16 = 0, + MXFP8FP4 = 1, +}; + +constexpr CUTLASS_HOST_DEVICE int get_element_size(const MmaKind& mma_kind) { + switch (mma_kind) { + case MmaKind::BF16: return 2; + case MmaKind::MXFP8FP4: return 1; + default: return 0; + } +} + +enum class GemmType { + Normal = 0, + MGroupedContiguous = 1, + MGroupedMasked = 2, + KGroupedContiguous = 3, + Batched = 4, + MGroupedContiguousWithPsumLayout = 5, +}; + +constexpr CUTLASS_HOST_DEVICE bool is_m_grouped_contiguous(const GemmType& gemm_type) { + switch (gemm_type) { + case GemmType::MGroupedContiguous: return true; + case GemmType::MGroupedContiguousWithPsumLayout: return true; + default: return false; + } +} + +enum class KernelType { + Kernel1D1D = 0, + Kernel1D2D = 1, + KernelNoSF = 2 +}; + +} // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/utils.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/utils.cuh index 8fb6c2fc..3a5f7ad6 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/utils.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/utils.cuh @@ -1,167 +1,24 @@ #pragma once -#include -#include #include -#include -#include -#include "cute_tie.cuh" +#include -#ifdef __CLION_IDE__ - -__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { - asm volatile("trap;"); -} - -#define printf host_device_printf -#endif - -#ifndef DG_DEVICE_ASSERT -#define DG_DEVICE_ASSERT(cond) \ -do { \ - if (not (cond)) { \ - printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ - asm("trap;"); \ - } \ -} while (0) -#endif - -#ifndef DG_TRAP_ONLY_DEVICE_ASSERT -#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \ -do { \ - if (not (cond)) \ - asm("trap;"); \ -} while (0) -#endif - -#ifndef DG_STATIC_ASSERT -#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) -#endif - -namespace deep_gemm { +namespace deep_gemm::utils { template struct PatternVisitor { FuncT func; - __device__ __host__ + CUTLASS_HOST_DEVICE explicit PatternVisitor(FuncT&& func): func(std::forward(func)) {} - __device__ __host__ - auto operator [](const uint32_t& i) { + CUTLASS_HOST_DEVICE + auto operator [](const uint32_t& i) const { return func(i); } }; -template -__device__ __host__ T ceil_div(T a, T b) { - return (a + b - 1) / b; -} - -template -__device__ __host__ constexpr T constexpr_ceil_div(T a, T b) { - return (a + b - 1) / b; -} - -template -__device__ __host__ T align(T a, T b) { - return ceil_div(a, b) * b; -} - -template -__device__ __host__ constexpr T constexpr_align(T a, T b) { - return constexpr_ceil_div(a, b) * b; -} - -template -__device__ __host__ constexpr T constexpr_gcd(T a, T b) { - return b == 0 ? a : constexpr_gcd(b, a % b); -} - -template -__forceinline__ __device__ void swap(T& a, T& b) { - T temp = a; - a = b; - b = temp; -} - -__forceinline__ __device__ uint32_t get_sm_idx() { - uint32_t sm_idx; - asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx)); - return sm_idx; -} - -__forceinline__ __device__ uint32_t get_lane_idx() { - uint32_t lane_id; - asm ("mov.u32 %0, %laneid;" : "=r"(lane_id)); - return lane_id; -} - -__device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) { - uint32_t ret; - asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ float2 ld_shared(const float2* ptr) { - float2 ret; - asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ float4 ld_shared(const float4* ptr) { - float4 ret; - asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ uint4 ld_shared(const uint4* ptr) { - uint4 ret; - asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ float ld_shared(const float* ptr) { - float ret; - asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr))); - return ret; -} - -__device__ __forceinline__ void st_shared(const float* ptr, float val) { - asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val)); -} - -__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) { - asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y)); -} - -__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { - asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val)); -} - -__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) { - asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y)); -} - -__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { - asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w)); -} - -__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) { - asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); -} - -template -__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) { - auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); - return *reinterpret_cast(&bf16x2); -} - -__device__ __forceinline__ void prefetch_l1(void *ptr) { - asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); -} - template struct Vectorized { static auto zeros() { @@ -180,4 +37,14 @@ struct Vectorized { using vec_t = decltype(zeros()); }; -} // namespace `deep_gemm` +template +CUTLASS_DEVICE constexpr uint32_t get_num_aligned_tmem_cols() { + DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns"); + if constexpr (kNumCols <= 32) return 32; + if constexpr (kNumCols <= 64) return 64; + if constexpr (kNumCols <= 128) return 128; + if constexpr (kNumCols <= 256) return 256; + return 512; +} + +} // namespace deep_gemm::utils diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd.cuh new file mode 100644 index 00000000..bf0e460c --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd.cuh @@ -0,0 +1,137 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace deep_gemm::epilogue { + +template +CUTLASS_DEVICE void +sm100_store_cd(const utils::PatternVisitor& smem_cd, uint32_t& tma_stage_idx, + const uint32_t& tmem_base_addr, + const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx, + const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx, + const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier, + const cute::TmaDescriptor& tensor_map_cd) { + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Share store pipeline between blocks + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Iterate over M waves + constexpr auto kNumMWaves = BLOCK_M / STORE_BLOCK_M; + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + auto smem_base_ptr = reinterpret_cast(smem_cd[tma_stage_idx]); + + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // The pipeline stage + const auto m_idx = base_m_idx + w * STORE_BLOCK_M; + const auto n_idx = epilogue_type_t::apply_index_n(base_n_idx + s * STORE_BLOCK_N); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = tmem_base_addr + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = smem_base_ptr + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + ptx::st_shared( + smem_ptr, + math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]) + ); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barrier->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx, batch_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx); + } + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +} + +} // namespace deep_gemm::epilogue diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh new file mode 100644 index 00000000..f3f5351e --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh @@ -0,0 +1,144 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace deep_gemm::epilogue { + +template +CUTLASS_DEVICE void +sm100_store_cd_swap_ab(const utils::PatternVisitor& smem_cd, uint32_t& tma_stage_idx, + const uint32_t& tmem_base_addr, + const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx, + const uint32_t& effective_m, + const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx, + const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier, + const cute::TmaDescriptor& tensor_map_cd) { + // NOTES: The epilogue requires a full warpgroup to read all 128 TMEM rows, + // implying STORE_BLOCK_N must be 128. + DG_STATIC_ASSERT(STORE_BLOCK_N == 128, "STORE_BLOCK_N must be 128 to match TMEM rows"); + + // TMA checks + constexpr uint32_t STORE_BLOCK_N_ATOM = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumSwizzleAtomRows = 8; + DG_STATIC_ASSERT(kSwizzleCDMode == 128, "TMA D must be 128B swizzled"); + DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(STORE_BLOCK_M % kNumSwizzleAtomRows == 0, "Invalid swizzling"); + DG_STATIC_ASSERT(STORE_BLOCK_N % STORE_BLOCK_N_ATOM == 0, "Invalid swizzling"); + + // Share store pipeline between blocks + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Iterate over M blocks + const auto num_stores = effective_m / STORE_BLOCK_M; + for (uint32_t s = 0; s < num_stores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_M / kNumSwizzleAtomRows; ++ i) { + uint32_t tmem_addr = tmem_base_addr + + s * STORE_BLOCK_M + // Store stage offset + i * kNumSwizzleAtomRows; // In-block offset + uint32_t values[kNumSwizzleAtomRows]; + + // Warps cooperatively write an atomic block to shared memory + DG_STATIC_ASSERT(STORE_BLOCK_N_ATOM % 32 == 0, "Invalid block sizes"); + constexpr uint32_t kNumWarpsPerAtom = STORE_BLOCK_N_ATOM / 32; + uint32_t outer_atom_offset = (epilogue_warp_idx / kNumWarpsPerAtom) * STORE_BLOCK_M * kSwizzleCDMode; + uint32_t inner_atom_offset = i * kNumSwizzleAtomRows * kSwizzleCDMode; + auto smem_base_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + outer_atom_offset + inner_atom_offset; + + if constexpr (cute::is_same_v) { + // NOTES: Swizzling is not required in this case, but used here for consistency with other cases + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + uint32_t col = lane_idx / 4; + + #pragma unroll + for (uint32_t row = 0; row < kNumSwizzleAtomRows; ++ row) { + auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes + + (lane_idx % 4) * sizeof(float); + ptx::st_shared(reinterpret_cast(smem_ptr), values[row]); + } + } else { + // Load from TMEM using `.16x256b` shape to satisfy STSM layout requirements + // Start from lane index 0 + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + // Start from lane index 16 + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Destination shared memory address + uint32_t row = lane_idx % 8; + uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8; + auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes; + + // Store matrix with transposition + ptx::SM90_U32x4_STSM_T::copy(math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]), + smem_ptr); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (s == num_stores - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barrier->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / STORE_BLOCK_N_ATOM; ++ i) { + auto smem_ptr = smem_cd[tma_stage_idx] + i * STORE_BLOCK_M * STORE_BLOCK_N_ATOM; + uint32_t m_idx = base_m_idx + s * STORE_BLOCK_M; + uint32_t n_idx = epilogue_type_t::apply_index_n(base_n_idx + i * STORE_BLOCK_N_ATOM); + + // Issue 2D or 3D TMA store + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx, batch_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx); + } + } + cute::tma_store_arrive(); + } + __syncwarp(); + } +} + +} // namespace deep_gemm::epilogue diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/epilogue/transform.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/epilogue/transform.cuh new file mode 100644 index 00000000..0266f4d4 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/epilogue/transform.cuh @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace deep_gemm::epilogue::transform { + +struct EpilogueIdentity { + template + CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) { + return n_idx; + } +}; + +template +struct EpilogueHeadSplits: EpilogueIdentity { + template + CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) { + DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0 and + kRight % STORE_BLOCK_N == 0, "Invalid head splits config"); + return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid; + } +}; + +} // namespace deep_gemm::epilogue::transform diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh index 0227b3e8..a60e2de8 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -4,14 +4,18 @@ #include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm100; - template -__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) sm100_bf16_gemm_impl(int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, @@ -48,41 +53,31 @@ sm100_bf16_gemm_impl(int* grouped_layout, if constexpr (kWithAccumulation) DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); - // Configs + // MMA Configs constexpr uint32_t LAYOUT_AD_M = 128; - constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); - constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; - constexpr uint32_t kNumTMAStoreStages = 2; - DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K"); - DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); - DG_STATIC_ASSERT(sizeof(cutlass::bfloat16_t) * LAYOUT_AD_M % kSwizzleAMode == 0, "Invalid swizzle A mode"); - - // Overwrite shape constants if the compiler gives - shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; - shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; - shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; - - // Utils - bool is_leader_cta = cute::block_rank_in_cluster() == 0; - const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); - - // Align to 1024 bytes for swizzle-128B - extern __shared__ __align__(1024) uint8_t smem_buffer[]; - - // 2-CTA MMA + constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast; + constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N; + constexpr uint32_t UMMA_K = 16; constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); - constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); - constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); - constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; - DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); - DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K"); DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or + (not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size"); + + // Epilogue configs + // Always enable pipeline for better performance + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + // NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N + // per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases + constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M; DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); // Share memory sizes - constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t); constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); @@ -91,41 +86,54 @@ sm100_bf16_gemm_impl(int* grouped_layout, DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); // NOTES: Make sure we have enough shared memory for UMMA padding - static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16); - DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); - - // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size - // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` - constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2; + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory out of bound for UMMA"); // Real tensor memory size and offsets - constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; - constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * UMMA_N; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Synchronize the cluster before 2-CTA TMEM allocation + kNumMulticast > 1 ? cute::cluster_sync() : void(); + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors at the very beginning - if (warp_idx == 0 and cute::elect_one_sync()) { + if (warp_idx == 0) { cute::prefetch_tma_descriptor(&tensor_map_a); cute::prefetch_tma_descriptor(&tensor_map_b); cute::prefetch_tma_descriptor(&tensor_map_cd); } + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + // D/A/B shared memory - auto smem_cd = PatternVisitor([&](const uint32_t& i) { + auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); }); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); - auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); - auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2; // Fill the tensor memory pointer @@ -159,9 +167,13 @@ sm100_bf16_gemm_impl(int* grouped_layout, } kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler( + shape_m, shape_n, shape_k, grouped_layout); // Pipeline and TMA phases uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0; @@ -178,16 +190,20 @@ sm100_bf16_gemm_impl(int* grouped_layout, // TMA load warp // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + // Use dynamic load block M, when swap-AB is enabled + const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M; + + // For k-grouped layout, the number of block K is variable + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { // Wait consumer release empty_barriers[stage_idx]->wait(phase ^ 1); // Compute offsets // NOTES: the group is always concatenated with the outer dimension - uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> ( + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> ( shape_m, BLOCK_M, m_block_idx); - uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> ( + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> ( shape_n, BLOCK_N, n_block_idx, m_block_idx); // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major @@ -195,14 +211,14 @@ sm100_bf16_gemm_impl(int* grouped_layout, DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or kMajorA == cute::UMMA::Major::K, "Invalid major"); uint32_t k_idx = k_block_idx * BLOCK_K; - uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); - uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); // Add 2 CTA offsets if constexpr (kNumMulticast > 1) { - m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0; n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); } @@ -210,16 +226,16 @@ sm100_bf16_gemm_impl(int* grouped_layout, constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); if constexpr (kMajorA == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx); if constexpr (kMajorA == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx); // Arrive at full barriers @@ -235,17 +251,16 @@ sm100_bf16_gemm_impl(int* grouped_layout, // MMA issue warp // NOTES: only the leader CTA will do this // Make instruction descriptor - // TODO: refactor `UMMA_M` calculation - constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); - constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); - constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); - auto instr_desc = cute::UMMA::make_instr_desc(); + auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc() + : cute::UMMA::make_instr_desc(); DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); // Merged stages only happens in NT normal GEMM cases constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; - auto a_desc = make_umma_desc(smem_a[0], 0, 0); - auto b_desc = make_umma_desc(smem_b[0], 0, 0); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; @@ -262,7 +277,7 @@ sm100_bf16_gemm_impl(int* grouped_layout, auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // UMMA and empty barrier arrival alias auto umma_arrive = [](const uint64_t* barrier) { @@ -279,36 +294,45 @@ sm100_bf16_gemm_impl(int* grouped_layout, // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting if (do_tmem_full_arrive) umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); }; + // Dynamic update of UMMA N based on effective M, when swap-AB is enabled + if constexpr (kSwapAB) { + uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx); + mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n); + } + // Launch MMAs - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { // Wait TMA arrival full_barriers[stage_idx]->wait(phase); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Issue UMMA in the leader CTA - using mma_t = cute::conditional_t; - const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); - const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); - const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + using mma_t = cute::conditional_t; + const auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); if (cute::elect_one_sync()) { #pragma unroll for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K; - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); - #pragma unroll - for (uint32_t w = 0; w < kNumMWaves; ++ w) { - DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); - a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K + w * WAVE_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); - mma_t::fma(a_desc, b_desc, - accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, - k_block_idx > 0 or k > 0, - runtime_instr_desc); + a_desc.lo = mma::sm100::advance_umma_desc_lo( + a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo( + b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + if (kSwapAB) { + mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc); + } else { + mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc); } } } + __syncwarp(); // Commit to the mbarrier object // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` @@ -319,15 +343,16 @@ sm100_bf16_gemm_impl(int* grouped_layout, if constexpr (kTensorCoreUtilControl < 100) { // For utilization control umma_arrive(reinterpret_cast(tensor_core_full_barrier)); + __syncwarp(); // Wait for last UMMA to be done tensor_core_full_barrier->wait(tensor_core_phase); tensor_core_phase ^= 1; // Sleep for certain cycles - constexpr static uint64_t kNumUMMACycles = (2ull * LAYOUT_AD_M * kNumMWaves * BLOCK_N * BLOCK_K) / 8192ull; + constexpr static uint64_t kNumUMMACycles = (2ull * UMMA_M * UMMA_N * BLOCK_K) / 8192ull; constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl; - const auto& start_clock = clock64(); + const auto start_clock = clock64(); if (cute::elect_one_sync()) while (clock64() - start_clock < kNumDummyCycles) {} __syncwarp(); @@ -336,9 +361,9 @@ sm100_bf16_gemm_impl(int* grouped_layout, } // To safely deconstruct barriers, we need another round of waits - const auto& iter_idx = scheduler.current_iter - 1; + const auto iter_idx = scheduler.current_iter - 1; if (kNumMulticast > 1 and iter_idx >= 0) { - const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); } } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { @@ -348,19 +373,10 @@ sm100_bf16_gemm_impl(int* grouped_layout, // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. // NOTES: we also forbid two CTAs to share the same SM and its tensor memory - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); - - // TMA checks - constexpr uint32_t kNumBankGroupBytes = 16; - constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); - DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); - DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // Share store pipeline between blocks uint32_t tma_stage_idx = 0; - auto advance_store_pipeline = [&]() { - tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; - }; // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { @@ -369,108 +385,47 @@ sm100_bf16_gemm_impl(int* grouped_layout, // Wait UMMA arrival tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Load from tensor memory into registers, and write shared memory with STSM - DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); - DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); - - // Iterate over M waves - #pragma unroll - for (uint32_t w = 0; w < kNumMWaves; ++ w) { - // Issue every swizzled atom and pipeline STSM and TMA store - constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; - #pragma unroll - for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { - // Wait shared memory to be released - if (epilogue_warp_idx == 0) - cute::tma_store_wait(); - cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); - - // The pipeline stage - const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; - const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; - - // Store into shared memory - #pragma unroll - for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { - // Calculate the index of the bank group to be written in the atom - auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); - - // Reshape the atom in another view and swizzle - // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` - // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` - // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern - constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; - auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); - auto col = kHasShortcut ? (i) : (bank_group_index % 8); - col ^= row % (kSwizzleCDMode / 16); - - // Source and destination memory address - uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset - w * BLOCK_N + // Wave offset - s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset - auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer - epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset - row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset - - // Load from tensor memory, store into shared memory - uint32_t values[kNumElemsPerBankGroup]; - if constexpr (cute::is_same_v) { - // For FP32 output, read and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); - cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, - values[0], values[1], values[2], values[3]); - cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, values[0], values[1], values[2], values[3]); - } else { - // For BF16 output, read, cast and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); - cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, - values[0], values[1], values[2], values[3], - values[4], values[5], values[6], values[7]); - cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, - cast_into_bf16_and_pack(values[0], values[1]), - cast_into_bf16_and_pack(values[2], values[3]), - cast_into_bf16_and_pack(values[4], values[5]), - cast_into_bf16_and_pack(values[6], values[7])); - } - } - - // Notify tensor memory empty (only at the leader CTA) arrival ASAP - // NOTES: only the last stage needs to do this - if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { - tcgen05_before_thread_sync(); - tmem_empty_barriers[accum_stage_idx]->arrive(0u); - } - __syncwarp(); - - // Synchronize all threads and issue TMA - cute::tma_store_fence(); - cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); - if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { - if constexpr (kGemmType == GemmType::Batched) { - using cute_tma_t = cute::conditional_t; - cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], - n_idx, m_idx, scheduler.current_group_idx); - } else { - using cute_tma_t = cute::conditional_t; - cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); - } - cute::tma_store_arrive(); - } - } + const auto tmem_base_addr = accum_stage_idx * UMMA_N; + const auto base_m_idx = scheduler.template get_global_idx< + (not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto base_n_idx = n_block_idx * BLOCK_N; + + if constexpr (kSwapAB) { + const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx); + epilogue::sm100_store_cd_swap_ab + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + effective_m, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } else { + epilogue::sm100_store_cd + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); } } - - // Deallocate tensor memory by the last UMMA store warp - // NOTES: warp 0 is waiting TMA store - if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) - Allocator().free(0, kNumTmemCols); } + + // TODO: Remove redundant synchronization + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Deallocate tensor memory + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + #else if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh index 86303347..13bb0872 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh @@ -5,18 +5,19 @@ #include #include -#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm100; - template -__global__ void __launch_bounds__(kNumThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumThreads, 1) sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -30,7 +31,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // Utils const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size"); DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode"); @@ -51,24 +52,24 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, } // Real tensor memory size and offsets - constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); // Fill D/A/B - auto smem_cd = PatternVisitor([&](const uint32_t& i) { + auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE)); }); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2); // Fill the tensor memory pointer @@ -93,14 +94,17 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, __syncthreads(); // Block indices - const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); - const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M); const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; const uint32_t n_block_idx = mn_block_idx % num_n_blocks; const uint32_t m_block_idx = mn_block_idx / num_n_blocks; const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (warp_idx == 0) { // TMA load warp for (uint32_t s = 0; s < num_total_stages; ++ s) { @@ -115,8 +119,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // Issue TMAs if (cute::elect_one_sync()) { - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); + tma::copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); + tma::copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); } // Arrive at full barriers @@ -134,8 +138,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, auto instr_desc = cute::UMMA::make_instr_desc(); DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); - auto a_desc = make_umma_desc(smem_a[0], 0, 0); - auto b_desc = make_umma_desc(smem_b[0], 0, 0); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; @@ -147,14 +151,14 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, "Invalid MMA instruction shape"); // Wait tensor memory empty barrier arrival - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Launch MMAs for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait TMA arrival const auto& stage_idx = s % kNumStages; full_barriers[stage_idx]->wait((s / kNumStages) & 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Issue UMMA in the leader CTA const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); @@ -163,9 +167,11 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, if (cute::elect_one_sync()) { #pragma unroll for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { - a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, 0, k * UMMA_K); - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); - SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc); + a_desc.lo = mma::sm100::advance_umma_desc_lo( + a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo( + b_desc_base_lo, 0, k * UMMA_K); + ptx::SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc); } } @@ -180,7 +186,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`. // NOTES: we also forbid two CTAs to share the same SM and its tensor memory if (warp_idx == 2) - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // TMA checks constexpr uint32_t kNumBankGroupBytes = 16; @@ -191,7 +197,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // Wait UMMA arrival tmem_full_barrier->wait(0); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Load from tensor memory into registers, and write shared memory with STSM DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); @@ -239,7 +245,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, values[0], values[1], values[2], values[3]); cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]); } // Synchronize all threads and issue TMA diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh new file mode 100644 index 00000000..b8a99fd0 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh @@ -0,0 +1,457 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, + const uint32_t logits_stride, + const uint32_t* cu_seq_len_k_start, + const uint32_t* cu_seq_len_k_end, + logits_dtype_t* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; + + // Prefetch TMA descriptors + if (warp_idx == kSpecWarpStart) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_sf_q); + cute::prefetch_tma_descriptor(&tensor_map_weights); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_sf_kv); + } + + // UMMA configs + static constexpr uint32_t kNumTmemStages = 3; + static constexpr uint32_t kNumUTCCPAlignedElems = 128; + static constexpr uint32_t UMMA_M = 128; + static constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; + static constexpr uint32_t UMMA_K = 64; + static constexpr uint32_t kNumSFQ = math::constexpr_align(BLOCK_Q * kNumHeads, kNumUTCCPAlignedElems); + static constexpr uint32_t kNumSFKV = math::constexpr_align(BLOCK_KV, kNumUTCCPAlignedElems); + static constexpr uint32_t kRealNumSFQ = BLOCK_Q * kNumHeads; + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(BLOCK_KV == kNumMathWarpGroups * UMMA_M and BLOCK_KV % kNumUTCCPAlignedElems == 0, "Invalid `BLOCK_KV`"); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2); + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQ * sizeof(int); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q and KV data on shared memory + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i; + }); + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i; + }); + const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages); + auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i); + }); + auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i); + }); + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; + auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; }); + auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(tmem_barrier_ptr + kNumTmemStages * 2); + + // Tensor memory configs + constexpr uint32_t kNumAccumTmemCols = BLOCK_Q * kNumHeads * kNumTmemStages; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQ / 32; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Initialize barriers + if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads + 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumTmemStages; ++i) { + full_tmem_barriers[i]->init(1); + empty_tmem_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + + // Allocate tensor memory + if (warp_idx == kSpecWarpStart + 2) + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + __syncthreads(); + + // Scheduler + const uint32_t num_q_blocks = math::ceil_div(seq_len, BLOCK_Q); + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; + auto load_schedule = [&](const uint32_t& q_idx) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto row_idx = cute::min(q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = cute::min(cu_seq_len_k_start[row_idx], seq_len_kv); + seq_k_end[i] = cute::min(cu_seq_len_k_end[row_idx], seq_len_kv); + start = cute::min(start, seq_k_start[i]); + end = cute::max(end, seq_k_end[i]); + } + // TMA alignment requirements for SF KV + start = start / 4 * 4; + return {start, math::ceil_div(end - start, BLOCK_KV)}; + }; + + // Make Q, KV and TMEM pipeline + auto make_pipeline = [](const uint32_t& num_stages) { + // Return current stage and phase, and advance pipeline by steps + return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple { + uint32_t current_idx = iter_idx; + iter_idx += step; + return {current_idx % num_stages, (current_idx / num_stages) & 1}; + }; + }; + auto advance_q_pipeline = make_pipeline(kNumQStages); + auto advance_kv_pipeline = make_pipeline(kNumKVStages); + auto advance_tmem_pipeline = make_pipeline(kNumTmemStages); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 56; + constexpr uint32_t kNumMathRegisters = 224; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + if (warp_idx == kSpecWarpStart) { + // TMA warp for loading Q + cutlass::arch::warpgroup_reg_dealloc(); + + // Enumerate Q blocks + if (cute::elect_one_sync()) { + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Wait Q consumer release + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast(full_q_barriers[q_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_q[q_stage_idx], 0, q_idx * BLOCK_Q * kNumHeads); + tma::copy(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_idx * BLOCK_Q); + tma::copy(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_idx * BLOCK_Q); + full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQ * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE); + } + } + __syncwarp(); + } else if (warp_idx == kSpecWarpStart + 1) { + // TMA warp for loading KV cache + cutlass::arch::warpgroup_reg_dealloc(); + + if (cute::elect_one_sync()) { + // Enumerate Q blocks + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Load KV block ranges + CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks); + + // Enumerate KV blocks + for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) { + // Wait KV consumer release + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + cute::SM90_TMA_LOAD_2D::copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_kv[kv_stage_idx], 0, kv_start + kv_idx * BLOCK_KV); + tma::copy(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx], + smem_sf_kv[kv_stage_idx], + kv_start + kv_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE); + } + } + } + } else if (warp_idx == kSpecWarpStart + 2) { + // UMMA warp + cutlass::arch::warpgroup_reg_dealloc(); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + // Enumerate Q blocks + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Load KV block ranges + CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks); + + // Wait TMA Q arrivals + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Transpose and copy SF Q + #pragma unroll + for (uint32_t i = 0; i < kNumSFQ / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + if (cute::elect_one_sync()) + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4); + __syncwarp(); + } + + // Enumerate KV blocks + for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) { + // Wait TMA KV arrivals + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Transpose + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + } + + // UMMA with SF + if (cute::elect_one_sync()) { + // Copy SF KV + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4); + } + + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + // Wait TMEM release + CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase); + uint32_t tmem_addr = tmem_stage_idx * UMMA_N; + + empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Issue UMMA with SF + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2); + // TODO: generalize umma desc + DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim"); + auto a_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + auto b_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_q[q_stage_idx] + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + ptx::SM100_MMA_MXF4_SS::fma( + a_desc, b_desc, tmem_addr, k, runtime_instr_desc, + kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ); + } + // TODO: move this into `deep_gemm/ptx/tcgen05.cuh` + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" + ::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx]))); + } + } + cutlass::arch::umma_arrive(reinterpret_cast(empty_kv_barriers[kv_stage_idx])); + } + + // UMMA warp must also arrive on empty_q to prevent running ahead + // of math warps in the Q pipeline. Without this, UMMA can consume + // kNumQStages Q blocks before math warps release any, causing a + // circular dependency: UMMA waits full_q -> TMA_Q waits empty_q + // -> Math waits full_tmem -> UMMA (already moved on). + empty_q_barriers[q_stage_idx]->arrive(); + } + } else if (warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + // Math warpgroups for reduce + cutlass::arch::warpgroup_reg_alloc(); + + const auto math_warpgroup_idx = warpgroup_idx; + const auto math_thread_idx = threadIdx.x; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr uint32_t N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; + + // Math warpgroups process TMEM stages alternately + // Advance pipeline to align with the assigned stage + advance_tmem_pipeline(math_warpgroup_idx); + + // Local register buffers + float accum[kNumHeads]; + float weights[BLOCK_Q][kNumHeads]; + + // Enumerate Q blocks + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Load KV block ranges + CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks); + + // Wait TMA Q arrivals + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + // TODO: optimize bank conflicts + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + + // Enumerate KV blocks + for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) { + // Calculate KV offset in advance + auto kv_offset = kv_start + kv_idx * BLOCK_KV + math_thread_idx; + + // Advance pipeline by `kNumMathWarpGroups` steps + // Wait UMMA arrival + CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase); + full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase); + ptx::tcgen05_after_thread_sync(); + + // Reduce over the head dim and store + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + // Load accumulator from TMEM + uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads; + tmem_load(cute::Int{}, tmem_addr, accum); + tmem_load(cute::Int{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2); + + // Release TMEM empty + if (i == BLOCK_Q - 1) { + ptx::tcgen05_before_thread_sync(); + empty_tmem_barriers[tmem_stage_idx]->arrive(); + } + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(sum.x + sum.y); + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + // TODO: optimize performance + const auto q_offset = (q_idx * BLOCK_Q + i) * static_cast(logits_stride); + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i]) + logits[q_offset + kv_offset - seq_k_start[i]] = result; + } else { + logits[q_offset + kv_offset] = result; + } + __syncwarp(); + } + } + + // Release last Q empty + empty_q_barriers[q_stage_idx]->arrive(); + } + + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } +} + +} // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh new file mode 100644 index 00000000..d9add534 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh @@ -0,0 +1,510 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, + const uint32_t logits_stride, const uint32_t block_table_stride, + const uint32_t* context_lens, logits_dtype_t* logits, + const uint32_t* block_table, const uint32_t* indices, + const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; + + // Prefetch TMA descriptors + if (warp_idx == kSpecWarpStart) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_sf_q); + cute::prefetch_tma_descriptor(&tensor_map_weights); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_sf_kv); + } + + // For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill. + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom); + + // UMMA configs + static constexpr uint32_t kNumTmemStages = 3; + static constexpr uint32_t kNumUTCCPAlignedElems = 128; + static constexpr uint32_t UMMA_M = 128; + static constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads; + static constexpr uint32_t UMMA_K = 64; + static constexpr uint32_t kNumSFQAtom = math::constexpr_align(kNextNAtom * kNumHeads, kNumUTCCPAlignedElems); + static constexpr uint32_t kNumSFKV = math::constexpr_align(SPLIT_KV, kNumUTCCPAlignedElems); + static constexpr uint32_t kRealNumSFQAtom = kNextNAtom * kNumHeads; + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(SPLIT_KV == kNumMathWarpGroups * UMMA_M and SPLIT_KV % kNumUTCCPAlignedElems == 0, "Invalid `SPLIT_KV`"); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2); + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQAtom * sizeof(int); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q and KV data on shared memory + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i; + }); + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i; + }); + const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages); + auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i); + }); + auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i); + }); + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; + auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; }); + auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(tmem_barrier_ptr + kNumTmemStages * 2); + + // Tensor memory configs + constexpr uint32_t kNumAccumTmemCols = kNextNAtom * kNumHeads * kNumTmemStages; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQAtom / 32; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Initialize barriers + if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads + 32); + } + cutlass::arch::fence_barrier_init(); + } + if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(1); + } + cutlass::arch::fence_barrier_init(); + } + if (warp_idx == kSpecWarpStart + 2) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumTmemStages; ++i) { + full_tmem_barriers[i]->init(1); + empty_tmem_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Scheduler + constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; + using Scheduler = sched::PagedMQALogitsScheduler; + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); + + // Make Q, KV and TMEM pipeline + auto make_pipeline = [](const uint32_t& num_stages) { + // Return current stage and phase, and advance pipeline by steps + return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple { + uint32_t current_idx = iter_idx; + iter_idx += step; + return {current_idx % num_stages, (current_idx / num_stages) & 1}; + }; + }; + auto advance_q_pipeline = make_pipeline(kNumQStages); + auto advance_kv_pipeline = make_pipeline(kNumKVStages); + auto advance_tmem_pipeline = make_pipeline(kNumTmemStages); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 56; + constexpr uint32_t kNumMathRegisters = 224; + + if (warp_idx == kSpecWarpStart) { + // TMA warp for loading Q + cutlass::arch::warpgroup_reg_dealloc(); + + if (cute::elect_one_sync()) { + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + + // Persistently schedule over blocks + // Initialize outside valid range to indicate no previous task + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, _, __; + while (scheduler.fetch_next_task(q_atom_idx, _, __)) { + // Issue TMA Q when (q_idx, atom_idx) changes + if (q_atom_idx != last_q_atom_idx) { + // Wait Q consumer release + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + const auto q_token_idx = Scheduler::atom_to_token_idx(q_atom_idx); + cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast(full_q_barriers[q_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_q[q_stage_idx], 0, q_token_idx * kNumHeads); + tma::copy(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_token_idx); + tma::copy(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_token_idx); + full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQAtom * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE); + } + last_q_atom_idx = q_atom_idx; + } + } + __syncwarp(); + } else if (warp_idx == kSpecWarpStart + 1) { + // TMA warp for loading KV cache + cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + + // Persistently schedule over blocks + uint32_t kv_block_idx_ptr = 32, kv_block_idx_storage; + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, kv_idx, num_kv; + while (scheduler.fetch_next_task(q_atom_idx, kv_idx, num_kv)) { + // Reset block table cache on kv restart + if (q_atom_idx != last_q_atom_idx) + kv_block_idx_ptr = 32; + last_q_atom_idx = q_atom_idx; + + // Coalesced load of block table + if (kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast(block_table_stride); + kv_block_idx_storage = (kv_idx + lane_idx < num_kv) + ? block_table[block_table_offset + kv_idx + lane_idx] : 0; + } + __syncwarp(); + + // Broadcast KV block indices + int kv_block_idx[kNumBlocksPerSplit]; + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) + kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); + kv_block_idx_ptr += kNumBlocksPerSplit; + DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `SPLIT_KV`"); + + // Wait KV consumer release + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + + // Issue TMA KV + if (cute::elect_one_sync()) { + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim / 2) * i, + 0, 0, kv_block_idx[i]); + tma::copy(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx], + smem_sf_kv[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); + } + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE); + } + } + } else if (warp_idx == kSpecWarpStart + 2) { + // UMMA warp + cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + // Persistently schedule over blocks + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, kv_idx, _; + while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) { + // Wait TMA Q arrivals + uint32_t q_stage_idx, q_phase; + if (q_atom_idx != last_q_atom_idx) { + CUTE_TIE(advance_q_pipeline(), q_stage_idx, q_phase); + + // Release previous Q empty (UMMA warp must participate to prevent + // running ahead of math warps in the Q pipeline) + if (last_q_atom_idx != batch_size * kNumNextNAtoms) + empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive(); + + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Transpose and copy SF Q + #pragma unroll + for (uint32_t i = 0; i < kNumSFQAtom / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + if (cute::elect_one_sync()) + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4); + __syncwarp(); + } + } + last_q_atom_idx = q_atom_idx; + + // Wait TMA KV arrivals + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Transpose + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + } + + // UMMA with SF + if (cute::elect_one_sync()) { + // Copy SF KV + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4); + } + + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + // Wait TMEM release + CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase); + uint32_t tmem_addr = tmem_stage_idx * UMMA_N; + + empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Issue UMMA with SF + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2); + // TODO: generalize UMMA desc + DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim"); + auto a_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + auto b_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_q[q_stage_idx] + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + ptx::SM100_MMA_MXF4_SS::fma(a_desc, b_desc, tmem_addr, k, runtime_instr_desc, + kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ); + } + // TODO: move this PTX into headers + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" + ::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx]))); + } + } + cutlass::arch::umma_arrive(reinterpret_cast(empty_kv_barriers[kv_stage_idx])); + } + } else if (warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + // Math warpgroups for reduce + cutlass::arch::warpgroup_reg_alloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + + const auto math_warpgroup_idx = warpgroup_idx; + const auto math_thread_idx = warp_idx * 32 + lane_idx; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr int N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; + + // Math warpgroups process TMEM stages alternately + // Advance pipeline to align with the assigned stage + advance_tmem_pipeline(math_warpgroup_idx); + + // Local register buffers + float accum[kNumHeads]; + float weights[kNextNAtom][kNumHeads]; + + // Persistently schedule over blocks + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, kv_idx, _; + bool is_paired_atom = false; + while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) { + if (q_atom_idx != last_q_atom_idx) { + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + + // Release last Q empty + if (last_q_atom_idx != batch_size * kNumNextNAtoms) + empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrivals + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextNAtom; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + float4 raw = ptx::ld_shared((float4*)(smem_weights[q_stage_idx] + i * kNumHeads + j)); + weights[i][j + 0] = raw.x; + weights[i][j + 1] = raw.y; + weights[i][j + 2] = raw.z; + weights[i][j + 3] = raw.w; + } + } + + // Check if this atom pairs two tokens from the same sequence + if constexpr (kIsVarlen) { + is_paired_atom = (scheduler.get_atom_advance(q_atom_idx, batch_size) == 2); + } + } + last_q_atom_idx = q_atom_idx; + + // Calculate KV offset in advance + auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx; + + // Advance pipeline by `kNumMathWarpGroups` steps + // Wait UMMA arrival + CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase); + full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase); + ptx::tcgen05_after_thread_sync(); + + // Reduce over the head dim and store + const auto reduce_and_store = [&](auto num_iters_c) { + constexpr uint32_t kNumIters = decltype(num_iters_c)::value; + + // Only loop over valid iterations + #pragma unroll + for (uint32_t i = 0; i < kNumIters; ++ i) { + // Load accumulator from TMEM + uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads; + tmem_load(cute::Int{}, tmem_addr, accum); + tmem_load(cute::Int{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2); + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(sum.x + sum.y); + + // Store into the global memory + logits[kv_offset + i * static_cast(logits_stride)] = result; + __syncwarp(); + } + + // Release TMEM empty + ptx::tcgen05_before_thread_sync(); + empty_tmem_barriers[tmem_stage_idx]->arrive(); + }; + + if constexpr (kIsVarlen) { + if (is_paired_atom) + reduce_and_store(cute::Int{}); + else + reduce_and_store(cute::Int<1>{}); + } else if constexpr (kPadOddN) { + if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1) + reduce_and_store(cute::Int<1>{}); + else + reduce_and_store(cute::Int{}); + } else { + reduce_and_store(cute::Int{}); + } + } + + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } +} + +} // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh new file mode 100644 index 00000000..0bc6a3fe --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh @@ -0,0 +1,514 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // MMA Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast; + constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N; + constexpr uint32_t UMMA_K = 32; + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or + (not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size"); + + // SF configs + constexpr uint32_t kNumUTCCPAlignedElems = 128; + constexpr uint32_t SF_BLOCK_M = math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems); + constexpr uint32_t SF_BLOCK_N = math::constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); + constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4; + constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4; + DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A"); + DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B"); + DG_STATIC_ASSERT((kGemmType != GemmType::KGroupedContiguous) or kGranKA == kGranKB, "K-grouped SF requires kGranKA == kGranKB"); + + // Epilogue configs + // Always enable pipeline for better performance + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + // NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N + // per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases + constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M; + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t); + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); + // NOTES: Make sure we have enough shared memory for UMMA padding + constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + + // Tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Synchronize the cluster before 2-CTA TMEM allocation + kNumMulticast > 1 ? cute::cluster_sync() : void(); + + // Utils + const bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = ptx::get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + const auto shape_sfa_k = math::ceil_div(shape_k, kGranKA * 4); + const auto shape_sfb_k = math::ceil_div(shape_k, kGranKB * 4); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // D/A/B shared memory + auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SFA/SFB shared memory + auto sf_start_ptr = reinterpret_cast(smem_b[kNumStages]); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Barriers and tensor memory pointer + auto barrier_start_ptr = reinterpret_cast(smem_sfb[kNumStages]);; + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto with_sf_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + // Arrive only at the leader CTA + with_sf_full_barriers[i]->init(kNumMulticast * 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = sched::Scheduler( + shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Dispatch warps into different roles + if (warp_idx == 0 and cute::elect_one_sync()) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Use dynamic load block M, when swap-AB is enabled + const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M; + + // For k-grouped layout, the number of block K is variable + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma::copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma::copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma::copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma::copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx); + auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2) + + SMEM_B_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2); + + // Issue SFA and SFB TMAs at certain stages + // No swizzling, so one TMA for one SF is enough + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + uint32_t sfa_m_idx = m_block_idx * BLOCK_M; + uint32_t sfa_k_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::SF_K>( + shape_sfa_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad)); + tma::copy(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx); + num_arrival_bytes += BLOCK_M * sizeof(uint32_t); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + uint32_t sfb_n_idx = n_block_idx * BLOCK_N; + uint32_t sfb_k_idx = scheduler.template get_global_idx( + shape_sfb_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx); + tma::copy(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx); + num_arrival_bytes += BLOCK_N * sizeof(uint32_t); + } + + // Arrive at full barriers + full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes); + } + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc_block_scaled() + : cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); + }; + + // Dynamic update of UMMA N based on effective M, when swap-AB is enabled + if constexpr (kSwapAB) { + uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx); + mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n); + } + + // Launch MMAs + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + #pragma unroll 4 + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA and SF-transpose arrival + with_sf_full_barriers[stage_idx]->wait(phase); + ptx::tcgen05_after_thread_sync(); + + const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx); + const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + // Do SF copy at certain stages + // TODO: process shared memory descriptor by addition + using cute_utccp_t = cute::conditional_t; + const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad; + if (sfa_stage_in_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + } + const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad; + if (sfb_stage_in_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + } + + // Issue UMMA + using mma_t = cute::conditional_t< + kNumMulticast == 1, ptx::SM100_MMA_MXF8F6F4_SS, ptx::SM100_MMA_MXF8F6F4_2x1SM_SS>; + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx); + const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx); + const auto runtime_instr_desc = kSwapAB ? + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfb_id, sfa_id): + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id); + + a_desc.lo = mma::sm100::advance_umma_desc_lo(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + if constexpr (kSwapAB) { + mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFB, kTmemStartColOfSFA); + } else { + mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFA, kTmemStartColOfSFB); + } + } + } + __syncwarp(); + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } else if (warp_idx == 2) { + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + + // Transpose for UTCCP at certain stages + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + + // Arrive + with_sf_full_barriers[stage_idx]->arrive(0u); + } + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { + // Epilogue warp groups + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + ptx::tcgen05_after_thread_sync(); + + const auto tmem_base_addr = accum_stage_idx * UMMA_N; + const auto base_m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto base_n_idx = n_block_idx * BLOCK_N; + + if constexpr (kSwapAB) { + const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx); + epilogue::sm100_store_cd_swap_ab< + BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N, + kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads, + kGemmType, kWithAccumulation, + cd_dtype_t, epilogue_type_t> + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + effective_m, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } else { + epilogue::sm100_store_cd< + BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N, + kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads, + kGemmType, kWithAccumulation, + cd_dtype_t, epilogue_type_t> + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } + } + } + + // TODO: Remove redundant synchronization + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Deallocate tensor memory + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh new file mode 100644 index 00000000..b2adc6c7 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh @@ -0,0 +1,1380 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template < + uint32_t kNumMaxTokensPerRank, + uint32_t kHidden, uint32_t kIntermediateHidden, + uint32_t kNumExperts, uint32_t kNumTopk, + uint32_t kNumExpertsPerWave, + uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, + uint32_t STORE_BLOCK_M, + uint32_t SF_BLOCK_M, uint32_t SF_BLOCK_N, + uint32_t kNumMaxPoolTokens, + uint32_t kNumPaddedSFPoolTokens, + uint32_t kNumStages, + uint32_t kNumDispatchThreads, uint32_t kNumNonEpilogueThreads, + uint32_t kNumEpilogueThreads, + uint32_t kNumSMs, uint32_t kNumRanks, + float kActivationClamp, + bool kFastMath, + uint32_t L1_SHAPE_N = kIntermediateHidden * 2, + uint32_t L1_SHAPE_K = kHidden, + uint32_t L2_SHAPE_N = kHidden, + uint32_t L2_SHAPE_K = kIntermediateHidden, + uint32_t kNumDispatchWarps = kNumDispatchThreads / 32, + uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32, + uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32, + uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4, + uint32_t kNumThreads = kNumDispatchThreads + kNumNonEpilogueThreads + kNumEpilogueThreads, + uint32_t kNumTokensPerWarp = 32 / kNumTopk, + uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks +> +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void +sm100_fp8_fp4_mega_moe_impl(void* y, + int* cumulative_local_expert_recv_stats, + const uint32_t num_tokens, + const __grid_constant__ layout::SymBuffer sym_buffer, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_output, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights_sf) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::TMEM::Allocator2Sm; + + // Template checks + DG_STATIC_ASSERT(kNumDispatchThreads % 128 == 0, "Invalid number of dispatch threads"); + DG_STATIC_ASSERT(kNumNonEpilogueThreads == 128, "Invalid number of MMA non-epilogue threads"); + DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Invalid number of MMA epilogue and combine threads"); + DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid number of experts or ranks"); + + // Thread indices + const bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const uint32_t sm_idx = blockIdx.x; + const uint32_t thread_idx = threadIdx.x; + const uint32_t warp_idx = cutlass::canonical_warp_idx_sync(); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0) { + cute::prefetch_tma_descriptor(&tensor_map_l1_acts); + cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_output); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights_sf); + } + + // Workspaces + const auto workspace = layout::Workspace( + sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk); + + // Token and buffer layouts + constexpr auto fp8_token_layout = layout::Data(kHidden); + constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); + constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); + constexpr auto fp8_sf_layout = layout::Data(kHidden / 32); + constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 32); + constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); + constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); + constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + // Registered inputs + const auto input_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxTokensPerRank, + workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumMaxTokensPerRank, + input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer( + input_topk_idx_layout, 1, kNumMaxTokensPerRank, + input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer( + input_topk_weights_layout, 1, kNumMaxTokensPerRank, + input_topk_idx_buffer.get_end_ptr()); + + // SF and its buffer configs + constexpr uint32_t kGranK = 32; + constexpr uint32_t kNumUTCCPAlignedElems = 128; + DG_STATIC_ASSERT(SF_BLOCK_M == math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems), "Invalid SF_BLOCK_M"); + DG_STATIC_ASSERT(SF_BLOCK_N == BLOCK_N, "No padding is needed for SFB"); + + // UTCCP 4x32 transpose index mapping within each 128-element group + const auto transform_sf_token_idx = [](const uint32_t& token_idx_in_expert) { + const uint32_t idx = token_idx_in_expert % BLOCK_M; + return token_idx_in_expert / BLOCK_M * SF_BLOCK_M + + (idx & ~127u) + (idx & 31u) * 4 + ((idx >> 5) & 3u); + }; + + // L1 inputs + const auto l1_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxPoolTokens, + input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumPaddedSFPoolTokens, + l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer( + l1_topk_weights_layout, 1, kNumMaxPoolTokens, + l1_sf_buffer.get_end_ptr()); + + // L2 inputs + const auto l2_token_buffer = layout::Buffer( + fp8_intermediate_token_layout, 1, kNumMaxPoolTokens, + l1_topk_weights_buffer.get_end_ptr() + ); + const auto l2_sf_buffer = layout::Buffer( + fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, + l2_token_buffer.get_end_ptr() + ); + + // Combine inputs + const auto combine_token_buffer = layout::Buffer( + bf16_token_layout, kNumTopk, kNumMaxTokensPerRank, + l2_sf_buffer.get_end_ptr() + ); + + // Data types + // NOTES: activations are FP8 (e4m3), weights are FP4 (e2m1) + using a_dtype_t = cutlass::float_e4m3_t; + using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; + + // MMA configs + // NOTES: always swap A/B, 2-CTA MMA, and matrices are K-major + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t UMMA_M = LAYOUT_AD_M * 2; + constexpr uint32_t UMMA_N = BLOCK_M; // Swap AB + constexpr uint32_t UMMA_K = 32; + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N; + DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_N == LAYOUT_AD_M, "Invalid block N"); + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + + // Swizzle configs + constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t kSwizzleCDMode = 128; + DG_STATIC_ASSERT(BLOCK_N % kSwizzleCDMode == 0, "Invalid block N"); + + // Epilogue configs + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + + // Shared memory + constexpr uint32_t kSharedMemoryAlignment = 1024; + extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; + + // Shared memory sizes + // NOTES: FP8 CD output for L1 (2 TMA stages, BLOCK_N/2 post-SwiGLU), BF16 output for L2 (no TMA, a single stage) + constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2; + constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = + math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); + constexpr uint32_t SMEM_SEND_BUFFER_SIZE = + math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + constexpr uint32_t SMEM_CD_L1_SIZE = + kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) * kNumTMAStoreStages; + constexpr uint32_t SMEM_CD_L2_SIZE = + kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE; + constexpr uint32_t SMEM_CD_L1_SIZE_PER_STAGE = SMEM_CD_L1_SIZE / kNumTMAStoreStages; + constexpr uint32_t SMEM_BEFORE_BARRIER_SIZE = + SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + DG_STATIC_ASSERT(SMEM_CD_SIZE % kSharedMemoryAlignment == 0 and + SMEM_A_SIZE_PER_STAGE % kSharedMemoryAlignment == 0 and + SMEM_B_SIZE_PER_STAGE % kSharedMemoryAlignment == 0, + "Shared memory of CD/A/B must be aligned to 1024 bytes"); + + // Tensor memory size + constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Assign shared memory for dispatch warps + const auto smem_expert_count = reinterpret_cast(smem_buffer); + const auto smem_send_buffers = layout::Buffer( + fp8_token_layout, kNumDispatchWarps, 1, + math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE)); + + // GEMM shared memory: C/D, A, B + // NOTES: GEMM shared memory starts after the dispatch region, aligned to 1024 bytes + auto smem_gemm_base = math::advance_ptr( + smem_buffer, SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + ); + + // D/A/B shared memory + auto smem_cd = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, i * SMEM_CD_L1_SIZE_PER_STAGE); + }); + auto smem_cd_l2 = smem_cd[0]; + auto smem_a = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SF shared memory: SFA and SFB per pipeline stage + auto sf_start_ptr = math::advance_ptr(smem_gemm_base, + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Epilogue amax reduction shared memory + auto smem_amax_reduction = reinterpret_cast(smem_sfb[kNumStages]); + + // Barriers and tensor memory pointer + auto barrier_start_ptr = reinterpret_cast(smem_amax_reduction + STORE_BLOCK_M * kNumEpilogueWarps / 2); + auto dispatch_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages + i); }); + auto combine_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + i); }); + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + kNumEpilogueWarps * 2); + + // A cluster sync is essential for 2CTA tensor memory allocation + comm::cluster_sync_with_relaxed_arrive(); + + // Initialization + if (warp_idx == 0) { + // Clean shared memory + if (cute::elect_one_sync()) + ptx::st_shared_bulk(smem_expert_count, kNumExperts * sizeof(uint32_t)); + } else if (warp_idx == 1) { + // Init m-barriers for dispatch + #pragma unroll + for (uint32_t i = lane_idx; i < kNumDispatchWarps; i += 32) + dispatch_barriers[i]->init(1); + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Init GEMM barriers + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(2 * 2); + empty_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(2 * kNumEpilogueThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueWarps * 2; ++ i) + combine_barriers[i]->init(1); + } + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 3) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + // NOTES: Using `.relaxed` is allowed here since `fence_barrier_init` is `.release.cluster`, + // and `barrier.cluster.wait.aligned` is by default `.acquire` + comm::cluster_sync_with_relaxed_arrive(); + + // Task scheduler + auto scheduler = sched::MegaMoEScheduler< + BLOCK_M, BLOCK_N, BLOCK_K, + L1_SHAPE_N, L1_SHAPE_K, + L2_SHAPE_N, L2_SHAPE_K, + kNumExpertsPerRank, + kNumExpertsPerWave, + kNumSMs, kNumRanks>(workspace); + + // MMA pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Intra-SM Barrier indices + constexpr uint32_t kDispatchBarrierIdx = 0; + constexpr uint32_t kDispatchWithEpilogueBarrierIdx = 1; + constexpr uint32_t kEpilogueFullBarrierIdx = 2; + constexpr uint32_t kEpilogueWGBarrierStartIdx = 3; + + // NVLink barrier tags + constexpr uint32_t kBeforeDispatchPullBarrierTag = 1; + constexpr uint32_t kBeforeCombineReduceBarrierTag = 2; + constexpr uint32_t kAfterWorkspaceCleanBarrierTag = 3; + + // Adjust registers + constexpr uint32_t kNumDispatchRegisters = 48; + constexpr uint32_t kNumNonEpilogueRegisters = 40; + constexpr uint32_t kNumEpilogueRegisters = 208; + DG_STATIC_ASSERT(kNumDispatchRegisters * kNumDispatchThreads + + kNumNonEpilogueRegisters * kNumNonEpilogueThreads + + kNumEpilogueRegisters * kNumEpilogueThreads <= 64512, + "Too many registers"); + + // Grid sync index assignments (dispatch and epilogue use separate counters to avoid conflicts) + constexpr uint32_t kDispatchGridSyncIndex = 0; + constexpr uint32_t kEpilogueGridSyncIndex = 1; + + // Different warp roles + if (warp_idx < kNumDispatchWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // Dispatch warps + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + constexpr uint32_t kNumActivateLanes = kNumTokensPerWarp * kNumTopk; + const auto read_topk_idx = [&](const auto& process) { + // TODO: figure out better unrolling + // Now, `unroll` is better than `unroll 8` + #pragma unroll + for (uint32_t i = (sm_idx * kNumDispatchWarps + warp_idx) * kNumTokensPerWarp; + i < num_tokens; + i += kNumSMs * kNumDispatchWarps * kNumTokensPerWarp) { + // Allocate slots for each token-topk + int expert_idx = -1; + if (i + (lane_idx / kNumTopk) < num_tokens and lane_idx < kNumActivateLanes) { + expert_idx = static_cast( + __ldg(input_topk_idx_buffer.get_base_ptr() + i * kNumTopk + lane_idx)); + if (expert_idx >= 0) + process(i * kNumTopk + lane_idx, expert_idx); + } + __syncwarp(); + } + }; + + // Count experts' tokens + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + atomicAdd_block(smem_expert_count + expert_idx, 1); + }); + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Get SM offset (~6.5 us) + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const uint64_t send_value = (1ull << 32) | static_cast(smem_expert_count[i]); + smem_expert_count[i] = static_cast( + ptx::atomic_add(workspace.get_expert_send_count_ptr(i), send_value)); + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Write source indices (~2 us with 512 tokens) + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + const auto dst_rank_idx = expert_idx / kNumExpertsPerRank; + const auto dst_slot_idx = atomicAdd_block(smem_expert_count + expert_idx, 1); + const auto dst_ptr = workspace.get_src_token_topk_idx_ptr( + expert_idx % kNumExpertsPerRank, sym_buffer.rank_idx, dst_slot_idx); + *sym_buffer.map(dst_ptr, dst_rank_idx) = token_topk_idx; + }); + + // Grid sync + comm::grid_sync( + workspace, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); } + ); + + // Write expert count + if (sm_idx == 0) { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const auto dst_rank_idx = i / kNumExpertsPerRank; + const auto dst_local_expert_idx = i % kNumExpertsPerRank; + const auto expert_status = *workspace.get_expert_send_count_ptr(i); + *sym_buffer.map( + workspace.get_expert_recv_count_ptr(sym_buffer.rank_idx, dst_local_expert_idx), + dst_rank_idx) = expert_status & 0xffffffff; + ptx::atomic_add_sys( + sym_buffer.map(workspace.get_expert_recv_count_sum_ptr(dst_local_expert_idx), dst_rank_idx), + expert_status); + } + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Barrier before pulling + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + /* After the grid sync above, there is no more writes by other SMs (except 0) */ false, + /* After the NVLink barrier, there is a grid sync */ true + ); + + // Ensure the epilogue barrier cannot run with the pull barrier + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Pull token data and SF from remote ranks into local L1 buffer + uint32_t pull_mbarrier_phase = 0; + const auto pull_buffer = smem_send_buffers.get_rank_buffer(warp_idx).get_data_buffer(0); + const auto pull_mbarrier = dispatch_barriers[warp_idx]; + + // Cache expert token counts in registers (same pattern as scheduler) + scheduler.fetch_expert_recv_count(); + + // Per-rank counts for current expert (re-loaded when expert changes) + constexpr uint32_t kNumRanksPerLane = math::constexpr_ceil_div(kNumRanks, 32u); + int current_expert_idx = -1; + uint32_t stored_rank_count[kNumRanksPerLane] = {}; + uint32_t expert_start_idx = 0, expert_end_idx = 0; + uint32_t expert_pool_block_offset = 0; + + constexpr uint32_t kNumGlobalWarps = kNumSMs * kNumDispatchWarps; + for (uint32_t token_idx = sm_idx * kNumDispatchWarps + warp_idx; ; token_idx += kNumGlobalWarps) { + // Advance expert until within the range + int old_expert_idx = current_expert_idx; + while (token_idx >= expert_end_idx) { + if (++ current_expert_idx >= kNumExpertsPerRank) + break; + + // Update pool block offset for the new expert + expert_pool_block_offset += math::ceil_div(expert_end_idx - expert_start_idx, BLOCK_M); + + // Move start and end to the next expert + expert_start_idx = expert_end_idx; + expert_end_idx += scheduler.get_num_tokens(current_expert_idx); + } + + // Finish all tokens + if (current_expert_idx >= kNumExpertsPerRank) + break; + + // Load per-rank counts when expert changes + if (old_expert_idx != current_expert_idx) { + old_expert_idx = current_expert_idx; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t j = i * 32 + lane_idx; + // TODO: this is not coalesced + stored_rank_count[i] = j < kNumRanks ? + static_cast(*workspace.get_expert_recv_count_ptr(j, current_expert_idx)) : 0; + } + } + + // Round-robin rank selection via iterative min-peeling + uint32_t current_rank_in_expert_idx; + uint32_t remaining[kNumRanksPerLane]; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] = stored_rank_count[i]; + uint32_t offset = 0; + uint32_t token_idx_in_expert = token_idx - expert_start_idx; + uint32_t slot_idx = token_idx_in_expert; + uint32_t token_idx_in_rank; + while (true) { + // Compute active count and min across all ranks + // NOTES: reduce within each lane first, then warp-reduce once + uint32_t num_actives_in_lane = 0; + uint32_t min_in_lane = 0xffffffff; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + num_actives_in_lane += remaining[i] > 0; + if (remaining[i] > 0) + min_in_lane = cute::min(min_in_lane, remaining[i]); + } + const uint32_t num_active_ranks = __reduce_add_sync(0xffffffff, num_actives_in_lane); + const uint32_t length = __reduce_min_sync(0xffffffff, min_in_lane); + + // Hit in the current round + const uint32_t num_round_tokens = length * num_active_ranks; + if (slot_idx < num_round_tokens) { + const uint32_t slot_idx_in_round = slot_idx % num_active_ranks; + uint32_t num_seen_ranks = 0; + current_rank_in_expert_idx = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t mask = __ballot_sync(0xffffffff, remaining[i] > 0); + const uint32_t num_active_lanes = __popc(mask); + if (slot_idx_in_round >= num_seen_ranks and slot_idx_in_round < num_seen_ranks + num_active_lanes) + current_rank_in_expert_idx = i * 32 + __fns(mask, 0, slot_idx_in_round - num_seen_ranks + 1); + num_seen_ranks += num_active_lanes; + } + token_idx_in_rank = offset + (slot_idx / num_active_ranks); + break; + } + + // Move into the next round + slot_idx -= num_round_tokens; + offset += length; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] -= cute::min(remaining[i], length); + } + + // Read source token-topk index (written by remote dispatch via NVLink) + const uint32_t src_token_topk_idx = *workspace.get_src_token_topk_idx_ptr( + current_expert_idx, current_rank_in_expert_idx, token_idx_in_rank); + const uint32_t src_token_idx = src_token_topk_idx / kNumTopk; + const uint32_t src_topk_idx = src_token_topk_idx % kNumTopk; + + // TMA load token from remote rank into shared memory + if (cute::elect_one_sync()) { + ptx::tma_load_1d( + pull_buffer.get_base_ptr(), + sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx), + pull_mbarrier, kHidden); + } + __syncwarp(); + + // Load and store SF (overlaps with TMA token load) + constexpr uint32_t kNumSFUint32 = kHidden / 128; + DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 128 == 0, "Invalid SF"); + const auto remote_sf_ptr = sym_buffer.map( + input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx); + const auto local_sf_ptr = l1_sf_buffer.get_base_ptr(); + const auto sf_pool_token_idx = expert_pool_block_offset * SF_BLOCK_M + + transform_sf_token_idx(token_idx_in_expert); + #pragma unroll + for (uint32_t i = 0; i < math::constexpr_ceil_div(kNumSFUint32, 32u); ++ i) { + const uint32_t j = i * 32 + lane_idx; + if (j < kNumSFUint32) + local_sf_ptr[j * kNumPaddedSFPoolTokens + sf_pool_token_idx] = remote_sf_ptr[j]; + } + __syncwarp(); + + // Store weights and token data + const uint32_t pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert; + if (cute::elect_one_sync()) { + // Load weights + const auto weight = *sym_buffer.map( + input_topk_weights_buffer.get_base_ptr() + src_token_topk_idx, + current_rank_in_expert_idx); + *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; + + // Wait for TMA token load to complete + ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); + ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase); + + // Store token to local L1 buffer via TMA + ptx::tma_store_1d( + l1_token_buffer.get_data_buffer(pool_token_idx).get_base_ptr(), + pull_buffer.get_base_ptr(), pull_buffer.get_num_bytes()); + + // Write source metadata for combine write-back + *workspace.get_token_src_metadata_ptr(pool_token_idx) = + {current_rank_in_expert_idx, src_token_idx, src_topk_idx}; + + // Wait for token TMA store to complete + cute::tma_store_arrive(); + ptx::tma_store_wait<0>(); + ptx::red_add_rel( + workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + token_idx_in_expert / BLOCK_M), 1); + } + __syncwarp(); + } + + // Clean workspace for the next usage, and also do cumulative stats + // NOTES: it is overlapped with combine reduction epilogue + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + DG_STATIC_ASSERT(kNumSMs > 1, "Invalid SM count"); + if (sm_idx == 0) { + // SM 0: clear expert send count + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) + *workspace.get_expert_send_count_ptr(i) = 0; + } else { + // Other SMs: clean blocks + for (uint32_t i = sm_idx - 1; i < kNumExpertsPerRank; i += kNumSMs - 1) { + // Read expert token count before clearing + const auto num_recv_tokens = static_cast( + *workspace.get_expert_recv_count_sum_ptr(i)); + const auto num_recv_m_blocks = math::ceil_div(num_recv_tokens, BLOCK_M); + + // Compute expert pool block offset + expert_pool_block_offset = scheduler.get_pool_block_offset(i); + + // Wait read count ready + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Clean expert token count, and add cumulative results + DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps"); + if (warp_idx == 0) { + *workspace.get_expert_recv_count_sum_ptr(i) = 0; + } else if (warp_idx == 1) { + if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr) + ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); + __syncwarp(); + } + + // Clean per-rank token count + for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) + *workspace.get_expert_recv_count_ptr(j, i) = 0; + __syncwarp(); + + // Clean L1 and L2 arrival stuffs + for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { + *workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0; + *workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0; + } + __syncwarp(); + } + } + + // Wait for all ranks to finish cleaning + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + /* Before the NVLink barrier, there is a grid sync */ true, + /* At the end of kernel does not need to sync */ false + ); + } else if (warp_idx == kNumDispatchWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM TMA load warp for tokens with SFA + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_a_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts : &tensor_map_l1_acts; + const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf; + + const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K; + const auto shape_sfa_k = math::ceil_div(shape_k, kGranK * 4u); + + // Compute pool block offset for this expert + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + + // Wait the entire token arrival for linear 1 + if (block_phase == sched::BlockPhase::Linear1) { + const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); + const auto expected = scheduler.template get_valid_m(); + while (ptx::ld_acq(ptr) != expected); + } else { + // The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2x L1 blocks' arrival + // NOTES: Originally we wait blocks on-demand to overlap L1 calculation + // with L2, but this optimization is negative when `num_experts_per_wave` + // guarantees L1's completion when L2 starts. So we remove it. + // In the future, if `num_experts_per_wave` is not large enough + // due to small `num_experts_per_rank`, we may need to add it back or add a switch + DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes"); + const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); + // NOTES: Equivalent to `(1ull << (2 * num_k_blocks)) - 1`, but split into two shifts + // to avoid undefined behavior when `num_k_blocks == 32` + const uint64_t expected = ((1ull << num_k_blocks) << num_k_blocks) - 1; + while (ptx::ld_acq_gpu(ptr) != expected); + } + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute token offset from pool block index + uint32_t m_idx = pool_block_idx * BLOCK_M; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfa_m_idx = pool_block_idx * SF_BLOCK_M; + uint32_t sfa_k_idx = k_block_idx; + + // Add 2 CTA offsets for non-leader CTA + if (not is_leader_cta) + m_idx += scheduler.template get_valid_m() / 2; + + // TMA copy tokens and SFA, then arrive at full barrier + if (cute::elect_one_sync()) { + tma::copy( + tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx, 2); + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE * 2 + SF_BLOCK_M * sizeof(uint32_t) * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == kNumDispatchWarps + 1) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM TMA load warp for weights with SF + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_b_ptr = + block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights : &tensor_map_l1_weights; + const auto tensor_map_sfb_ptr = + block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights_sf : &tensor_map_l1_weights_sf; + + const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K; + const auto shape_n = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_N : L1_SHAPE_N; + const auto shape_sfb_k = math::ceil_div(shape_k, kGranK * 4u); + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute weight offset + uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfb_n_idx = n_block_idx * BLOCK_N; + uint32_t sfb_k_idx = local_expert_idx * shape_sfb_k + k_block_idx; + + // TMA copy weights with SF + if (cute::elect_one_sync()) { + tma::copy( + tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx, 2); + tma::copy( + tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE + BLOCK_N * sizeof(uint32_t) * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == kNumDispatchWarps + 2) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM MMA issue warp (only the leader CTA will run) + if (is_leader_cta) { + // Make instruction descriptor with block scaling + // NOTES: always swap A/B + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled< + b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t, + UMMA_M, UMMA_N, + cute::UMMA::Major::K, cute::UMMA::Major::K + >(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + // Dynamic update of UMMA N based on effective M + mma::sm100::update_instr_desc_with_umma_n(instr_desc, scheduler.template get_valid_m()); + + // Wait tensor memory empty barrier arrival + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + constexpr uint16_t kCTAMask = (1 << 2) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); + }; + + // Launch MMAs + #pragma unroll 2 + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA load completion + full_barriers[stage_idx]->wait(phase); + ptx::tcgen05_after_thread_sync(); + + const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx); + const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + // UTCCP copy SFA and SFB to TMEM + using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta; + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + + // Issue UMMA + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const auto runtime_instr_desc = + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k); + a_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K); + ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma( + b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFB, kTmemStartColOfSFA); + } + } + __syncwarp(); + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_k_blocks - 1); + } + }); + + // To safely deconstruct barriers, we need another round of waits + if (current_iter_idx > 0) { + const auto accum_phase_idx = ((current_iter_idx - 1) / kNumEpilogueStages) & 1; + tmem_empty_barriers[(current_iter_idx - 1) % kNumEpilogueStages]->wait(accum_phase_idx); + } + } + } else if (warp_idx == kNumDispatchWarps + 3) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + } else if (warp_idx >= kNumDispatchWarps + kNumMMANonEpilogueWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // GEMM epilogue warps + const auto epilogue_warp_idx = warp_idx - (kNumDispatchWarps + kNumMMANonEpilogueWarps); + const auto epilogue_wg_idx = epilogue_warp_idx / 4; + const auto epilogue_thread_idx = epilogue_warp_idx * 32 + lane_idx; + const auto warp_idx_in_wg = epilogue_warp_idx % 4; + DG_STATIC_ASSERT((kNumDispatchWarps + kNumMMANonEpilogueWarps) % 4 == 0 and + kNumEpilogueWarps % 4 == 0, "Invalid epilogue warps"); + + // TODO: support effective block M + // NOTES: + // - 2 warpgroups divide the whole BM into BM / 2 + // - 4 warps divide the whole BN into BN / 4 + // - BM / 2 is further divided into stored blocks, i.e. with `STORE_BLOCK_M` size + // - `STORE_BLOCK_M` in further divided into `ATOM_M` + constexpr uint32_t WG_BLOCK_M = BLOCK_M / kNumEpilogueWarpgroups; + constexpr uint32_t ATOM_M = 8; + constexpr uint32_t kNumBankGroupBytes = 16u; + constexpr uint32_t kNumAtomsPerStore = STORE_BLOCK_M / ATOM_M; + DG_STATIC_ASSERT(BLOCK_M % kNumEpilogueWarpgroups == 0, "Invalid block M"); + DG_STATIC_ASSERT(WG_BLOCK_M % STORE_BLOCK_M == 0, "Invalid warpgroup block M"); + DG_STATIC_ASSERT(STORE_BLOCK_M % ATOM_M == 0, "Invalid store block M"); + DG_STATIC_ASSERT(BLOCK_N == 128, "Invalid block N"); + + // Ensure the epilogue barrier cannot run with the pull barrier + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Persistently schedule over blocks + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + // Wait UMMA arrival + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1; + tmem_full_barriers[accum_stage_idx]->wait(accum_phase); + ptx::tcgen05_after_thread_sync(); + + // Compute offsets + // NOTES: use shuffle here to let NVCC know warp divergence won't happen + const uint32_t valid_m = ptx::exchange(scheduler.template get_valid_m(), 0); + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + uint32_t m_idx = pool_block_idx * BLOCK_M; + uint32_t n_idx = n_block_idx * BLOCK_N; + + if (block_phase == sched::BlockPhase::Linear1) { + // Unified L1 epilogue: SwiGLU in-place using granularity 8 interleaved weights + // With `SM100_TMEM_LOAD_16dp256b1x`, gate/up pairs are: + // (values[0], values[2]), (values[1], values[3]), + // (values[4], values[6]), (values[5], values[7]) + float stored_cached_weight = 0; + + #pragma unroll + for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) { + // Early break if the entire store block is beyond the valid token range + if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + break; + } + + // Iterate all atoms in the store block + float2 swiglu_values[kNumAtomsPerStore * 2]; + float2 amax_values[kNumAtomsPerStore]; + #pragma unroll + for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { + const uint32_t j = s * kNumAtomsPerStore + i; + + // Load weights from global into register cache per 32 tokens + DG_STATIC_ASSERT(32 % ATOM_M == 0, "Invalid block size"); + if ((j * ATOM_M) % 32 == 0 and (WG_BLOCK_M % 32 == 0 or j * ATOM_M + lane_idx < WG_BLOCK_M)) { + stored_cached_weight = *l1_topk_weights_buffer + .get_data_buffer(m_idx + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M + lane_idx) + .get_base_ptr(); + } + + // Load weights from register cache + const float2 weights = { + ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 0), + ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 1) + }; + + // Load from TMEM + uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M; + uint32_t values[ATOM_M]; + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Signal tensor memory consumed on the last atom + if (j == WG_BLOCK_M / ATOM_M - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Apply SwiGLU: silu(gate) * up + // Gate/up pairs: (0, 2), (1, 3), (4, 6), (5, 7) + auto fp32_values = reinterpret_cast(values); + #pragma unroll + for (uint32_t k = 0; k < 2; ++ k) { + auto bf16_gate = __float22bfloat162_rn(make_float2(fp32_values[k * 4], fp32_values[k * 4 + 1])); + auto bf16_up = __float22bfloat162_rn(make_float2(fp32_values[k * 4 + 2], fp32_values[k * 4 + 3])); + + // Clamp + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) { + bf16_gate = __hmin2(bf16_gate, {kActivationClamp, kActivationClamp}); + bf16_up = __hmax2(bf16_up, {-kActivationClamp, -kActivationClamp}); + bf16_up = __hmin2(bf16_up, {kActivationClamp, kActivationClamp}); + } + + // SwiGLU + auto gate = __bfloat1622float2(bf16_gate); + auto neg_gate_exp = make_float2( + kFastMath ? __expf(-gate.x) : expf(-gate.x), + kFastMath ? __expf(-gate.y) : expf(-gate.y)); + const auto denom = __fadd2_rn({1.0f, 1.0f}, neg_gate_exp); + if constexpr (kFastMath) { + gate = __fmul2_rn(gate, {math::fast_rcp(denom.x), math::fast_rcp(denom.y)}); + } else { + gate = {gate.x / denom.x, gate.y / denom.y}; + } + const auto up = __bfloat1622float2(bf16_up); + swiglu_values[i * 2 + k] = __fmul2_rn(__fmul2_rn(gate, up), weights); + } + + // Amax reduction + amax_values[i].x = math::warp_reduce<4, true>( + cute::max(cute::abs(swiglu_values[i * 2 + 0].x), cute::abs(swiglu_values[i * 2 + 1].x)), + math::ReduceMax()); + amax_values[i].y = math::warp_reduce<4, true>( + cute::max(cute::abs(swiglu_values[i * 2 + 0].y), cute::abs(swiglu_values[i * 2 + 1].y)), + math::ReduceMax()); + if (lane_idx < 4) + smem_amax_reduction[epilogue_warp_idx * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx] = amax_values[i]; + __syncwarp(); + } + + // Wait shared memory release from previous TMA store + // And fence `smem_amax_reduction` + const uint32_t tma_stage_idx = s % kNumTMAStoreStages; + ptx::tma_store_wait(); + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Cast to FP8 E4M3 and store into shared memory + #pragma unroll + for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { + // Reduce amax + const float2 wp_amax = + smem_amax_reduction[(epilogue_warp_idx ^ 1) * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx % 4]; + amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x); + amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y); + + // Calculate SF + float2 sf, sf_inv; + math::get_e4m3_sf_and_sf_inv(amax_values[i], sf, sf_inv); + + // Cast + const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv); + const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv); + const auto fp8x4_values = __nv_fp8x4_e4m3(make_float4(upper.x, upper.y, lower.x, lower.y)); + + // STSM + uint32_t row = lane_idx; + uint32_t col = warp_idx_in_wg; + const auto smem_ptr = smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N + + i * ATOM_M * L1_OUT_BLOCK_N + + row * L1_OUT_BLOCK_N + + (col ^ (row / 2)) * kNumBankGroupBytes; + ptx::SM100_U8x4_STSM_T<__nv_fp8x4_e4m3>::copy(fp8x4_values, smem_ptr); + + // Store SF to `l2_sf_buffer` as UE8M0 (MN-major layout) + // Only one warp per pair writes (both hold the same SF after cross-warp reduce) + // Each lane < 4 holds SF for 2 rows (sf.x and sf.y) + if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) { + const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2; + const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4; + const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t); + const auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); + // NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4 + // NOTES: originally there was: + // - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2 + // - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)` + // We find out that + // 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside + // 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside + // This reduce the number of computation instructions. + const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; + __builtin_assume(token_base_idx < BLOCK_M); + const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M + + m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4; + const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast(sizeof(uint32_t)) + byte_idx; + sf_base_ptr[sf_addr] = + (*reinterpret_cast(&sf.x) >> 23); + sf_base_ptr[sf_addr + 4 * static_cast(sizeof(uint32_t))] = + (*reinterpret_cast(&sf.y) >> 23); + } + __syncwarp(); + } + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Issue TMA store after all atoms in this store block + if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { + uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N, + out_n_idx, + m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M); + cute::tma_store_arrive(); + } + __syncwarp(); + } + + // Notify L2 + // TODO: less epilogue sync scope + ptx::tma_store_wait<0>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + DG_STATIC_ASSERT(L2_SHAPE_K <= 64 * L1_OUT_BLOCK_N, "L2 shape K is too large"); + ptx::red_or_rel_gpu( + workspace.get_l2_arrival_mask_ptr(pool_block_idx), + 1ull << n_block_idx + ); + } + __syncwarp(); + } else { + DG_STATIC_ASSERT(STORE_BLOCK_M % 8 == 0, "Invalid store M"); + constexpr uint32_t kNumRowsPerWarp = STORE_BLOCK_M / 8; + + // L2 BF16 epilogue: write GEMM output to remote combine buffer via NVLink + #pragma unroll + for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) { + // Early break if the entire store block is beyond the valid token range + // TODO: check performance + if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + break; + } + + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_M / ATOM_M; ++ i) { + // Load from TMEM using .16x256b shape to satisfy STSM layout requirements + // Start from lane index 0 and 16 + uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; + uint32_t values[ATOM_M]; + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Wait shared memory release from previous NVLink store + // NOTES: skip for the first store block since the prior full barrier already ensures completion + if (i == 0 and s > 0) + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Signal tensor memory consumed + if (s == WG_BLOCK_M / STORE_BLOCK_M - 1 and i == STORE_BLOCK_M / ATOM_M - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Store into shared memory + // NOTES: only use first 16 lanes for address + // NOTES: 2 warps share a BF16 swizzle atom + uint32_t row = lane_idx % 8; + uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8; + const auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast(sizeof(nv_bfloat16)) + + (warp_idx_in_wg / 2) * STORE_BLOCK_M * kSwizzleCDMode + + i * ATOM_M * kSwizzleCDMode + + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes; + ptx::SM90_U32x4_STSM_T::copy( + math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]), + smem_ptr + ); + } + + // Wait shared memory ready + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Write into remote buffers + // One warp per row, now the layout is different from shared memory storing + const uint32_t row_in_atom = (warp_idx_in_wg * 2 + lane_idx / 16) % ATOM_M; + const uint32_t bank_group_idx = lane_idx % 8; + + #pragma unroll + for (uint32_t j = 0; j < kNumRowsPerWarp; ++ j) { + const uint32_t row_in_store = j * 8 + warp_idx_in_wg * 2 + lane_idx / 16; + const uint32_t m_idx_in_block = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + row_in_store; + + // Skip padding rows beyond the actual token count for this expert + if (m_idx_in_block >= valid_m) + break; + + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block); + const uint32_t dst_rank_idx = src_metadata.rank_idx; + const uint32_t dst_token_idx = src_metadata.token_idx; + const uint32_t dst_topk_idx = src_metadata.topk_idx; + + // Read from shared memory + const auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast(sizeof(nv_bfloat16)) + + (lane_idx % 16 / 8) * STORE_BLOCK_M * kSwizzleCDMode + + row_in_store * kSwizzleCDMode + + (bank_group_idx ^ row_in_atom) * kNumBankGroupBytes; + const auto packed = ptx::ld_shared(reinterpret_cast(smem_ptr)); + + // Write into remote + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + const auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * static_cast(sizeof(nv_bfloat16)) + (lane_idx % 16) * static_cast(sizeof(float4))); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } + } + + // Ensure the next epilogue safe to use shared memory + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + } + }); + + // Deallocate tensor memory + // NOTES: must be called by the same logical warp ID on both CTAs + if (epilogue_warp_idx == 0) + Allocator().free(0, kNumTmemCols); + + // NVLink barrier (grid sync + cross-rank signal + grid sync): ~4 us + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, epilogue_thread_idx, + [&]() { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); } + ); + + // Barrier with dispatch warps, so that they can do clean workspace + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Combine: reduce top-k results and write back + // NOTES: reuse shared memory from start up to the barriers + // 1 token, 1 topk latency: ~3 us + constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16); + constexpr uint32_t kNumElemsPerUint4 = sizeof(uint4) / sizeof(nv_bfloat162); + + // 3 slots of chunk is needed: 2 load stages and 1 store + constexpr uint32_t kNumChunkSlots = 3; + constexpr uint32_t kNumMaxRegistersForBuffer = 128; + + // NOTES: either 1 or 2 chunks for simplicity + // NOTES: Restrict on both smem and register + constexpr uint32_t kNumChunks = + kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE and kHidden <= 32 * kNumMaxRegistersForBuffer ? 1 : 2; + constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks; + constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4); + constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32; + DG_STATIC_ASSERT(kHidden % kNumChunks == 0, "Hidden must be divisible by number of chunks"); + DG_STATIC_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / kNumChunks <= SMEM_BEFORE_BARRIER_SIZE, "Hidden is too large"); + DG_STATIC_ASSERT(kNumChunkBytes % 16 == 0, "Combine chunk must be TMA-aligned (16 bytes)"); + DG_STATIC_ASSERT(kNumChunkBytes % sizeof(uint4) == 0, "Combine chunk must be divisible by 16 bytes"); + DG_STATIC_ASSERT(kNumChunkUint4 % 32 == 0, "Combine chunk must be a multiple of 32 16-byte elements (one per lane)"); + DG_STATIC_ASSERT(kNumTopk <= 32, "Top-k must fit in a single warp"); + + // Verify combined shared memory budget at runtime + DG_DEVICE_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumChunkBytes <= static_cast( + reinterpret_cast(barrier_start_ptr) - smem_buffer)); + + // Per-warp buffer: 2 stage load buffers + 1 store buffer + const auto combine_load_buffer = utils::PatternVisitor([&](const uint32_t& i) { + return math::advance_ptr(smem_buffer, (epilogue_warp_idx + i * kNumEpilogueWarps) * kNumChunkBytes); + }); + const auto combine_store_buffer = math::advance_ptr(smem_buffer, (epilogue_warp_idx + kNumEpilogueWarps * 2) * kNumChunkBytes); + + // Per-warp barriers + auto combine_load_barriers = utils::PatternVisitor([&](const uint32_t& i) { + return combine_barriers[i + epilogue_warp_idx * 2]; + }); + + // Iterate over all tokens + uint32_t combine_phase = 0; + uint32_t load_stage_idx = 0; + for (uint32_t token_idx = sm_idx * kNumEpilogueWarps + epilogue_warp_idx; + token_idx < num_tokens; + token_idx += kNumSMs * kNumEpilogueWarps) { + // Read top-k slot indices: each lane reads one slot, then broadcast via exchange + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + const int stored_topk_slot_idx = lane_idx < kNumTopk ? + static_cast(__ldg(input_topk_idx_buffer.get_base_ptr() + token_idx * kNumTopk + lane_idx)) : -1; + const uint32_t total_mask = __ballot_sync(0xffffffff, stored_topk_slot_idx >= 0); + + // Iterate all chunks + for (uint32_t chunk = 0; chunk < kNumChunks; ++ chunk) { + const uint32_t chunk_byte_offset = chunk * kNumChunkBytes; + + // Move mask and load + uint32_t mask = total_mask; + const auto move_mask_and_load = [&](const uint32_t& i) { + if (mask) { + // Move + const uint32_t slot_idx = __ffs(mask) - 1; + mask ^= 1 << slot_idx; + + // Load + if (cute::elect_one_sync()) { + const auto src_ptr = math::advance_ptr( + combine_token_buffer.get_rank_buffer(slot_idx) + .get_data_buffer(token_idx).get_base_ptr(), + chunk_byte_offset); + ptx::tma_load_1d(combine_load_buffer[i], src_ptr, combine_load_barriers[i], kNumChunkBytes); + ptx::mbarrier_arrive_and_set_tx(combine_load_barriers[i], kNumChunkBytes); + } + __syncwarp(); + return true; + } + return false; + }; + + // Load the first selection + bool do_reduce = move_mask_and_load(load_stage_idx); + + // Accumulate all top-k contributions for this chunk in float registers + float2 reduced[kNumUint4PerLane * kNumElemsPerUint4] = {}; + while (do_reduce) { + // Prefetch next top-k into the buffer while current is being accumulated + do_reduce = move_mask_and_load(load_stage_idx ^ 1); + + // Accumulate + combine_load_barriers[load_stage_idx]->wait(combine_phase); + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + const auto uint4_values = combine_load_buffer[load_stage_idx][j * 32 + lane_idx]; + const auto bf16_values = reinterpret_cast(&uint4_values); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + ptx::accumulate(reduced[j * kNumElemsPerUint4 + l], bf16_values[l]); + } + combine_phase ^= load_stage_idx; + load_stage_idx ^= 1; + } + + // Cast + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + uint4 casted; + auto casted_bf16 = reinterpret_cast(&casted); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + casted_bf16[l] = __float22bfloat162_rn(reduced[j * kNumElemsPerUint4 + l]); + + // Wait share memory release and write + if (j == 0) { + ptx::tma_store_wait<0>(); + __syncwarp(); + } + ptx::st_shared(combine_store_buffer + j * 32 + lane_idx, + casted.x, casted.y, casted.z, casted.w); + } + __syncwarp(); + + // TMA store the token chunk + if (cute::elect_one_sync()) { + cute::tma_store_fence(); + ptx::tma_store_1d( + math::advance_ptr(y, static_cast(token_idx) * kNumHiddenBytes + chunk_byte_offset), + combine_store_buffer, kNumChunkBytes); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 45a603ad..7ce008e5 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -155,6 +155,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + if (kNumMulticast > 1) + cute::cluster_sync(); + // Initialize barriers if (warp_idx == 1 and cute::elect_one_sync()) { #pragma unroll @@ -546,12 +549,13 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, } } } - - // Deallocate tensor memory by the last UMMA store warp - // NOTES: warp 0 is waiting TMA store - if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) - Allocator().free(0, kNumTmemCols); } + + // Deallocate tensor memory + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + #else if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh index 180a308b..e6744f59 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -6,27 +6,31 @@ #include #include +#include +#include +#include #include -#include -#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; -using namespace deep_gemm::sm100; - template -__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, - const uint32_t max_seqlen_k, const uint64_t stride_logits, + const uint32_t max_seqlen_k, const uint32_t stride_logits, uint32_t* cu_seq_len_k_start, uint32_t* cu_seq_len_k_end, - float* logits, + logits_dtype_t* logits, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, @@ -35,26 +39,26 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64` // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` // Q should be load only at once for a block - const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q); // Types using Barrier = cutlass::arch::ClusterTransactionBarrier; - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const auto& warp_in_group_idx = warp_idx % 4; - const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; // Prefetch TMA descriptors DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); - if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart) { cute::prefetch_tma_descriptor(&tensor_map_q); cute::prefetch_tma_descriptor(&tensor_map_kv); cute::prefetch_tma_descriptor(&tensor_map_kv_scales); cute::prefetch_tma_descriptor(&tensor_map_weights); } - __syncwarp(); // Shared memory configs // NOTES: weight may be unaligned @@ -62,7 +66,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u); // Align to 512 bytes for swizzle-64B extern __shared__ __align__(512) uint8_t smem_buffer[]; @@ -75,19 +79,19 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); // Data on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); }); - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); }); - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); @@ -95,76 +99,77 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // TMA barriers auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); - auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); }); - auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); }); + auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); }); // Tensor memory allocation auto tmem_ptr_in_smem = reinterpret_cast(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2); // Initialize barriers DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads"); - const bool& is_tma_load_warp = (warp_idx == (kNumMathThreads / 32)); - const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 1)); - if (is_tma_load_warp and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumQStages; ++ i) { full_q_barriers[i]->init(1); - empty_q_barriers[i]->init(kNumMathThreads); + empty_q_barriers[i]->init(kNumMathThreads + 32); } #pragma unroll for (uint32_t i = 0; i < kNumKVStages; ++ i) { full_kv_barriers[i]->init(1); empty_kv_barriers[i]->init(kNumMathThreads); } - #pragma unroll - for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { - full_umma_barriers[i]->init(1); - empty_umma_barriers[i]->init(128); - } - - // Make initialized barrier visible in async proxy cutlass::arch::fence_barrier_init(); - } else if (is_umma_warp) { + } + if (warp_idx == kSpecWarpStart + 1) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } // Allocate tensor memory cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); } __syncthreads(); // Register reconfigurations - constexpr uint32_t kNumSpecializedRegisters = 24; - constexpr uint32_t kNumMathRegisters = 240; + constexpr uint32_t kNumSpecializedRegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; // Block scheduler - uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; - const auto& get_next_block_q_idx = [&]() -> cute::tuple { - return {block_q_idx + gridDim.x, q_iter_idx + 1}; + uint32_t block_q_idx = sm_idx, q_iter_idx = 0; + const auto get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + kNumSMs, q_iter_idx + 1}; }; uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; - const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { uint32_t start = cute::numeric_limits::max(); uint32_t end = cute::numeric_limits::min(); #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); - seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); - seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = cu_seq_len_k_start[q_idx]; + seq_k_end[i] = cu_seq_len_k_end[q_idx]; start = min(start, min(seq_k_start[i], seq_len_kv)); end = max(end, min(seq_k_end[i], seq_len_kv)); } + // TMA alignment requirements for SF KV start = start / 4 * 4; return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase - start, ceil_div(end - start, BLOCK_KV)}; // Task info + start, math::ceil_div(end - start, BLOCK_KV)}; // Task info }; // KV pipeline uint32_t num_total_kv_blocks = 0; - const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { return { (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase @@ -177,13 +182,16 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; - if (is_tma_load_warp) { + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + if (warp_idx == kSpecWarpStart) { cutlass::arch::warpgroup_reg_dealloc(); // Prefetch - const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + const auto issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); }; if (cute::elect_one_sync() and block_q_idx < num_q_blocks) @@ -209,10 +217,10 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); // Issue TMA KV - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], - smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], - smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } num_total_kv_blocks += num_kv_blocks; @@ -221,11 +229,11 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); } } - } else if (is_umma_warp) { + } else if (warp_idx == kSpecWarpStart + 1) { cutlass::arch::warpgroup_reg_dealloc(); // Require full allocation - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // Make UMMA desc auto instr_desc = cute::UMMA::make_instr_descwait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { - auto a_desc = make_umma_desc( + auto a_desc = mma::sm100::make_umma_desc( smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); - auto b_desc = make_umma_desc( + auto b_desc = mma::sm100::make_umma_desc( smem_q[q_stage_idx], 0, k * UMMA_K); cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); } @@ -266,23 +274,37 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, } num_total_kv_blocks += num_kv_blocks; + // UMMA warp must also arrive on empty_q to prevent running ahead + // of math warps in the Q pipeline + empty_q_barriers[q_stage_idx]->arrive(); + // Jump to the next block CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); } - } else if (warp_idx >= kNumMathThreads / 32) { + } else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) { cutlass::arch::warpgroup_reg_dealloc(); - } else if (warp_idx < kNumMathThreads / 32) { + } else if (warp_idx < kSpecWarpStart) { cutlass::arch::warpgroup_reg_alloc(); // Offsets - const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); - const auto& warp_offset = warp_idx * 32; - const auto& v_offset = lane_idx; + const auto tmem_start = warpgroup_idx * UMMA_N; + const auto math_thread_idx = warp_idx * 32 + lane_idx; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr int N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; - // Preload weights - constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads); - float weights[BLOCK_Q][kNumWeightsInReg]; - DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + // Local register buffers + float weights[BLOCK_Q][kNumHeads]; while (block_q_idx < num_q_blocks) { CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); @@ -293,9 +315,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Read weights #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) { - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); - } + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); } // Compute over KV blocks @@ -307,82 +329,59 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Read per-KV scales - float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_offset); + float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx); // Wait UMMA arrival full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Release KV empty empty_kv_barriers[kv_stage_idx]->arrive(); // Reduce over the head dim and store - const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; - static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + const auto kv_offset = kv_start + kv_block_idx * BLOCK_KV + math_thread_idx; DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); - constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q; - DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems"); - uint32_t shifted_accum[kNumLDTMElems]; - auto tmem_load = [&](auto... Is) { - if constexpr (kNumLDTMElems == 32) { - cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 64) { - cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 128) { - cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); - } - }; - [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); - cutlass::arch::fence_view_async_tmem_load(); - - tcgen05_before_thread_sync(); - empty_umma_barriers[warpgroup_idx]->arrive(); - #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); + // Load accumulator from TMEM + float accum[kNumHeads]; + tmem_load(cute::Int{}, tmem_start + i * kNumHeads, accum); + + // Release TMEM empty + if (i == BLOCK_Q - 1) { + ptx::tcgen05_before_thread_sync(); + empty_umma_barriers[warpgroup_idx]->arrive(); + } + // Accumulate weighted ReLU in parallel auto sum_0 = make_float2(0, 0); auto sum_1 = make_float2(0, 0); - const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { + const auto transform = [&](const uint32_t& j, const float2& sum) { auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); auto b = make_float2(weights[i][j], weights[i][j + 1]); return __ffma2_rn(a, b, sum); }; #pragma unroll - for (int j = 0; j < kNumWeightsInReg; j += 4) { - sum_0 = transform_reg(j, sum_0); - sum_1 = transform_reg(j + 2, sum_1); - } - - const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { - auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); - auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), - ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); - return __ffma2_rn(a, b, sum); - }; - - #pragma unroll - for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { - sum_0 = transform_smem(j, sum_0); - sum_1 = transform_smem(j + 2, sum_1); + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); } auto sum = __fadd2_rn(sum_0, sum_1); - float result = scale_kv * (sum.x + sum.y); + auto result = static_cast(scale_kv * (sum.x + sum.y)); // Store into the global memory - // NOTES: we have redundant writes here, consider more carefully - const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast(stride_logits); if constexpr (kIsCompressedLogits) { - if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i]) - logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result; + if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i]) + logits[q_offset + kv_offset - seq_k_start[i]] = result; } else { - logits[q_idx * stride_logits + kv_offset + v_offset] = result; + logits[q_offset + kv_offset] = result; } + __syncwarp(); } } num_total_kv_blocks += num_kv_blocks; @@ -393,12 +392,12 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Jump to the next block CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); } - } - // Free tensor memory - __syncthreads(); - if (is_tma_load_warp) - cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } } } // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh index 7058c40f..9a5bddbf 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh @@ -6,56 +6,65 @@ #include #include +#include +#include +#include #include -#include -#include - -#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; -using namespace deep_gemm::sm100; - template -__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, - const uint64_t logits_stride, const uint64_t block_table_stride, - const uint32_t* context_lens, float* logits, - const uint32_t* block_table, const uint32_t* schedule_meta, + const uint32_t logits_stride, const uint32_t block_table_stride, + const uint32_t* context_lens, logits_dtype_t* logits, + const uint32_t* block_table, const uint32_t* indices, + const uint32_t* schedule_meta, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { using Barrier = cutlass::arch::ClusterTransactionBarrier; - // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; // Prefetch TMA descriptors DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); - if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart) { cute::prefetch_tma_descriptor(&tensor_map_q); cute::prefetch_tma_descriptor(&tensor_map_kv); cute::prefetch_tma_descriptor(&tensor_map_kv_scales); cute::prefetch_tma_descriptor(&tensor_map_weights); } - __syncwarp(); + + // For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill. + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom); // Shared memory configs static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; - static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float); - static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float); // Align to swizzling alignment bytes extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; @@ -63,43 +72,40 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); // Q and KV data on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i); }); constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages; - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); }); // Barriers and TMEM pointer on shared memory const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; - auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); - auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); + auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); + auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); auto tmem_ptr_in_smem = reinterpret_cast(umma_barrier_ptr + kNumMathWarpGroups * 2); - constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups; + constexpr uint32_t kNumTmemCols = kNextNAtom * kNumHeads * kNumMathWarpGroups; DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); - const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4); - const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4); - const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1); // Initialize barriers - if (is_tma_load_warp and cute::elect_one_sync()) { + if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumQStages; ++ i) { full_q_barriers[i]->init(1); - empty_q_barriers[i]->init(kNumMathThreads); + empty_q_barriers[i]->init(kNumMathThreads + 32); } #pragma unroll for (uint32_t i = 0; i < kNumKVStages; ++ i) { @@ -108,7 +114,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, } cutlass::arch::fence_barrier_init(); } - if (is_umma_warp) { + if (warp_idx == kSpecWarpStart + 1) { if (cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) { @@ -123,79 +129,92 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, __syncthreads(); // Register reconfigurations - constexpr uint32_t kNumSpecializedRegisters = 40; - constexpr uint32_t kNumMathRegisters = 232; + constexpr uint32_t kNumSpecializedRegisters = 56; + constexpr uint32_t kNumMathRegisters = 224; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); // Scheduler constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; - auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + using Scheduler = sched::PagedMQALogitsScheduler; DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); // Q and KV pipeline - const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase }; - const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase }; - uint32_t q_iter_idx = 0, kv_iter_idx = 0; // UMMA settings // Construct instruction with layout D constexpr uint32_t UMMA_M = 128; constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); - constexpr uint32_t UMMA_N = kNextN * kNumHeads; + constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads; DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); - if (is_tma_load_warp) { - // TMA warp-group for loading data + if (warp_idx == kSpecWarpStart) { + // TMA warp for loading data cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + uint32_t q_iter_idx = 0, kv_iter_idx = 0; - const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& tma_q_atom_idx) { if (cute::elect_one_sync()) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + const auto q_token_idx = Scheduler::atom_to_token_idx(tma_q_atom_idx); + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_token_idx * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_token_idx); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); } }; - // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none - uint32_t q_idx = batch_size, kv_idx, num_kv; - uint32_t next_q_idx, next_kv_idx, next_num_kv; + // Initialize outside valid range to indicate no previous task + uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx, num_kv; + uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; bool fetched_next_task; // Prefetch the first Q - if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv))) - issue_tma_q(0, next_q_idx), q_iter_idx = 1; + if ((fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_atom_idx), q_iter_idx = 1; - int kv_block_idx_ptr = 32; + uint32_t kv_block_idx_ptr = 32; uint32_t kv_block_idx_storage; while (fetched_next_task) { - // Prefetch next Q when current Q changes - bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); - q_idx = next_q_idx; + // Prefetch next Q when (q, atom) changes + const auto next_advance = scheduler.get_atom_advance(next_q_atom_idx, batch_size); + bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + next_advance); + + if (q_atom_idx != next_q_atom_idx) + kv_block_idx_ptr = 32; + + q_atom_idx = next_q_atom_idx; kv_idx = next_kv_idx; num_kv = next_num_kv; // Read KV block index - // TODO: deal with `-1`? - if (kv_idx == 0 or kv_block_idx_ptr == 32) { + // TODO(xuzhean): consider -1 + if (kv_block_idx_ptr == 32) { kv_block_idx_ptr = 0; - kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0); + const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast(block_table_stride); + kv_block_idx_storage = (kv_idx + lane_idx < num_kv) + ? block_table[block_table_offset + kv_idx + lane_idx] : 0; } + __syncwarp(); DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`"); // Wait Q consumer release and issue TMA Q if (prefetch_q) { CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); - issue_tma_q(q_stage_idx, q_idx + 1); + issue_tma_q(q_stage_idx, q_atom_idx + next_advance); } - int kv_block_idx[kNumBlocksPerSplit]; + uint32_t kv_block_idx[kNumBlocksPerSplit]; #pragma unroll - for (int i = 0; i < kNumBlocksPerSplit; ++ i) + for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i) kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); kv_block_idx_ptr += kNumBlocksPerSplit; @@ -205,45 +224,53 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, if (cute::elect_one_sync()) { #pragma unroll - for (int i = 0; i < kNumBlocksPerSplit; ++ i) { - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], - smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i, - 0, 0, 1, kv_block_idx[i]); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], - smem_kv_scales[kv_stage_idx] + BLOCK_KV * i, - 0, kv_block_idx[i]); + for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i) { + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i, + 0, 0, 1, kv_block_idx[i]); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); } full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } // Fetch next task - fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv); + fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv); } - } else if (is_umma_warp) { + } else if (warp_idx == kSpecWarpStart + 1) { cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + uint32_t q_iter_idx = 0, kv_iter_idx = 0; // Require full allocation - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); // Make UMMA desc auto instr_desc = cute::UMMA::make_instr_desc(); auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); - uint32_t q_idx = batch_size, kv_idx; - uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx; + uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; uint32_t q_stage_idx, q_phase; uint32_t umma_phase = 1; - while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { - if (q_idx != next_q_idx) { + while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) { + if (q_atom_idx != next_q_atom_idx) { + // Release previous Q empty (UMMA warp must participate to prevent + // running ahead of math warps in the Q pipeline) + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); full_q_barriers[q_stage_idx]->wait(q_phase); } - q_idx = next_q_idx; + q_atom_idx = next_q_atom_idx; kv_idx = next_kv_idx; + // Wait KV arrival CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); full_kv_barriers[kv_stage_idx]->wait(kv_phase); @@ -251,12 +278,12 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, #pragma unroll for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { empty_umma_barriers[i]->wait(umma_phase); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { - auto a_desc = make_umma_desc( + auto a_desc = mma::sm100::make_umma_desc( smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); - auto b_desc = make_umma_desc( + auto b_desc = mma::sm100::make_umma_desc( smem_q[q_stage_idx], 0, k * UMMA_K); cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); } @@ -264,29 +291,46 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, } umma_phase ^= 1; } - } else if (is_math_warp) { - // Math warp-groups for WGMMA + } else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + // Math warpgroups for reduce cutlass::arch::warpgroup_reg_alloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + uint32_t q_iter_idx = 0, kv_iter_idx = 0; // Offsets - const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); - const uint32_t thread_idx = threadIdx.x; + const auto math_warpgroup_idx = warpgroup_idx; + const auto tmem_start = math_warpgroup_idx * UMMA_N; + const auto math_thread_idx = warp_idx * 32 + lane_idx; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr int N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; - // Weights - constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads)); - float weights[kNextN][kNumWeightsInReg]; - DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + // Local register buffers + float weights[kNextNAtom][kNumHeads]; - // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none - uint32_t q_idx = batch_size, kv_idx; - uint32_t next_q_idx, next_kv_idx, next_num_kv; + // Initialize outside valid range to indicate no previous task + uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx; + uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; uint32_t q_stage_idx, q_phase; uint32_t umma_phase = 0; + bool is_paired_atom = false; - while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { - // Current Q changes - if (q_idx != next_q_idx) { - // Release Last Q empty + while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) { + // Q or atom changes + if (q_atom_idx != next_q_atom_idx) { + // Release last Q empty if (q_iter_idx > 0) empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); @@ -296,30 +340,34 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Read weights #pragma unroll - for (uint32_t i = 0; i < kNextN; ++ i) { - for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + for (uint32_t i = 0; i < kNextNAtom; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + + if constexpr (kIsVarlen) { + is_paired_atom = (scheduler.get_atom_advance(next_q_atom_idx, batch_size) == 2); } } - // Get current Q and KV index - q_idx = next_q_idx; + // Get current task indices + q_atom_idx = next_q_atom_idx; kv_idx = next_kv_idx; // Calculate KV offset in advance - auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV; + auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast(logits_stride) + kv_idx * BLOCK_KV; - // Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]` // Wait TMA KV arrival CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Read per-KV scales - float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx); + float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx); // Wait UMMA arrival - full_umma_barriers[warpgroup_idx]->wait(umma_phase); - tcgen05_after_thread_sync(); + full_umma_barriers[math_warpgroup_idx]->wait(umma_phase); + ptx::tcgen05_after_thread_sync(); umma_phase ^= 1; // Release KV empty @@ -327,72 +375,65 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Reduce over the head dim and store DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); - constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN; - uint32_t shifted_accum[kNumLDTMElems]; - DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM"); - auto tmem_load = [&](auto... Is) { - if constexpr (kNumLDTMElems == 32) { - cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 64) { - cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 128) { - cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); - } - }; - [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); - cutlass::arch::fence_view_async_tmem_load(); - - tcgen05_before_thread_sync(); - empty_umma_barriers[warpgroup_idx]->arrive(); - - #pragma unroll - for (uint32_t i = 0; i < kNextN; ++ i) { - auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); - - auto sum_0 = make_float2(0, 0); - auto sum_1 = make_float2(0, 0); - const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { - auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); - auto b = make_float2(weights[i][j], weights[i][j + 1]); - return __ffma2_rn(a, b, sum); - }; + const auto reduce_and_store = [&](auto num_iters_c) { + constexpr uint32_t kNumIters = decltype(num_iters_c)::value; + float accum[kNumHeads]; #pragma unroll - for (int j = 0; j < kNumWeightsInReg; j += 4) { - sum_0 = transform_reg(j, sum_0); - sum_1 = transform_reg(j + 2, sum_1); + for (uint32_t i = 0; i < kNumIters; ++ i) { + // Load accumulator from TMEM + tmem_load(cute::Int{}, tmem_start + i * kNumHeads, accum); + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(scale_kv * (sum.x + sum.y)); + + // Store into the global memory + logits[kv_offset + i * static_cast(logits_stride) + math_thread_idx] = result; + __syncwarp(); } - const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { - auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); - auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), - ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); - return __ffma2_rn(a, b, sum); - }; - - #pragma unroll - for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { - sum_0 = transform_smem(j, sum_0); - sum_1 = transform_smem(j + 2, sum_1); - } - - auto sum = __fadd2_rn(sum_0, sum_1); - float result = scale_kv * (sum.x + sum.y); + // Release TMEM empty + ptx::tcgen05_before_thread_sync(); + empty_umma_barriers[math_warpgroup_idx]->arrive(); + }; - // Store into the global memory - // NOTES: we have redundant writes here, consider more carefully - logits[kv_offset + i * logits_stride + thread_idx] = result; + if constexpr (kIsVarlen) { + if (is_paired_atom) + reduce_and_store(cute::Int{}); + else + reduce_and_store(cute::Int<1>{}); + } else if constexpr (kPadOddN) { + if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1) + reduce_and_store(cute::Int<1>{}); + else + reduce_and_store(cute::Int{}); + } else { + reduce_and_store(cute::Int{}); } } - } else { - cutlass::arch::warpgroup_reg_dealloc(); - } - // Free tensor memory - __syncthreads(); - if (is_umma_warp) - cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } } } // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh index 4e4ff21d..aaf7fd9a 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh @@ -4,20 +4,22 @@ #include -#include +#include +#include +#include #include -#include -#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm100; - template -__device__ __forceinline__ +CUTLASS_DEVICE uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) { // Calculate the index of the bank group to be written in the atom - const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase); + const auto bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase); // Reshape the atom in another view and swizzle // - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)` @@ -37,7 +39,7 @@ template -__global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1) sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -58,7 +60,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Utils const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); // Align to 1024 bytes for swizzle-128B extern __shared__ __align__(1024) uint8_t smem_buffer[]; @@ -70,7 +72,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); // Real tensor memory size and offsets - constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); // Prefetch TMA descriptors at the very beginning if (warp_idx == 0 and cute::elect_one_sync()) { @@ -82,20 +84,20 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Data on shared memory (layout as ordered below) // Fill D/A/B pointers auto smem_cd = reinterpret_cast(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); - auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto empty_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4; // Fill the tensor memory pointer @@ -121,7 +123,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, } __syncthreads(); - constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K); constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); @@ -131,6 +133,9 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const uint32_t m_offset = shape_m * k_split_idx; const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Dispatch warps into different roles if (warp_idx < kNumMMAThreads / 32) { // TMA load warp @@ -145,8 +150,8 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, uint32_t k_idx = k_offset + s * BLOCK_K; // Issue TMAs - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + tma::copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma::copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); // Arrive at full barriers constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; @@ -168,7 +173,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); - auto b_desc = make_umma_desc(smem_b[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; // Checks for MMA instructions @@ -185,7 +190,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const auto& stage_idx = s % kNumStages; const auto& cast_stage_idx = s % kNumCastStages; full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Issue UMMA const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); @@ -194,7 +199,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK; const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK; const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK; - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, offset, in_atom_idx); + b_desc.lo = mma::sm100::advance_umma_desc_lo(b_desc_base_lo, offset, in_atom_idx); umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc); } @@ -218,7 +223,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Wait UMMA arrival tmem_full_barrier->wait(0); - tcgen05_after_thread_sync(); + ptx::tcgen05_after_thread_sync(); // Load from tensor memory into registers, and write shared memory with STSM DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough"); @@ -239,7 +244,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, values[0], values[1], values[2], values[3]); cutlass::arch::fence_view_async_tmem_load(); if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16)) - st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]); if constexpr (BLOCK_M == 64) __syncwarp(); } @@ -290,9 +295,9 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, #pragma unroll for (uint32_t i = 0; i < kNumLoads; i += 2) { auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset(i + lane_idx / 16, lane_idx % 16); - sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0], - uint32_values[0][i + 1], uint32_values[1][i + 1], - smem_ptr); + ptx::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0], + uint32_values[0][i + 1], uint32_values[1][i + 1], + smem_ptr); } // Wait tensor memory empty @@ -321,15 +326,15 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, cutlass::arch::fence_view_async_tmem_store(); // Arrive for issuing MMAs - tcgen05_before_thread_sync(); + ptx::tcgen05_before_thread_sync(); full_cast_barriers[cast_stage_idx]->arrive(); } // Intra-warp reduction and write back #pragma unroll for (uint32_t u = 0; u < 2; ++ u) { - const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y); - const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8; + const auto reduced_sum = math::warp_reduce_sum<4>(sum[u].x + sum[u].y); + const auto m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8; if (lane_idx % 4 == 0 and m_idx < shape_m) sqr_sum[m_offset + m_idx] = reduced_sum; } diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh index 7a77e4e8..84a149eb 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -11,14 +11,19 @@ #include #include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_bf16_gemm_impl(int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, @@ -51,7 +56,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; // Types - using WGMMA = typename BF16MMASelector::type; + using WGMMA = typename mma::sm90::BF16MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); @@ -61,7 +66,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; // Shared memory - static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(cd_dtype_t)), 1024u); + static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(cd_dtype_t)), 1024u); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); @@ -71,7 +76,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Configs const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_idx(); + const uint32_t lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors at the very beginning if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { @@ -88,17 +93,17 @@ sm90_bf16_gemm_impl(int* grouped_layout, // D/A/B shared memory auto smem_d = reinterpret_cast(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); // Initialize barriers if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { @@ -119,9 +124,12 @@ sm90_bf16_gemm_impl(int* grouped_layout, constexpr uint32_t kNumTMARegisters = 48; constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); // Pipeline and TMA phases uint32_t stage_idx = 0, phase = 0; @@ -151,7 +159,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { // Wait consumer release empty_barriers[stage_idx]->wait(phase ^ 1); @@ -159,31 +167,30 @@ sm90_bf16_gemm_impl(int* grouped_layout, constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; auto& full_barrier = *full_barriers[stage_idx]; - const auto m_idx = scheduler.template get_global_idx(shape_m, BLOCK_M, m_block_idx); - const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx); + const auto m_idx = scheduler.template get_global_idx(shape_m, BLOCK_M, m_block_idx); + const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx); DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); - uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); - uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); // Issue TMAs constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); if constexpr (kMajorA == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx); if constexpr (kMajorA == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::K) - tma_copy( + tma::copy( &tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::MN) - tma_copy( + tma::copy( &tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx); - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); } } @@ -203,8 +210,8 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Merged stages only happens in NT normal GEMM cases constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; - auto a_desc = make_gmma_desc(smem_a[0], math_wg_idx * WGMMA::M, 0); - auto b_desc = make_gmma_desc(smem_b[0], 0, 0); + auto a_desc = mma::sm90::make_gmma_desc(smem_a[0], math_wg_idx * WGMMA::M, 0); + auto b_desc = mma::sm90::make_gmma_desc(smem_b[0], 0, 0); const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); @@ -229,10 +236,10 @@ sm90_bf16_gemm_impl(int* grouped_layout, }; // TODO: remove some useless computation for unaligned Ms - const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { - const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); - const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); // Wait TMA arrivals full_barriers[stage_idx]->wait(phase); @@ -240,26 +247,26 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - const uint32_t& atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K; - a_desc.reg32_[0] = advance_gmma_desc_lo( + const uint32_t atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K; + a_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo( a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K); - b_desc.reg32_[0] = advance_gmma_desc_lo( + b_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo( b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K); WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1); } } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival empty_barrier_arrive(stage_idx); @@ -324,7 +331,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, } // NOTES: only 16 lanes' addresses are used - SM90_U32x2_STSM_N::copy( + ptx::SM90_U32x2_STSM_N::copy( __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), smem_ptr @@ -341,8 +348,8 @@ sm90_bf16_gemm_impl(int* grouped_layout, auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1])); - st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3])); + ptx::st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1])); + ptx::st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3])); } } } @@ -350,7 +357,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); // Use TMA store to write back to global memory - const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh index 191a4fe2..7c344296 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh @@ -4,26 +4,32 @@ #include #include +#include #include -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, float *d) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Types - using WGMMA = typename BF16MMASelector::type; + using WGMMA = typename mma::sm90::BF16MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); @@ -33,7 +39,7 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, // Configs const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_idx(); + const uint32_t lane_idx = ptx::get_lane_idx(); DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M"); DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads"); DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads"); @@ -48,17 +54,17 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, // Align to 1024 bytes for swizzle-128B // Fill shared memory pointers extern __shared__ __align__(1024) uint8_t smem_buffer[]; - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); // Initialize barriers if (warp_idx == 1 and cute::elect_one_sync()) { @@ -80,14 +86,17 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, constexpr uint32_t kNumMathRegisters = 232; // Block indices - const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); - const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M); const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; const uint32_t n_block_idx = mn_block_idx % num_n_blocks; const uint32_t m_block_idx = mn_block_idx / num_n_blocks; const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (warp_idx >= kNumMathThreads / 32) { // TMA warp-group for loading data cutlass::arch::warpgroup_reg_dealloc(); @@ -98,18 +107,18 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, #pragma unroll for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait consumer release - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1); auto& full_barrier = *full_barriers[stage_idx]; - const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; - const uint32_t& k_idx = sk_idx % SHAPE_K; - const uint32_t& s_idx = sk_idx / SHAPE_K; + const uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + const uint32_t k_idx = sk_idx % SHAPE_K; + const uint32_t s_idx = sk_idx / SHAPE_K; constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16); - tma_copy( + tma::copy( &tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1); - tma_copy( + tma::copy( &tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); } @@ -125,32 +134,32 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, // Launch MMAs for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait TMA arrivals - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; full_barriers[stage_idx]->wait((s / kNumStages) & 1); // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); WGMMA::wgmma(desc_a, desc_b, accum, 1); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival at the last warpgroup wave empty_barriers[stage_idx]->arrive(); } - const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4; - const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2; + const auto row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4; + const auto col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2; #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { if (col + i * 8 >= SHAPE_N) diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh index cdd28fcb..195d431f 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -6,18 +6,26 @@ #include #include +#include #include #include #include +#include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, int* grouped_layout, cute::TmaDescriptor* tensor_map_buffer, @@ -45,7 +53,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type"); // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); @@ -55,13 +63,13 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; // Shared memory - static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0); + static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 2 : 0); static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u); + static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u); DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment"); // Configs @@ -83,47 +91,41 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); // Tensor maps on shared and global memory - auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * i); - }); - auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * (2 + i)); - }); - auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; }); - auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; }); + auto smem_tensor_map_a = reinterpret_cast(smem_buffer); + auto smem_tensor_map_b = smem_tensor_map_a + 1; + auto gmem_tensor_map_a = tensor_map_buffer + blockIdx.x * 2; + auto gmem_tensor_map_b = gmem_tensor_map_a + 1; // Data on shared memory auto smem_d = reinterpret_cast(smem_buffer + SMEM_TENSOR_MAP_SIZE); - auto smem_a = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); - auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE)); }); - auto smem_sfb = PatternVisitor([&](const uint32_t& i) { + auto smem_sfb = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE)); }); // Barriers on shared memory constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE); - auto full_barriers = PatternVisitor([&](const uint32_t& i) { + auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast(sizeof(Barrier)))); }); - auto empty_barriers = PatternVisitor([&](const uint32_t& i) { + auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast(sizeof(Barrier)))); }); if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { // Load tensormap A/B to shared memory if constexpr (kGemmType == GemmType::KGroupedContiguous) { - *smem_tensor_map_a[0] = tensor_map_a_base; - *smem_tensor_map_a[1] = tensor_map_a_base; - *smem_tensor_map_b[0] = tensor_map_b_base; - *smem_tensor_map_b[1] = tensor_map_b_base; + *smem_tensor_map_a = tensor_map_a_base; + *smem_tensor_map_b = tensor_map_b_base; } // Initialize barriers @@ -149,12 +151,15 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24); constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); // TMA and MMA pipeline - const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple { + const auto get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple { return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase }; uint32_t iter_idx = 0; @@ -165,9 +170,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // NOTES: only one thread (or warp) will be used if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { - const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base; - const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base; - uint32_t last_group_idx = kNumGroups, sum_k = 0; + uint32_t last_group_idx = kNumGroups; // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { @@ -177,35 +180,27 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - - const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); - const uint32_t& m_idx = m_block_idx * BLOCK_M; - const uint32_t& n_idx = n_block_idx * BLOCK_N; - - if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) { - const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1; - const uint32_t& next_stage_idx = stage_idx ^ 1; + + const uint32_t num_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + const uint32_t m_idx = m_block_idx * BLOCK_M; + const uint32_t n_idx = n_block_idx * BLOCK_N; + + if (kGemmType == GemmType::KGroupedContiguous && last_group_idx != scheduler.current_group_idx) { last_group_idx = scheduler.current_group_idx; - // Prepare next tensor map - sum_k += scheduler.current_shape_k; - if (scheduler.next_group_idx < kNumGroups) { - tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast(sum_k) * shape_m); - tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast(sum_k) * shape_n); - tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); - tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); - *(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]); - *(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]); - tensor_map_release_cta(); - } - - // Get current tensor map - if (scheduler.current_num_valid_groups > 0) { - tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]); - tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]); - current_tensor_map_a = gmem_tensor_map_a[stage_idx]; - current_tensor_map_b = gmem_tensor_map_b[stage_idx]; - } + // Directly update current tensor map + const uint64_t current_k_offset = scheduler.current_k_cumsum; + ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_a, gmem_a_ptr + current_k_offset * shape_m); + ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_b, gmem_b_ptr + current_k_offset * shape_n); + ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a, scheduler.current_shape_k, scheduler.current_shape_k); + ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b, scheduler.current_shape_k, scheduler.current_shape_k); + *(gmem_tensor_map_a) = *(smem_tensor_map_a); + *(gmem_tensor_map_b) = *(smem_tensor_map_b); + ptx::tensor_map_release_gpu(); + + // Immediately acquire current tensor map + ptx::tensor_map_acquire_gpu(gmem_tensor_map_a); + ptx::tensor_map_acquire_gpu(gmem_tensor_map_b); } #pragma unroll kNumPipelineUnrolls @@ -216,12 +211,14 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // Issue TMA auto& full_barrier = *full_barriers[stage_idx]; - const uint32_t& k_idx = k_block_idx * BLOCK_K; - const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; - tma_copy(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); - tma_copy(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); - tma_copy(current_tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); - tma_copy(current_tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); + const uint32_t k_idx = k_block_idx * BLOCK_K; + const uint32_t sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; + const auto tensor_map_a_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_a : &tensor_map_a_base); + const auto tensor_map_b_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_b : &tensor_map_b_base); + tma::copy(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); + tma::copy(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); + tma::copy(tensor_map_a_ptr, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); + tma::copy(tensor_map_b_ptr, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE); } } @@ -248,9 +245,9 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, while (scheduler.get_next_block(m_block_idx, n_block_idx)) { // Accumulation for WGMMA or CUDA promotion DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes"); - const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); - const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0); - const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K); + const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); + const uint32_t current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0); + const uint32_t num_k_blocks = math::ceil_div(current_shape_k, BLOCK_K); float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; float2 scales_b[WGMMA::kNumAccum / 4]; @@ -272,30 +269,30 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // Read A scales // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0); - auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1); + auto scale_a_0 = ptx::ld_shared(smem_sfa[stage_idx] + r_0); + auto scale_a_1 = ptx::ld_shared(smem_sfa[stage_idx] + r_1); // Read B scales #pragma unroll for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) - scales_b[i] = ld_shared(reinterpret_cast(smem_sfb[stage_idx] + i * 8 + col_idx * 2)); + scales_b[i] = ptx::ld_shared(reinterpret_cast(smem_sfb[stage_idx] + i * 8 + col_idx * 2)); // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); WGMMA::wgmma(desc_a, desc_b, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival empty_barrier_arrive(stage_idx); @@ -318,12 +315,12 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, cutlass::arch::NamedBarrier::sync(128, math_wg_idx); // Store to D shared memory - const auto& smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); - const auto& smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); + const auto smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); + const auto smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); #pragma unroll for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); - st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); + ptx::st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); + ptx::st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); } cute::tma_store_fence(); cutlass::arch::NamedBarrier::sync(128, math_wg_idx); diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 9247304c..aa412484 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -10,17 +10,21 @@ #include #include -#include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) { +CUTLASS_DEVICE void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) { if (num_former_iters == kNumFormerIters) { func(cute::Int{}); return; @@ -35,12 +39,12 @@ template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, @@ -50,10 +54,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + DG_STATIC_ASSERT( + math::constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or + (math::constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); @@ -64,23 +70,23 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Shared memory static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); - static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(__nv_bfloat16)), 1024u); + static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(__nv_bfloat16)), 1024u); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u); - const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); - const uint32_t& shape_n_sfb = ceil_div(shape_n, BLOCK_K); - const uint32_t& smem_sfb_size = align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); + static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u); + const uint32_t shape_k_scales = math::ceil_div(shape_k, BLOCK_K); + const uint32_t shape_n_sfb = math::ceil_div(shape_n, BLOCK_K); + const uint32_t smem_sfb_size = math::align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); // NOTES: Make sure we have enough shared memory for WGMMA padding static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); // Configs - const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K); + const uint32_t num_total_k_blocks = math::ceil_div(shape_k, BLOCK_K); const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const uint32_t lane_idx = get_lane_idx(); + const uint32_t lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors at the very beginning if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { @@ -97,22 +103,22 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Data on shared memory auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); }); constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); - auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); }); auto smem_sfb = reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); // Fill barriers auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); - auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); - auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); + auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); + auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); // Initialize barriers DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); @@ -136,9 +142,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, constexpr uint32_t kNumTMARegisters = 40; constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); // Pipeline and TMA phases uint32_t stage_idx = 0, phase = 0; @@ -177,15 +186,15 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; auto& full_barrier = *full_barriers[stage_idx]; const uint32_t k_idx = k_block_idx * BLOCK_K; - tma_copy(&tensor_map_a, &full_barrier, + tma::copy(&tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), num_tma_multicast_a, batch_idx); - tma_copy(&tensor_map_sfa, &full_barrier, - smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx(shape_k_scales, 1, k_block_idx), + tma::copy(&tensor_map_sfa, &full_barrier, + smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx(shape_k_scales, 1, k_block_idx), num_tma_multicast_a); // Issue TMA B - tma_copy(&tensor_map_b, &full_barrier, + tma::copy(&tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), num_tma_multicast_b, batch_idx); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); @@ -206,8 +215,8 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; - auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1); - auto b_desc = make_smem_desc(smem_b[0], 1); + auto a_desc = mma::sm90::make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1); + auto b_desc = mma::sm90::make_smem_desc(smem_b[0], 1); const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); @@ -225,14 +234,14 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Load B scales with math warp-groups // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks if (threadIdx.x >= 32) { - auto previous_group_offset = scheduler.template get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); + auto previous_group_offset = scheduler.template get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales; const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1; auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb; #pragma unroll for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) - st_shared(smem_sfb + i, __ldg(i < shape_k_scales ? local_sfb + i * stride_k_sfb : local_sfb + (i - shape_k_scales) * stride_k_sfb + stride_n_sfb)); + ptx::st_shared(smem_sfb + i, i < shape_k_scales ? local_sfb[i * stride_k_sfb] : local_sfb[(i - shape_k_scales) * stride_k_sfb + stride_n_sfb]); } cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); @@ -259,22 +268,22 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Skip useless computations if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) { // The compiler must know the dynamic variable `num_former_iters`'s real value - constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; - constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; + constexpr bool kShouldOptimize = BLOCK_K / math::constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr uint32_t kGap = math::constexpr_gcd(BLOCK_K, BLOCK_N) / 8; constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; // Dispatch `num_former_iters` and launch MMAs dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) { #pragma unroll 8 for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { - const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); - const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); // Read B scales - float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1; + float scale_b_0 = ptx::ld_shared(smem_sfb + k_block_idx), scale_b_1; // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales); + scale_b_1 = ptx::ld_shared(smem_sfb + k_block_idx + shape_k_scales); // Wait TMA arrivals full_barriers[stage_idx]->wait(phase); @@ -286,25 +295,25 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Read A scales // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0; - auto scale_a_1 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0; + auto scale_a_0 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0; + auto scale_a_1 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0; // Commit WGMMA instructions #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16; b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16; WGMMA::wgmma(a_desc, b_desc, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Notify barrier arrival at the last warpgroup wave if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) @@ -325,7 +334,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant - const bool& predicate = kMustUseUniformedScaleB or i < num_former_iters; + const bool predicate = kMustUseUniformedScaleB or i < num_former_iters; shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; @@ -399,7 +408,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, } // NOTES: only 16 lanes' addresses are used - SM90_U32x2_STSM_N::copy( + ptx::SM90_U32x2_STSM_N::copy( __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), smem_ptr diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh index d58c7162..225af441 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh @@ -7,36 +7,31 @@ #include #include +#include +#include #include -#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - -// ReSharper disable once CppNotAllPathsReturnValue -template -static constexpr int to_swizzle_cute_type() { - DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling"); - if constexpr (kHeadDim == 32) - return static_cast(cute::SM90::GMMA::LayoutType::B32); - if constexpr (kHeadDim == 64) - return static_cast(cute::SM90::GMMA::LayoutType::B64); - if constexpr (kHeadDim == 128) - return static_cast(cute::SM90::GMMA::LayoutType::B128); -} - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) + uint32_t kNumSMs, + uint32_t kNumTMAThreads, uint32_t kNumMathThreads, + typename logits_dtype_t> +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, - const uint32_t max_seqlen_k, const uint64_t stride_logits, + const uint32_t max_seqlen_k, const uint32_t stride_logits, uint32_t* cu_seq_len_k_start, uint32_t* cu_seq_len_k_end, - float* logits, + logits_dtype_t* logits, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, @@ -44,10 +39,10 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // TODO: consider TMA multicast // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` // Q should be load only at once for a block - const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q); // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; // Prefetch TMA descriptors @@ -74,19 +69,19 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); // Data on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); }); - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i); @@ -94,13 +89,13 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // TMA barriers auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); // Initialize barriers - const bool& is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32; + const bool is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32; if (is_tma_load_warp and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumQStages; ++ i) { @@ -123,38 +118,43 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, constexpr uint32_t kNumMathRegisters = 112; // Block scheduler - uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; - const auto& get_next_block_q_idx = [&]() -> cute::tuple { - return {block_q_idx + gridDim.x, q_iter_idx + 1}; + const auto sm_idx = blockIdx.x; + uint32_t block_q_idx = sm_idx, q_iter_idx = 0; + const auto get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + kNumSMs, q_iter_idx + 1}; }; uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; - const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { uint32_t start = cute::numeric_limits::max(); uint32_t end = cute::numeric_limits::min(); #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); - seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); - seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = cu_seq_len_k_start[q_idx]; + seq_k_end[i] = cu_seq_len_k_end[q_idx]; start = min(start, min(seq_k_start[i], seq_len_kv)); end = max(end, min(seq_k_end[i], seq_len_kv)); } + // TMA alignment requirements for SF KV start = start / 4 * 4; return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase - start, ceil_div(end - start, BLOCK_KV)}; // Task info + start, math::ceil_div(end - start, BLOCK_KV)}; // Task info }; // KV pipeline uint32_t num_total_kv_blocks = 0; - const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { return { (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase }; }; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (threadIdx.x >= kNumMathThreads) { // TMA warp-group for loading data cutlass::arch::warpgroup_reg_dealloc(); @@ -165,8 +165,8 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Prefetch const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); }; if (cute::elect_one_sync() and block_q_idx < num_q_blocks) @@ -192,9 +192,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); // Issue TMA KV - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } @@ -212,7 +212,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const auto& thread_idx = threadIdx.x % kNumMathThreads; const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0); const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + const auto& lane_idx = ptx::get_lane_idx(); float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4]; const auto& warp_offset = warp_idx * 16; @@ -230,7 +230,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, for (uint32_t i = 0; i < BLOCK_Q; ++ i) { #pragma unroll for (uint32_t j = 0; j < kNumHeads / 4; ++ j) - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); } // Compute over KV blocks @@ -242,29 +242,31 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Read per-KV scales - float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset); - float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset); + float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset); + float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset); // Issue WGMMA DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size"); DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K, - to_swizzle_cute_type(), 0, kHeadDim * 8); - auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, - to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_a = mma::sm90::make_smem_desc( + smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = mma::sm90::make_smem_desc( + smem_q[q_stage_idx] + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); WGMMA::wgmma(desc_a, desc_b, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); // Release KV empty empty_kv_barriers[kv_stage_idx]->arrive(); @@ -278,7 +280,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { auto shifted_accum = accum + i * kNumAccumPerReduce; - const auto& transform = [&](const uint32_t& j) { + const auto transform = [&](const uint32_t& j) { return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; }; @@ -302,16 +304,15 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, } // Store into the global memory - // NOTES: we have redundant writes here, consider more carefully - const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast(stride_logits); if constexpr (kIsCompressedLogits) { if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i]) - logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0; + logits[q_offset + kv_offset + v_0_offset - seq_k_start[i]] = static_cast(v_0); if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i]) - logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1; + logits[q_offset + kv_offset + v_1_offset - seq_k_start[i]] = static_cast(v_1); } else { - logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0; - logits[q_idx * stride_logits + kv_offset + v_1_offset] = v_1; + logits[q_offset + kv_offset + v_0_offset] = static_cast(v_0); + logits[q_offset + kv_offset + v_1_offset] = static_cast(v_1); } } } diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh index 482a85a8..cc2592bb 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh @@ -6,133 +6,46 @@ #include #include +#include +#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -template -__global__ __launch_bounds__(32, 1) -void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d, - const uint32_t* context_lens, uint32_t* schedule_metadata) { - DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); - const uint32_t lane_idx = get_lane_idx(); - - uint32_t num_segs[kAlignedBatchSize / 32]; - #pragma unroll - for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { - const uint32_t q_idx = k * 32 + lane_idx; - const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); - const uint32_t& context_len = (q_idx < batch_size ? __ldg(context_lens + lens_idx) : 0); - num_segs[k] = ceil_div(context_len, SPLIT_KV); - } - - __shared__ uint32_t prefix_sum[kAlignedBatchSize]; - uint32_t sum = 0; - #pragma unroll - for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { - uint32_t x = num_segs[k]; - #pragma unroll - for (uint32_t offset = 1; offset < 32; offset <<= 1) { - const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset); - x += (lane_idx >= offset ? y : 0); - } - x += sum; - prefix_sum[k * 32 + lane_idx] = x; - sum = __shfl_sync(0xffffffff, x, 31); - } - - const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs; - for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { - uint32_t seg_starts = sm_idx * q + min(sm_idx, r); - uint32_t q_idx = 0; - while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts) - ++ q_idx; - const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]); - __syncwarp(); - - schedule_metadata[sm_idx * 2] = q_idx; - schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; - } -} - -template -struct PagedMQALogitsScheduler { - uint32_t batch_size; - const uint32_t* context_lens; - - uint32_t current_q_idx, current_kv_idx; - uint32_t end_q_idx, end_kv_idx; - uint32_t current_num_kv; - - __device__ __forceinline__ uint32_t get_num_kv(const uint32_t& q_idx) { - const auto& lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); - return q_idx < batch_size ? ceil_div(__ldg(context_lens + lens_idx), BLOCK_KV) : 0; - } - - __device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx, - const uint32_t* context_lens, const uint32_t* schedule_meta) { - this->batch_size = batch_size; - this->context_lens = context_lens; - - const auto& current_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx); - const auto& end_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx + 1); - current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit; - end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit; - - current_num_kv = get_num_kv(current_q_idx); - } - - __device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) { - q_idx = current_q_idx; - kv_idx = current_kv_idx; - num_kv = current_num_kv; - - if (q_idx == end_q_idx and kv_idx == end_kv_idx) - return false; - - current_kv_idx += kNumBlocksPerSplit; - if (current_kv_idx >= current_num_kv) { - ++ current_q_idx; - current_kv_idx = 0; - current_num_kv = get_num_kv(current_q_idx); - } - - return true; - } - - __device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const { - return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx; - } -}; - -using namespace deep_gemm::sm90; - template -__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) + uint32_t kNumTMAThreads, uint32_t kNumMathThreads, + typename logits_dtype_t> +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, - const uint64_t logits_stride, const uint64_t block_table_stride, - const uint32_t* context_lens, float* logits, - const uint32_t* block_table, const uint32_t* schedule_meta, + const uint32_t logits_stride, const uint32_t block_table_stride, + const uint32_t* context_lens, logits_dtype_t* logits, + const uint32_t* block_table, const uint32_t* indices, + const uint32_t* schedule_meta, const __grid_constant__ cute::TmaDescriptor tensor_map_q, const __grid_constant__ cute::TmaDescriptor tensor_map_kv, const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + DG_STATIC_ASSERT(not kIsVarlen, "Varlen is not supported for SM90 paged MQA logits"); + // Types - using WGMMA = typename FP8MMASelector::type; + using WGMMA = typename mma::sm90::FP8MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers - const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - const auto& warpgroup_idx = warp_idx / 4; - const auto& lane_idx = get_lane_idx(); + const auto warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); // Prefetch TMA descriptors static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128; @@ -150,15 +63,15 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = math::constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) + - constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); + math::constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) + - constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); + math::constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); // Align to swizzling alignment bytes extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; @@ -166,31 +79,31 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); // Q data and barriers on shared memory - auto smem_q = PatternVisitor([&](const uint32_t& i) { + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i); }); auto q_barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); // Separate math warpgroups and tma load warps into KV groups // Each math warpgroup corresponds to a tma load warp - const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); + const auto kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); // Per group KV data and barriers on shared memory - const auto& smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; - auto smem_kv = PatternVisitor([&](const uint32_t& i) { + const auto smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i); }); - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); }); auto kv_barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); // Initialize barriers if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) { @@ -218,15 +131,19 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, constexpr uint32_t kNumTMARegisters = 64; constexpr uint32_t kNumMathRegisters = 104; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Scheduler - auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + auto scheduler = sched::PagedMQALogitsScheduler( + blockIdx.x, batch_size, context_lens, schedule_meta, indices); DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); // Q and KV pipeline - const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase }; - const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase }; uint32_t q_iter_idx = 0, kv_iter_idx = 0; @@ -237,10 +154,10 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, if (kv_group_idx >= kNumMathWarpGroups) return; - const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { if (kv_group_idx == 0 and cute::elect_one_sync()) { - tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); - tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx * kNextN); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); } }; @@ -259,7 +176,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, while (fetched_next_task) { // Prefetch next Q when current Q changes - bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); + bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_atom_idx(next_q_idx + 1)); q_idx = next_q_idx; kv_idx = next_kv_idx; num_kv = next_num_kv; @@ -276,9 +193,9 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, if (kv_idx == 0 or kv_block_idx_ptr == 32) { kv_block_idx_ptr = 0; kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ? - __ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0); + block_table[q_idx * static_cast(block_table_stride) + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)] : 0); } - const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); + const auto kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); // Wait KV consumer release CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); @@ -286,10 +203,10 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, // Issue TMA KV if (cute::elect_one_sync()) { - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], - smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], - smem_kv_scales[kv_stage_idx], 0, kv_block_idx); + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], 0, kv_block_idx); full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } @@ -301,9 +218,9 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, cutlass::arch::warpgroup_reg_alloc(); float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4]; - const auto& sub_warp_offset = (warp_idx % 4) * 16; - const auto& v_0_offset = lane_idx / 4 + 0; - const auto& v_1_offset = lane_idx / 4 + 8; + const auto sub_warp_offset = (warp_idx % 4) * 16; + const auto v_0_offset = lane_idx / 4 + 0; + const auto v_1_offset = lane_idx / 4 + 8; // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none uint32_t q_idx = batch_size, kv_idx; @@ -326,7 +243,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, for (uint32_t i = 0; i < kNextN; ++ i) { #pragma unroll for (uint32_t j = 0; j < kNumHeads / 4; ++ j) - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); } } @@ -335,7 +252,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, kv_idx = next_kv_idx; // Calculate KV offset in advance - auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); + auto kv_offset = q_idx * kNextN * static_cast(logits_stride) + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]` // Wait TMA KV arrival @@ -347,25 +264,29 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); - auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_a = mma::sm90::make_smem_desc( + smem_kv[kv_stage_idx] + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = mma::sm90::make_smem_desc( + smem_q[q_stage_idx] + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); WGMMA::wgmma(desc_a, desc_b, accum, k); } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum[i]); // Read per-KV scales - float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset); - float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset); + float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset); + float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset); // Wait WGMMA - warpgroup_wait<0>(); + ptx::warpgroup_wait<0>(); // Release KV empty empty_kv_barriers[kv_stage_idx]->arrive(); @@ -378,7 +299,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, #pragma unroll for (uint32_t i = 0; i < kNextN; ++ i) { auto shifted_accum = accum + i * kNumAccumPerReduce; - const auto& transform = [&](const uint32_t& j) { + const auto transform = [&](const uint32_t& j) { return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; }; @@ -396,15 +317,15 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, // Inter-thread reduction #pragma unroll for (uint32_t j = 0; j < 2; ++ j) { - const auto& offset = static_cast(1u << j); + const auto offset = static_cast(1u << j); v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); } // Store into the global memory // NOTES: we have redundant writes here, consider more carefully - logits[kv_offset + i * logits_stride + v_0_offset] = v_0; - logits[kv_offset + i * logits_stride + v_1_offset] = v_1; + logits[kv_offset + i * static_cast(logits_stride) + v_0_offset] = static_cast(v_0); + logits[kv_offset + i * static_cast(logits_stride) + v_1_offset] = static_cast(v_1); } } } diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh index e3bf9847..93b14100 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh @@ -5,20 +5,23 @@ #include #include -#include +#include #include -#include +#include +#include +#include +#include +#include +#include namespace deep_gemm { -using namespace deep_gemm::sm90; - template -__device__ __forceinline__ +CUTLASS_DEVICE uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) { constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase; - const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange; + const auto bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange; constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups; @@ -35,7 +38,7 @@ template -__global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1) +CUTLASS_GLOBAL void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1) sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -56,7 +59,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Utils const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); // Align to 1024 bytes for swizzle-128B extern __shared__ __align__(1024) uint8_t smem_buffer[]; @@ -76,17 +79,17 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // Data on shared memory (layout as ordered below) // Fill D/A/B pointers auto smem_cd = reinterpret_cast(smem_buffer); - auto smem_a = PatternVisitor([&](const uint32_t& i) { + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); }); - auto smem_b = PatternVisitor([&](const uint32_t& i) { + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); // Initialize barriers if (warp_idx == 1 and cute::elect_one_sync()) { @@ -101,7 +104,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, } __syncthreads(); - constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K); constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); @@ -113,12 +116,15 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, constexpr uint32_t kNumTMARegisters = 40; constexpr uint32_t kNumMathRegisters = 256; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // TMA load warp if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { cutlass::arch::warpgroup_reg_dealloc(); for (uint32_t s = 0; s < num_total_stages; ++ s) { // Wait consumer release - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); // Compute offsets @@ -126,8 +132,8 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, uint32_t k_idx = k_offset + s * BLOCK_K; // Issue TMAs - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + tma::copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma::copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); // Arrive at full barriers constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; @@ -135,7 +141,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, } for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) { - const auto& stage_idx = s % kNumStages; + const auto stage_idx = s % kNumStages; empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); } } else if (warp_idx < kNumMathThreads / 32) { @@ -148,7 +154,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, constexpr uint32_t WGMMA_N = BLOCK_N; constexpr uint32_t WGMMA_K = 8; - using WGMMA = typename TF32MMASelector::type; + using WGMMA = typename mma::sm90::TF32MMASelector::type; float accum[WGMMA::kNumAccum] = {0}; constexpr uint32_t kNumBankGroupBytes = 16; @@ -196,14 +202,14 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y; } - warpgroup_wait<0>(); + ptx::warpgroup_wait<0>(); if (s > 0) empty_barriers[(s - 1) % kNumStages]->arrive(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float); constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K; @@ -213,18 +219,19 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) { #pragma unroll for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) { - auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1); + auto b_desc = mma::sm90::make_smem_desc( + smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1); WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1); } } - warpgroup_commit_batch(); + ptx::warpgroup_commit_batch(); #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); + ptx::warpgroup_fence_operand(accum[i]); } - const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0); - const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1); + const auto& reduced_sum_0 = math::warp_reduce_sum<4>(sqr_sum_acc_0); + const auto& reduced_sum_1 = math::warp_reduce_sum<4>(sqr_sum_acc_1); const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4); if (lane_idx % 4 == 0) { @@ -233,7 +240,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, if (m_idx + 8 < shape_m) sqr_sum[m_offset + m_idx + 8] = reduced_sum_1; } - warpgroup_wait<0>(); + ptx::warpgroup_wait<0>(); empty_barriers[(num_total_stages-1) % kNumStages]->arrive(); // Write accum to shared memory @@ -260,8 +267,8 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, // 0/1 write to the same row, 2/3 write to another row auto values = reinterpret_cast(accum + i * 2); - st_shared(smem_ptr, values[0], values[1]); - st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]); + ptx::st_shared(smem_ptr, values[0], values[1]); + ptx::st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]); } cute::tma_store_fence(); cutlass::arch::NamedBarrier::sync(128, 1); diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh index cc9e5e6b..2f66b980 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh @@ -3,21 +3,24 @@ #include #include -#include +#include +#include namespace deep_gemm { -template -__global__ __launch_bounds__(kNumWarps * 32, 1) +template +CUTLASS_GLOBAL __launch_bounds__(kNumWarps * 32, 1) void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits, - const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) { - const uint32_t& num_sms = gridDim.x; - const uint32_t& sm_idx = blockIdx.x; - const uint32_t& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - constexpr float neg_inf = -cute::numeric_limits::infinity(); + const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, logits_dtype_t* logits) { + const uint32_t num_sms = gridDim.x; + const uint32_t sm_idx = blockIdx.x; + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + constexpr uint32_t kAlignment = 16 / sizeof(logits_dtype_t); + const logits_dtype_t neg_inf = -cute::numeric_limits::infinity(); // Allocate filled `-inf` shared memory - extern __shared__ __align__(1024) float smem_buffer[]; + extern __shared__ __align__(1024) logits_dtype_t smem_buffer[]; #pragma unroll for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32) smem_buffer[i] = neg_inf; @@ -25,38 +28,42 @@ void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const __syncthreads(); // Assign sequence to each warp - const auto& assign_task = [&](const uint32_t& num, const uint32_t& idx, - const uint32_t& start, const uint32_t& total) -> cute::tuple { - const auto& per = total / num, rem = total % num; - return {start + idx * per + min(idx, rem), per + (idx < rem)}; + const auto assign_task = [&](const uint32_t& num, const uint32_t& idx, + const uint32_t& start, const uint32_t& total) -> cute::tuple { + const auto per = total / num, rem = total % num; + return {start + idx * per + cute::min(idx, rem), per + (idx < rem)}; }; CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len); CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + if (cute::elect_one_sync()) { for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { - const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); - const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; - const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN]; + const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1; + const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment; for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) { - const auto& right = min(left + BLOCK_KV, static_cast(stride_logits)); + const auto right = cute::min(left + BLOCK_KV, static_cast(stride_logits)); if (right <= ks or ke <= left) { - cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(float)); + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(logits_dtype_t)); } else { if (left < aligned_ks) - cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(float)); + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(logits_dtype_t)); if (aligned_ke < right) - cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(float)); + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(logits_dtype_t)); } } } } + __syncwarp(); for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { - const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); - const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; - const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN]; + const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1; + const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment; for (uint32_t j = aligned_ks; j < ks; ++ j) logits[i * stride_logits + j] = neg_inf; for (uint32_t j = ke; j < aligned_ke; ++ j) diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh index bea70002..a977c554 100644 --- a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh @@ -1,13 +1,16 @@ #pragma once +#include #include +#include +#include namespace deep_gemm { template -__global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { - typedef typename Vectorized::vec_t in_vec_t; +CUTLASS_GLOBAL void transpose_fp32(const float* sf, float* out, const uint32_t mn) { + typedef typename utils::Vectorized::vec_t in_vec_t; constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float); constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec; @@ -15,16 +18,19 @@ __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { extern __shared__ float smem_buffer[]; constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(float)); const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); - const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + const auto tma_aligned_mn = math::align(mn, kNumTMAAlignedElems); // Shift into the block sf = sf + static_cast(blockIdx.y) * mn * SF_K; out = out + static_cast(blockIdx.y) * tma_aligned_mn * SF_K; const auto& local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Load for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) { - auto in_vec = __ldg(local_sf + i); + auto in_vec = local_sf[i]; const auto& in_values = reinterpret_cast(&in_vec); const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec; @@ -39,26 +45,29 @@ __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) { const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn; const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; - out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx); + out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ptx::ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx); } } // NOTES: the two kernels below always pack the K dimension template -__global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) { +CUTLASS_GLOBAL void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) { extern __shared__ uint32_t smem_buffer[]; // Shapes and strides - constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u); + constexpr auto kNumPackedSFK = math::constexpr_ceil_div(SF_K, 4u); constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(int)); const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); - const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + const auto tma_aligned_mn = math::align(mn, kNumTMAAlignedElems); // Shift into the group sf = sf + static_cast(blockIdx.y) * mn * SF_K; out = out + static_cast(blockIdx.y) * tma_aligned_mn * kNumPackedSFK; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Load FP32 SFs DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size"); const auto local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); @@ -66,13 +75,13 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con const auto num_uint4 = num_values / 4; #pragma unroll for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) { - const auto& [x, y, z, w] = __ldg(reinterpret_cast(local_sf) + i); - st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); + const auto& [x, y, z, w] = reinterpret_cast(local_sf)[i]; + ptx::st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); } // Fill unaligned values as well if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values) - st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx)); + ptx::st_shared(smem_buffer + unaligned_idx, local_sf[unaligned_idx]); __syncthreads(); // Pack into UE8M0 and store @@ -85,7 +94,7 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con #pragma unroll for (uint32_t j = 0; j < 4; ++ j) { const auto sf_k_idx = sf_k_pack_idx * 4 + j; - values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0; + values[j] = sf_k_idx < SF_K ? ptx::ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0; } // Pack and store @@ -101,8 +110,9 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con template -__global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, - const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) { +CUTLASS_GLOBAL void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, + const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k, + const uint32_t gran_k) { // Always packing the K dimension // NOTES: should also assert `mn % 4 == 0` at launch DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)"); @@ -120,11 +130,14 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, // Each warp is responsible for a packed row const auto warp_idx = threadIdx.x / 32; - const auto lane_idx = get_lane_idx(); + const auto lane_idx = ptx::get_lane_idx(); const auto packed_sf_k_idx = static_cast(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx; if (warp_idx >= in_block_packed_sf_k) return; + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + // Make an offset on the input uint32_t input_offset = 0; if constexpr (kNumGroups > 1) { @@ -134,18 +147,18 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, #pragma unroll for (uint32_t i = 0; i < 4; ++ i) { const auto group_idx = lane_idx * 4 + i; - group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0; + group_ks[i] = group_idx < kNumGroups ? ks[group_idx] : 0; } __syncwarp(); // Make the offset sf_k = 0; - auto sum_packed_sf_k = 0; + uint32_t sum_packed_sf_k = 0; #pragma unroll for (uint32_t i = 0; i < kNumGroups; ++ i) { - const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4); + const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / gran_k, i / 4); sf_k += sf_k_in_group; - sum_packed_sf_k += ceil_div(sf_k_in_group, 4u); + sum_packed_sf_k += math::ceil_div(sf_k_in_group, 4u); if (packed_sf_k_idx < sum_packed_sf_k) break; if (const auto remainder = sf_k_in_group % 4; remainder > 0) @@ -153,14 +166,14 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, } } - for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) { + for (uint32_t mn_idx = ptx::get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) { // Load uint4 values[4]; #pragma unroll for (uint32_t j = 0; j < 4; ++ j) { values[j] = make_uint4(0, 0, 0, 0); if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k) - values[j] = __ldg(reinterpret_cast(sf + sf_k_idx * mn) + mn_idx); + values[j] = reinterpret_cast(sf + sf_k_idx * mn)[mn_idx]; } // Pack and store diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/layout/mega_moe.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/layout/mega_moe.cuh new file mode 100644 index 00000000..13520c60 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/layout/mega_moe.cuh @@ -0,0 +1,260 @@ +#pragma once + +#include + +#include +#include + +namespace deep_gemm::layout { + +static constexpr int kNumCandidateBlockMs = 7; +static constexpr int kCandidateBlockM[kNumCandidateBlockMs] = {8, 16, 32, 64, 96, 128, 192}; +static constexpr int kMaxCandidateBlockM = 192; +static constexpr int kMinCandidateBlockM = 8; +static constexpr int kLCMCandidateBlockM = 384; + +// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding, among all possible BLOCK_M +template +CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk, + T num_experts_per_rank) { + const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank; + const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank); + return math::constexpr_align( + num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (static_cast(kMaxCandidateBlockM) - 1), + static_cast(kLCMCandidateBlockM)); +} + +// SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M +template +CUTLASS_HOST_DEVICE constexpr T get_num_padded_sf_pool_tokens(T num_max_pool_tokens, T block_m) { + return (num_max_pool_tokens / block_m) * math::constexpr_align(block_m, static_cast(128)); +} + +// Per-token source metadata for combine write-back +struct TokenSrcMetadata { + uint32_t rank_idx; + uint32_t token_idx; + uint32_t topk_idx; +}; + +struct Workspace { + void* base; + uint32_t num_ranks, num_experts; + uint32_t num_experts_per_rank; + uint32_t num_max_tokens_per_rank; + uint32_t num_max_recv_tokens_per_expert; + + // Pool capacity: all local experts share a contiguous token pool + uint32_t num_max_pool_tokens; + uint32_t num_max_pool_blocks; + + // For both grid barrier and NVLink barrier + static constexpr uint64_t kNumBarrierSignalBytes = 32; + + CUTLASS_HOST_DEVICE + Workspace(void* base, + const uint32_t& num_ranks, + const uint32_t& num_experts, + const uint32_t& num_max_tokens_per_rank, + const uint32_t& num_topk): + base(base), + num_ranks(num_ranks), num_experts(num_experts), + num_max_tokens_per_rank(num_max_tokens_per_rank) { + num_experts_per_rank = num_experts / num_ranks; + num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank; + num_max_pool_tokens = get_num_max_pool_tokens(num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank); + num_max_pool_blocks = num_max_pool_tokens / kMinCandidateBlockM; + } + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes() const { + uint64_t num_bytes = 0; + + // Barrier + num_bytes += kNumBarrierSignalBytes; + + // Expert send/recv count + num_bytes += num_experts * sizeof(uint64_t) * 2; + + // Expert recv count sum + num_bytes += num_experts_per_rank * sizeof(uint64_t); + + // L1 arrival count (padded to even entry count for `uint64_t` alignment of L2 mask) + num_bytes += math::align(num_max_pool_blocks, 2u) * sizeof(uint32_t); + + // L2 block arrival mask + num_bytes += num_max_pool_blocks * sizeof(uint64_t); + + // Dispatch pulling source token-topk + num_bytes += num_experts_per_rank * num_ranks * num_max_recv_tokens_per_expert * sizeof(int); + + // Combine push source indices + num_bytes += num_max_pool_tokens * sizeof(TokenSrcMetadata); + + // Align to TMA descriptor requirements + num_bytes = math::align(num_bytes, 16); + return num_bytes; + } + + CUTLASS_HOST_DEVICE + void* get_end_ptr() const { + return math::advance_ptr(base, get_num_bytes()); + } + + // Grid sync counters: `kNumBarrierSignalBytes` layout + // [ 0..15]: 4 x `uint32_t` grid sync counters + // [16..20]: `uint32_t` NVLink barrier counter + // [20..27]: 2 x `int` NVLink barrier signals (phase 0 and 1) + static constexpr uint32_t kNumMaxGridSyncCounters = 4; + + template + CUTLASS_DEVICE + uint32_t* get_grid_sync_count_ptr() const { + DG_STATIC_ASSERT(kIndex < kNumMaxGridSyncCounters, "Grid sync index out of bounds"); + return static_cast(base) + kIndex; + } + + CUTLASS_DEVICE + uint32_t* get_nvl_barrier_counter_ptr() const { + return static_cast(base) + kNumMaxGridSyncCounters; + } + + CUTLASS_DEVICE + int* get_nvl_barrier_signal_ptr(const uint32_t& phase) const { + // NOTES: the signal is signed, as we may minus + return math::advance_ptr(base, (kNumMaxGridSyncCounters + 1) * sizeof(uint32_t) + phase * sizeof(int)); + } + + CUTLASS_DEVICE + uint64_t* get_expert_send_count_ptr(const uint32_t& expert_idx = 0) const { + return math::advance_ptr(base, kNumBarrierSignalBytes) + expert_idx; + } + + CUTLASS_DEVICE + uint64_t* get_expert_recv_count_ptr( + const uint32_t& rank_idx = 0, const uint32_t& expert_idx = 0) const { + return get_expert_send_count_ptr(num_experts) + rank_idx * num_experts_per_rank + expert_idx; + } + + CUTLASS_DEVICE + uint64_t* get_expert_recv_count_sum_ptr(const uint32_t& expert_idx = 0) const { + return get_expert_send_count_ptr(num_experts * 2) + expert_idx; + } + + CUTLASS_DEVICE + uint32_t* get_l1_arrival_count_ptr(const uint32_t& pool_block_idx = 0) const { + const auto base = get_expert_recv_count_sum_ptr(num_experts_per_rank); + return reinterpret_cast(base) + pool_block_idx; + } + + CUTLASS_DEVICE + uint64_t* get_l2_arrival_mask_ptr(const uint32_t& pool_block_idx = 0) const { + // Pad L1 entry count to even so that the `l2_arrival_mask` is 8-byte aligned + const auto base = get_l1_arrival_count_ptr(math::align(num_max_pool_blocks, 2u)); + return reinterpret_cast(base) + pool_block_idx; + } + + // For dispatch pulling + CUTLASS_DEVICE + uint32_t* get_src_token_topk_idx_ptr( + const uint32_t& expert_idx = 0, const uint32_t& rank_idx = 0, const uint32_t& token_idx = 0) const { + const auto base = get_l2_arrival_mask_ptr(num_max_pool_blocks); + return reinterpret_cast(base) + + expert_idx * (num_ranks * num_max_recv_tokens_per_expert) + + rank_idx * num_max_recv_tokens_per_expert + token_idx; + } + + // For combine usages + CUTLASS_DEVICE + TokenSrcMetadata* get_token_src_metadata_ptr(const uint32_t& pool_token_idx = 0) const { + const auto base = reinterpret_cast(get_src_token_topk_idx_ptr(num_experts_per_rank)); + return base + pool_token_idx; + } +}; + +struct Data { + uint32_t num_bytes; + bool require_tma_alignment; + void* base; + + CUTLASS_HOST_DEVICE + constexpr explicit Data( + const uint32_t& num_bytes, + const bool& require_tma_alignment = true, + void* base = nullptr) : + num_bytes(num_bytes), require_tma_alignment(require_tma_alignment), base(base) { + DG_UNIFIED_ASSERT(num_bytes % 16 == 0 or not require_tma_alignment); + } + + template + CUTLASS_HOST_DEVICE constexpr dtype_t get_num_bytes() const { + return static_cast(num_bytes); + } + + template + CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const { + return static_cast(base); + } + + CUTLASS_HOST_DEVICE void set_base_ptr(void* ptr) { + base = ptr; + } +}; + +struct Buffer { + Data data_layout; + uint32_t num_ranks; + uint32_t num_max_tokens_per_rank; + + void* base; + + CUTLASS_HOST_DEVICE + Buffer(const Data& data_layout, + const uint32_t& num_ranks, + const uint32_t& max_num_tokens_per_rank, + void* base = nullptr) : + data_layout(data_layout), + num_ranks(num_ranks), num_max_tokens_per_rank(max_num_tokens_per_rank), + base(base) {} + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes_per_rank() const { + return num_max_tokens_per_rank * data_layout.get_num_bytes(); + } + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes() const { + return get_num_bytes_per_rank() * num_ranks; + } + + template + CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const { + return static_cast(base); + } + + CUTLASS_HOST_DEVICE + void* get_end_ptr() const { + return math::advance_ptr(base, get_num_bytes()); + } + + CUTLASS_HOST_DEVICE + Buffer get_rank_buffer(const uint32_t& rank_idx) const { + return { + data_layout, + 1, num_max_tokens_per_rank, + math::advance_ptr(base, get_num_bytes_per_rank() * rank_idx) + }; + } + + CUTLASS_HOST_DEVICE + Data get_data_buffer(const uint32_t& token_idx, const bool& global = false) const { + DG_DEVICE_ASSERT(num_ranks == 1 or global); + return Data( + data_layout.num_bytes, + data_layout.require_tma_alignment, + math::advance_ptr(base, data_layout.get_num_bytes() * token_idx) + ); + } +}; + +} // namespace deep_gemm::layout diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/layout/sym_buffer.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/layout/sym_buffer.cuh new file mode 100644 index 00000000..7f11aabc --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/layout/sym_buffer.cuh @@ -0,0 +1,41 @@ +#pragma once + +#include + +namespace deep_gemm::layout { + +constexpr static uint32_t kNumMaxRanks = 72; + +template +struct SymBuffer { + int64_t base; + int64_t offsets[kNumMaxRanks]; + uint32_t rank_idx; + + DG_STATIC_ASSERT(kNumRanks <= kNumMaxRanks, "Too many ranks"); + + SymBuffer() = default; + + template + explicit SymBuffer(const Container& c, const uint32_t& rank_idx): rank_idx(rank_idx) { + const auto size = static_cast(c.size()); + base = c[rank_idx]; + for (uint32_t i = 0; i < kNumMaxRanks; ++ i) + offsets[i] = i < size ? (c[i] - base) : 0; + } + +#if defined(__CUDA_ARCH__) or defined(__CLION_IDE__) + template + CUTLASS_DEVICE ptr_t get_base_ptr() const { + return reinterpret_cast(base); + } + + template + CUTLASS_DEVICE ptr_t map(const ptr_t& ptr, const uint32_t& dst_rank_idx) const { + int64_t mapped_ptr = offsets[dst_rank_idx] + reinterpret_cast(ptr); + return *reinterpret_cast(&mapped_ptr); + } +#endif +}; + +} // namespace deep_gemm::layout diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/mma/sm100.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/mma/sm100.cuh new file mode 100644 index 00000000..0c554f4c --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/mma/sm100.cuh @@ -0,0 +1,151 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace deep_gemm::mma::sm100 { + +/// Shared memory descriptor +CUTLASS_DEVICE +cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr, + const uint32_t& stride_byte_offset, const uint32_t& leading_byte_offset) { + cute::UMMA::SmemDescriptor desc; + + // Set the version for SM100 + desc.version_ = 1; + + // Legacy mode + desc.lbo_mode_ = 0; + + // Layout + desc.layout_type_ = static_cast(layout); + + // Start address + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); + + // Base offset + desc.base_offset_ = 0; + + // SBO and LBO + desc.stride_byte_offset_ = stride_byte_offset >> 4; + desc.leading_byte_offset_ = leading_byte_offset >> 4; + + return desc; +} + +CUTLASS_DEVICE +cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) { + // NOTES: the UTCCP layout is K-major by default + // Atom size: 8 x 128 bits + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero + return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0); +} + +CUTLASS_DEVICE +void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) { + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); +} + +CUTLASS_DEVICE +static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) { + return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16; +} + +/// UMMA descriptors +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::UMMA::LayoutType to_umma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + // A special case + if constexpr ((cute::is_same_v and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) { + DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base"); + return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B; + } + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B; + if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B; + if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B; +} + +template +CUTLASS_DEVICE +constexpr uint32_t get_umma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : tma::get_inner_block_atom_size(); +} + +template +CUTLASS_DEVICE +uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) { + return base + (((offset + k_idx * get_umma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +CUTLASS_DEVICE +cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_umma_desc_stride_k(); + const auto layout_type = to_umma_layout_type(); + const auto num_non_contiguous = 128 / get_atom_base(layout_type); + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)` + // also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = tma::get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + math::swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } +} + +CUTLASS_DEVICE uint64_t make_runtime_instr_desc_with_sf_id( + cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) { + desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id; + return static_cast(static_cast(desc)) << 32; +} + +CUTLASS_DEVICE void update_instr_desc_with_umma_n( + cute::UMMA::InstrDescriptorBlockScaled& desc, const uint32_t& umma_n) { + desc.n_dim_ = umma_n >> 3; +} + +CUTLASS_DEVICE void update_instr_desc_with_umma_n( + cute::UMMA::InstrDescriptor& desc, const uint32_t& umma_n) { + desc.n_dim_ = umma_n >> 3; +} + +} // namespace deep_gemm::mma::sm100 diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/mma/sm90.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/mma/sm90.cuh new file mode 100644 index 00000000..2c061940 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/mma/sm90.cuh @@ -0,0 +1,293 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace deep_gemm::mma::sm90 { + +/// MMA +template +struct FP8MMA { + template + CUTLASS_DEVICE static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + CUTLASS_DEVICE static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct FP8MMASelector { + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN(); + } + + static constexpr auto select_type() { + return FP8MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct BF16MMA { + template + CUTLASS_DEVICE static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + CUTLASS_DEVICE static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 16; + static constexpr int kNumAccum = M * N / 128; +}; + +template +constexpr cute::SM90::GMMA::Major to_sm90_major() { + DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness"); + return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN; +} + +template +struct BF16MMASelector { + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + constexpr auto kGMMAMajorA = to_sm90_major(); + constexpr auto kGMMAMajorB = to_sm90_major(); + if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS(); + if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS(); + if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS(); + if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS(); + if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS(); + if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS(); + if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS(); + if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS(); + if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS(); + if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS(); + if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS(); + if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS(); + if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS(); + if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS(); + if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS(); + if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS(); + if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS(); + if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS(); + if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS(); + if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS(); + if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS(); + if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS(); + if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS(); + if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS(); + if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS(); + if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS(); + if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS(); + if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS(); + if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS(); + if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS(); + if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS(); + if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS(); + } + + static constexpr auto select_type() { + return BF16MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct TF32MMARS { + template + CUTLASS_DEVICE static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + CUTLASS_DEVICE static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(reinterpret_cast(a), desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 8; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct TF32MMASelector { + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (kUseRS) { + if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN(); + if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN(); + if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN(); + if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN(); + if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN(); + if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN(); + DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N"); + } + } + + static constexpr auto select_type() { + if constexpr (kUseRS) { + return TF32MMARS(); + } else { + DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now"); + } + } + + using type = decltype(select_type()); +}; + +/// Shared memory descriptor +template +CUTLASS_DEVICE cute::GmmaDescriptor +make_smem_desc(PointerType smem_ptr, const int& layout_type, + const uint32_t& leading_byte_offset = 0, + const uint32_t& stride_byte_offset = 1024) { + // NOTES: the default LBO and SBO are for K-major types + cute::GmmaDescriptor desc; + const auto uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + desc.bitfield.start_address_ = uint_ptr >> 4; + desc.bitfield.layout_type_ = layout_type; + desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; + desc.bitfield.base_offset_ = 0; + return desc; +} + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +CUTLASS_DEVICE +constexpr uint32_t get_gmma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32; + if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64; + if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128; +} + +template +CUTLASS_DEVICE +uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) { + return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +CUTLASS_DEVICE +cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_gmma_desc_stride_k(); + const auto layout_type = to_gmma_layout_type(); + constexpr uint32_t num_non_contiguous = 128 / 16; + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + math::swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +static constexpr int to_swizzle_cute_type() { + DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling"); + if constexpr (kHeadDim == 32) + return static_cast(cute::SM90::GMMA::LayoutType::B32); + if constexpr (kHeadDim == 64) + return static_cast(cute::SM90::GMMA::LayoutType::B64); + if constexpr (kHeadDim == 128) + return static_cast(cute::SM90::GMMA::LayoutType::B128); +} + +} // namespace deep_gemm::mma::sm90 diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/ld_st.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/ld_st.cuh new file mode 100644 index 00000000..c3e03bec --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/ld_st.cuh @@ -0,0 +1,251 @@ +#pragma once + +#include +#include + +namespace deep_gemm::ptx { + +// Compatibility: 256 bits LD/ST instructions +#if defined(CUDART_VERSION) and CUDART_VERSION >= 13000 +using longlong4_t = longlong4_32a; +#define make_longlong4_t make_longlong4_32a +#else +struct alignas(32) longlong4_t { long long x, y, z, w; }; +CUTLASS_HOST_DEVICE longlong4_t make_longlong4_t( + const long long& x, const long long& y, const long long& z, const long long& w) { + return {x, y, z, w}; +} +#endif + +/// LD/ST matrix +// TODO: remove `struct` +struct SM90_U32x2_LDSM_N { + CUTLASS_DEVICE static void + copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst_0), "=r"(dst_1) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +struct SM90_U32x4_LDSM_N { + CUTLASS_DEVICE static void + copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +template +struct SM90_U32x2_STSM_N { + CUTLASS_DEVICE static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); + } +}; + +template +struct SM90_U32x4_STSM_T { + CUTLASS_DEVICE static void + copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src[4] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1), + *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; + asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16.trans [%0], {%1, %2, %3, %4};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), + "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); + } +}; + +template +struct SM100_U8x4_STSM_T { + __device__ __forceinline__ static void + copy(dtype_t src_0, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src = *reinterpret_cast(&src_0); + asm volatile("stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 [%0], {%1};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src)); + } +}; + +template +struct SM100_U8x8_STSM_T { + __device__ __forceinline__ static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 [%0], {%1, %2};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); + } +}; + +/// Shared memory +CUTLASS_DEVICE uint32_t ld_shared(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE float2 ld_shared(const float2* ptr) { + float2 ret; + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE float4 ld_shared(const float4* ptr) { + float4 ret; + asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE uint4 ld_shared(const uint4* ptr) { + uint4 ret; + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE float ld_shared(const float* ptr) { + float ret; + asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE void st_shared(const float* ptr, float val) { + asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val)); +} + +CUTLASS_DEVICE void st_shared(const float2* ptr, float2 val) { + asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y)); +} + +CUTLASS_DEVICE void st_shared(const uint32_t* ptr, uint32_t val) { + asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val)); +} + +CUTLASS_DEVICE void st_shared(const void* ptr, uint32_t x, uint32_t y) { + asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y)); +} + +CUTLASS_DEVICE void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w)); +} + +CUTLASS_DEVICE void st_shared(const __int128_t* ptr, __int128_t val) { + asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); +} + +CUTLASS_DEVICE void st_shared_bulk(void* smem_ptr, const uint32_t& num_bytes) { + // `size` must be 64-bit before PTX ISA 9.0 + asm volatile("st.bulk.weak.shared::cta [%0], %1, 0;" :: + "l"(__cvta_generic_to_shared(smem_ptr)), "l"(static_cast(num_bytes))); +} + +/// Global memory +CUTLASS_DEVICE uint64_t ld_volatile(const uint64_t* ptr) { + uint64_t ret; + asm volatile("ld.volatile.global.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint32_t ld_acq(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.acquire.gpu.global.b32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint64_t ld_acq_sys(const uint64_t* ptr) { + uint64_t ret; + asm volatile("ld.acquire.sys.global.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE void st_relaxed_sys(const uint64_t* ptr, const uint64_t& value) { + asm volatile("st.L1::no_allocate.relaxed.sys.global.u64 [%0], %1;" :: "l"(ptr), "l"(value)); +} + +/// Atomics +CUTLASS_DEVICE uint64_t atomic_add(const uint64_t* ptr, const uint64_t& value) { + uint64_t ret; + asm volatile("atom.global.add.u64 %0, [%1], %2;" : "=l"(ret) : "l"(ptr), "l"(value)); + return ret; +} + +CUTLASS_DEVICE uint64_t atomic_add_sys(const uint64_t* ptr, const uint64_t& value) { + uint64_t ret; + asm volatile("atom.sys.global.add.u64 %0, [%1], %2;" : "=l"(ret) : "l"(ptr), "l"(value)); + return ret; +} + +CUTLASS_DEVICE uint32_t atomic_add_rel(const uint32_t* ptr, const uint32_t& value) { + uint32_t ret; + asm volatile("atom.release.gpu.global.add.u32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); + return ret; +} + +CUTLASS_DEVICE void red_add(const int* ptr, const int& value) { + asm volatile("red.gpu.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_add(const uint32_t* ptr, const uint32_t& value) { + asm volatile("red.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_or_rel_sys(const uint64_t* ptr, const uint64_t& value) { + asm volatile("red.release.sys.global.or.b64 [%0], %1;" :: "l"(ptr), "l"(value)); +} + +CUTLASS_DEVICE void red_or_rel_gpu(uint64_t* ptr, const uint64_t& value) { + asm volatile("red.release.gpu.global.or.b64 [%0], %1;" :: "l"(ptr), "l"(value)); +} + +CUTLASS_DEVICE void red_add_rel(const uint32_t* ptr, const uint32_t& value) { + asm volatile("red.release.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_add_rel_sys(const int* ptr, const int& value) { + asm volatile("red.release.sys.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE int ld_acq_sys(const int* ptr) { + int ret; + asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint32_t ld_acq_sys(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint64_t ld_acq_gpu(const uint64_t* ptr) { + uint64_t ret; + asm volatile("ld.acquire.gpu.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +/// Predicated loads +CUTLASS_DEVICE longlong4_t ld_gez_pred(const longlong4_t* ptr, const int& pred) { + longlong4_t ret = make_longlong4_t(0, 0, 0, 0); + asm volatile( + "{\n\t" + " .reg .pred p;\n\t" + " setp.ge.s32 p, %5, 0;\n\t" + " @p ld.global.L2::256B.v4.s64 {%0, %1, %2, %3}, [%4];\n\t" + "}" + : "+l"(ret.x), "+l"(ret.y), "+l"(ret.z), "+l"(ret.w) + : "l"(ptr), "r"(pred) + : "memory"); + return ret; +} + +/// Prefetch +CUTLASS_DEVICE void prefetch_l1(void *ptr) { + asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); +} + +} // namespace deep_gemm::ptx diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh new file mode 100644 index 00000000..528b3dd1 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh @@ -0,0 +1,168 @@ +#pragma once + +namespace deep_gemm::ptx { + +/// UMMA versions with relaxed assertions +struct SM100_MMA_F16BF16_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F16BF16_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF8F6F4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_MXF8F6F4_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_F8F6F4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F8F6F4_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], %1, %2, %3, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_F16BF16_WS_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +/// Tensor memory operations +CUTLASS_DEVICE void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +CUTLASS_DEVICE void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +} // namespace deep_gemm::ptx diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/tma.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/tma.cuh new file mode 100644 index 00000000..1530a3ed --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/tma.cuh @@ -0,0 +1,112 @@ +#pragma once + +#include +#include + +namespace deep_gemm::ptx { + +// Tensor-map instructions +CUTLASS_DEVICE void tensor_map_release_gpu() { + asm volatile ("fence.proxy.tensormap::generic.release.gpu;" ::: "memory"); +} + +CUTLASS_DEVICE void tensor_map_acquire_gpu(const cute::TmaDescriptor* gmem_desc_ptr) { + auto gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + asm volatile ("fence.proxy.tensormap::generic.acquire.gpu [%0], 128;" :: "l"(gmem_int_desc) : "memory"); +} + +CUTLASS_DEVICE void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) { + auto smem_int_desc = static_cast(__cvta_generic_to_shared(smem_desc)); + const auto new_int64_addr = reinterpret_cast(new_addr); + asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr)); +} + +CUTLASS_DEVICE void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) { + auto smem_int_desc = __cvta_generic_to_shared(smem_desc); + asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim)); +#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3))) + asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride)); +#else + DG_STATIC_ASSERT(false, "Invalid CUDA version"); +#endif +} + +/// TMA instructions +CUTLASS_DEVICE void mbarrier_arrive( + cutlass::arch::ClusterTransactionBarrier* ptr) { + asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0]; \n\t" :: + "r"(static_cast(__cvta_generic_to_shared(ptr)))); +} + +CUTLASS_DEVICE void mbarrier_arrive_and_set_tx( + cutlass::arch::ClusterTransactionBarrier* ptr, const uint32_t& num_bytes) { + asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: + "r"(num_bytes), "r"(static_cast(__cvta_generic_to_shared(ptr)))); +} + +CUTLASS_DEVICE void mbarrier_wait_and_flip_phase( + cutlass::arch::ClusterTransactionBarrier* ptr, uint32_t& phase) { + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + "}" :: + "r"(static_cast(__cvta_generic_to_shared(ptr))), + "r"(phase), "r"(0x989680)); + phase ^= 1; +} + +CUTLASS_DEVICE void tma_load_1d( + const void* dst_ptr, const void* src_ptr, + cutlass::arch::ClusterTransactionBarrier* mbarrier_ptr, + const uint32_t& num_bytes, + const cute::TMA::CacheHintSm90& hint = cute::TMA::CacheHintSm90::EVICT_FIRST) { + // NOTES: normally, the loaded part will be evicted soon + asm volatile( + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\n" :: + "r"(static_cast(__cvta_generic_to_shared(dst_ptr))), + "l"(src_ptr), + "r"(num_bytes), + "r"(static_cast(__cvta_generic_to_shared(mbarrier_ptr))), + "l"(hint) + : "memory"); +} + +CUTLASS_DEVICE void tma_store_1d( + const void* dst_ptr, const void* src_ptr, const uint32_t& num_bytes, + const cute::TMA::CacheHintSm90& hint = cute::TMA::CacheHintSm90::EVICT_NORMAL) { + // NOTES: normally, the stored part will be used soon + asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\n" :: + "l"(dst_ptr), + "r"(static_cast(__cvta_generic_to_shared(src_ptr))), + "r"(num_bytes), + "l"(hint) + : "memory"); +} + +template +__forceinline__ __device__ void tma_store_wait() { + // NOTES: this function does not have `.read` + asm volatile("cp.async.bulk.wait_group %0;" ::"n"(kNumRemainingWaits) : "memory"); +} + +CUTLASS_DEVICE +void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier& mbarrier, + void* smem_ptr, const uint32_t& col_idx, const int4& row_idxs, const uint64_t& cache_hint) { + const auto smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + const auto mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbarrier_addr), "l"(cache_hint) + : "memory" + ); +} + +} // namespace deep_gemm::ptx diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/utils.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/utils.cuh new file mode 100644 index 00000000..5c27166b --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/utils.cuh @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +#include + +namespace deep_gemm::ptx { + +CUTLASS_DEVICE uint32_t get_sm_idx() { + uint32_t sm_idx; + asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx)); + return sm_idx; +} + +CUTLASS_DEVICE uint32_t get_lane_idx() { + uint32_t lane_id; + asm ("mov.u32 %0, %%laneid;" : "=r"(lane_id)); + return lane_id; +} + +CUTLASS_DEVICE void sync_aligned(const uint32_t& num_threads, const uint32_t& barrier_idx) { + asm volatile("bar.sync %0, %1;" : : "r"(barrier_idx), "r"(num_threads)); +} + +CUTLASS_DEVICE void sync_unaligned(const uint32_t& num_threads, const uint32_t& barrier_idx) { + asm volatile("barrier.sync %0, %1;" : : "r"(barrier_idx), "r"(num_threads)); +} + +template +CUTLASS_DEVICE dtype_t exchange(dtype_t ptr, const uint32_t& src_lane_idx) { + DG_STATIC_ASSERT(sizeof(dtype_t) % sizeof(uint32_t) == 0, ""); + const auto send_int_values = reinterpret_cast(&ptr); + dtype_t recv_dtype; + auto recv_int_values = reinterpret_cast(&recv_dtype); + #pragma unroll + for (uint32_t i = 0; i < sizeof(dtype_t) / sizeof(uint32_t); ++ i) + recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], static_cast(src_lane_idx)); + return recv_dtype; +} + +CUTLASS_DEVICE void accumulate(float2& a, nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000) + // Use `add.rn.f32.bf16` instruction to perform fused (cast + add) operation on SM100 + asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.x) : "h"(*reinterpret_cast(&b.x))); + asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.y) : "h"(*reinterpret_cast(&b.y))); +#else + const auto [x, y] = __bfloat1622float2(b); + a.x += x, a.y += y; +#endif +} + +} // namespace deep_gemm::ptx diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/wgmma.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/wgmma.cuh new file mode 100644 index 00000000..8912a157 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/ptx/wgmma.cuh @@ -0,0 +1,25 @@ +#pragma once + +#include + +namespace deep_gemm::ptx { + +CUTLASS_DEVICE void warpgroup_arrive() { + asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); +} + +CUTLASS_DEVICE void warpgroup_commit_batch() { + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +CUTLASS_DEVICE void warpgroup_fence_operand(float& reg) { + asm volatile("" : "+f"(reg) :: "memory"); +} + +template +CUTLASS_DEVICE void warpgroup_wait() { + DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +} + +} // namespace deep_gemm::ptx diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/scheduler/gemm.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/scheduler/gemm.cuh new file mode 100644 index 00000000..5cd50c66 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/scheduler/gemm.cuh @@ -0,0 +1,300 @@ +#pragma once + +#include +#include + +namespace deep_gemm::sched { + +enum class IndexType { + MN, + K, + SF_K, +}; + +template +static constexpr uint32_t get_num_1d_blocks_per_group() { + // Select the best from candidates + uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits::max(); + for (const auto candidate: {8u, 16u}) { + const auto usage = kIsMulticastOnA ? + candidate * BLOCK_N + math::constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + math::constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + if (usage < min_usage) + min_usage = usage, num_best_blocks = candidate; + } + return num_best_blocks; +} + +#pragma clang diagnostic push +#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" +template ()> +struct Scheduler { + int current_iter = -1; + + // Block configs + uint32_t num_blocks; + uint32_t num_m_blocks; + uint32_t num_n_blocks; + + // For SM90 multicast checks + uint32_t num_blocks_in_group; + bool is_peer_cta_alive = true; + + // For grouped GEMM + int* grouped_layout; + uint32_t current_group_idx = 0; + // Only used for masked layout + uint32_t current_m_cumsum = 0; + // Only used for contiguous psum layout + uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0; + // Only used for k-grouped layout + uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0; + uint32_t next_group_idx, next_shape_k; + + // Only used for k-grouped gemm + CUTLASS_DEVICE void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const { + for (; group_idx < kNumGroups; ++ group_idx) { + shape_k = grouped_layout[group_idx]; + if (shape_k > 0) + break; + } + } + + // ReSharper disable once CppPossiblyUninitializedMember + CUTLASS_DEVICE explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, + const uint32_t& shape_k, int* grouped_layout = nullptr) { + num_m_blocks = math::ceil_div(shape_m, BLOCK_M); + num_n_blocks = math::ceil_div(shape_n, BLOCK_N); + current_shape_k = shape_k; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + num_blocks = num_m_blocks * num_n_blocks; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + this->grouped_layout = grouped_layout; + current_psum_m = grouped_layout[0]; + num_m_blocks = math::ceil_div(current_psum_m, BLOCK_M); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + get_next_k_group(current_group_idx, current_shape_k); + next_group_idx = current_group_idx + 1; + get_next_k_group(next_group_idx, next_shape_k); + } + } + + CUTLASS_DEVICE void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + const auto primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks; + const auto secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks; + const auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; + const auto group_idx = block_idx / num_blocks_per_group; + auto first_block_idx = group_idx * kNum1DBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); + + // Fix unaligned TMA multicast + // NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast + // while SM100 uses 2-CTA, which can not be dynamically disabled +#if __CUDA_ARCH__ < 1000 + if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) { + if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { + num_blocks_in_group = num_blocks_in_group ^ 1; + } else { + in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; + first_block_idx += num_blocks_in_group ^ 1; + num_blocks_in_group = 1; + } + } +#endif + + // Convert to final M/N block indices + // `kIsMulticastOnA == true` leads to groups on N + if constexpr (kIsMulticastOnA) { + m_block_idx = in_group_idx / num_blocks_in_group; + n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + } else { + m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + n_block_idx = in_group_idx / num_blocks_in_group; + } + } + + template + CUTLASS_DEVICE uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx = 0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + const auto offset = kWithGroupOffset ? cute::max(0, grouped_layout[m_block_idx * BLOCK_M]) : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + const auto offset = kWithGroupOffset ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + auto offset = 0; + if constexpr (kWithGroupOffset) { + if constexpr (kIndexType == IndexType::MN) + offset = current_group_idx * shape_dim; + else if constexpr (kIndexType == IndexType::K) + offset = current_k_cumsum; + else if constexpr (kIndexType == IndexType::SF_K) + offset = current_sf_k_cumsum; + } + return offset + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::Batched) { + // Ignore kWithGroupOffset, and apply offset for IndexType::SF_K + const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } + } + + // For swap A/B and psum layout only + CUTLASS_DEVICE uint32_t get_aligned_effective_m_in_block(const uint32_t& m_block_idx) const { + constexpr uint32_t UMMA_STEP_N = 16; + DG_STATIC_ASSERT(BLOCK_M % UMMA_STEP_N == 0, "Invalid alignment"); + if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) + return math::align(m_block_idx == last_psum_m / BLOCK_M + num_m_blocks - 1 ? current_psum_m - m_block_idx * BLOCK_M : BLOCK_M, UMMA_STEP_N); + return BLOCK_M; + } + + CUTLASS_DEVICE bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x; + + if constexpr (kGemmType == GemmType::MGroupedMasked) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + num_m_blocks = math::ceil_div(static_cast(grouped_layout[current_group_idx]), BLOCK_M); + const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * num_n_blocks) + break; + + // Move to check the next group + current_group_idx ++, current_m_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + while (true) { + // Within current group + if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks) + break; + + // Move to check the next group + if (++ current_group_idx == kNumGroups) + return false; + + // NOTES: `num_m_blocks` varies with the increase of the group index + last_psum_m = math::align(current_psum_m, BLOCK_M); + current_psum_m = grouped_layout[current_group_idx]; + current_m_block_cumsum += num_m_blocks; + num_m_blocks = math::ceil_div(current_psum_m - last_psum_m, BLOCK_M); + } + + get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx); + + // NOTES: `last_psum_m` is aligned with block M + m_block_idx += last_psum_m / BLOCK_M; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + if (next_block_idx < (current_num_valid_groups + 1) * num_blocks) + break; + + // Move to check the next group + current_k_cumsum += current_shape_k; + current_sf_k_cumsum += math::ceil_div(current_shape_k, SF_K_ALIGNMENT); + current_num_valid_groups ++; + + current_group_idx = next_group_idx ++; + current_shape_k = next_shape_k; + get_next_k_group(next_group_idx, next_shape_k); + } + + get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::Batched) { + if (next_block_idx >= num_blocks * kNumGroups) + return false; + + current_group_idx = next_block_idx / num_blocks; + const auto block_idx = next_block_idx - current_group_idx * num_blocks; + if constexpr (kIsMulticastOnA) { + m_block_idx = block_idx / num_n_blocks; + n_block_idx = block_idx % num_n_blocks; + } else { + m_block_idx = block_idx % num_m_blocks; + n_block_idx = block_idx / num_m_blocks; + } + } else { + if (next_block_idx >= num_blocks) + return false; + + // For SM90 only + // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned + is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass) + num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass) + (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound + get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx); + } + return true; + } + + // For SM90 only + CUTLASS_DEVICE bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { + if (num_blocks_in_group == 1) + return false; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or + kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + return true; + } else { + DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type"); + if constexpr (kIsMulticastOnA) { + return true; + } else { + const auto group_idx = grouped_layout[m_block_idx * BLOCK_M]; + const auto peer_group_idx = grouped_layout[(m_block_idx ^ 1) * BLOCK_M]; + return group_idx == peer_group_idx; + } + } + } + + // For SM90 only + // ReSharper disable once CppNotAllPathsReturnValue + CUTLASS_DEVICE bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + return true; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + return grouped_layout[m_offset + m_block_idx * BLOCK_M] >= 0; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + return m_offset + m_block_idx * BLOCK_M < grouped_layout[current_group_idx]; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + return m_offset + m_block_idx * BLOCK_M < current_psum_m; + } else { + // Unreachable + DG_TRAP_ONLY_DEVICE_ASSERT(false); + } + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm::sched diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh new file mode 100644 index 00000000..cdbecccd --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh @@ -0,0 +1,221 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace deep_gemm::sched { + +// Computation phase for the current block +enum class BlockPhase { + None = 0, + Linear1 = 1, + Linear2 = 2 +}; + +template +struct MegaMoEScheduler { + DG_STATIC_ASSERT(L1_SHAPE_N % BLOCK_N == 0, "Invalid shape"); + DG_STATIC_ASSERT(L2_SHAPE_N % BLOCK_N == 0, "Invalid shape"); + DG_STATIC_ASSERT(L1_SHAPE_K % BLOCK_K == 0, "Invalid shape"); + DG_STATIC_ASSERT(L2_SHAPE_K % BLOCK_K == 0, "Invalid shape"); + DG_STATIC_ASSERT(kNumExpertsPerRank % kNumExpertsPerWave == 0, "Invalid wave config"); + + // NOTES: N block counts must be even so that 2 adjacent CTAs in a cluster + // always land on the same m_block_idx with n_block_idx differing by 1 + DG_STATIC_ASSERT(kNumSMs % 2 == 0, "Number of SMs must be even for 2-CTA cluster"); + DG_STATIC_ASSERT(kNumL1BlockNs % 2 == 0, "L1 N block count must be even for 2-CTA cluster"); + DG_STATIC_ASSERT(kNumL2BlockNs % 2 == 0, "L2 N block count must be even for 2-CTA cluster"); + + // Arrival counts + const layout::Workspace& workspace; + + // Scheduler state + BlockPhase next_phase = BlockPhase::Linear1; + + // Current expert and block indices + uint32_t current_local_expert_idx = 0; + uint32_t current_num_tokens = 0; + uint32_t current_pool_block_offset = 0; + uint32_t block_idx = 0; + uint32_t m_block_idx = 0; + uint32_t n_block_idx = 0; + + // Pre-cached per-expert token counts (filled during `for_each_block` init) + // Layout: `stored_num_tokens_per_expert[i]` holds expert (i * 32 + lane_idx)'s count + uint32_t stored_num_tokens_per_expert[kNumExpertsPerLane] = {}; + + CUTLASS_DEVICE explicit MegaMoEScheduler(const layout::Workspace& workspace): workspace(workspace) { + block_idx = blockIdx.x; + } + + CUTLASS_DEVICE uint32_t get_wave_expert_end_idx() const { + return math::align(current_local_expert_idx + 1, kNumExpertsPerWave); + } + + CUTLASS_DEVICE uint32_t get_num_tokens(const uint32_t& expert_idx) const { + uint32_t valid_value; + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + valid_value = (expert_idx == i * 32 + ptx::get_lane_idx()) ? + stored_num_tokens_per_expert[i] : valid_value; + } + return ptx::exchange(valid_value, expert_idx % 32); + } + + // Get pool block offset for a given expert index from a per-lane token count array + CUTLASS_DEVICE uint32_t get_pool_block_offset(const uint32_t& expert_idx) { + uint32_t num_blocks = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + if (i * 32 + ptx::get_lane_idx() < expert_idx) + num_blocks += math::ceil_div(stored_num_tokens_per_expert[i], BLOCK_M); + } + return __reduce_add_sync(0xffffffff, num_blocks); + } + + CUTLASS_DEVICE void advance_expert_idx() { + current_pool_block_offset += get_current_num_m_blocks(); + current_local_expert_idx += 1; + current_num_tokens = get_num_tokens(current_local_expert_idx); + } + + CUTLASS_DEVICE void set_expert_idx(const uint32_t& expert_idx) { + current_local_expert_idx = expert_idx; + current_num_tokens = get_num_tokens(expert_idx); + current_pool_block_offset = get_pool_block_offset(expert_idx); + } + + CUTLASS_DEVICE uint32_t get_current_pool_block_offset() const { + return current_pool_block_offset; + } + + CUTLASS_DEVICE uint32_t get_current_num_m_blocks() const { + return math::ceil_div(current_num_tokens, BLOCK_M); + } + + template + CUTLASS_DEVICE uint32_t get_valid_m() const { + const auto m = cute::min(current_num_tokens - m_block_idx * BLOCK_M, BLOCK_M); + return kDoUMMAAligned ? math::align(m, 16u) : m; + } + + CUTLASS_DEVICE bool fetch_next_l1_block() { + const auto wave_end_expert_idx = get_wave_expert_end_idx(); + while (current_local_expert_idx < wave_end_expert_idx) { + const auto num_m_blocks = get_current_num_m_blocks(); + m_block_idx = block_idx / kNumL1BlockNs; + if (m_block_idx < num_m_blocks) + return true; + + // Current expert is fully assigned, move to the next + block_idx -= num_m_blocks * kNumL1BlockNs; + advance_expert_idx(); + } + return false; + } + + CUTLASS_DEVICE bool fetch_next_l2_block() { + const auto wave_end_expert_idx = get_wave_expert_end_idx(); + while (current_local_expert_idx < wave_end_expert_idx) { + const auto num_m_blocks = get_current_num_m_blocks(); + if (block_idx < num_m_blocks * kNumL2BlockNs) { + m_block_idx = block_idx / kNumL2BlockNs; + return true; + } + + // Current expert is fully assigned, move to the next + block_idx -= num_m_blocks * kNumL2BlockNs; + advance_expert_idx(); + } + return false; + } + + // Core state machine: assigns the next block + CUTLASS_DEVICE cute::tuple get_next_block() { + while (true) { + if (current_local_expert_idx >= kNumExpertsPerRank) + break; + + if (next_phase == BlockPhase::Linear1) { + if (fetch_next_l1_block()) { + // Found a new L1 block + n_block_idx = block_idx - m_block_idx * kNumL1BlockNs; + // Jump to next block + block_idx += kNumSMs; + return {BlockPhase::Linear1, current_local_expert_idx, m_block_idx, n_block_idx}; + } else { + // L1 for the current wave is complete, transition to L2 + next_phase = BlockPhase::Linear2; + set_expert_idx(math::align(current_local_expert_idx - 1, kNumExpertsPerWave)); + } + } else { + if (fetch_next_l2_block()) { + // Found a new L2 block + n_block_idx = block_idx - m_block_idx * kNumL2BlockNs; + // Jump to next block + block_idx += kNumSMs; + return {BlockPhase::Linear2, current_local_expert_idx, m_block_idx, n_block_idx}; + } else { + // Move to L1 of the next wave + next_phase = BlockPhase::Linear1; + } + } + } + + // All waves and experts are fully processed + return {BlockPhase::None, 0, 0, 0}; + } + + CUTLASS_DEVICE void fetch_expert_recv_count() { + // NOTES: each lane caches experts at indices (i * 32 + lane_idx) + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + const auto expert_idx = i * 32 + ptx::get_lane_idx(); + uint64_t value = 0; + if (expert_idx < kNumExpertsPerRank) { + do { + value = ptx::ld_volatile(workspace.get_expert_recv_count_sum_ptr(expert_idx)); + } while (static_cast(value >> 32) != kNumSMs * kNumRanks); + } + stored_num_tokens_per_expert[i] = static_cast(value); + } + __syncwarp(); + } + + template + CUTLASS_DEVICE void for_each_block(Func&& func) { + // Wait for all expert counters to be finalized + fetch_expert_recv_count(); + + // Initialize current expert with 0 + set_expert_idx(0); + + // Iterate over all blocks + // TODO: add swizzle within expert waves for better L2 cache utilization + while (true) { + CUTE_TIE_DECL(get_next_block(), block_phase, current_local_expert_idx, m_block_idx, n_block_idx); + if (block_phase == BlockPhase::None) + break; + + func(block_phase, current_local_expert_idx, + block_phase == BlockPhase::Linear2 ? kNumL2BlockKs : kNumL1BlockKs, + m_block_idx, n_block_idx); + } + } +}; + +} // namespace deep_gemm::sched diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh new file mode 100644 index 00000000..548bbbc6 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh @@ -0,0 +1,239 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm::sched { + +template +CUTLASS_GLOBAL __launch_bounds__(32, 1) +void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d, + const uint32_t* context_lens, const uint32_t* indices, uint32_t* schedule_metadata) { + DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + __shared__ uint32_t varlen_atom_token_start[kAlignedBatchSize]; + __shared__ uint32_t varlen_atom_context_len[kAlignedBatchSize]; + __shared__ uint32_t varlen_num_atoms_shared; + uint32_t num_items; + + if constexpr (kIsVarlen) { + if (lane_idx == 0) { + uint32_t t = 0, atom_count = 0; + while (t < batch_size) { + varlen_atom_token_start[atom_count] = t; + const bool is_paired = (t + 1 < batch_size and indices[t] == indices[t + 1]); + varlen_atom_context_len[atom_count] = is_paired ? context_lens[t + 1] : context_lens[t]; + t += is_paired ? 2 : 1; + ++ atom_count; + } + varlen_num_atoms_shared = atom_count; + } + __syncwarp(); + num_items = varlen_num_atoms_shared; + } else { + num_items = batch_size; + } + + // Compute num_segs and prefix sum + uint32_t num_segs[kAlignedBatchSize / 32]; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + const uint32_t q_idx = k * 32 + lane_idx; + uint32_t context_len; + if constexpr (kIsVarlen) { + context_len = (q_idx < num_items ? varlen_atom_context_len[q_idx] : 0); + } else { + const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); + context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0); + } + num_segs[k] = math::ceil_div(context_len, SPLIT_KV); + } + + __shared__ uint32_t prefix_sum[kAlignedBatchSize]; + uint32_t sum = 0; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + uint32_t x = num_segs[k]; + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const uint32_t y = __shfl_up_sync(0xffffffff, x, offset); + x += (lane_idx >= offset ? y : 0); + } + x += sum; + prefix_sum[k * 32 + lane_idx] = x; + sum = __shfl_sync(0xffffffff, x, 31); + } + + // SM work distribution + if constexpr (kIsVarlen) { + const uint32_t total = sum; + const uint32_t q = total / kNumSMs, r = total % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t lo = 0, hi = num_items; + while (lo < hi) { + const uint32_t mid = (lo + hi) / 2; + const bool pred = prefix_sum[mid] <= seg_starts; + lo = pred ? mid + 1 : lo; + hi = pred ? hi : mid; + } + const uint32_t atom_idx = lo; + const uint32_t kv_split_idx = (atom_idx == 0 ? seg_starts : seg_starts - prefix_sum[atom_idx - 1]); + const uint32_t q_atom_idx = (atom_idx < num_items ? varlen_atom_token_start[atom_idx] : batch_size); + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_atom_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } + } else { + const uint32_t next_n_atom = (next_n >= 2) ? 2 : 1; + const uint32_t num_next_n_atoms = math::ceil_div(next_n, next_n_atom); + const uint32_t total = sum * num_next_n_atoms; + const uint32_t q = total / kNumSMs, r = total % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t lo = 0, hi = batch_size; + while (lo < hi) { + const uint32_t mid = (lo + hi) / 2; + const bool pred = prefix_sum[mid] * num_next_n_atoms <= seg_starts; + lo = pred ? mid + 1 : lo; + hi = pred ? hi : mid; + } + const uint32_t q_idx = lo; + const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms); + const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]); + const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0; + const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0; + const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx; + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_atom_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } + } +} + +// Conditional storage for varlen indices pointer (EBO: zero cost when unused) +template +struct IndicesStorage { + const uint32_t* indices; +}; + +template <> +struct IndicesStorage {}; + +template +struct PagedMQALogitsScheduler : IndicesStorage { + const uint32_t* context_lens; + uint32_t batch_size; + + uint32_t current_q_atom_idx, current_kv_idx; + uint32_t end_q_atom_idx, end_kv_idx; + uint32_t current_num_kv; + + CUTLASS_DEVICE static uint32_t atom_to_token_idx(const uint32_t& q_atom_idx) { + if constexpr (kIsVarlen) { + return q_atom_idx; + } else { + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + if constexpr (kPadOddN) { + return q_atom_idx / kNumNextNAtoms * kNextN + q_atom_idx % kNumNextNAtoms * kNextNAtom; + } else { + return q_atom_idx * kNextNAtom; + } + } + } + + CUTLASS_DEVICE static uint32_t atom_to_block_table_row(const uint32_t& q_atom_idx) { + if constexpr (kIsVarlen) { + return q_atom_idx; + } else { + return q_atom_idx / kNumNextNAtoms; + } + } + + CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const { + if constexpr (kIsVarlen) { + const bool is_paired = (q_atom_idx + 1 < batch_size and + this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]); + const uint32_t ctx_len = is_paired ? context_lens[q_atom_idx + 1] : context_lens[q_atom_idx]; + return math::ceil_div(ctx_len, BLOCK_KV); + } else { + const uint32_t q_idx = q_atom_idx / kNumNextNAtoms; + const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); + return math::ceil_div(context_lens[lens_idx], BLOCK_KV); + } + } + + CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t& batch_size, + const uint32_t* context_lens, + const uint32_t* schedule_meta, const uint32_t* indices) { + this->context_lens = context_lens; + this->batch_size = batch_size; + if constexpr (kIsVarlen) { + this->indices = indices; + } + + const auto current_pack = reinterpret_cast(schedule_meta)[sm_idx]; + const auto end_pack = reinterpret_cast(schedule_meta)[sm_idx + 1]; + current_q_atom_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit; + end_q_atom_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit; + + current_num_kv = get_num_kv(current_q_atom_idx); + } + + // Advance step in q_atom_idx space when moving to the next atom. + // Varlen: 1 or 2 depending on whether consecutive tokens share the same sequence. + // Non-varlen: always 1 (one atom unit). + CUTLASS_DEVICE uint32_t get_atom_advance(const uint32_t& q_atom_idx, const uint32_t& bound) const { + if constexpr (kIsVarlen) { + return (q_atom_idx + 1 < bound and this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]) ? 2 : 1; + } else { + return 1; + } + } + + // Whether num_kv should be refreshed after advancing to q_atom_idx. + // Varlen: always refresh (each atom may have a different context_len). + // Non-varlen: only at atom-group boundaries (atoms within a group share context_len). + CUTLASS_DEVICE bool should_refresh_num_kv(const uint32_t& q_atom_idx) const { + if constexpr (kIsVarlen) { + return true; + } else { + return q_atom_idx % kNumNextNAtoms == 0; + } + } + + CUTLASS_DEVICE bool fetch_next_task(uint32_t &q_atom_idx, uint32_t &kv_idx, uint32_t &num_kv) { + q_atom_idx = current_q_atom_idx; + kv_idx = current_kv_idx; + num_kv = current_num_kv; + + if (current_q_atom_idx == end_q_atom_idx and current_kv_idx == end_kv_idx) + return false; + + current_kv_idx += kNumBlocksPerSplit; + if (current_kv_idx >= current_num_kv) { + current_kv_idx = 0; + current_q_atom_idx += get_atom_advance(current_q_atom_idx, end_q_atom_idx); + if (should_refresh_num_kv(current_q_atom_idx) and exist_q_atom_idx(current_q_atom_idx)) { + current_num_kv = get_num_kv(current_q_atom_idx); + } + } + return true; + } + + CUTLASS_DEVICE bool exist_q_atom_idx(const uint32_t& q_atom_idx) const { + return q_atom_idx < end_q_atom_idx or (q_atom_idx == end_q_atom_idx and 0 < end_kv_idx); + } +}; + +} // namespace deep_gemm::sched diff --git a/deep-gemm/torch-ext/deep_gemm/legacy/__init__.py b/deep-gemm/torch-ext/deep_gemm/legacy/__init__.py new file mode 100644 index 00000000..cce39ec7 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/legacy/__init__.py @@ -0,0 +1,5 @@ +# All kernels may be deprecated in the future (or rewrite in TileLang) +from .m_grouped_gemm import * +from .a_fused_m_grouped_gemm import * +from .a_fused_k_grouped_gemm import * +from .b_fused_k_grouped_gemm import * diff --git a/deep-gemm/torch-ext/deep_gemm/legacy/a_fused_k_grouped_gemm.py b/deep-gemm/torch-ext/deep_gemm/legacy/a_fused_k_grouped_gemm.py new file mode 100644 index 00000000..7b42f152 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/legacy/a_fused_k_grouped_gemm.py @@ -0,0 +1,88 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_k_grouped_gemm_configs(), key=[], restore_value=['d_ptr']) +@triton.jit +def a_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + k_indices_ptr, k_start_ptr, k_end_ptr, + M: tl.constexpr, + N: tl.constexpr, + K, + ACC: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_b = (pid // (num_pid_m * num_pid_n)).to(tl.int64) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + m_mask = (m_range < M)[:, None] + n_mask = (n_range < N)[None, :] + + k_start = tl.load(k_start_ptr + pid_b) + k_end = tl.load(k_end_ptr + pid_b) + if k_start >= k_end: + if not ACC: + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=m_mask & n_mask) + return + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(k_start, k_end, BLOCK_SIZE_K): + k_range = k + tl.arange(0, BLOCK_SIZE_K) + rows = tl.load(k_indices_ptr + k_range).to(tl.int64) + a_ptrs = a_ptr + m_range[:, None] + rows[None, :] * M + + b_ptrs = b_ptr + k_range[:, None].to(tl.int64) * N + n_range[None, :] + a = tl.load(a_ptrs, mask=(rows >= 0)[None, :] & m_mask, other=0) + b = tl.load(b_ptrs, mask=n_mask, other=0) + acc = tl.dot(a, b, acc) + + # Write back + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + if ACC: + acc += tl.load(d_ptrs, mask=m_mask & n_mask) + acc = acc.to(d_ptr.dtype.element_ty) + tl.store(d_ptrs, acc, mask=m_mask & n_mask) + + +def a_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + handle: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], acc: bool): + k_indices, k_start, k_end = handle + + assert a.is_contiguous() and b.is_contiguous() and d.is_contiguous() + assert k_indices.is_contiguous() and k_start.is_contiguous() and k_end.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert k_indices.dtype == torch.int32 and k_start.dtype == torch.int32 and k_end.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 2 and d.dim() == 3 + assert k_start.numel() == k_end.numel() and k_indices.size(0) == b.size(0) + assert d.size(0) == k_start.numel() and d.size(1) == a.size(1) and d.size(2) == b.size(1) + assert b.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + + K_, M = a.shape + K, N = b.shape + B = k_start.numel() + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']) * B,) + a_fused_k_grouped_bf16_gemm_contiguous_tl_impl[grid]( + a, b, d, k_indices, k_start, k_end, M, N, K, ACC=acc) diff --git a/deep-gemm/torch-ext/deep_gemm/legacy/a_fused_m_grouped_gemm.py b/deep-gemm/torch-ext/deep_gemm/legacy/a_fused_m_grouped_gemm.py new file mode 100644 index 00000000..41b35d53 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/legacy/a_fused_m_grouped_gemm.py @@ -0,0 +1,92 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_m_grouped_gemm_configs(), key=[]) +@triton.jit +def a_fused_m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + m_indices_ptr, m_row_indices_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_B_K_MAJOR: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + n_mask = (n_range < N)[None, :] + + batch_id = tl.load(m_indices_ptr + pid_m * BLOCK_SIZE_M).to(tl.int64) + if batch_id < 0: + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=n_mask) + return + + # b block + rows = tl.load(m_row_indices_ptr + m_range).to(tl.int64) + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + k_range = k.to(tl.int64) + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + k_mask = k_range < K + a_ptrs = a_ptr + rows[:, None] * K + k_range[None, :] + b_ptrs = b_ptr + batch_id * K * N + k_range[:, None] * (1 if IS_B_K_MAJOR else N) + n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1) + a = tl.load(a_ptrs, mask=(rows >= 0)[:, None] & k_mask[None, :], other=0.0) + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask, other=0.0) + acc = tl.dot(a, b, acc) + d = acc.to(d_ptr.dtype.element_ty) + + # Write back + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, d, mask=n_mask) + + +def a_fused_m_grouped_bf16_gemm_nt_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + mappings: Tuple[torch.Tensor, torch.Tensor]): + m_indices, m_row_indices = mappings + r0, r1, r2 = b.shape + + assert a.is_contiguous() and (b.is_contiguous() or b.mT.is_contiguous()) and d.is_contiguous() + assert m_indices.is_contiguous() and m_row_indices.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 and d.dtype == torch.bfloat16 + assert m_indices.dtype == torch.int32 and m_row_indices.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 3 and d.dim() == 2 + assert a.size(1) == r2 and d.size(0) == m_indices.numel() and d.size(1) == r1 + assert m_indices.numel() == m_row_indices.numel() + assert m_indices.numel() % get_mk_alignment_for_contiguous_layout() == 0 + + if d.size(0) == 0: + return d + + M_, K = a.shape + B, K, N = r0, r2, r1 + M = m_indices.numel() + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N, meta['BLOCK_SIZE_N']), ) + a_fused_m_grouped_bf16_gemm_contiguous_tl_impl[grid](a, b, d, m_indices, m_row_indices, + M, N, K, IS_B_K_MAJOR=b.is_contiguous()) + + +def a_fused_m_grouped_bf16_gemm_nn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + mappings: Tuple[torch.Tensor, torch.Tensor]): + a_fused_m_grouped_bf16_gemm_nt_contiguous_tl(a, b.mT, d, mappings) diff --git a/deep-gemm/torch-ext/deep_gemm/legacy/b_fused_k_grouped_gemm.py b/deep-gemm/torch-ext/deep_gemm/legacy/b_fused_k_grouped_gemm.py new file mode 100644 index 00000000..7df8741f --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/legacy/b_fused_k_grouped_gemm.py @@ -0,0 +1,86 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_k_grouped_gemm_configs(), key=[], restore_value=['d_ptr']) +@triton.jit +def b_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + k_indices_ptr, k_start_ptr, k_end_ptr, + M: tl.constexpr, + N: tl.constexpr, + K, + ACC: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_b = (pid // (num_pid_m * num_pid_n)).to(tl.int64) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + k_start = tl.load(k_start_ptr + pid_b) + k_end = tl.load(k_end_ptr + pid_b) + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + m_mask = (m_range < M)[:, None] + n_mask = (n_range < N)[None, :] + + if k_start >= k_end: + if not ACC: + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=m_mask & n_mask) + return + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(k_start, k_end, BLOCK_SIZE_K): + k_range = k.to(tl.int64) + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + rows = tl.load(k_indices_ptr + k_range).to(tl.int64) + a_ptrs = a_ptr + m_range[:, None] + k_range[None, :] * M + b_ptrs = b_ptr + rows[:, None] * N + n_range[None, :] + a = tl.load(a_ptrs, mask=m_mask, other=0.0) + b = tl.load(b_ptrs, mask=(rows >= 0)[:, None] & n_mask, other=0.0) + acc = tl.dot(a, b, acc) + + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + if ACC: + acc += tl.load(d_ptrs, mask=m_mask & n_mask) + acc = acc.to(d_ptr.dtype.element_ty) + tl.store(d_ptrs, acc, mask=m_mask & n_mask) + + +def b_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + handle: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], acc: bool): + k_indices, k_start, k_end = handle + + assert a.is_contiguous() and b.is_contiguous() and d.is_contiguous() + assert k_indices.is_contiguous() and k_start.is_contiguous() and k_end.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert k_indices.dtype == torch.int32 and k_start.dtype == torch.int32 and k_end.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 2 and d.dim() == 3 + assert k_start.numel() == k_end.numel() and k_indices.size(0) == a.size(0) + assert d.size(0) == k_start.numel() and d.size(1) == a.size(1) and d.size(2) == b.size(1) + assert a.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + + K, M = a.shape + K_, N = b.shape + B = k_start.numel() + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']) * B,) + b_fused_k_grouped_bf16_gemm_contiguous_tl_impl[grid](a, b, d, k_indices, k_start, k_end, M, N, K, ACC=acc) diff --git a/deep-gemm/torch-ext/deep_gemm/legacy/m_grouped_gemm.py b/deep-gemm/torch-ext/deep_gemm/legacy/m_grouped_gemm.py new file mode 100644 index 00000000..e685a9ab --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/legacy/m_grouped_gemm.py @@ -0,0 +1,84 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_m_grouped_gemm_configs(), key=[]) +@triton.jit +def m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + m_indices_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_B_K_MAJOR: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + n_mask = (n_range < N)[None, :] + + # Empty tokens + batch_id = tl.load(m_indices_ptr + pid_m * BLOCK_SIZE_M).to(tl.int64) + if batch_id < 0: + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=n_mask) + return + + # Compute + a_ptrs = a_ptr + m_range[:, None].to(tl.int64) * K + tl.arange(0, BLOCK_SIZE_K)[None, :] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + b_ptrs = b_ptr + batch_id * K * N + \ + tl.arange(0, BLOCK_SIZE_K)[:, None].to(tl.int64) * (1 if IS_B_K_MAJOR else N) + \ + n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1) + for k in range(0, K, BLOCK_SIZE_K): + k_mask = (k + tl.arange(0, BLOCK_SIZE_K)) < K + a = tl.load(a_ptrs, mask=k_mask[None, :], other=0.0) + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * (1 if IS_B_K_MAJOR else N) + + # Write back + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, accumulator.to(d_ptr.dtype.element_ty), mask=n_mask) + + +def m_grouped_bf16_gemm_nt_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + m_indices: torch.Tensor): + r0, r1, r2 = b.shape + + assert a.is_contiguous() and (b.is_contiguous or b.mT.is_contiguous()) + assert m_indices.is_contiguous() and d.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert m_indices.dtype == torch.int32 and d.dtype == torch.bfloat16 + assert a.dim() == 2 and b.dim() == 3 and d.dim() == 2 + assert a.size(1) == r2 and a.size(0) == d.size(0) and r1 == d.size(1) + assert m_indices.numel() == a.size(0) + assert a.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + M, K = a.shape + B, N, K_ = r0, r1, r2 + + # For Triton 2.0, persistent kernel will lead to errors + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + m_grouped_bf16_gemm_contiguous_tl_impl[grid]( + a, b, d, m_indices, M, N, K, IS_B_K_MAJOR=b.is_contiguous()) + + +def m_grouped_bf16_gemm_nn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + m_indices: torch.Tensor): + m_grouped_bf16_gemm_nt_contiguous_tl(a, b.mT, d, m_indices) diff --git a/deep-gemm/torch-ext/deep_gemm/legacy/tune_options.py b/deep-gemm/torch-ext/deep_gemm/legacy/tune_options.py new file mode 100644 index 00000000..ed6a7f77 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/legacy/tune_options.py @@ -0,0 +1,28 @@ +from triton import Config +from .._C import get_mk_alignment_for_contiguous_layout + + +def get_config_smem_size(config: Config, elem_bytes: int = 2): + # NOTES: FP8 kernels will not use Triton, so by default we assume BF16 kernels + return (config.kwargs['BLOCK_SIZE_M'] + config.kwargs['BLOCK_SIZE_N']) * config.kwargs['BLOCK_SIZE_K'] * elem_bytes * config.num_stages + + +_gemm_configs = [ + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4), + Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4), +] + +# NOTES: we only consider A100 shared memory sizes here, as legacy kernels are only used for Ampere +_gemm_configs = list(filter(lambda x: get_config_smem_size(x) <= 166912, _gemm_configs)) +_gemm_configs = list(filter(lambda x: x.kwargs['BLOCK_SIZE_M'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) +_gemm_configs = list(filter(lambda x: x.kwargs['BLOCK_SIZE_K'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) + +get_m_grouped_gemm_configs = lambda: list(filter(lambda x: x.kwargs['BLOCK_SIZE_M'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) +get_k_grouped_gemm_configs = lambda: list(filter(lambda x: x.kwargs['BLOCK_SIZE_K'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) diff --git a/deep-gemm/torch-ext/deep_gemm/mega/__init__.py b/deep-gemm/torch-ext/deep_gemm/mega/__init__.py new file mode 100644 index 00000000..670b409d --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/mega/__init__.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import torch +from typing import Tuple, Optional +from ..utils.math import align + +# noinspection PyBroadException +try: + # noinspection PyProtectedMember + import torch.distributed._symmetric_memory as symm_mem + import torch.distributed as dist +except Exception as exception: + print(f'Failed to load mega kernels, please check your PyTorch version: {exception}') + +from .. import _C + + +class SymmBuffer: + def __init__(self, group: dist.ProcessGroup, + # MoE arguments + num_experts: int, + num_max_tokens_per_rank: int, num_topk: int, + hidden: int, intermediate_hidden: int, + use_fp8_dispatch: bool = True, + activation: str = 'swiglu'): + self.group = group + self.num_experts = num_experts + self.num_max_tokens_per_rank = num_max_tokens_per_rank + self.num_topk = num_topk + self.hidden = hidden + self.intermediate_hidden = intermediate_hidden + + # Allocate a symmetric buffer + num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe( + group.size(), num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + use_fp8_dispatch, activation + ) + self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda') + self.handle = symm_mem.rendezvous(self.buffer, group=group) + self.buffer.zero_() + self.group.barrier() + torch.cuda.synchronize() + + # Create input buffer views + (self.x, self.x_sf, + self.topk_idx, self.topk_weights, + self.l1_acts, self.l1_acts_sf, + self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer) + + def destroy(self): + self.handle = None + self.buffer = None + self.group = None + self.x = None + self.x_sf = None + + +def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup, + num_experts: int, + num_max_tokens_per_rank: int, num_topk: int, + hidden: int, intermediate_hidden: int, + use_fp8_dispatch: bool = True, + activation: str = 'swiglu') -> SymmBuffer: + # Token count must be aligned to block sizes + num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe()) + + return SymmBuffer( + group, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + use_fp8_dispatch, activation + ) + + +def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + # [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up] + def interleave(t, gran: int = 8) -> torch.Tensor: + g, n, *rest = t.shape + half = n // 2 + gate = t[:, :half].reshape(g, half // gran, gran, *rest) + up = t[:, half:].reshape(g, half // gran, gran, *rest) + return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest)) + + return interleave(l1_weights[0]), interleave(l1_weights[1]) + + +def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: + num_groups, mn, packed_sf_k = sf.shape + assert sf.dtype == torch.int and mn % 128 == 0 + result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k) + .transpose(2, 3) + .reshape(num_groups, mn, packed_sf_k)) + return torch.empty_like(sf).copy_(result) + + +def transform_weights_for_mega_moe( + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor] +) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + # L1: interleave gate/up, then transpose SF for UTCCP + l1_interleaved = _interleave_l1_weights(l1_weights) + l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1])) + # L2: only transpose SF for UTCCP + l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1])) + return l1_weights, l2_weights + + +def fp8_fp4_mega_moe(y: torch.Tensor, + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor], + sym_buffer: SymmBuffer, + cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, + recipe: Tuple[int, int, int] = (1, 1, 32), + activation: str = 'swiglu', + activation_clamp: Optional[float] = None, + fast_math: bool = True): + _C.fp8_fp4_mega_moe( + y, + l1_weights, l2_weights, + cumulative_local_expert_recv_stats, + sym_buffer.buffer, + sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), + sym_buffer.num_max_tokens_per_rank, + sym_buffer.num_experts, sym_buffer.num_topk, + recipe, + activation, activation_clamp, + fast_math + ) diff --git a/deep-gemm/torch-ext/deep_gemm/testing/bench.py b/deep-gemm/torch-ext/deep_gemm/testing/bench.py index 2c752da2..552b9aa1 100644 --- a/deep-gemm/torch-ext/deep_gemm/testing/bench.py +++ b/deep-gemm/torch-ext/deep_gemm/testing/bench.py @@ -1,6 +1,7 @@ import os import sys import torch +from typing import Callable, Optional def bench(fn, num_warmups: int = 5, num_tests: int = 10, @@ -78,7 +79,8 @@ def __exit__(self, *_): def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, trace_path: str = None, flush_l2: bool = True, - with_multiple_kernels: bool = False): + with_multiple_kernels: bool = False, + barrier: Optional[Callable] = None): assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) is_tuple = isinstance(kernel_names, tuple) @@ -96,14 +98,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, # Profile suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress with suppress(): - schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) - profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) + schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) + profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule, acc_events=True) with profiler: for i in range(2): for _ in range(num_tests): if flush_l2: torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() + if barrier is not None: + # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead + # noinspection PyProtectedMember + torch.cuda._sleep(int(2e7)) # ~10ms + barrier() fn() + torch.cuda.synchronize() profiler.step() # Parse the profiling table @@ -111,7 +120,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names if not with_multiple_kernels: for name in kernel_names: - assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table' + assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table {prof_lines}' # Save chrome traces if trace_path is not None: diff --git a/deep-gemm/torch-ext/deep_gemm/utils/__init__.py b/deep-gemm/torch-ext/deep_gemm/utils/__init__.py index e8f859a2..a0dc6f78 100644 --- a/deep-gemm/torch-ext/deep_gemm/utils/__init__.py +++ b/deep-gemm/torch-ext/deep_gemm/utils/__init__.py @@ -1,3 +1,4 @@ from . import math, layout from .layout import * from .math import * +from .dist import init_dist, uneven_all_gather diff --git a/deep-gemm/torch-ext/deep_gemm/utils/dist.py b/deep-gemm/torch-ext/deep_gemm/utils/dist.py new file mode 100644 index 00000000..426c3967 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/utils/dist.py @@ -0,0 +1,74 @@ +import inspect +import os +import torch +import torch.distributed as dist +from typing import Tuple + +_local_rank = None + + +def init_dist(local_rank: int, num_local_ranks: int) -> Tuple[int, int, dist.ProcessGroup]: + # NOTES: you may rewrite this function with your own cluster settings + ip = os.getenv('MASTER_ADDR', '127.0.0.1') + port = int(os.getenv('MASTER_PORT', '8361')) + num_nodes = int(os.getenv('WORLD_SIZE', 1)) + node_rank = int(os.getenv('RANK', 0)) + + # Set local rank + global _local_rank + _local_rank = local_rank + + sig = inspect.signature(dist.init_process_group) + params = { + 'backend': 'nccl', + 'init_method': f'tcp://{ip}:{port}', + 'world_size': num_nodes * num_local_ranks, + 'rank': node_rank * num_local_ranks + local_rank, + } + if 'device_id' in sig.parameters: + # noinspection PyTypeChecker + params['device_id'] = torch.device(f'cuda:{local_rank}') + dist.init_process_group(**params) + torch.set_default_device('cuda') + torch.cuda.set_device(local_rank) + + return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes))) + + +def uneven_all_gather(tensor: torch.Tensor, dim: int = 0, group: dist.ProcessGroup = None) -> torch.Tensor: + world_size = dist.get_world_size(group) + + # Exchange sizes + local_dim_size = torch.tensor([tensor.shape[dim]], device=tensor.device, dtype=torch.long) + all_dim_sizes = [torch.zeros_like(local_dim_size) for _ in range(world_size)] + dist.all_gather(all_dim_sizes, local_dim_size, group=group) + all_dim_sizes = [s.item() for s in all_dim_sizes] + max_dim_size = max(all_dim_sizes) + + # Pad + if tensor.shape[dim] < max_dim_size: + pad_shape = list(tensor.shape) + pad_shape[dim] = max_dim_size - tensor.shape[dim] + padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device) + tensor_padded = torch.cat([tensor, padding], dim=dim) + else: + tensor_padded = tensor.contiguous() + + # All-gather + gathered = [torch.zeros_like(tensor_padded) for _ in range(world_size)] + dist.all_gather(gathered, tensor_padded, group=group) + + # Remove padding + trimmed = [ + torch.narrow(gathered[i], dim, 0, all_dim_sizes[i]) + for i in range(world_size) + ] + return torch.cat(trimmed, dim=dim) + + +def dist_print(s: str = '', once_in_node: bool = False) -> None: + global _local_rank + assert _local_rank is not None + if not once_in_node or _local_rank == 0: + print(s, flush=True) + dist.barrier() diff --git a/deep-gemm/torch-ext/deep_gemm/utils/layout.py b/deep-gemm/torch-ext/deep_gemm/utils/layout.py index a6bc29d9..6512c5ab 100644 --- a/deep-gemm/torch-ext/deep_gemm/utils/layout.py +++ b/deep-gemm/torch-ext/deep_gemm/utils/layout.py @@ -1,25 +1,21 @@ -from .._ops import ops - - -def get_mk_alignment_for_contiguous_layout(): - return ops.get_mk_alignment_for_contiguous_layout() - - -def get_tma_aligned_size(mn: int, element_size: int): - return ops.get_tma_aligned_size(mn, element_size).item() - - -def get_mn_major_tma_aligned_tensor(sf): - return ops.get_mn_major_tma_aligned_tensor(sf) - - -def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): - return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) - - -def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks): - return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks) - - +try: + from .._C import ( + get_tma_aligned_size, + get_mn_major_tma_aligned_tensor, + get_mn_major_tma_aligned_packed_ue8m0_tensor, + get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor + ) +except ImportError: + # Expected behavior for CUDA runtime version before 12.1 + pass + +# Valid for all CUDA versions +from .._C import ( + set_mk_alignment_for_contiguous_layout, + get_mk_alignment_for_contiguous_layout, + get_theoretical_mk_alignment_for_contiguous_layout, +) + +# Some alias get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout diff --git a/deep-gemm/torch-ext/deep_gemm/utils/math.py b/deep-gemm/torch-ext/deep_gemm/utils/math.py index c65026e5..f1582ed5 100644 --- a/deep-gemm/torch-ext/deep_gemm/utils/math.py +++ b/deep-gemm/torch-ext/deep_gemm/utils/math.py @@ -11,21 +11,30 @@ def align(x: int, y: int) -> int: def ceil_to_ue8m0(x: torch.Tensor): - assert x.view(-1).amax().item() > 0 - return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + bits = x.abs().float().view(torch.int) + exp = ((bits >> 23) & 0xFF) + (bits & 0x7FFFFF).bool().int() + return (exp.clamp(1, 254) << 23).view(torch.float) -def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: +def pack_ue8m0_to_int(x: torch.Tensor): + assert x.dtype == torch.float and x.size(-1) % 4 == 0 + assert (x.view(torch.int) & ((1 << 23) - 1) == 0).all() + return (x.view(torch.int) >> 23).to(torch.uint8).view(torch.int) + + +def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128, + use_packed_ue8m0: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape padded_n = align(n, gran_k) x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0) x_padded[:, :n] = x - x_view = x_padded.view(m, -1, gran_k) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + x_view = x_padded.view(m, padded_n // gran_k, gran_k) + x_amax = x_view.abs().float().amax(dim=2).view(m, padded_n // gran_k).clamp(1e-4) sf = x_amax / 448.0 sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf - return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf + x_fp8 = (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous() + return x_fp8, pack_ue8m0_to_int(sf) if use_packed_ue8m0 else sf def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: @@ -70,13 +79,14 @@ def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: code = idx.to(torch.uint8) sign = (x < 0) & (idx != 0) code = code | (sign.to(torch.uint8) << 3) - return code # uint8, 0..15 + return code.view(torch.int8) -def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 +def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128, + use_packed_ue8m0: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: m, n = x.shape assert n % 2 == 0 + assert not use_packed_ue8m0 or use_ue8m0 padded_n = align(n, gran_k) x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device) x_padded[:, :n] = x @@ -85,23 +95,49 @@ def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) - sf = x_amax / 6.0 sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf x_scaled = x_view * (1.0 / sf.unsqueeze(2)) - codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n) + codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # int8, (m, padded_n) codes2 = codes.view(m, padded_n // 2, 2) - packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8 - return packed[:, :n // 2].contiguous(), sf + packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # int8 + return packed[:, :n // 2].contiguous(), pack_ue8m0_to_int(sf) if use_packed_ue8m0 else sf def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor: - assert a.dtype == torch.uint8 + assert a.dtype == torch.int8 assert a.dim() == 2 m, n2 = a.shape n = n2 * 2 assert (m % 2) == 0 lo = a & 0x0F hi = (a >> 4) & 0x0F - codes = torch.empty((m, n), device=a.device, dtype=torch.uint8) + codes = torch.empty((m, n), device=a.device, dtype=torch.int8) codes[:, 0::2], codes[:, 1::2] = lo, hi codes_t = codes.transpose(0, 1).contiguous() codes2 = codes_t.view(n, m // 2, 2) out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) - return out.contiguous() \ No newline at end of file + return out.contiguous() + + +def _dequantize_from_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: + fp4_values = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device=x.device, dtype=torch.float) + sign, value_idx = (x & 0x08) != 0, (x & 0x07).to(torch.int) + value = fp4_values[value_idx] + return torch.where(sign & (value_idx != 0), -value, value) + + +def unpack_ue8m0_from_int(packed_sf: torch.Tensor) -> torch.Tensor: + return (packed_sf.view(torch.uint8).to(torch.int) << 23).view(torch.float) + + +def cast_back_from_fp4(packed: torch.Tensor, sf: torch.Tensor, gran_k: int = 128, + use_packed_ue8m0: bool = False) -> torch.Tensor: + m, n2 = packed.shape + n = n2 * 2 + if use_packed_ue8m0: + sf = unpack_ue8m0_from_int(sf) + unpacked = torch.zeros((m, n), dtype=torch.int8, device=packed.device) + unpacked[:, ::2] = packed & 0x0F + unpacked[:, 1::2] = (packed >> 4) & 0x0F + x_dequantized = _dequantize_from_fp4_e2m1(unpacked) + group_idx = torch.arange(n, device=packed.device) // gran_k + x_restored = x_dequantized * sf[:, group_idx] + return x_restored \ No newline at end of file diff --git a/deep-gemm/torch-ext/torch_binding.cpp b/deep-gemm/torch-ext/torch_binding.cpp index 82493aa5..1aa20034 100644 --- a/deep-gemm/torch-ext/torch_binding.cpp +++ b/deep-gemm/torch-ext/torch_binding.cpp @@ -21,10 +21,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("get_tc_util() -> int"); ops.impl("get_tc_util", &deep_gemm_get_tc_util); + ops.def("set_pdl(bool enable_pdl) -> ()"); + ops.impl("set_pdl", &deep_gemm_set_pdl); + + ops.def("get_pdl() -> bool"); + ops.impl("get_pdl", &deep_gemm_get_pdl); + + ops.def("set_ignore_compile_dims(bool ignore_compile_dims) -> ()"); + ops.impl("set_ignore_compile_dims", &deep_gemm_set_ignore_compile_dims); + + ops.def("set_block_size_multiple_of(int block_m, int block_n) -> ()"); + ops.impl("set_block_size_multiple_of", &deep_gemm_set_block_size_multiple_of); + + ops.def("set_mk_alignment_for_contiguous_layout(int alignment) -> ()"); + ops.impl("set_mk_alignment_for_contiguous_layout", + &deep_gemm_set_mk_alignment_for_contiguous_layout); + ops.def("get_mk_alignment_for_contiguous_layout() -> int"); ops.impl("get_mk_alignment_for_contiguous_layout", &deep_gemm_get_mk_alignment_for_contiguous_layout); + ops.def( + "get_theoretical_mk_alignment_for_contiguous_layout(" + "int expected_m, bool has_expected_m) -> int" + ); + ops.impl("get_theoretical_mk_alignment_for_contiguous_layout", + &deep_gemm_get_theoretical_mk_alignment_for_contiguous_layout); + // Layout ops (CUDA dispatch) ops.def( "get_tma_aligned_size(int mn, int element_size) -> Tensor" @@ -45,7 +68,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(" - "Tensor sf, Tensor ks_tensor, Tensor ks_int_tensor) -> Tensor" + "Tensor sf, Tensor ks_tensor, Tensor ks_int_tensor, int gran_k) -> Tensor" ); ops.impl("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", torch::kCUDA, &deep_gemm_get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor); @@ -53,10 +76,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "transform_sf_into_required_layout(" "Tensor sf, int mn, int k, " - "int recipe_0, int recipe_1, int recipe_2, bool has_recipe, " - "int recipe_ab_0, int recipe_ab_1, bool has_recipe_ab, " + "int recipe_0, int recipe_1, int recipe_2, int recipe_len, " "int num_groups, bool has_num_groups, " - "bool is_sfa, bool disable_ue8m0_cast) -> Tensor" + "bool is_sfa, bool has_is_sfa, bool disable_ue8m0_cast) -> Tensor" ); ops.impl("transform_sf_into_required_layout", torch::kCUDA, &deep_gemm_transform_sf_into_required_layout); @@ -260,6 +282,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("fp8_gemm_nt_skip_head_mid", torch::kCUDA, &deep_gemm_fp8_gemm_nt_skip_head_mid); + ops.def( + "fp8_fp4_mqa_logits(" + "Tensor q_data, Tensor? q_sf, Tensor kv_data, Tensor kv_sf, " + "Tensor weights, Tensor cu_seq_len_k_start, Tensor cu_seq_len_k_end, " + "bool clean_logits, int max_seqlen_k, ScalarType logits_dtype) -> Tensor" + ); + ops.impl("fp8_fp4_mqa_logits", torch::kCUDA, &deep_gemm_fp8_fp4_mqa_logits); + ops.def( "fp8_mqa_logits(" "Tensor q, Tensor kv_data, Tensor kv_sf, " @@ -270,21 +300,67 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "get_paged_mqa_logits_metadata(" - "Tensor context_lens, int block_kv, int num_sms) -> Tensor" + "Tensor context_lens, int block_kv, int num_sms, Tensor? indices) -> Tensor" ); ops.impl("get_paged_mqa_logits_metadata", torch::kCUDA, &deep_gemm_get_paged_mqa_logits_metadata); + ops.def( + "fp8_fp4_paged_mqa_logits(" + "Tensor q_data, Tensor? q_sf, Tensor fused_kv_cache, " + "Tensor weights, Tensor context_lens, " + "Tensor block_table, Tensor schedule_meta, " + "int max_context_len, bool clean_logits, ScalarType logits_dtype, " + "Tensor? indices) -> Tensor" + ); + ops.impl("fp8_fp4_paged_mqa_logits", torch::kCUDA, + &deep_gemm_fp8_fp4_paged_mqa_logits); + ops.def( "fp8_paged_mqa_logits(" "Tensor q, Tensor fused_kv_cache, " "Tensor weights, Tensor context_lens, " "Tensor block_table, Tensor schedule_meta, " - "int max_context_len, bool clean_logits) -> Tensor" + "int max_context_len, bool clean_logits, Tensor? indices) -> Tensor" ); ops.impl("fp8_paged_mqa_logits", torch::kCUDA, &deep_gemm_fp8_paged_mqa_logits); + // Mega MoE ops + ops.def("get_token_alignment_for_mega_moe() -> int"); + ops.impl("get_token_alignment_for_mega_moe", + &deep_gemm_get_token_alignment_for_mega_moe); + + ops.def( + "get_symm_buffer_size_for_mega_moe(" + "int num_ranks, int num_experts, int num_max_tokens_per_rank, " + "int num_topk, int hidden, int intermediate_hidden, " + "bool use_fp8_dispatch, str activation) -> int" + ); + ops.impl("get_symm_buffer_size_for_mega_moe", + &deep_gemm_get_symm_buffer_size_for_mega_moe); + + ops.def( + "get_symm_buffer_views_for_mega_moe(" + "Tensor buffer, int num_ranks, int num_experts, " + "int num_max_tokens_per_rank, int num_topk, int hidden, " + "int intermediate_hidden, bool use_fp8_dispatch, str activation) -> Tensor[]" + ); + ops.impl("get_symm_buffer_views_for_mega_moe", torch::kCUDA, + &deep_gemm_get_symm_buffer_views_for_mega_moe); + + ops.def( + "fp8_fp4_mega_moe(" + "Tensor! y, Tensor l1_weights, Tensor l1_weights_sf, " + "Tensor l2_weights, Tensor l2_weights_sf, " + "Tensor? cumulative_local_expert_recv_stats, Tensor sym_buffer, " + "int[] sym_buffer_ptrs, int rank_idx, int num_max_tokens_per_rank, " + "int num_experts, int num_topk, int recipe_0, int recipe_1, " + "int recipe_2, str activation, float? activation_clamp, " + "bool fast_math) -> ()" + ); + ops.impl("fp8_fp4_mega_moe", torch::kCUDA, &deep_gemm_fp8_fp4_mega_moe); + // Einsum ops (CUDA dispatch) ops.def( "einsum(str expr, Tensor a, Tensor b, Tensor! d, " diff --git a/deep-gemm/torch-ext/torch_binding.h b/deep-gemm/torch-ext/torch_binding.h index 82bc9012..4120093a 100644 --- a/deep-gemm/torch-ext/torch_binding.h +++ b/deep-gemm/torch-ext/torch_binding.h @@ -1,8 +1,9 @@ #pragma once -#include +#include "utils/torch_compat.hpp" #include #include +#include using Tensor = at::Tensor; @@ -18,11 +19,20 @@ int64_t deep_gemm_get_num_sms(); void deep_gemm_set_tc_util(int64_t tc_util); int64_t deep_gemm_get_tc_util(); +void deep_gemm_set_pdl(bool enable_pdl); +bool deep_gemm_get_pdl(); + +void deep_gemm_set_ignore_compile_dims(bool ignore_compile_dims); +void deep_gemm_set_block_size_multiple_of(int64_t block_m, int64_t block_n); + // ============================================================================ // Layout ops // ============================================================================ +void deep_gemm_set_mk_alignment_for_contiguous_layout(int64_t alignment); int64_t deep_gemm_get_mk_alignment_for_contiguous_layout(); +int64_t deep_gemm_get_theoretical_mk_alignment_for_contiguous_layout( + int64_t expected_m, bool has_expected_m); Tensor deep_gemm_get_tma_aligned_size(int64_t mn, int64_t element_size); @@ -31,14 +41,13 @@ Tensor deep_gemm_get_mn_major_tma_aligned_tensor(const Tensor& sf); Tensor deep_gemm_get_mn_major_tma_aligned_packed_ue8m0_tensor(const Tensor& sf); Tensor deep_gemm_get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( - const Tensor& sf, const Tensor& ks_tensor, const Tensor& ks_int_tensor); + const Tensor& sf, const Tensor& ks_tensor, const Tensor& ks_int_tensor, int64_t gran_k); Tensor deep_gemm_transform_sf_into_required_layout( const Tensor& sf, int64_t mn, int64_t k, - int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, - int64_t recipe_ab_0, int64_t recipe_ab_1, bool has_recipe_ab, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, int64_t recipe_len, int64_t num_groups, bool has_num_groups, - bool is_sfa, bool disable_ue8m0_cast); + bool is_sfa, bool has_is_sfa, bool disable_ue8m0_cast); // ============================================================================ // GEMM ops - FP8/FP4 @@ -215,6 +224,13 @@ void deep_gemm_fp8_gemm_nt_skip_head_mid( int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, const std::string& compiled_dims, bool disable_ue8m0_cast); +Tensor deep_gemm_fp8_fp4_mqa_logits( + const Tensor& q_data, const std::optional& q_sf, + const Tensor& kv_data, const Tensor& kv_sf, + const Tensor& weights, + const Tensor& cu_seq_len_k_start, const Tensor& cu_seq_len_k_end, + bool clean_logits, int64_t max_seqlen_k, at::ScalarType logits_dtype); + Tensor deep_gemm_fp8_mqa_logits( const Tensor& q, const Tensor& kv_data, const Tensor& kv_sf, @@ -223,13 +239,42 @@ Tensor deep_gemm_fp8_mqa_logits( bool clean_logits, int64_t max_seqlen_k); Tensor deep_gemm_get_paged_mqa_logits_metadata( - const Tensor& context_lens, int64_t block_kv, int64_t num_sms); + const Tensor& context_lens, int64_t block_kv, int64_t num_sms, + const std::optional& indices); + +Tensor deep_gemm_fp8_fp4_paged_mqa_logits( + const Tensor& q_data, const std::optional& q_sf, + const Tensor& fused_kv_cache, + const Tensor& weights, const Tensor& context_lens, + const Tensor& block_table, const Tensor& schedule_meta, + int64_t max_context_len, bool clean_logits, at::ScalarType logits_dtype, + const std::optional& indices); Tensor deep_gemm_fp8_paged_mqa_logits( const Tensor& q, const Tensor& fused_kv_cache, const Tensor& weights, const Tensor& context_lens, const Tensor& block_table, const Tensor& schedule_meta, - int64_t max_context_len, bool clean_logits); + int64_t max_context_len, bool clean_logits, const std::optional& indices); + +int64_t deep_gemm_get_token_alignment_for_mega_moe(); +int64_t deep_gemm_get_symm_buffer_size_for_mega_moe( + int64_t num_ranks, int64_t num_experts, int64_t num_max_tokens_per_rank, + int64_t num_topk, int64_t hidden, int64_t intermediate_hidden, + bool use_fp8_dispatch, const std::string& activation); +std::vector deep_gemm_get_symm_buffer_views_for_mega_moe( + const Tensor& buffer, int64_t num_ranks, int64_t num_experts, + int64_t num_max_tokens_per_rank, int64_t num_topk, int64_t hidden, + int64_t intermediate_hidden, bool use_fp8_dispatch, const std::string& activation); +void deep_gemm_fp8_fp4_mega_moe( + const Tensor& y, + const Tensor& l1_weights, const Tensor& l1_weights_sf, + const Tensor& l2_weights, const Tensor& l2_weights_sf, + const std::optional& cumulative_local_expert_recv_stats, + const Tensor& sym_buffer, c10::List sym_buffer_ptrs, int64_t rank_idx, + int64_t num_max_tokens_per_rank, int64_t num_experts, int64_t num_topk, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, + const std::string& activation, const std::optional& activation_clamp, + bool fast_math); // ============================================================================ // Einsum ops From 71aba80011dd066309d5e65c2ac085fb27131ed7 Mon Sep 17 00:00:00 2001 From: adarshxs Date: Tue, 5 May 2026 23:53:40 +0530 Subject: [PATCH 2/7] upd --- deep-gemm/torch-ext/deep_gemm/__init__.py | 5 +++++ terraform.tfstate | 1 + 2 files changed, 6 insertions(+) create mode 100644 terraform.tfstate diff --git a/deep-gemm/torch-ext/deep_gemm/__init__.py b/deep-gemm/torch-ext/deep_gemm/__init__.py index d3acc4db..21fba2a6 100644 --- a/deep-gemm/torch-ext/deep_gemm/__init__.py +++ b/deep-gemm/torch-ext/deep_gemm/__init__.py @@ -2,6 +2,11 @@ import subprocess import torch +# Avoid holding a CUDA tensor in DeepGEMM's process-lifetime runtime singleton. +# In packaged/lazy-loaded use, that can outlive PyTorch's CUDA teardown and crash +# during interpreter shutdown. +os.environ.setdefault("DG_USE_TEMP_CUBLASLT_WORKSPACE", "1") + # Import the compiled extension from ._ops import ops as _ops, add_op_namespace_prefix from . import utils diff --git a/terraform.tfstate b/terraform.tfstate new file mode 100644 index 00000000..c46a2ace --- /dev/null +++ b/terraform.tfstate @@ -0,0 +1 @@ +{"version":4,"terraform_version":"1.11.6","serial":1,"lineage":"80a098ad-0d68-c9f6-646b-54c23c125621","outputs":{},"resources":[],"check_results":null} From a98a57a534101c1a52e4fc401aeaf65e2f85d74b Mon Sep 17 00:00:00 2001 From: adarshxs Date: Tue, 5 May 2026 23:54:39 +0530 Subject: [PATCH 3/7] upd --- terraform.tfstate | 1 - 1 file changed, 1 deletion(-) delete mode 100644 terraform.tfstate diff --git a/terraform.tfstate b/terraform.tfstate deleted file mode 100644 index c46a2ace..00000000 --- a/terraform.tfstate +++ /dev/null @@ -1 +0,0 @@ -{"version":4,"terraform_version":"1.11.6","serial":1,"lineage":"80a098ad-0d68-c9f6-646b-54c23c125621","outputs":{},"resources":[],"check_results":null} From 82cf6cd32d5c110a58a9c75a792d520dec4c21d4 Mon Sep 17 00:00:00 2001 From: adarshxs Date: Wed, 6 May 2026 00:13:16 +0530 Subject: [PATCH 4/7] upd --- deep-gemm/csrc/jit/compiler.hpp | 16 ++++-- deep-gemm/csrc/jit/include_parser.hpp | 74 +++++++++++++++++++++------ 2 files changed, 69 insertions(+), 21 deletions(-) diff --git a/deep-gemm/csrc/jit/compiler.hpp b/deep-gemm/csrc/jit/compiler.hpp index 265b787d..0bc7fec6 100644 --- a/deep-gemm/csrc/jit/compiler.hpp +++ b/deep-gemm/csrc/jit/compiler.hpp @@ -2,13 +2,13 @@ #include #include +#include #include #include #include #ifdef DG_ENABLE_NVRTC_COMPILER #include #endif -#include #include #include "../utils/exception.hpp" @@ -173,6 +173,14 @@ DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuobjdump_path); class NVCCCompiler final: public Compiler { std::filesystem::path nvcc_path; + static bool parse_nvcc_release(const std::string& output, int& major, int& minor) { + const std::string marker = "release "; + const auto pos = output.find(marker); + if (pos == std::string::npos) + return false; + return std::sscanf(output.c_str() + pos + marker.size(), "%d.%d", &major, &minor) == 2; + } + std::pair get_nvcc_version() const { DG_HOST_ASSERT(std::filesystem::exists(nvcc_path)); @@ -183,9 +191,7 @@ class NVCCCompiler final: public Compiler { // The version should be at least 12.3, for the best performance with 12.9 int major, minor; - std::smatch match; - DG_HOST_ASSERT(std::regex_search(output, match, std::regex(R"(release (\d+\.\d+))"))); - std::sscanf(match[1].str().c_str(), "%d.%d", &major, &minor); + DG_HOST_ASSERT(parse_nvcc_release(output, major, minor)); DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVCC version should be >= 12.3"); if (major == 12 and minor < 9) printf("Warning: please use at least NVCC 12.9 for the best DeepGEMM performance\n"); @@ -245,7 +251,7 @@ class NVCCCompiler final: public Compiler { // Check local memory usage if (get_env("DG_JIT_PTXAS_CHECK", 0)) - DG_HOST_ASSERT(not std::regex_search(output, std::regex(R"(Local memory used)"))); + DG_HOST_ASSERT(output.find("Local memory used") == std::string::npos); // Print PTXAS log if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0)) diff --git a/deep-gemm/csrc/jit/include_parser.hpp b/deep-gemm/csrc/jit/include_parser.hpp index 99f2663c..9e74bba5 100644 --- a/deep-gemm/csrc/jit/include_parser.hpp +++ b/deep-gemm/csrc/jit/include_parser.hpp @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include @@ -13,26 +12,69 @@ namespace deep_gemm { class IncludeParser { std::unordered_map> cache; + static bool is_include_space(const char& c) { + return c == ' ' or c == '\t' or c == '\r' or c == '\f' or c == '\v'; + } + + static void raise_non_standard_include( + const std::string& include_str, const std::filesystem::path& file_path) { + std::string error_info = fmt::format("Non-standard include: {}", include_str); + if (file_path != "") + error_info += fmt::format(" ({})", file_path.string()); + DG_HOST_UNREACHABLE(error_info); + } + static std::vector get_includes(const std::string& code, const std::filesystem::path& file_path = "") { std::vector includes; - const std::regex pattern(R"(#\s*include\s*[<"][^>"]+[>"])"); - std::sregex_iterator iter(code.begin(), code.end(), pattern); - const std::sregex_iterator end; // TODO: parse relative paths as well - for (; iter != end; ++ iter) { - const auto include_str = iter->str(); - const int len = include_str.length(); - if (include_str.substr(0, 10) == "#include <" and include_str[len - 1] == '>' and include_str[10] != ' ' and include_str[len - 2] != ' ') { - std::string filename = include_str.substr(10, len - 11); - if (filename.substr(0, 9) == "deep_gemm") // We only parse `` - includes.push_back(filename); - } else { - std::string error_info = fmt::format("Non-standard include: {}", include_str); - if (file_path != "") - error_info += fmt::format(" ({})", file_path.string()); - DG_HOST_UNREACHABLE(error_info); + size_t line_begin = 0; + while (line_begin < code.size()) { + auto line_end = code.find('\n', line_begin); + if (line_end == std::string::npos) + line_end = code.size(); + + auto pos = line_begin; + while (pos < line_end and is_include_space(code[pos])) + ++ pos; + + const auto directive_begin = pos; + if (pos < line_end and code[pos] == '#') { + ++ pos; + while (pos < line_end and is_include_space(code[pos])) + ++ pos; + + constexpr size_t kIncludeLen = 7; + if (line_end - pos >= kIncludeLen and code.compare(pos, kIncludeLen, "include") == 0) { + pos += kIncludeLen; + if (pos < line_end and not is_include_space(code[pos]) and code[pos] != '<' and code[pos] != '"') { + line_begin = line_end + (line_end < code.size()); + continue; + } + + while (pos < line_end and is_include_space(code[pos])) + ++ pos; + + if (pos < line_end and code[pos] == '<') { + const auto name_begin = pos + 1; + const auto name_end = code.find('>', name_begin); + if (name_end == std::string::npos or name_end > line_end or + name_begin == name_end or code[name_begin] == ' ' or code[name_end - 1] == ' ') { + raise_non_standard_include(code.substr(directive_begin, line_end - directive_begin), file_path); + } + + const auto filename = code.substr(name_begin, name_end - name_begin); + if (filename.substr(0, 9) == "deep_gemm") // We only parse `` + includes.push_back(filename); + } else if (pos < line_end and code[pos] == '"') { + const auto quote_end = code.find('"', pos + 1); + const auto include_end = quote_end == std::string::npos or quote_end > line_end ? line_end : quote_end + 1; + raise_non_standard_include(code.substr(directive_begin, include_end - directive_begin), file_path); + } + } } + + line_begin = line_end + (line_end < code.size()); } return includes; } From d12ee93d48a395e4ba0a7a984b22ad4abd70910c Mon Sep 17 00:00:00 2001 From: adarshxs Date: Wed, 6 May 2026 00:27:14 +0530 Subject: [PATCH 5/7] upd --- deep-gemm/csrc/jit/include_parser.hpp | 8 +++---- deep-gemm/csrc/jit/kernel_runtime.hpp | 9 ++++++-- .../csrc/jit_kernels/impls/runtime_utils.hpp | 4 +++- deep-gemm/csrc/utils/hash.hpp | 16 +++++++++----- deep-gemm/csrc/utils/system.hpp | 21 ++++++++++++------- 5 files changed, 39 insertions(+), 19 deletions(-) diff --git a/deep-gemm/csrc/jit/include_parser.hpp b/deep-gemm/csrc/jit/include_parser.hpp index 9e74bba5..3b372019 100644 --- a/deep-gemm/csrc/jit/include_parser.hpp +++ b/deep-gemm/csrc/jit/include_parser.hpp @@ -87,12 +87,12 @@ class IncludeParser { } std::string get_hash_value(const std::string& code, const bool& exclude_code = true) { - std::stringstream ss; + std::string hash_input; for (const auto& i: get_includes(code)) - ss << get_hash_value_by_path(library_include_path / i) << "$"; + hash_input += get_hash_value_by_path(library_include_path / i) + "$"; if (not exclude_code) - ss << "#" << get_hex_digest(code); - return get_hex_digest(ss.str()); + hash_input += "#" + get_hex_digest(code); + return get_hex_digest(hash_input); } std::string get_hash_value_by_path(const std::filesystem::path& path) { diff --git a/deep-gemm/csrc/jit/kernel_runtime.hpp b/deep-gemm/csrc/jit/kernel_runtime.hpp index 40597fb4..badb9392 100644 --- a/deep-gemm/csrc/jit/kernel_runtime.hpp +++ b/deep-gemm/csrc/jit/kernel_runtime.hpp @@ -56,15 +56,20 @@ class KernelRuntime final { const std::vector illegal_names = {"vprintf", "__instantiate_kernel", "__internal", "__assertfail"}; const auto [exit_code, symbols] = call_external_command(fmt::format("{} -symbols {}", cuobjdump_path.c_str(), cubin_path.c_str())); DG_HOST_ASSERT(exit_code == 0); - std::istringstream iss(symbols); std::vector symbol_names; - for (std::string line; std::getline(iss, line); ) { + size_t line_begin = 0; + while (line_begin < symbols.size()) { + auto line_end = symbols.find('\n', line_begin); + if (line_end == std::string::npos) + line_end = symbols.size(); + const auto line = symbols.substr(line_begin, line_end - line_begin); if (line.find("STT_FUNC") == 0 and line.find("STO_ENTRY") != std::string::npos and std::none_of(illegal_names.begin(), illegal_names.end(), [&](const auto name) { return line.find(name) != std::string::npos; })) { const auto last_space = line.rfind(' '); symbol_names.push_back(line.substr(last_space + 1)); } + line_begin = line_end + (line_end < symbols.size()); } // Print symbols diff --git a/deep-gemm/csrc/jit_kernels/impls/runtime_utils.hpp b/deep-gemm/csrc/jit_kernels/impls/runtime_utils.hpp index 7aa87526..962b385d 100644 --- a/deep-gemm/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/deep-gemm/csrc/jit_kernels/impls/runtime_utils.hpp @@ -63,7 +63,9 @@ static std::string to_string(const at::ScalarType& dtype) { static std::string to_string(const float& v) { if (std::isfinite(v)) { - return fmt::format(R"({:a}f)", v); + char buffer[32]; + std::snprintf(buffer, sizeof(buffer), "%.9gf", v); + return buffer; } else if (std::isinf(v)) { return v > 0 ? "cute::numeric_limits::infinity()" : "-cute::numeric_limits::infinity()"; diff --git a/deep-gemm/csrc/utils/hash.hpp b/deep-gemm/csrc/utils/hash.hpp index 9efe6408..47035e96 100644 --- a/deep-gemm/csrc/utils/hash.hpp +++ b/deep-gemm/csrc/utils/hash.hpp @@ -25,11 +25,17 @@ static std::string get_hex_digest(const std::vector& data) { return z ^ (z >> 31); }; - std::ostringstream oss; - oss << std::hex << std::setfill('0') - << std::setw(16) << split_mix(state_0) - << std::setw(16) << split_mix(state_1); - return oss.str(); + static constexpr char kHex[] = "0123456789abcdef"; + std::string out(32, '0'); + const uint64_t states[] = {split_mix(state_0), split_mix(state_1)}; + for (size_t state_idx = 0; state_idx < 2; ++ state_idx) { + auto value = states[state_idx]; + for (int nibble = 15; nibble >= 0; -- nibble) { + out[state_idx * 16 + static_cast(nibble)] = kHex[value & 0x0f]; + value >>= 4; + } + } + return out; } static std::string get_hex_digest(const std::string& data) { diff --git a/deep-gemm/csrc/utils/system.hpp b/deep-gemm/csrc/utils/system.hpp index fda020be..a1c09ef4 100644 --- a/deep-gemm/csrc/utils/system.hpp +++ b/deep-gemm/csrc/utils/system.hpp @@ -88,13 +88,20 @@ static std::string get_uuid() { }()); static std::uniform_int_distribution dist; - std::stringstream ss; - ss << getpid() << "-" - << std::hex << std::setfill('0') - << std::setw(8) << dist(gen) << "-" - << std::setw(8) << dist(gen) << "-" - << std::setw(8) << dist(gen); - return ss.str(); + static constexpr char kHex[] = "0123456789abcdef"; + const auto append_hex_u32 = [](std::string& out, uint32_t value) { + for (int nibble = 7; nibble >= 0; -- nibble) { + out += kHex[(value >> (nibble * 4)) & 0x0f]; + } + }; + + std::string uuid = std::to_string(getpid()) + "-"; + append_hex_u32(uuid, dist(gen)); + uuid += "-"; + append_hex_u32(uuid, dist(gen)); + uuid += "-"; + append_hex_u32(uuid, dist(gen)); + return uuid; } static void safe_remove_all(const std::filesystem::path& path) { From 17595644ab323483348a0da2b69b5147410ccd5f Mon Sep 17 00:00:00 2001 From: adarshxs Date: Wed, 6 May 2026 00:45:49 +0530 Subject: [PATCH 6/7] upd --- deep-gemm/csrc/jit/compiler.hpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/deep-gemm/csrc/jit/compiler.hpp b/deep-gemm/csrc/jit/compiler.hpp index 0bc7fec6..9022f46f 100644 --- a/deep-gemm/csrc/jit/compiler.hpp +++ b/deep-gemm/csrc/jit/compiler.hpp @@ -40,6 +40,11 @@ class Compiler { std::string signature, flags; std::filesystem::path cache_dir_path; + static std::string get_cutlass_include_flags() { + const auto cutlass_include = get_env("DG_CUTLASS_INCLUDE"); + return cutlass_include.empty() ? "" : fmt::format("-I{} ", cutlass_include); + } + Compiler() { // Check `prepare_init` DG_HOST_ASSERT(not library_root_path.empty()); @@ -210,10 +215,10 @@ class NVCCCompiler final: public Compiler { // The override the compiler flags // Only NVCC >= 12.9 supports arch-specific family suffix const auto arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9); - flags = fmt::format("{} -I{} --gpu-architecture=sm_{} " + flags = fmt::format("{} -I{} {}--gpu-architecture=sm_{} " "--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi " "-O3 --expt-relaxed-constexpr --expt-extended-lambda", - flags, library_include_path.c_str(), arch); + flags, library_include_path.c_str(), get_cutlass_include_flags(), arch); } void compile(const std::string &code, const std::filesystem::path& dir_path, @@ -272,6 +277,7 @@ class NVRTCCompiler final: public Compiler { // Build include directories list std::string include_dirs; include_dirs += fmt::format("-I{} ", library_include_path.string()); + include_dirs += get_cutlass_include_flags(); include_dirs += fmt::format("-I{} ", (cuda_home / "include").string()); // Add PCH support for version 12.8 and above From 7c39dd9fd8cc2a13e1299794f33646ee108a9689 Mon Sep 17 00:00:00 2001 From: adarshxs Date: Wed, 6 May 2026 00:57:03 +0530 Subject: [PATCH 7/7] upd --- deep-gemm/torch-ext/deep_gemm/__init__.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/deep-gemm/torch-ext/deep_gemm/__init__.py b/deep-gemm/torch-ext/deep_gemm/__init__.py index 21fba2a6..8c4fe1c5 100644 --- a/deep-gemm/torch-ext/deep_gemm/__init__.py +++ b/deep-gemm/torch-ext/deep_gemm/__init__.py @@ -1,5 +1,6 @@ import os import subprocess +import sysconfig import torch # Avoid holding a CUDA tensor in DeepGEMM's process-lifetime runtime singleton. @@ -794,6 +795,14 @@ def _find_cuda_home() -> str: _include, # legacy layout: include/cutlass os.path.join(_include, "third-party", "cutlass", "include"), # submodule layout ] + for _site_packages in { + sysconfig.get_paths().get("purelib"), + sysconfig.get_paths().get("platlib"), + }: + if _site_packages: + _cutlass_include_candidates.append( + os.path.join(_site_packages, "cutlass_library", "source", "include") + ) for _cutlass_include in _cutlass_include_candidates: if os.path.isdir(os.path.join(_cutlass_include, "cutlass")): os.environ["DG_CUTLASS_INCLUDE"] = _cutlass_include