feat: support fp8 kv cache and triton decoding kernels#29
Merged
Conversation
- Add KvCacheDType enum supporting bf16/fp16/fp32/fp8_e4m3/fp8_e5m2 - Add parse_kv_cache_dtype() to convert string to dtype - Add get_fp8_dtype_for_storage() to get FP8 dtype from vLLM platform - Add compute_fp8_scale() to compute quantization scale using absmax - Support FP8 storage as uint8 + view(fp8_dtype) pattern - Add helper functions for FP8 min/max bounds
…che kernels Core changes: - Add kv_cache_dtype and k_scale/v_scale parameters to store/load wrappers - Refactor store kernels to support FP8 quantization with per-head scale: * store_kvcache_kernel_causal_lm: add FP8 quantization logic * store_kvcache_kernel_diffusion_lm: add FP8 quantization logic * store_kvcache_kernel_diffusion_lm_distinct: add FP8 quantization logic - Refactor load_kvcache_kernel_kv to support FP8 dequantization: * Load FP8 values from cache (uint8 storage + view to FP8 dtype) * Dequantize using per-head scale and cast to output dtype * Support BF16/FP16/FP32 cache without quantization overhead - Update store_kvcache_unified_layout() to handle FP8 uint8->fp8 view - Update store_kvcache_distinct_layout() to handle FP8 uint8->fp8 view - Update load_kvcache() to support configurable output dtype (defaults to k_new.dtype) - Use constexpr int constants instead of enum in Triton kernels (Triton limitation) Technical details: - FP8 uses absmax-based quantization: value_fp8 = clamp(value_fp32 / scale, fp8_range) - FP8 dequantization: value_out = (value_fp8.to(float32) * scale).to(output_dtype) - Scale can be scalar or per-head vector [num_kv_heads] - Maintains backward compatibility: defaults to BF16 when kv_cache_dtype not specified
- Update import from attention_v4 to ops module - Fix function name from store_kvcache_unified to store_kvcache_unified_layout
- Add test_kv_cache_fp8_unified_roundtrip.py for unified layout FP8 store/load roundtrip - Add test_kv_cache_fp8_distinct_roundtrip.py for distinct layout FP8 store test - Test FP8 quantization/dequantization with per-head scales - Verify roundtrip accuracy with atol=1e-1, rtol=1e-1 tolerance for FP8 precision
- Reduce num_warps from 4 to 1 to reduce shared memory usage - Reduce num_unroll_cache from 4 to 2 to reduce shared memory usage - Add comments explaining why BLOCK_M/BLOCK_N cannot be reduced - Minor code formatting fix in kv_cache_kernels.py
… KV cache implementation
- Add kv_cache_dtype field to Config class (default: bf16) - Add _get_kv_cache_storage_info() helper function to determine storage dtype and itemsize - Update allocate_kv_cache() in ModelRunnerForCausalLM to use kv_cache_dtype - Update allocate_kv_cache() in ModelRunnerForDiffusionLM to use kv_cache_dtype - Support FP8 KV cache allocation using uint8 storage dtype
- Add kv_cache_dtype parameter passing in attention layers (v4 and v5) - Implement running max strategy for FP8 scale computation - Pass scale parameters to store/load functions in forward method - Update ContextForCausalLM to support kv_cache_dtype - Update ModelRunnerForCausalLM to pass kv_cache_dtype to context Changes: - attention_v4.py: Add _get_kv_cache_dtype(), _update_and_compute_fp8_scales(), _get_fp8_scales_from_max() methods; update forward() to pass scales - attention_v5.py: Same changes as attention_v4.py - context.py: Add kv_cache_dtype field to ContextForCausalLM - model_runner.py: Pass kv_cache_dtype to set_context_causal_lm() calls All tests passed including unit tests and FP8 roundtrip tests.
- Fix store_kvcache calls to pass context as keyword argument - Resolves 'got multiple values for argument' error when using FP8 KV cache - Verified with full pipeline test using FP8 KV cache Changes: - attention_v4.py: Pass context as keyword argument in store_kvcache call - attention_v5.py: Same fix as attention_v4.py - test_fp8_kv_cache_pipeline.py: Add integration test for FP8 KV cache in full pipeline Test results: - Successfully generated text using FP8 KV cache (fp8_e4m3) - All 3 test prompts generated correctly - No errors in FP8 quantization/dequantization path
- Add test_kv_cache_memory_usage.py to verify KV cache memory allocation - Add test_kv_cache_speed_comparison.py to compare FP8 vs BF16 performance - Verified FP8 reduces per-block memory by 50% and allows 2x blocks allocation - Performance tests show FP8 is comparable to BF16 in speed Test results: - FP8: 428 blocks × 7 MB/block = 2996 MB total - BF16: 214 blocks × 14 MB/block = 2996 MB total - FP8 throughput: 63.15 tok/s vs BF16: 56.27 tok/s (12% faster)
feat: kv cache fp8 support
…s; remove unused checker.py
…rom global memory fetching into fragment fetching
…ilable, checking errors of cuda graph capturing fixed.
… and WARP_SPECIALIZATION
- Fix quantize function to support 2D input tensors - Implement FP8 unified store kernel and helper - Implement FP8 load with Python-level dequantization - Support both static and varlen decode modes - Remove debug code - Update documentation Note: temp/ directory excluded from commit
- Add FP8 distinct store kernel (Triton) - Add FP8 distinct store helper with Python-level quantization - Update store_kvcache_distinct_layout to support FP8 strategy - Extend _load_kvcache_fp8 to support distinct layout - Fix _load_kvcache_bf16 to handle distinct layout stride calculation - Implement distinct layout decode path in attn_impl.py - Add load_kvcache export to diffulex_kernel/__init__.py - Add test script for distinct layout - Update .gitignore to exclude temp/ directory
…zation strategy support - Rename dllm_flash_attn_prefill to _dllm_flash_attn_prefill_bf16 - Rename dllm_flash_attn_decode to _dllm_flash_attn_decode_bf16 - Add new dllm_flash_attn_prefill wrapper that dynamically selects kernel based on quantization strategy - Add new dllm_flash_attn_decode wrapper that dynamically selects kernel based on quantization strategy - Currently FP8 strategy uses BF16 kernel (FP8 kernels to be implemented later) - Maintain backward compatibility with same function signatures - Tested: BF16 path works correctly in end-to-end tests
…and pull requests
Key optimizations: 1. Replace element-wise FP8->FP32->BF16 dequantization loops with T.copy for vectorized cast 2. Fuse K_Scale into score computation (avoid element-wise multiplication) 3. Fuse V_Scale into cache branch output (only affects cache path, not V_new) Performance improvement: - FP8 decode throughput: ~11.9 tok/s -> ~24.4 tok/s (2x improvement) - FP8/BF16 decode ratio: 0.759x (was ~0.38x) Technical details: - Removed K_Cache_shared_fp8/V_Cache_shared_fp8 buffers and element-wise conversion loops - Use T.copy(K_Cache[..], K_Cache_shared_bf16) for direct FP8->BF16 cast - Apply K_Scale[kv_head_idx] to acc_score_kvcache after GEMM (before softmax) - Apply V_Scale[kv_head_idx] to acc_score_kvcache before V_Cache GEMM (only cache branch) - Maintains numerical equivalence with previous implementation
主要变更: 1. 重构量化模块架构: - 新增 QuantizationConfig 和 registry 系统 - 支持 KV cache 和 Attention-Q 的量化策略 - 实现策略能力接口,移除硬编码的 isinstance 检查 - 添加 AttnQQuantizationStrategy 支持(架构层,kernel 待实现) 2. 重命名 FP8 内核: - dllm_flash_attn_decode_kernel_fp8 -> dllm_flash_attn_decode_kernel_bf16_q_fp8_kv - 更准确地反映内核的实际功能(BF16 Q + FP8 KV) 3. 简化内核实现: - 移除 USE_KV_SHARED 环境变量开关 - 移除 fragment 路径,只保留 shared memory 路径 - 简化配置管理(从字典改为单个配置对象) 4. 测试和验证: - 添加端到端测试验证 BF16 和 BF16+FP8 KV 路径 - 所有测试通过,文本生成功能正常 向后兼容:保持现有 API 不变,现有代码无需修改
合并 origin/main 的更新: - 更新 README.md 的设备列表 - 更新 .gitignore,添加 cuda_cache/ - 更新 GitHub workflows 权限配置 保持 README.md 为 main 分支的原始版本,不包含量化相关文档。
Feat/kv cache fp8 support
- Add LinearQuantizationStrategy interface supporting weight+activation quantization - Support layer-type-specific strategies (attn/mlp/other) - Add registry system for linear quantization strategies - Add Config fields: linear_attn_weight_dtype, linear_mlp_weight_dtype, linear_attn_act_dtype, linear_mlp_act_dtype - Integrate factory to inject strategies into QuantizationContext - Add dynamic dispatch in Linear.forward() based on quant_kind - Tag Linear layers in models (dream/llada/sdar/fast_dllm_v2) with quant_kind - Add placeholder strategies (stub) that raise NotImplementedError for non-bf16 dtypes - Add unit tests for registry/factory/dispatch behavior - Default bf16 behavior unchanged (fully backward compatible) All non-bf16 paths currently raise NotImplementedError with clear error messages, providing stable interface for future kernel/packed weight implementations.
…ar backends and comprehensive metrics collection
[Feat] Enhance Decoding Strategies for Easier Development and More Efficient Inference
- Delete AttnQ strategy implementations (attn_q_bf16.py, attn_q_fp8_stub.py) - Remove AttnQQuantizationStrategy base class from strategy.py - Remove attn_q related methods from context.py (get_attn_q_strategy, set_attn_q_strategy) - Remove attn_q registry functions from registry.py (register_attn_q_strategy, create_attn_q_strategy, registered_attn_q_dtypes) - Remove attn_q exports from __init__.py - Remove attn_q_dtype from config.py (ActivationQuantConfig) - Remove attn_q strategy creation from factory.py - Update kernel code (dllm_flash_attn.py) to use fixed BF16 for Q (removed get_attn_q_strategy calls) - Remove q_scale field from _AttnMetaDataLike protocol
…port # Conflicts: # diffulex/__init__.py # diffulex/engine/model_runner.py # diffulex_kernel/__init__.py # diffulex_kernel/python/dllm_flash_attn_kernels.py # test/python/test_linear_fp8.py # test/python/test_linear_quantization_module.py # test/python/test_quantization_e2e.py # test/python/test_quantization_module.py # test/python/test_quantization_paths.py # test/test_gptq_awq_strategies.py
- 修复 update_scales 方法中 scale 和 absmax 转换的逻辑错误 - 现在正确地将 scale 转换为 absmax 后再进行比较和更新 - 符合 vLLM 的 RunningMax 实现方式 - 添加了详细的注释说明更新流程 - 更新了量化测试脚本和配置文件
- 从 git 跟踪中移除 .cursor 目录 - 将 .cursor/ 添加到 .gitignore 以避免将来误提交
- Optimize W8A16 small-M decode: pad M<16 to 16 (instead of 64) and use block_M=16/32/64 - Add w8a16_gemm_bias kernel with fused bias epilogue (opt-in via DIFFULEX_W8A16_FUSE_BIAS) - Add runtime profiling hooks for W8A16 (DIFFULEX_LINEAR_PROFILE) to track M distribution and fallbacks - Implement FP8 KV varlen fused dequantization kernel (Triton) for unified layout - Add benchmark configs for W4A8 and W8A8 quantization strategies - Add profiling hooks for KV cache load timing (DIFFULEX_PROFILE_KVCACHE)
主要新增内容:
1. **Marlin/AllSpark INT8 W8A16 量化策略集成**:
- 新增 linear_marlin_int8_w8a16.py:实现基于 vLLM AllSpark kernel 的 W8A16 量化策略
- 新增 diffulex_kernel/csrc/marlin/:vendored vLLM 的 AllSpark CUDA kernels
* allspark_qgemm_w8a16.cu: W8A16 fused GEMM kernel
* allspark_repack.cu: N32K16 权重重排 kernel
* allspark_utils.cuh: 工具函数和数据结构
* torch_bindings_marlin.cpp: PyTorch C++ 绑定
- 新增 diffulex_kernel/python/marlin_ops.py:Python 接口用于 JIT 编译和加载 Marlin/AllSpark kernels
2. **量化策略注册更新**:
- 在 registry.py 中添加 'marlin' 别名支持(映射到 marlin_int8)
- 在 strategies/__init__.py 中导入新的策略
3. **性能改进**:
- Marlin W8A16 策略显著提升了 Prefill 吞吐量(从 4518.92 tok/s 提升到 9520.91 tok/s,约 2.1 倍)
- Decode 吞吐量接近 BF16 基线(23.16 tok/s vs 23.36 tok/s)
- 支持与 FP8 KV cache 组合使用
4. **其他改进**:
- 优化了多个量化策略的实现
- 改进了 KV cache 管理
- 增强了 profiler 功能
- 新增了多个 benchmark 配置文件
Linear Quantization Support
主要新增内容:
1. **Marlin/AllSpark INT8 W8A16 量化策略集成**:
- 新增 linear_marlin_int8_w8a16.py:实现基于 vLLM AllSpark kernel 的 W8A16 量化策略
- 新增 diffulex_kernel/csrc/marlin/:vendored vLLM 的 AllSpark CUDA kernels
* allspark_qgemm_w8a16.cu: W8A16 fused GEMM kernel
* allspark_repack.cu: N32K16 权重重排 kernel
* allspark_utils.cuh: 工具函数和数据结构
* torch_bindings_marlin.cpp: PyTorch C++ 绑定
- 新增 diffulex_kernel/python/marlin_ops.py:Python 接口用于 JIT 编译和加载 Marlin/AllSpark kernels
2. **量化策略注册更新**:
- 在 registry.py 中添加 'marlin' 别名支持(映射到 marlin_int8)
- 在 strategies/__init__.py 中导入新的策略
3. **性能改进**:
- Marlin W8A16 策略显著提升了 Prefill 吞吐量(从 4518.92 tok/s 提升到 9520.91 tok/s,约 2.1 倍)
- Decode 吞吐量接近 BF16 基线(23.16 tok/s vs 23.36 tok/s)
- 支持与 FP8 KV cache 组合使用
4. **其他改进**:
- 优化了多个量化策略的实现
- 改进了 KV cache 管理
- 增强了 profiler 功能
- 新增了多个 benchmark 配置文件
feat: integrate Marlin/AllSpark INT8 W8A16 quantization strategy
主要变更: - 添加 GPTQ Marlin (W4A16) 和 AWQ Marlin (W4A16) 量化策略 - 修复 loader.py 以正确加载 gptq_marlin 格式权重(支持 Marlin 特有的 repacked qweight 和 permuted scales) - 修改 quantize_model.py 支持导出 gptq_marlin 格式(对称量化 + Marlin repack/permute) - 更新 linear.py: - 添加 _offline_quant_bits 缓冲区存储量化位数 - 添加 GPTQ runtime shuffle 支持(gptq_shuffle) - 添加 GPTQ/AWQ Marlin 的 lazy repack 支持(_maybe_prepare_offline_gptq_marlin/_awq_marlin) - 统一使用 vLLM 格式(int32 packed, fp16 scales) - 简化各策略文件,移除重复代码 - 移除旧的 AllSpark Marlin 实现文件 - 添加多个 benchmark 配置文件(GPTQ/AWQ Marlin 各 bit 版本)
benchmark_results 是本地生成的评测产物,不应进入版本库。 本提交将其作为正常删除移出,并依赖 .gitignore 中的 benchmark_results/ 规则避免后续再次提交。
- 添加 quant-method=auto 支持:使用 auto-gptq / awq 进行真正的校准量化 - 添加校准数据参数:--calib-text-file, --calib-num-samples, --calib-seq-len 等 - 实现 _export_autogptq_to_vllm_weights:从 auto-gptq 量化模型中导出 vLLM 格式权重 - 实现 _export_awq_to_vllm_weights:从 awq 量化模型中导出 vLLM 格式权重 - 保留 quant-method=simple 旧实现作为后向兼容 - 修复 loader.py 中 gptq_marlin scales 的 shape 推理和 TP sharding 逻辑 - 修复 linear_gptq_marlin_w4a16.py 移除不必要的 bf16->fp16 转换
主要重构内容: 1. **diffulex/layer/linear.py** - 大幅简化量化逻辑(-197行): - 新增 `_forward_base()`: 统一的前向分发器,替换子类中重复的量化分支逻辑 - 新增 `_build_offline_forward_kwargs()`: 统一构建离线量化(GPTQ/AWQ)前向参数 - 新增 `_get_linear_strategy()`, `_offline_meta()`, `_infer_gptq_weight_bits()` 等辅助方法 - 修复 `LoRAMixin.merge_lora` 中 base weight 为 None 的边界情况 - 移除未使用的导入(marlin_zero_points, unpack_cols, marlin_make_empty_g_idx) 2. **diffulex/utils/loader.py** - 优化性能和代码结构: - 一次性扫描 safetensors 文件建立 key_to_file 索引,避免重复文件 I/O - 缓存 `model.named_modules()` 结果,避免重复构建字典 - 新增 `_find_offline_capable_module()`: 统一模块查找逻辑 - 新增 `_load_tensors_for_prefix()`: 集中加载张量,仅打开必要的文件 - 将 print() 替换为 logger.warning()/logger.exception() 以规范化日志 3. **diffulex/engine/model_runner.py** - 消除重复循环: - 在 `allocate_kv_cache` 中统一缓存 attention 模块列表 - 用 `enumerate(attn_modules)` 替换重复的模块遍历循环 4. **diffulex/utils/quantization/strategies/linear_int4_w4a16.py** - 修复缺失实现: - 添加 `quantize_weight_for_kernel` 方法,修复 W4A16 在线量化运行时错误 5. 删除未使用的配置文件 `gptq_marlin_w2_bf16kv_varlen.yml` 测试: 已验证 W8A16 在线量化和 GPTQ 离线量化功能正常
- 将最后总结从最后一步的瞬时吞吐改为真正的平均值(总token/总时间) - 新增 ms/step 统计信息,便于分析性能 - 修复了之前只显示最后一步瞬时值而非平均值的问题
- 量化 linear:去 kwargs/pop/重复可用性检查,缓存 out_features 与必要中间张量 - 直连 vLLM CUDA ops(W8A8/GPTQ/AWQ/Marlin 等)以降低 Python glue 开销 - load-time 处理 qweight/scales 的布局与 contiguous,避免 forward 里重复处理 - 移除 linear.py 中 profiler record 标注,保持代码简洁 - 补充 trace/profile 辅助分析脚本与相关测试
… strategies - Remove all .item() calls in LinearBase hot paths (GPU->CPU sync breaks graph capture) - Add Python-side meta cache (_offline_quant_*_py, _gptq_is_shuffled_py, etc.) - Use in-place fill_() + Python mirrors for state updates - Simplify linear quantization strategies for future CUDA Graph support - Remove fast_path checks and redundant branching in linear_marlin_int8_w8a16 - Remove fast_path in linear_int8_w8a8 (unified vLLM path) - Simplify linear_gptq_w4a16 (direct torch.ops._C.gptq_gemm call) - Make linear_fp8_w8a16 use explicit quant_scales parameter - Fix FP8 weight layout: do not force contiguous for transpose-view (KxN stride0==1) - Remove profiler record_function wrappers (graph-friendly) Net: -129 lines, cleaner codebase ready for CUDA Graph capture
- Add per-layer ForwardPlan to pre-resolve bf16/quant/offline paths and reduce per-call Python branching. - Prefer direct torch.ops kernels (GPTQ/AWQ/Marlin) with static args for stable capture. - Fix D2F static CUDA graph capture/replay metadata (token buckets + cu_seqlens) and add profiler flag.
- Fix tensor shape mismatch bug in static+CUDA Graph decode mode (model_runner.py) - Improve bucket selection logic for variable token counts - Add safety fallback when runtime batch exceeds captured capacity - Fix metadata buffer initialization and padding - Add new static mode benchmark configs: - awq_bf16kv_static.yml - gptq_marlin_w4_bf16kv_static.yml - gptq_marlin_w8_bf16kv_static.yml - Update quantization strategies and loader utilities - Update benchmark configurations for consistency
- 移除 v0.0.1 之后新增的 bench 配置与量化架构文档 - 将 W8A16/DP 等调参从 env 收敛到 Config/strategy.configure - 示例/脚本去掉硬编码本机路径与默认 GPU,并修复语法问题
|
Important Review skippedToo many files! This PR contains 158 files, which is 8 over the limit of 150. You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing touches🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.