Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
e17bb0c
feat(kv-cache): add dtype protocol utilities for KV cache
Dec 15, 2025
4f31a78
feat(kv-cache): add multi-dtype support (BF16/FP16/FP32/FP8) to KV ca…
Dec 15, 2025
4b2f61c
fix(test): fix import path in test_dllm_kv_cache_store
Dec 22, 2025
286bcf7
test(kv-cache): add FP8 roundtrip tests for unified and distinct layouts
Dec 22, 2025
7589e92
fix(attention): optimize shared memory usage in prefix_prefill kernel
Dec 22, 2025
29e0bd2
merge: merge origin/main into feat/kv-cache-fp8-support, preserve FP8…
Dec 24, 2025
fdd4bd5
feat(kv-cache): add FP8 KV cache support in model_runner
Dec 24, 2025
7496582
feat(attention): integrate FP8 KV cache support in attention layers
Dec 24, 2025
7080972
fix(attention): fix parameter passing in store_kvcache calls
Dec 24, 2025
c26b135
test: add memory usage and speed comparison tests for FP8 KV cache
Dec 24, 2025
835c47f
Merge pull request #13 from luozixin2/feat/kv-cache-fp8-support
luozixin2 Dec 24, 2025
c7ee24b
feat(mode): sdar inference supported, decoding kv cache slot mapping …
drewjin Dec 25, 2025
1741805
feat: add test suite and utility functions for flash attention kernel…
drewjin Dec 25, 2025
0d75af5
feat(kernel): update the page_table fetch logics of decoding_kernel f…
drewjin Dec 25, 2025
191e706
fix: dllm_flash_attn_decode_kernel recompilation problem fixed
drewjin Dec 27, 2025
d2507ac
fix: all attn kernels available for inference, checking functions ava…
drewjin Dec 28, 2025
c06b7ef
fix: fix kernel compilation error on Hopper devices vis disabling TMA…
drewjin Dec 28, 2025
8434932
test: add test cases for multiround decoding
drewjin Dec 28, 2025
535e296
feat(strategy): create fast-dllm-v2 strategy
drewjin Dec 29, 2025
2828858
feat(fp8-kv): implement FP8 KV cache with Python dequantization
Dec 29, 2025
90a518b
update .gitignore
drewjin Dec 29, 2025
b97af94
feat(fp8-kv): implement FP8 distinct layout store and load
Dec 29, 2025
714f915
feat(sequence): add new sub-block statuses and attributes to FDV2SubB…
drewjin Dec 29, 2025
746be44
feat(attention): add wrapper functions for prefill/decode with quanti…
Dec 29, 2025
39c0d7e
chore: update GitHub workflows to grant write permissions for issues …
drewjin Dec 29, 2025
6b1f450
perf: optimize FP8 KV decode kernel using T.copy and scale fusion
Dec 29, 2025
6a951ea
重构量化模块架构并重命名FP8内核
Dec 30, 2025
820eac4
merge: merge origin/main into feat/kv-cache-fp8-support
Dec 30, 2025
39cb15e
Merge pull request #21 from luozixin2/feat/kv-cache-fp8-support
luozixin2 Dec 30, 2025
65edadd
feat: add Linear layer quantization strategy framework
Dec 31, 2025
fc32954
feat: implement W8A16 Linear quantization strategy (int8 weight + bf1…
Dec 31, 2025
266ea93
perf: implement lazy cache for W8A16 Linear quantization strategy
Dec 31, 2025
64e4347
feat: implement W8A16 TileLang kernel for Linear quantization
Dec 31, 2025
039693c
Merge branch 'main' into feat/fast-dllm-v2
drewjin Dec 31, 2025
ea47276
feat: 为 test_text_generation.py 添加 warmup 机制和性能对比功能
Dec 31, 2025
9ba300d
feat: implement load-time quantization and memory-saving for W8A16 Li…
Dec 31, 2025
ca3007c
Optimize W8A16 and W4A16 kernels: move per-channel scale from weight …
Dec 31, 2025
833b32c
Improve W8A8/W4A8 quality by using FP16 scales instead of BF16
Jan 1, 2026
f9a9e1a
chore: update pyproject.toml to add pandas and tilelang dependencies,…
drewjin Jan 5, 2026
6055b39
Merge branch 'feat/fast-dllm-v2' into feat/enhance-strategy
drewjin Jan 5, 2026
ba2801a
feat: implement Diffulex benchmark framework with support for multipl…
drewjin Jan 5, 2026
47b5e9d
feat: add logging capabilities and configuration management to Difful…
drewjin Jan 5, 2026
5aa3bf4
chore: add make.bat into the build scripts of docs
drewjin Jan 5, 2026
50f803d
chore: add offline evaluation script and update tilelang dependency
drewjin Jan 5, 2026
2e03ca7
bugfix: fix config dataclass mutable default and field propagation in…
drewjin Jan 5, 2026
4c5d860
bugfix: _dp_child_entry missing decoding_strategy
drewjin Jan 5, 2026
15704df
feat: introduce Diffulex Profiler for performance analysis with modul…
drewjin Jan 5, 2026
7e65c0b
bugfix: try to fix profiler bug, upload and sync first
drewjin Jan 6, 2026
5b8352f
Merge pull request #18 from drewjin/feat/enhance-strategy
drewjin Jan 6, 2026
c74b14b
Remove AttnQ quantization strategy support
Jan 12, 2026
f8aa715
Merge remote-tracking branch 'fork/main' into feat/kv-cache-fp8-support
Jan 12, 2026
67686e0
Merge branch 'zhijie-group:feat/kv-cache-fp8-support' into feat/kv-ca…
luozixin2 Jan 12, 2026
0d9dd96
Merge branch 'zhijie-group:v0.0.1' into v0.0.1
luozixin2 Jan 12, 2026
44fca07
Merge remote-tracking branch 'fork/v0.0.1' into feat/kv-cache-fp8-sup…
Jan 12, 2026
b4a4ed1
fix: 修复 FP8 KV cache RunningMax 策略中的 scale 更新逻辑
Jan 13, 2026
7b15d65
chore: 移除 .cursor 目录并添加到 .gitignore
Jan 13, 2026
9015510
Merge commit '67686e0' into feat/kv-cache-fp8-support
Jan 13, 2026
426b314
feat: optimize W8A16 decode and FP8 KV varlen path
Jan 14, 2026
dde9962
feat: integrate Marlin/AllSpark INT8 W8A16 quantization strategy
Jan 16, 2026
4a6e365
Merge pull request #23 from luozixin2/feat/kv-cache-fp8-support
luozixin2 Jan 16, 2026
3ec5e80
feat: integrate Marlin/AllSpark INT8 W8A16 quantization strategy
Jan 16, 2026
55b8b4d
Merge pull request #25 from luozixin2/feat/kv-cache-fp8-support
luozixin2 Jan 16, 2026
a925717
feat: 支持 GPTQ Marlin 和 AWQ Marlin 量化格式
Jan 18, 2026
16d7892
chore: 从仓库移除 benchmark_results
Jan 18, 2026
a594135
升级 quantize_model.py 为真正的 GPTQ/AWQ 量化路径
Jan 18, 2026
8824ccd
refactor: 优化代码结构和消除重复逻辑
Jan 18, 2026
23d377a
fix: 修正 bench 中 prefill/decode 吞吐量的平均值计算
Jan 24, 2026
896b8df
perf: 优化量化 linear fast path 并移除 profiler 标注
Jan 25, 2026
f6d0fa2
refactor: remove CUDA Graph blockers and simplify linear quantization…
Jan 25, 2026
7fba595
perf: cache linear forward dispatch for CUDA Graph
Jan 26, 2026
0d51145
Fix static+CUDA Graph mode and add benchmark configs
Jan 27, 2026
8ea8717
chore: 清理实验配置与环境变量依赖
Jan 28, 2026
600eb4c
Merge pull request #26 from luozixin2/feat/kv-cache-fp8-support
luozixin2 Jan 28, 2026
b2f66f4
Merge branch 'main' into feat/kv-cache-fp8-support
drewjin Feb 14, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ autotuner.log
Fast-dLLM
Discrete-Diffusion-Forcing
position_explanation.md
temp/
cuda_cache/

# IDE
Expand All @@ -50,4 +51,9 @@ kernel_diff_analysis.md
tilelang_optimization_analysis.md
boundary_check_comparison.md
GITHUB_ISSUE.md
Tilelang-failed_test_cases/
Tilelang-failed_test_cases/
# Benchmark results
benchmark_results/
benchmark_results_tmp/
# Cursor IDE files
.cursor/
46 changes: 42 additions & 4 deletions diffulex/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,42 @@
from diffulex.diffulex import Diffulex
from diffulex.sampling_params import SamplingParams
# Import strategies to trigger registration
from diffulex import strategy # noqa: F401
"""Diffulex package root.

Keep this module lightweight so that importing submodules like
`diffulex.utils.quantization` does not eagerly import the full engine/kernel.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
# These are available for type checkers; runtime import is lazy via __getattr__.
from diffulex.diffulex import Diffulex as Diffulex # noqa: F401
from diffulex.sampling_params import SamplingParams as SamplingParams # noqa: F401
from diffulex.logger import get_logger as get_logger, setup_logger as setup_logger, LoggerMixin as LoggerMixin # noqa: F401


def __getattr__(name: str):
if name == "Diffulex":
# Only trigger heavy side-effect imports when users actually construct the engine.
# This keeps `import diffulex.utils.quantization` lightweight.
from diffulex import strategy as _strategy # noqa: F401
from diffulex.diffulex import Diffulex

return Diffulex
if name == "SamplingParams":
from diffulex.sampling_params import SamplingParams

return SamplingParams
if name == "get_logger":
from diffulex.logger import get_logger
return get_logger
if name == "setup_logger":
from diffulex.logger import setup_logger
return setup_logger
if name == "LoggerMixin":
from diffulex.logger import LoggerMixin
return LoggerMixin
raise AttributeError(name)


__all__ = ["Diffulex", "SamplingParams", "get_logger", "setup_logger", "LoggerMixin"]
7 changes: 5 additions & 2 deletions diffulex/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ def __repr__(self):
def __getattr__(name):
"""Lazy import to avoid circular deps during module init."""
if name == "Attention":
from .attn_impl import Attention
return Attention
try:
from .attn_impl import Attention
return Attention
except Exception as e:
raise ImportError(f"Failed to import diffulex.attention.attn_impl.Attention: {e}")
if name == "fetch_attn_metadata":
return metadata.fetch_attn_metadata
raise AttributeError(f"module {__name__} has no attribute {name}")
46 changes: 45 additions & 1 deletion diffulex/attention/attn_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def __init__(
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = self.v_cache = torch.tensor([])
# Quantization scales (will be bound by ModelRunner if strategy requires them)
self.k_scale = None
self.v_scale = None

self.q_shape = {
'nh': self.num_heads,
Expand Down Expand Up @@ -53,6 +56,21 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
# Fast Store KV cache
if k_cache.numel() and v_cache.numel():
if attn_metadata.need_kv_cache_store:
# Update scales if quantization strategy requires them
if self.k_scale is not None and self.v_scale is not None:
from diffulex.utils.quantization.context import get_kv_cache_strategy
strategy = get_kv_cache_strategy()
if strategy is not None:
self.k_scale, self.v_scale = strategy.update_scales(
k, v, self.k_scale, self.v_scale,
self.num_kv_heads, k.device
)
# Pass scale to metadata if required by strategy
if strategy is not None:
strategy.maybe_set_attn_metadata_scales(
attn_metadata, k_scale=self.k_scale, v_scale=self.v_scale
)

store_kvcache = store_kvcache_unified_layout if is_unified_layout else store_kvcache_distinct_layout
store_kvcache(k, v, k_cache, v_cache, attn_metadata.slot_mapping, attn_metadata)

Expand All @@ -64,9 +82,35 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
o = dllm_flash_attn_prefill(q, k, v, self.scale, attn_metadata)
else:
if is_unified_layout:
from diffulex.utils.quantization.context import get_kv_cache_strategy
strategy = get_kv_cache_strategy()
if strategy is not None:
# e.g. FP8: pass scales to metadata for kernel / load_kvcache to handle
strategy.maybe_set_attn_metadata_scales(
attn_metadata, k_scale=self.k_scale, v_scale=self.v_scale
)

o = dllm_flash_attn_decode(q, k, v, k_cache, v_cache, self.scale, attn_metadata)
else:
raise NotImplementedError("Distinct layout is not supported yet...")
# Distinct layout: use varlen mode with load_kvcache
from diffulex_kernel import load_kvcache
from diffulex.utils.quantization.context import get_kv_cache_strategy
strategy = get_kv_cache_strategy()
if strategy is not None:
# e.g. FP8: pass scales to metadata for load_kvcache to handle
strategy.maybe_set_attn_metadata_scales(
attn_metadata, k_scale=self.k_scale, v_scale=self.v_scale
)

# Distinct layout uses varlen mode
k_comb, v_comb = load_kvcache(k_cache, v_cache, attn_metadata, k, v)
from flash_attn import flash_attn_varlen_func
o = flash_attn_varlen_func(
q, k_comb, v_comb,
attn_metadata.cu_seqlens_q, attn_metadata.cu_seqlens_k,
attn_metadata.max_seqlen_q, attn_metadata.max_seqlen_k,
softmax_scale=self.scale, block_table=None
)

# Final reshape
return rearrange(o, 's nh hd -> s (nh hd)').contiguous()
3 changes: 3 additions & 0 deletions diffulex/attention/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class AttnMetaDataBase:
attn_type: str = "block_attention"
diffusion_block_size: int = 32
decode_mode: str = "static"
k_scale: torch.Tensor | None = None # Quantization scale for K cache, shape [num_kv_heads]
v_scale: torch.Tensor | None = None # Quantization scale for V cache, shape [num_kv_heads]
q_scale: torch.Tensor | None = None # Quantization scale for Q, strategy-defined shape (e.g. [num_heads] or [1])

@property
def num_seqs(self) -> int:
Expand Down
36 changes: 32 additions & 4 deletions diffulex/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os

from dataclasses import dataclass
from dataclasses import dataclass, field
from transformers import AutoConfig
from diffulex.logger import get_logger

logger = get_logger(__name__)


@dataclass
Expand Down Expand Up @@ -31,9 +34,10 @@ class Config:
master_addr: str = "localhost"
master_port: int = 2333
# Shared memory segment name for intra-TP RPC; must be unique per DP group.
shm_name: str = "diffuserve_shm"
shm_name: str = "diffulex_shm"
# Start device index for this TP group (set by DP launcher).
device_start: int = 0
device_ids: list[int] = field(default_factory=lambda: [])

enforce_eager: bool = False
hf_config: AutoConfig | None = None
Expand All @@ -42,6 +46,23 @@ class Config:
num_kvcache_blocks: int = -1
k_cache_hdim_split_factor_x: int = 8
kv_cache_layout: str = "unified" # "unified" or "distinct"
kv_cache_dtype: str = "bf16" # "bf16", "fp16", "fp32", "fp8_e4m3", "fp8_e5m2"
decode_mode: str | None = None # "static" or "varlen", None means auto-select based on kv_cache_dtype
# Attention-Q dtype (activation quantization). "bf16" default; "fp8" is a placeholder
# for future kernels (enabling it will currently raise NotImplementedError at runtime).
attn_q_dtype: str = "bf16"
# Linear quantization (weights + activations). All are placeholders for future kernels.
# Use "bf16" to disable quantization.
# Supported aliases (normalized in registry): bf16/int8/int4/fp8/fp8_e4m3/fp8_e5m2/gptq/awq.
linear_attn_weight_dtype: str = "bf16"
linear_mlp_weight_dtype: str = "bf16"
linear_attn_act_dtype: str = "bf16"
linear_mlp_act_dtype: str = "bf16"

# Kernel tuning knobs (avoid environment-variable based tuning in library code).
# Currently used by some W8A16 linear strategies.
linear_w8a16_quant_block_n: int = 256
linear_w8a16_allspark_cublas_m_threshold: int = 256

def __post_init__(self):
assert os.path.isdir(self.model)
Expand All @@ -56,9 +77,16 @@ def __post_init__(self):
if not self.lora_path:
raise ValueError("lora_path must be provided when use_lora is True")
if not os.path.exists(self.lora_path):
print(f"Warning: LoRA path {self.lora_path} does not exist")
logger.warning(f"LoRA path {self.lora_path} does not exist")

self.hf_config = AutoConfig.from_pretrained(self.model, trust_remote_code=True)
cfg_max_model_len = self.hf_config.max_position_embeddings if hasattr(self.hf_config, "max_position_embeddings") else self.hf_config.max_sequence_length
self.max_model_len = min(self.max_model_len, cfg_max_model_len)
assert self.max_num_batched_tokens >= self.max_model_len
assert self.max_num_batched_tokens >= self.max_model_len

if not self.device_ids:
import torch
# When CUDA_VISIBLE_DEVICES is set, PyTorch maps physical devices to logical device 0, 1, ...
# So we should use logical device indices (0, 1, ...) instead of physical device IDs
self.device_ids = list(range(torch.cuda.device_count()))
logger.info(f"Using CUDA devices: {self.device_ids}")
4 changes: 2 additions & 2 deletions diffulex/diffulex.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class Diffulex:
def __new__(cls, model, **kwargs):
cfg = Config(model, **{k: v for k, v in kwargs.items() if k in Config.__dataclass_fields__.keys()})
if cfg.data_parallel_size > 1:
data_parallel_size = kwargs.get('data_parallel_size', 1)
if data_parallel_size > 1:
return DiffulexDPWorker(model, **kwargs)
return DiffulexTPWorker(model, **kwargs)
23 changes: 16 additions & 7 deletions diffulex/engine/dp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from diffulex.config import Config
from diffulex.engine.tp_worker import DiffulexTPWorker
from diffulex.sampling_params import SamplingParams
from diffulex.logger import get_logger

logger = get_logger(__name__)


def _dp_child_entry(config: Config, dp_idx: int, local_devices: list[int], conn):
Expand All @@ -25,11 +28,12 @@ def _dp_child_entry(config: Config, dp_idx: int, local_devices: list[int], conn)
faulthandler.enable(all_threads=True)
except Exception:
pass
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(x) for x in local_devices)
# os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(x) for x in local_devices)
cfg = Config(
model=config.model,
lora_path=config.lora_path,
model_name=config.model_name,
decoding_strategy=config.decoding_strategy,
mask_token_id=config.mask_token_id,
diffusion_block_size=config.diffusion_block_size,
accept_threshold=config.accept_threshold,
Expand All @@ -52,6 +56,7 @@ def _dp_child_entry(config: Config, dp_idx: int, local_devices: list[int], conn)
kv_cache_layout=config.kv_cache_layout,
)
setattr(cfg, "device_start", 0)
setattr(cfg, "device_ids", local_devices)

engine = DiffulexTPWorker(cfg.model, **{k: getattr(cfg, k) for k in cfg.__dataclass_fields__.keys() if k != "model"})

Expand Down Expand Up @@ -81,17 +86,23 @@ def _dp_child_entry(config: Config, dp_idx: int, local_devices: list[int], conn)
else:
conn.send(("err", f"unknown_cmd:{cmd}"))
except Exception as e:
# Include full traceback for easier debugging and also print to stderr as a fallback.
# Include full traceback for easier debugging and also log as a fallback.
tb = traceback.format_exc()
msg = f"{type(e).__name__}: {e}\n{tb}"
try:
conn.send(("err", msg))
except Exception:
pass
try:
print(f"[DP Child {dp_idx}] Unhandled exception:\n{msg}", file=sys.stderr, flush=True)
# Use logger for error reporting
child_logger = get_logger(f"diffulex.engine.dp_worker.child_{dp_idx}")
child_logger.error(f"[DP Child {dp_idx}] Unhandled exception:\n{msg}")
except Exception:
pass
# Final fallback to stderr
try:
print(f"[DP Child {dp_idx}] Unhandled exception:\n{msg}", file=sys.stderr, flush=True)
except Exception:
pass


class DiffulexDPWorker:
Expand All @@ -116,12 +127,10 @@ def __init__(self, model, **kwargs):
need_gpus = self.dp_size * cfg.tensor_parallel_size
assert len(vis) >= need_gpus, f"Require {need_gpus} GPUs (dp={self.dp_size}, tp={cfg.tensor_parallel_size}), visible {len(vis)}"

# Optional overrides: kwargs['device_ids'] or env D2F_DEVICE_MAP
# Optional overrides: kwargs['device_ids']
override = None
if 'device_ids' in kwargs and kwargs['device_ids']:
override = list(kwargs['device_ids'])
elif os.environ.get('D2F_DEVICE_MAP'):
override = [int(x) for x in os.environ['D2F_DEVICE_MAP'].split(',') if x.strip() != '']
if override is not None:
assert len(override) >= need_gpus, f"device_ids length {len(override)} < required {need_gpus}"
# All override devices must be in visible list
Expand Down
Loading