diff --git a/examples/matmul/this-sm100/matmul_readme.md b/examples/matmul/this-sm100/matmul_readme.md new file mode 100644 index 0000000..ad5e9dd --- /dev/null +++ b/examples/matmul/this-sm100/matmul_readme.md @@ -0,0 +1,52 @@ +The code is adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh and https://github.com/deepseek-ai/DeepGEMM/tree/main/deep_gemm/include/deep_gemm/common/. + +Based on the reference implementation, certain simplifications and modifications were made, resulting in a final codebase of approximately 400 lines (including both host and device code, with header dependencies limited to CUDA, CUTLASS, and CUTE—only leveraging their instruction-level abstractions). Key modifications include: + +- Removal of the additional unrolled tiling with `tile-size = num-stages` in the k-loop. The original design made the stage index a compile-time constant and flipped phases only at the end of the inner loop, reducing state update overhead. However, it required producer-consumer synchronization for residual k-blocks when the k-loop length was not divisible by `num-stages`, leading to pipeline bubbles and increased code complexity. The revised version maintains stage and phase as global states across both mn-loop and k-loop, updating them at the end of each k-iteration. +- Removal of `tcgen05.fence::after_thread_sync` and `tcgen05.fence::before_thread_sync` instructions after barrier-wait or before barrier-arrive for workers dependent on TMEM (UMMA warp, Epilogue warpgroup). Testing showed these had no observable impact on correctness, and they are not used in CUTLASS implementations. +- Unified interface for managing global pipeline states across three pipelines: SMEM-AB, TMEM-C, and SMEM-C. +- Restructured code with added comments. The overall structure now includes: + - TMA descriptor prefetching + - Shared memory buffer and barrier allocation + - Tensor memory buffer allocation + - Barrier initialization + - TMA-Load worker (CTA0/1-Warp0-Lane0) + - mn-loop (persistent) + - k-loop + - SMEM-AB-pipeline: wait for consumer-empty + - SMEM-AB-pipeline: producer-execute + - `cute::SM100_TMA_2SM_LOAD_2D` + - SMEM-AB-pipeline: commit producer-ready + - SMEM-AB-pipeline: advance to next producer stage + - UMMA worker (CTA0-Warp1-Lane0) + - mn-loop (persistent) + - TMEM-C-pipeline: wait for consumer-empty + - TMEM-C-pipeline: producer-execute + - k-loop + - SMEM-AB-pipeline: wait for producer-ready + - SMEM-AB-pipeline: consumer-execute + - `cute::SM100_MMA_F16BF16_2x1SM_SS` + - SMEM-AB-pipeline: commit consumer-empty + - SMEM-AB-pipeline: advance to next consumer stage + - TMEM-C-pipeline: commit producer-ready + - TMEM-C-pipeline: advance to next producer stage + - Epilogue worker (CTA0/1-Warp4/5/6/7) + - mn-loop (persistent) + - TMEM-C-pipeline: wait for producer-ready + - TMEM-C-pipeline: consumer-execute + - mn-inner-loop + - SMEM-C-pipeline: wait for consumer-empty + - SMEM-C-pipeline: producer-execute + - `cute::SM100_TMEM_LOAD_32dp32b8x` + - `st.shared.v4.u32` + - SMEM-C-pipeline: commit producer-ready + - SMEM-C-pipeline: wait for producer-ready + - SMEM-C-pipeline: consumer-execute + - `cute::SM90_TMA_STORE_2D` + - SMEM-C-pipeline: commit consumer-empty + - SMEM-C-pipeline: advance to next stage + - TMEM-C-pipeline: commit consumer-empty + - TMEM-C-pipeline: advance to next consumer stage +- Only the `MultiCast=2` case is retained, meaning UMMA is performed within a CTA-Pair. +- Only the `GemmType::Normal` case is supported; logic for GroupGEMM (primarily affecting scheduler code) has been removed. +- The tiling size is preset such that each CTA handles a computation of MxNxK = 256x256x64, and each CTA-Pair handles 512x256x64.