Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions deep-gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ cmake_minimum_required(VERSION 3.10)
project(deep_gemm LANGUAGES CXX CUDA)
set(CMAKE_VERBOSE_MAKEFILE ON)

set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi -Wno-deprecated-declarations")
set(CUDA_SEPARABLE_COMPILATION ON)
list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG")
list(APPEND CUDA_NVCC_FLAGS "-O3")
Expand All @@ -18,11 +17,11 @@ find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED)
find_package(Torch REQUIRED)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CUDA_STANDARD 20)

include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include/cccl ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)

# The main Python API entrance
Expand Down
138 changes: 83 additions & 55 deletions deep-gemm/README.md
Original file line number Diff line number Diff line change
@@ -1,45 +1,27 @@
# DeepGEMM

DeepGEMM is a library designed for clean and efficient General Matrix Multiplications (GEMMs). It supports FP8 and BF16 (working in progress) for both normal and Mix-of-Experts (MoE) grouped scenarios. Written in CUDA, the library has no kernel compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module.
DeepGEMM is a unified, high-performance tensor core kernel library that brings together the key computation primitives of modern large language models — GEMMs (FP8, FP4, BF16), fused MoE with overlapped communication (Mega MoE), MQA scoring for the lightning indexer, HyperConnection (HC), and more — into a single, cohesive CUDA codebase. All kernels are compiled at runtime via a lightweight Just-In-Time (JIT) module, requiring no CUDA compilation during installation.

DeepGEMM leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only a limited number of core kernel functions. This makes it a clean and accessible resource for learning NVIDIA GPU kernel optimization techniques.
DeepGEMM leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), but avoids heavy reliance on their templates or algebras. The library is designed for simplicity, with only a limited number of core kernel functions, making it a clean and accessible resource for learning NVIDIA GPU kernel optimization techniques.

Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes.

## News

- 2026.04.16: Mega MoE, FP8xFP4 GEMM, FP4 Indexer, PDL, faster JIT compilation and more.
- Please see [#304](https://github.com/deepseek-ai/DeepGEMM/pull/304) for more details.
- For Mega MoE benchmarks, refer to [#316](https://github.com/deepseek-ai/DeepGEMM/pull/316).
- 2025.09.28: DeepGEMM now supports scoring kernels (weighted ReLU MQA logits) for the lightning indexer for DeepSeek v3.2.
- Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details.
- Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details.
- 2025.07.20: DeepGEMM now supports both SM90/SM100, and has a full refactor with a low-CPU-overhead JIT CPP module.
- NVRTC and post-compilation SASS optimization are all disabled.
- NVRTC will be supported later.
- As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported.
- Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details.
- NVRTC and post-compilation SASS optimization are all disabled.
- NVRTC will be supported later.
- As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported.
- Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details.
- 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details.
- 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases).
- 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details.

## Roadmap

- [x] More correctness tests for grouped-contiguous layout
- [x] Shared memory swizzling for output
- [x] MoE scheduler with TMA multicast compatibility
- [x] Fix TMA multicast compatibility for indivisible shapes
- [x] Skip useless computation on M
- [x] NVRTC as a faster compiler
- [x] Sanitizer for testing
- [x] Weight gradient kernels for dense models
- [x] Weight gradient kernels for MoE models
- [ ] Better `get_best_configs` modeling
- [ ] CUDA PDL support
- [ ] Larger TMA multicast size for some shapes
- [x] MMA template refactor with CUTLASS
- [x] Remove shape limitations on N and K
- [x] BF16 kernels
- [ ] Split/stream-k optimizations
- [ ] Ampere kernels
- [ ] Polish docs

## Quick start

### Requirements
Expand All @@ -48,9 +30,9 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
- Python 3.8 or higher
- Compilers with C++20 support
- CUDA Toolkit:
- CUDA 12.3 or higher for SM90
- **We highly recommend 12.9 or higher for the best performance**
- CUDA 12.9 or higher for SM100
- CUDA 12.3 or higher for SM90
- **We highly recommend 12.9 or higher for the best performance**
- CUDA 12.9 or higher for SM100
- PyTorch 2.1 or higher
- CUTLASS 4.0 or higher (could be cloned by Git submodule)
- `{fmt}` library (could be cloned by Git submodule)
Expand All @@ -65,11 +47,6 @@ cd DeepGEMM
# Link some essential includes and build the CPP JIT module
cat develop.sh
./develop.sh

# Test all GEMM implements
python tests/test_layout.py
python tests/test_attention.py
python tests/test_core.py
```

### Installation
Expand Down Expand Up @@ -134,35 +111,78 @@ out_ij = out_ij.sum() # Scalar

For more details and the paged version `fp8_paged_mqa_logits`, please refer to `tests/test_attention.py`.

#### Mega MoE

Mega MoE fuses and overlaps EP dispatch, linear 1 (FP8xFP4), SwiGLU, linear 2 (FP8xFP4), and EP combine into a single mega-kernel, overlapping NVLink communication and tensor core computation. It requires multi-process launch with symmetric memory. Usage:

```python
# Allocate symmetric memory buffer
# NOTES: requires PyTorch >= 2.9
buffer = deep_gemm.get_symm_buffer_for_mega_moe(
group, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden
)

# Transform weights (FP4 with UE8M0 SF) into the required layout
transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights)

# Copy inputs into the buffer before each call
# You may fuse these into previous kernels
buffer.x[:num_tokens].copy_(x_fp8)
buffer.x_sf[:num_tokens].copy_(x_sf)
buffer.topk_idx[:num_tokens].copy_(topk_idx)
buffer.topk_weights[:num_tokens].copy_(topk_weights)

# Run the fused mega MoE kernel
y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
deep_gemm.fp8_fp4_mega_moe(y, transformed_l1, transformed_l2, buffer)
```

For the full example with multi-process setup and benchmarking, please refer to `tests/test_mega_moe.py`.

#### Utilities

The library provides some utility functions besides the above kernels:

- `deep_gemm.set_num_sms`: set the maximum SM count to use
- `deep_gemm.get_num_sms`: get the current SM maximum count (return the device SM count if not set)
- `deep_gemm.set_tc_util`: set an approximated tensor core utilization ratio
- `deep_gemm.get_tc_util`: get the current tensor core utilization ratio
- `deep_gemm.transform_sf_into_required_layout`: transform scaling factors into required layout
- `deep_gemm.set_num_sms` / `get_num_sms`: set/get the maximum SM count to use
- `deep_gemm.set_tc_util` / `get_tc_util`: set/get an approximated tensor core utilization ratio
- `deep_gemm.set_pdl` / `get_pdl`: enable/disable Programmatic Dependent Launch (PDL)
- `deep_gemm.set_mk_alignment_for_contiguous_layout` / `get_mk_alignment_for_contiguous_layout`: set/get the group-level M/K alignment for contiguous layout
- `deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout`: get the theoretical minimum M/K alignment
- `deep_gemm.set_ignore_compile_dims`: configure dimensions to ignore during JIT compilation
- `deep_gemm.set_block_size_multiple_of`: constrain block sizes to be multiples of a given value
- `deep_gemm.transform_sf_into_required_layout`: transform scaling factors into the required layout
- `deep_gemm.get_tma_aligned_size`: get the required TMA alignment size
- `deep_gemm.get_mk_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout
- `deep_gemm.get_mn_major_tma_aligned_tensor`: get a MN-major TMA-aligned tensor
- `deep_gemm.get_mn_major_tma_aligned_packed_ue8m0_tensor`: get a MN-major TMA-aligned tensor (with packing FP32 into UE8M0)
- `deep_gemm.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor`: K-grouped GEMM packing kernel

The library also provides some environment variables, which may be useful:

- General
- `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default
- JIT cache related
- `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default
- NVCC/NVRTC selections
- `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default
- `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default
- Compiler options
- `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default
- `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default
- Heuristic selection
- `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default
- `DG_JIT_DEBUG`: `0` or `1`, print JIT debugging information, `0` by default
- `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default
- JIT cache
- `DG_JIT_CACHE_DIR`: string, cache directory for compiled kernels, `$HOME/.deep_gemm` by default
- Compiler selection
- `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC (faster compilation, may have lower performance for some cases), `0` by default
- `DG_JIT_NVCC_COMPILER`: string, NVCC compiler path; defaults to `torch.utils.cpp_extension.CUDA_HOME`
- `DG_JIT_CPP_STANDARD`: integer, C++ standard version, `20` by default
- Compiler output
- `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print compilation commands, `0` by default
- `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS output, `0` by default
- `DG_JIT_PTXAS_CHECK`: `0` or `1`, assert no local memory usage in compiled kernels, `0` by default
- `DG_JIT_PRINT_LOAD_TIME`: `0` or `1`, print kernel load time, `0` by default
- Debug and profiling
- `DG_JIT_WITH_LINEINFO`: `0` or `1`, embed source line info for profiling tools, `0` by default
- `DG_JIT_DUMP_ASM`: `0` or `1`, dump both PTX and SASS, `0` by default
- `DG_JIT_DUMP_PTX`: `0` or `1`, dump PTX output, `0` by default
- `DG_JIT_DUMP_SASS`: `0` or `1`, dump SASS output, `0` by default
- `DG_COMM_KERNEL_DEBUG`: `0` or `1`, zero symmetric buffer before each Mega MoE call for debugging, `0` by default
- `DG_USE_NVIDIA_TOOLS`: `0` or `1`, skip internal profiling when running under external NVIDIA tools, `0` by default
- Build options
- `DG_SKIP_CUDA_BUILD`: `0` or `1`, skip CUDA extension build during installation, `0` by default
- `DG_FORCE_BUILD`: `0` or `1`, force local build instead of downloading pre-built wheels, `0` by default
- `DG_JIT_USE_RUNTIME_API`: `0` or `1`, use CUDA Runtime API for kernel loading (requires CUDA runtime >= 12.8), `0` by default

For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation.

Expand All @@ -174,6 +194,14 @@ DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project

This code repository is released under [the MIT License](LICENSE).

--
## Citation

vendored at commit 477618cd51baffca09c4b0b87e97c03fe827ef03
```bibtex
@misc{deepgemm2025,
title={DeepGEMM: clean and efficient BLAS kernel library on GPU},
author={Chenggang Zhao and Zhean Xu and Liang Zhao and Jiashi Li and Chenhao Xu and Anyi Xu and Shengyu Liu and Kexing Zhou and Kuai Yu},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/DeepGEMM}},
}
```
79 changes: 54 additions & 25 deletions deep-gemm/build.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
name = "deep-gemm"
license = "MIT"
backends = ["cuda"]
version = 1
version = 2

[general.cuda]
minver = "12.3"
minver = "12.8"

[torch]
src = [
Expand All @@ -16,46 +16,49 @@ pyext = ["py", "cuh", "hpp", "h"]

[kernel.deep_gemm]
backend = "cuda"
cuda-minver = "12.3"
cuda-minver = "12.8"
cuda-capabilities = ["9.0a"]
depends = ["torch", "cutlass_4_0"]
include = [
"csrc",
"deep_gemm/include",
]
cuda-flags = [
"-std=c++17",
"-std=c++20",
"-O3",
"-DNDEBUG",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"-Wno-deprecated-declarations",
]
src = [
# Compiled source
"csrc/impl.cu",
# API headers
"csrc/apis/attention.hpp",
"csrc/apis/einsum.hpp",
"csrc/apis/gemm.hpp",
"csrc/apis/hyperconnection.hpp",
"csrc/apis/layout.hpp",
"csrc/apis/mega.hpp",
"csrc/apis/runtime.hpp",
# JIT infrastructure headers
"csrc/jit/cache.hpp",
"csrc/jit/compiler.hpp",
"csrc/jit/device_runtime.hpp",
"csrc/jit/handle.hpp",
"csrc/jit/include_parser.hpp",
"csrc/jit/kernel_runtime.hpp",
# JIT kernel heuristics
"csrc/jit_kernels/heuristics/common.hpp",
"csrc/jit_kernels/heuristics/config.hpp",
"csrc/jit_kernels/heuristics/mega_moe.hpp",
"csrc/jit_kernels/heuristics/runtime.hpp",
"csrc/jit_kernels/heuristics/sm100.hpp",
"csrc/jit_kernels/heuristics/sm90.hpp",
# JIT kernel implementations
"csrc/jit_kernels/heuristics/utils.hpp",
"csrc/jit_kernels/impls/epilogue.hpp",
"csrc/jit_kernels/impls/runtime_utils.hpp",
"csrc/jit_kernels/impls/sm100_bf16_gemm.hpp",
"csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp",
"csrc/jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp",
"csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp",
"csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp",
"csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp",
"csrc/jit_kernels/impls/sm90_bf16_gemm.hpp",
Expand All @@ -65,10 +68,11 @@ src = [
"csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp",
"csrc/jit_kernels/impls/smxx_clean_logits.hpp",
"csrc/jit_kernels/impls/smxx_cublaslt.hpp",
"csrc/jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp",
"csrc/jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp",
"csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp",
"csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp",
"csrc/jit_kernels/impls/smxx_layout.hpp",
# Utility headers
"csrc/utils/compatibility.hpp",
"csrc/utils/exception.hpp",
"csrc/utils/format.hpp",
Expand All @@ -77,31 +81,56 @@ src = [
"csrc/utils/lazy_init.hpp",
"csrc/utils/math.hpp",
"csrc/utils/system.hpp",
# Runtime JIT headers (deep_gemm/include)
"csrc/utils/torch_compat.hpp",
"deep_gemm/include/deep_gemm/comm/barrier.cuh",
"deep_gemm/include/deep_gemm/common/compile.cuh",
"deep_gemm/include/deep_gemm/common/cute_tie.cuh",
"deep_gemm/include/deep_gemm/common/epilogue_utils.cuh",
"deep_gemm/include/deep_gemm/common/types.hpp",
"deep_gemm/include/deep_gemm/common/sm90_utils.cuh",
"deep_gemm/include/deep_gemm/common/exception.cuh",
"deep_gemm/include/deep_gemm/common/math.cuh",
"deep_gemm/include/deep_gemm/common/reduction.cuh",
"deep_gemm/include/deep_gemm/common/utils.cuh",
"deep_gemm/include/deep_gemm/common/tma_utils.cuh",
"deep_gemm/include/deep_gemm/common/sm100_utils.cuh",
"deep_gemm/include/deep_gemm/common/scheduler.cuh",
"deep_gemm/include/deep_gemm/impls/smxx_layout.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh",
"deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh",
"deep_gemm/include/deep_gemm/common/sm100_utils.cuh",
"deep_gemm/include/deep_gemm/common/sm90_utils.cuh",
"deep_gemm/include/deep_gemm/common/tma_copy.cuh",
"deep_gemm/include/deep_gemm/common/tma_utils.cuh",
"deep_gemm/include/deep_gemm/common/types.cuh",
"deep_gemm/include/deep_gemm/common/types.hpp",
"deep_gemm/include/deep_gemm/common/utils.cuh",
"deep_gemm/include/deep_gemm/epilogue/sm100_store_cd.cuh",
"deep_gemm/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh",
"deep_gemm/include/deep_gemm/epilogue/transform.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh",
"deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh",
"deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh",
"deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh",
"deep_gemm/include/deep_gemm/impls/smxx_layout.cuh",
"deep_gemm/include/deep_gemm/layout/mega_moe.cuh",
"deep_gemm/include/deep_gemm/layout/sym_buffer.cuh",
"deep_gemm/include/deep_gemm/mma/sm100.cuh",
"deep_gemm/include/deep_gemm/mma/sm90.cuh",
"deep_gemm/include/deep_gemm/ptx/ld_st.cuh",
"deep_gemm/include/deep_gemm/ptx/tcgen05.cuh",
"deep_gemm/include/deep_gemm/ptx/tma.cuh",
"deep_gemm/include/deep_gemm/ptx/utils.cuh",
"deep_gemm/include/deep_gemm/ptx/wgmma.cuh",
"deep_gemm/include/deep_gemm/scheduler/gemm.cuh",
"deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh",
"deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh",
]

[general.hub]
Expand Down
Loading
Loading