diff --git a/.gitignore b/.gitignore index 9db0810..76f8e70 100755 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,7 @@ autotuner.log Fast-dLLM Discrete-Diffusion-Forcing position_explanation.md +temp/ cuda_cache/ # IDE @@ -50,4 +51,9 @@ kernel_diff_analysis.md tilelang_optimization_analysis.md boundary_check_comparison.md GITHUB_ISSUE.md -Tilelang-failed_test_cases/ \ No newline at end of file +Tilelang-failed_test_cases/ +# Benchmark results +benchmark_results/ +benchmark_results_tmp/ +# Cursor IDE files +.cursor/ diff --git a/diffulex/__init__.py b/diffulex/__init__.py index 23098a7..48701da 100755 --- a/diffulex/__init__.py +++ b/diffulex/__init__.py @@ -1,4 +1,42 @@ -from diffulex.diffulex import Diffulex -from diffulex.sampling_params import SamplingParams -# Import strategies to trigger registration -from diffulex import strategy # noqa: F401 +"""Diffulex package root. + +Keep this module lightweight so that importing submodules like +`diffulex.utils.quantization` does not eagerly import the full engine/kernel. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + # These are available for type checkers; runtime import is lazy via __getattr__. + from diffulex.diffulex import Diffulex as Diffulex # noqa: F401 + from diffulex.sampling_params import SamplingParams as SamplingParams # noqa: F401 + from diffulex.logger import get_logger as get_logger, setup_logger as setup_logger, LoggerMixin as LoggerMixin # noqa: F401 + + +def __getattr__(name: str): + if name == "Diffulex": + # Only trigger heavy side-effect imports when users actually construct the engine. + # This keeps `import diffulex.utils.quantization` lightweight. + from diffulex import strategy as _strategy # noqa: F401 + from diffulex.diffulex import Diffulex + + return Diffulex + if name == "SamplingParams": + from diffulex.sampling_params import SamplingParams + + return SamplingParams + if name == "get_logger": + from diffulex.logger import get_logger + return get_logger + if name == "setup_logger": + from diffulex.logger import setup_logger + return setup_logger + if name == "LoggerMixin": + from diffulex.logger import LoggerMixin + return LoggerMixin + raise AttributeError(name) + + +__all__ = ["Diffulex", "SamplingParams", "get_logger", "setup_logger", "LoggerMixin"] diff --git a/diffulex/attention/__init__.py b/diffulex/attention/__init__.py index a390a61..7e536f8 100644 --- a/diffulex/attention/__init__.py +++ b/diffulex/attention/__init__.py @@ -17,8 +17,11 @@ def __repr__(self): def __getattr__(name): """Lazy import to avoid circular deps during module init.""" if name == "Attention": - from .attn_impl import Attention - return Attention + try: + from .attn_impl import Attention + return Attention + except Exception as e: + raise ImportError(f"Failed to import diffulex.attention.attn_impl.Attention: {e}") if name == "fetch_attn_metadata": return metadata.fetch_attn_metadata raise AttributeError(f"module {__name__} has no attribute {name}") \ No newline at end of file diff --git a/diffulex/attention/attn_impl.py b/diffulex/attention/attn_impl.py index 9ec5f7f..01297ec 100644 --- a/diffulex/attention/attn_impl.py +++ b/diffulex/attention/attn_impl.py @@ -25,6 +25,9 @@ def __init__( self.scale = scale self.num_kv_heads = num_kv_heads self.k_cache = self.v_cache = torch.tensor([]) + # Quantization scales (will be bound by ModelRunner if strategy requires them) + self.k_scale = None + self.v_scale = None self.q_shape = { 'nh': self.num_heads, @@ -53,6 +56,21 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, # Fast Store KV cache if k_cache.numel() and v_cache.numel(): if attn_metadata.need_kv_cache_store: + # Update scales if quantization strategy requires them + if self.k_scale is not None and self.v_scale is not None: + from diffulex.utils.quantization.context import get_kv_cache_strategy + strategy = get_kv_cache_strategy() + if strategy is not None: + self.k_scale, self.v_scale = strategy.update_scales( + k, v, self.k_scale, self.v_scale, + self.num_kv_heads, k.device + ) + # Pass scale to metadata if required by strategy + if strategy is not None: + strategy.maybe_set_attn_metadata_scales( + attn_metadata, k_scale=self.k_scale, v_scale=self.v_scale + ) + store_kvcache = store_kvcache_unified_layout if is_unified_layout else store_kvcache_distinct_layout store_kvcache(k, v, k_cache, v_cache, attn_metadata.slot_mapping, attn_metadata) @@ -64,9 +82,35 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o = dllm_flash_attn_prefill(q, k, v, self.scale, attn_metadata) else: if is_unified_layout: + from diffulex.utils.quantization.context import get_kv_cache_strategy + strategy = get_kv_cache_strategy() + if strategy is not None: + # e.g. FP8: pass scales to metadata for kernel / load_kvcache to handle + strategy.maybe_set_attn_metadata_scales( + attn_metadata, k_scale=self.k_scale, v_scale=self.v_scale + ) + o = dllm_flash_attn_decode(q, k, v, k_cache, v_cache, self.scale, attn_metadata) else: - raise NotImplementedError("Distinct layout is not supported yet...") + # Distinct layout: use varlen mode with load_kvcache + from diffulex_kernel import load_kvcache + from diffulex.utils.quantization.context import get_kv_cache_strategy + strategy = get_kv_cache_strategy() + if strategy is not None: + # e.g. FP8: pass scales to metadata for load_kvcache to handle + strategy.maybe_set_attn_metadata_scales( + attn_metadata, k_scale=self.k_scale, v_scale=self.v_scale + ) + + # Distinct layout uses varlen mode + k_comb, v_comb = load_kvcache(k_cache, v_cache, attn_metadata, k, v) + from flash_attn import flash_attn_varlen_func + o = flash_attn_varlen_func( + q, k_comb, v_comb, + attn_metadata.cu_seqlens_q, attn_metadata.cu_seqlens_k, + attn_metadata.max_seqlen_q, attn_metadata.max_seqlen_k, + softmax_scale=self.scale, block_table=None + ) # Final reshape return rearrange(o, 's nh hd -> s (nh hd)').contiguous() \ No newline at end of file diff --git a/diffulex/attention/metadata.py b/diffulex/attention/metadata.py index 75c290e..c9dff9f 100644 --- a/diffulex/attention/metadata.py +++ b/diffulex/attention/metadata.py @@ -18,6 +18,9 @@ class AttnMetaDataBase: attn_type: str = "block_attention" diffusion_block_size: int = 32 decode_mode: str = "static" + k_scale: torch.Tensor | None = None # Quantization scale for K cache, shape [num_kv_heads] + v_scale: torch.Tensor | None = None # Quantization scale for V cache, shape [num_kv_heads] + q_scale: torch.Tensor | None = None # Quantization scale for Q, strategy-defined shape (e.g. [num_heads] or [1]) @property def num_seqs(self) -> int: diff --git a/diffulex/config.py b/diffulex/config.py index 96af47c..f571f34 100755 --- a/diffulex/config.py +++ b/diffulex/config.py @@ -1,7 +1,10 @@ import os -from dataclasses import dataclass +from dataclasses import dataclass, field from transformers import AutoConfig +from diffulex.logger import get_logger + +logger = get_logger(__name__) @dataclass @@ -31,9 +34,10 @@ class Config: master_addr: str = "localhost" master_port: int = 2333 # Shared memory segment name for intra-TP RPC; must be unique per DP group. - shm_name: str = "diffuserve_shm" + shm_name: str = "diffulex_shm" # Start device index for this TP group (set by DP launcher). device_start: int = 0 + device_ids: list[int] = field(default_factory=lambda: []) enforce_eager: bool = False hf_config: AutoConfig | None = None @@ -42,6 +46,23 @@ class Config: num_kvcache_blocks: int = -1 k_cache_hdim_split_factor_x: int = 8 kv_cache_layout: str = "unified" # "unified" or "distinct" + kv_cache_dtype: str = "bf16" # "bf16", "fp16", "fp32", "fp8_e4m3", "fp8_e5m2" + decode_mode: str | None = None # "static" or "varlen", None means auto-select based on kv_cache_dtype + # Attention-Q dtype (activation quantization). "bf16" default; "fp8" is a placeholder + # for future kernels (enabling it will currently raise NotImplementedError at runtime). + attn_q_dtype: str = "bf16" + # Linear quantization (weights + activations). All are placeholders for future kernels. + # Use "bf16" to disable quantization. + # Supported aliases (normalized in registry): bf16/int8/int4/fp8/fp8_e4m3/fp8_e5m2/gptq/awq. + linear_attn_weight_dtype: str = "bf16" + linear_mlp_weight_dtype: str = "bf16" + linear_attn_act_dtype: str = "bf16" + linear_mlp_act_dtype: str = "bf16" + + # Kernel tuning knobs (avoid environment-variable based tuning in library code). + # Currently used by some W8A16 linear strategies. + linear_w8a16_quant_block_n: int = 256 + linear_w8a16_allspark_cublas_m_threshold: int = 256 def __post_init__(self): assert os.path.isdir(self.model) @@ -56,9 +77,16 @@ def __post_init__(self): if not self.lora_path: raise ValueError("lora_path must be provided when use_lora is True") if not os.path.exists(self.lora_path): - print(f"Warning: LoRA path {self.lora_path} does not exist") + logger.warning(f"LoRA path {self.lora_path} does not exist") self.hf_config = AutoConfig.from_pretrained(self.model, trust_remote_code=True) cfg_max_model_len = self.hf_config.max_position_embeddings if hasattr(self.hf_config, "max_position_embeddings") else self.hf_config.max_sequence_length self.max_model_len = min(self.max_model_len, cfg_max_model_len) - assert self.max_num_batched_tokens >= self.max_model_len \ No newline at end of file + assert self.max_num_batched_tokens >= self.max_model_len + + if not self.device_ids: + import torch + # When CUDA_VISIBLE_DEVICES is set, PyTorch maps physical devices to logical device 0, 1, ... + # So we should use logical device indices (0, 1, ...) instead of physical device IDs + self.device_ids = list(range(torch.cuda.device_count())) + logger.info(f"Using CUDA devices: {self.device_ids}") \ No newline at end of file diff --git a/diffulex/diffulex.py b/diffulex/diffulex.py index 08612ba..8a46e5a 100755 --- a/diffulex/diffulex.py +++ b/diffulex/diffulex.py @@ -4,7 +4,7 @@ class Diffulex: def __new__(cls, model, **kwargs): - cfg = Config(model, **{k: v for k, v in kwargs.items() if k in Config.__dataclass_fields__.keys()}) - if cfg.data_parallel_size > 1: + data_parallel_size = kwargs.get('data_parallel_size', 1) + if data_parallel_size > 1: return DiffulexDPWorker(model, **kwargs) return DiffulexTPWorker(model, **kwargs) \ No newline at end of file diff --git a/diffulex/engine/dp_worker.py b/diffulex/engine/dp_worker.py index 0f03004..7068b89 100755 --- a/diffulex/engine/dp_worker.py +++ b/diffulex/engine/dp_worker.py @@ -15,6 +15,9 @@ from diffulex.config import Config from diffulex.engine.tp_worker import DiffulexTPWorker from diffulex.sampling_params import SamplingParams +from diffulex.logger import get_logger + +logger = get_logger(__name__) def _dp_child_entry(config: Config, dp_idx: int, local_devices: list[int], conn): @@ -25,11 +28,12 @@ def _dp_child_entry(config: Config, dp_idx: int, local_devices: list[int], conn) faulthandler.enable(all_threads=True) except Exception: pass - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(x) for x in local_devices) + # os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(x) for x in local_devices) cfg = Config( model=config.model, lora_path=config.lora_path, model_name=config.model_name, + decoding_strategy=config.decoding_strategy, mask_token_id=config.mask_token_id, diffusion_block_size=config.diffusion_block_size, accept_threshold=config.accept_threshold, @@ -52,6 +56,7 @@ def _dp_child_entry(config: Config, dp_idx: int, local_devices: list[int], conn) kv_cache_layout=config.kv_cache_layout, ) setattr(cfg, "device_start", 0) + setattr(cfg, "device_ids", local_devices) engine = DiffulexTPWorker(cfg.model, **{k: getattr(cfg, k) for k in cfg.__dataclass_fields__.keys() if k != "model"}) @@ -81,7 +86,7 @@ def _dp_child_entry(config: Config, dp_idx: int, local_devices: list[int], conn) else: conn.send(("err", f"unknown_cmd:{cmd}")) except Exception as e: - # Include full traceback for easier debugging and also print to stderr as a fallback. + # Include full traceback for easier debugging and also log as a fallback. tb = traceback.format_exc() msg = f"{type(e).__name__}: {e}\n{tb}" try: @@ -89,9 +94,15 @@ def _dp_child_entry(config: Config, dp_idx: int, local_devices: list[int], conn) except Exception: pass try: - print(f"[DP Child {dp_idx}] Unhandled exception:\n{msg}", file=sys.stderr, flush=True) + # Use logger for error reporting + child_logger = get_logger(f"diffulex.engine.dp_worker.child_{dp_idx}") + child_logger.error(f"[DP Child {dp_idx}] Unhandled exception:\n{msg}") except Exception: - pass + # Final fallback to stderr + try: + print(f"[DP Child {dp_idx}] Unhandled exception:\n{msg}", file=sys.stderr, flush=True) + except Exception: + pass class DiffulexDPWorker: @@ -116,12 +127,10 @@ def __init__(self, model, **kwargs): need_gpus = self.dp_size * cfg.tensor_parallel_size assert len(vis) >= need_gpus, f"Require {need_gpus} GPUs (dp={self.dp_size}, tp={cfg.tensor_parallel_size}), visible {len(vis)}" - # Optional overrides: kwargs['device_ids'] or env D2F_DEVICE_MAP + # Optional overrides: kwargs['device_ids'] override = None if 'device_ids' in kwargs and kwargs['device_ids']: override = list(kwargs['device_ids']) - elif os.environ.get('D2F_DEVICE_MAP'): - override = [int(x) for x in os.environ['D2F_DEVICE_MAP'].split(',') if x.strip() != ''] if override is not None: assert len(override) >= need_gpus, f"device_ids length {len(override)} < required {need_gpus}" # All override devices must be in visible list diff --git a/diffulex/engine/model_runner.py b/diffulex/engine/model_runner.py index 4a2694b..316fd02 100755 --- a/diffulex/engine/model_runner.py +++ b/diffulex/engine/model_runner.py @@ -12,9 +12,16 @@ from diffulex.config import Config from diffulex.sampler import AutoSampler -from diffulex.engine.sequence import SequenceBase +from diffulex.engine.sequence import AutoSequence, SequenceBase +from diffulex.attention.metadata import set_warming_up, reset_warming_up from diffulex.model import AutoModelForDiffusionLM from diffulex.engine.strategy_registry import DiffulexStrategyRegistry +from diffulex.utils.quantization.factory import QuantizationStrategyFactory +from diffulex.utils.quantization.context import get_kv_cache_strategy +from diffulex.utils.quantization.strategies import NoQuantizationStrategy +from diffulex.logger import get_logger + +logger = get_logger(__name__) class ModelRunnerBase(ABC): @@ -30,8 +37,14 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event]): # Initialize model, sampler, and kv cache init_method = f"tcp://{config.master_addr}:{config.master_port}" - dist.init_process_group("nccl", init_method, world_size=self.world_size, rank=rank) - device_id = (getattr(config, "device_start", 0) or 0) + rank + dist.init_process_group("nccl", init_method, world_size=self.world_size, rank=rank, device_id=config.device_ids[rank]) + # Choose CUDA device for this TP rank. + # config.device_ids is already a list of logical CUDA device indices (respecting CUDA_VISIBLE_DEVICES). + # Do NOT add rank again, otherwise rank 1 with device_ids=[0,1] becomes device 2. + if getattr(config, "device_ids", None): + device_id = config.device_ids[rank] + else: + device_id = (getattr(config, "device_start", 0) or 0) + rank assert 0 <= device_id < torch.cuda.device_count(), f"Invalid device_id {device_id}." torch.cuda.set_device(device_id) default_dtype = torch.get_default_dtype() @@ -41,8 +54,10 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event]): torch.set_default_device(f"cuda:{device_id}") self.model = self.load_model(config) self.sampler = self.load_sampler(config) + # Initialize quantization context + QuantizationStrategyFactory.create_from_config(config) self.warmup_model() - self.allocate_kv_cache() # NOCHANGE + self.allocate_kv_cache() if not self.enforce_eager: self.capture_cudagraph() @@ -132,11 +147,28 @@ def load_model(self, config: Config): def load_sampler(self, config: Config): """Instantiate the sampler implementation; override to customize.""" return AutoSampler.from_config(config) + + def _prefill_warmup(self): + logger.info("Warming up prefill...") + max_num_batched_tokens, max_model_len = ( + self.config.max_num_batched_tokens, + self.config.max_model_len, + ) + num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) + test_input_ids = [0] * max_model_len + seqs = [AutoSequence.create(config=self.config, token_ids=test_input_ids) for _ in range(num_seqs)] + self.run(seqs, True) + for seq in seqs: + seq.post_process() + torch.cuda.empty_cache() - @abstractmethod def warmup_model(self): - """Model-specific warmup logic.""" - pass + logger.info("Warming up model...") + set_warming_up(True) + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + self._prefill_warmup() + reset_warming_up() def allocate_kv_cache(self): config = self.config @@ -158,18 +190,19 @@ def allocate_kv_cache(self): else: raise AttributeError(f"Cannot determine head_dim from config: {type(hf_config)}") - dtype = ( - hf_config.torch_dtype - if hasattr(hf_config, "torch_dtype") and hf_config.torch_dtype - else torch.bfloat16 - ) + # Get storage dtype and itemsize from quantization strategy + strategy = get_kv_cache_strategy() + if strategy is None: + strategy = NoQuantizationStrategy() + storage_dtype, itemsize = strategy.get_storage_dtype() + block_bytes = ( 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim - * dtype.itemsize + * itemsize ) get_num_kvcache_blocks = ( lambda gpu_memory_utilization: int(total * gpu_memory_utilization - used - peak + current) @@ -181,28 +214,31 @@ def allocate_kv_cache(self): except Exception: gpu_memory_utilization = config.gpu_memory_utilization while num_kvcache_blocks <= 200: - print( - "Warning: GPU memory utilization " - f"{gpu_memory_utilization} is too low to allocate kv cache. " + logger.warning( + f"GPU memory utilization {gpu_memory_utilization} is too low to allocate kv cache. " "Automatically adding 0.05." ) gpu_memory_utilization += 0.05 num_kvcache_blocks = get_num_kvcache_blocks(gpu_memory_utilization) - print( + logger.info( f"Set gpu_memory_utilization to {gpu_memory_utilization:.2f} " "to allocate kv cache." ) config.gpu_memory_utilization = gpu_memory_utilization config.num_kvcache_blocks = num_kvcache_blocks - print( - "Allocated {num_blocks} blocks of size {block_size} for kv cache on rank {rank}.".format( - num_blocks=config.num_kvcache_blocks, - block_size=self.block_size, - rank=self.rank, - ) + logger.info( + f"Allocated {config.num_kvcache_blocks} blocks of size {self.block_size} " + f"for kv cache on rank {self.rank}." ) + # Cache the list of Attention-like modules once, to keep binding logic consistent + # across cache layout branches (and avoid duplicated traversal). + attn_modules = [ + m for m in self.model.modules() + if hasattr(m, "k_cache") and hasattr(m, "v_cache") + ] + if config.kv_cache_layout == "distinct": x = config.k_cache_hdim_split_factor_x self.k_cache = torch.zeros( @@ -212,6 +248,7 @@ def allocate_kv_cache(self): head_dim // x, self.block_size, x, + dtype=storage_dtype, ) self.v_cache = torch.zeros( hf_config.num_hidden_layers, @@ -219,13 +256,11 @@ def allocate_kv_cache(self): num_kv_heads, head_dim, self.block_size, + dtype=storage_dtype, ) - layer_id = 0 - for module in self.model.modules(): - if hasattr(module, "k_cache") and hasattr(module, "v_cache"): - module.k_cache = self.k_cache[layer_id] - module.v_cache = self.v_cache[layer_id] - layer_id += 1 + for layer_id, module in enumerate(attn_modules): + module.k_cache = self.k_cache[layer_id] + module.v_cache = self.v_cache[layer_id] elif config.kv_cache_layout == "unified": self.kv_cache = torch.zeros( 2, @@ -234,19 +269,43 @@ def allocate_kv_cache(self): self.block_size, num_kv_heads, head_dim, + dtype=storage_dtype, ) - layer_id = 0 - for module in self.model.modules(): - if hasattr(module, "k_cache") and hasattr(module, "v_cache"): - module.k_cache = self.kv_cache[0, layer_id] - module.v_cache = self.kv_cache[1, layer_id] - layer_id += 1 + for layer_id, module in enumerate(attn_modules): + module.k_cache = self.kv_cache[0, layer_id] + module.v_cache = self.kv_cache[1, layer_id] else: raise ValueError( "Unsupported kv_cache_layout: {layout}. Supported values are 'distinct' and 'unified'.".format( layout=config.kv_cache_layout ) ) + + # Allocate scale tensors if quantization strategy requires them + # Get device from cache (already allocated above) + if config.kv_cache_layout == "distinct": + device = self.k_cache.device + else: # unified + device = self.kv_cache.device + k_scale_init, v_scale_init = strategy.init_scales(num_kv_heads, device) + if k_scale_init is not None and v_scale_init is not None: + # Allocate scale tensors: [num_layers, num_kv_heads] + self.k_scale = torch.zeros( + hf_config.num_hidden_layers, num_kv_heads, + dtype=torch.float32, device=device + ) + self.v_scale = torch.zeros( + hf_config.num_hidden_layers, num_kv_heads, + dtype=torch.float32, device=device + ) + # Initialize with strategy's initial scale values + self.k_scale[:] = k_scale_init[None, :] + self.v_scale[:] = v_scale_init[None, :] + + # Bind scales to Attention modules + for layer_id, module in enumerate(attn_modules): + module.k_scale = self.k_scale[layer_id] + module.v_scale = self.v_scale[layer_id] def prepare_block_tables(self, seqs: list[SequenceBase]): max_len = max(len(seq.block_table) for seq in seqs) @@ -305,6 +364,16 @@ class AutoModelRunner(DiffulexStrategyRegistry): @classmethod def from_config(cls, config: Config, rank: int, event: Event | list[Event]): + # Ensure project root is in sys.path for spawn mode subprocesses + import sys + import os + if not any('diffulex_kernel' in p for p in sys.path): + # Try to find project root by locating diffulex package + diffulex_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + if os.path.basename(diffulex_path) == 'diffulex': + project_root = os.path.dirname(diffulex_path) + if project_root not in sys.path: + sys.path.insert(0, project_root) cls._MODULE_MAPPING: dict[str, RunnerFactory] candidates: list[str] = [] for attr in ("decoding_strategy",): diff --git a/diffulex/engine/tp_worker.py b/diffulex/engine/tp_worker.py index 474a884..e215568 100755 --- a/diffulex/engine/tp_worker.py +++ b/diffulex/engine/tp_worker.py @@ -14,6 +14,9 @@ from diffulex.engine.sequence import AutoSequence from diffulex.engine.scheduler import AutoScheduler, SchedulerBase from diffulex.engine.model_runner import AutoModelRunner +from diffulex.logger import get_logger + +logger = get_logger(__name__) class DiffulexTPWorker: @@ -72,6 +75,13 @@ async def add_request_async(self, prompt: str | list[int], sampling_params: Samp return await loop.run_in_executor(None, self.add_request, prompt, sampling_params) def step(self): + # Clear step-local activation quant cache (W8A8/W4A8, etc.) so we only reuse within a single step. + try: + from diffulex.utils.quantization.context import clear_act_quant_cache + clear_act_quant_cache() + except Exception: + # Quantization context may not be initialized in some paths; ignore. + pass seqs, is_prefill = self.scheduler.schedule() sample_output = self.model_runner.call("run", seqs, is_prefill) n_diff_steps = self.scheduler.postprocess(seqs, sample_output) @@ -124,21 +134,41 @@ def generate( sid = self.add_request(prompt, sp) seqid_to_idx[sid] = idx outputs = [None] * len(prompts) - prefill_throughput = decode_throughput = 0. + # Track token/time totals for correct average throughput reporting. + prefill_total_tokens = 0 + decode_total_tokens = 0 + prefill_total_time = 0.0 + decode_total_time = 0.0 + prefill_steps = 0 + decode_steps = 0 n_steps = 0 n_diff_steps = [-1] * len(prompts) while not self.is_finished(): - t = perf_counter() n_steps += 1 + t = perf_counter() output, num_tokens, is_prefill, cur_n_diff_steps, _ = self.step() + dt = perf_counter() - t + + # Accumulate totals to compute average throughput correctly. + if is_prefill: + prefill_steps += 1 + prefill_total_tokens += int(num_tokens) + prefill_total_time += float(dt) + else: + decode_steps += 1 + decode_total_tokens += int(num_tokens) + decode_total_time += float(dt) + if use_tqdm: - if is_prefill: - prefill_throughput = num_tokens / (perf_counter() - t) - else: - decode_throughput = num_tokens / (perf_counter() - t) + avg_prefill_throughput = ( + prefill_total_tokens / prefill_total_time if prefill_total_time > 0 else 0.0 + ) + avg_decode_throughput = ( + decode_total_tokens / decode_total_time if decode_total_time > 0 else 0.0 + ) pbar.set_postfix({ - "Prefill": f"{int(prefill_throughput)}tok/s", - "Decode": f"{int(decode_throughput)}tok/s", + "Prefill(avg)": f"{int(avg_prefill_throughput)}tok/s", + "Decode(avg)": f"{int(avg_decode_throughput)}tok/s", }) if cur_n_diff_steps: for seq_id, n_step in cur_n_diff_steps.items(): @@ -150,7 +180,34 @@ def generate( if use_tqdm: pbar.update(1) - print(f"Finished in {n_steps} steps, prefill throughput: {prefill_throughput:.2f} tok/s, decode throughput: {decode_throughput:.2f} tok/s") + avg_prefill_throughput = ( + prefill_total_tokens / prefill_total_time if prefill_total_time > 0 else 0.0 + ) + avg_decode_throughput = ( + decode_total_tokens / decode_total_time if decode_total_time > 0 else 0.0 + ) + avg_prefill_step_ms = ( + (prefill_total_time / prefill_steps) * 1000.0 if prefill_steps > 0 else 0.0 + ) + avg_decode_step_ms = ( + (decode_total_time / decode_steps) * 1000.0 if decode_steps > 0 else 0.0 + ) + logger.info( + "Finished in %d steps (prefill=%d, decode=%d). " + "Prefill: %d tok in %.2fs (avg %.2f tok/s, %.2f ms/step). " + "Decode: %d tok in %.2fs (avg %.2f tok/s, %.2f ms/step).", + n_steps, + prefill_steps, + decode_steps, + prefill_total_tokens, + prefill_total_time, + avg_prefill_throughput, + avg_prefill_step_ms, + decode_total_tokens, + decode_total_time, + avg_decode_throughput, + avg_decode_step_ms, + ) # Ensure all outputs are present assert all(toks is not None for toks in outputs), "Some sequences did not produce outputs" outputs = [{ diff --git a/diffulex/layer/linear.py b/diffulex/layer/linear.py index cf14eb9..fd16eb3 100755 --- a/diffulex/layer/linear.py +++ b/diffulex/layer/linear.py @@ -1,14 +1,517 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist +from diffulex.utils.quantization.context import get_linear_strategy + def divide(numerator, denominator): assert numerator % denominator == 0 return numerator // denominator +@dataclass +class _ForwardPlanSig: + """Signature for validating cached forward plans. + + We intentionally keep it small and Python-only so it is CUDA-graph friendly + (no `.item()` and no device sync). + """ + + device_type: str + device_index: int + x_dtype: torch.dtype + x_shape: tuple[int, ...] + has_bias: bool + mode: str # "bf16" | "quant" | "offline" + strategy_name: str + + +class _ForwardPlanBase: + sig: _ForwardPlanSig + + def __call__(self, x: torch.Tensor) -> torch.Tensor: # pragma: no cover + raise NotImplementedError + + +class _BF16Plan(_ForwardPlanBase): + def __init__( + self, + *, + sig: _ForwardPlanSig, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> None: + self.sig = sig + self._weight = weight + self._bias = bias + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return F.linear(x, self._weight, self._bias) + + +class _QuantInt8W8A16Plan(_ForwardPlanBase): + def __init__( + self, + *, + sig: _ForwardPlanSig, + strategy, + quant_kind: str, + qweight: torch.Tensor, + scales_1xn: torch.Tensor, + out_features: int, + bias: Optional[torch.Tensor], + ) -> None: + self.sig = sig + self._strategy = strategy + self._quant_kind = (quant_kind or "other").strip().lower() or "other" + self._qweight = qweight + self._scales_1xn = scales_1xn + self._out_features = int(out_features) + self._bias = bias + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self._strategy.linear_forward( + x, + self._qweight, + self._bias, + quant_kind=self._quant_kind, + quant_scales=self._scales_1xn, + out_features=self._out_features, + ) + + +class _QuantInt8W8A8Plan(_ForwardPlanBase): + def __init__( + self, + *, + sig: _ForwardPlanSig, + strategy, + quant_kind: str, + qweight: torch.Tensor, + scales_1xn: torch.Tensor, + out_features: int, + bias: Optional[torch.Tensor], + ) -> None: + self.sig = sig + self._strategy = strategy + self._quant_kind = (quant_kind or "other").strip().lower() or "other" + self._qweight = qweight + self._scales_1xn = scales_1xn + self._out_features = int(out_features) + self._bias = bias + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self._strategy.linear_forward( + x, + self._qweight, + self._bias, + quant_kind=self._quant_kind, + quant_scales=self._scales_1xn, + out_features=self._out_features, + ) + + +class _QuantGenericPlan(_ForwardPlanBase): + def __init__( + self, + *, + sig: _ForwardPlanSig, + strategy, + quant_kind: str, + weight: torch.Tensor, + scales: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> None: + self.sig = sig + self._strategy = strategy + self._quant_kind = (quant_kind or "other").strip().lower() or "other" + self._weight = weight + self._scales = scales + self._bias = bias + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self._strategy.linear_forward( + x, + self._weight, + self._bias, + quant_kind=self._quant_kind, + quant_scales=self._scales, + ) + + +class _OfflineGPTQPlan(_ForwardPlanBase): + def __init__( + self, + *, + sig: _ForwardPlanSig, + strategy, + quant_kind: str, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + g_idx: torch.Tensor, + weight_bits: int, + out_features: int, + in_features: int, + group_size: int, + bias: Optional[torch.Tensor], + ) -> None: + self.sig = sig + self._strategy = strategy + self._quant_kind = (quant_kind or "other").strip().lower() or "other" + self._qweight = qweight + self._qzeros = qzeros + self._scales = scales + self._g_idx = g_idx + self._weight_bits = int(weight_bits) + self._out_features = int(out_features) + self._in_features = int(in_features) + self._group_size = int(group_size) + self._bias = bias + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self._strategy.linear_forward( + x, + None, + self._bias, + quant_kind=self._quant_kind, + gptq_qweight=self._qweight, + gptq_qzeros=self._qzeros, + gptq_scales=self._scales, + gptq_g_idx=self._g_idx, + weight_bits=self._weight_bits, + use_v2_format=False, + out_features=self._out_features, + in_features=self._in_features, + group_size=self._group_size, + ) + + +class _OfflineAWQPlan(_ForwardPlanBase): + def __init__( + self, + *, + sig: _ForwardPlanSig, + strategy, + quant_kind: str, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + pack_factor: int, + out_features: int, + in_features: int, + group_size: int, + bias: Optional[torch.Tensor], + ) -> None: + self.sig = sig + self._strategy = strategy + self._quant_kind = (quant_kind or "other").strip().lower() or "other" + self._qweight = qweight + self._qzeros = qzeros + self._scales = scales + self._pack_factor = int(pack_factor) + self._out_features = int(out_features) + self._in_features = int(in_features) + self._group_size = int(group_size) + self._bias = bias + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self._strategy.linear_forward( + x, + None, + self._bias, + quant_kind=self._quant_kind, + awq_qweight=self._qweight, + awq_qzeros=self._qzeros, + awq_scales=self._scales, + pack_factor=self._pack_factor, + out_features=self._out_features, + in_features=self._in_features, + group_size=self._group_size, + ) + + +class _OfflineGPTQMarlinPlan(_ForwardPlanBase): + def __init__( + self, + *, + sig: _ForwardPlanSig, + strategy, + quant_kind: str, + qweight: torch.Tensor, + scales: torch.Tensor, + zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + in_features: int, + out_features: int, + group_size: int, + weight_bits: int, + tp_dim: Optional[int], + bias: Optional[torch.Tensor], + ) -> None: + self.sig = sig + self._strategy = strategy + self._quant_kind = (quant_kind or "other").strip().lower() or "other" + self._qweight = qweight + self._scales = scales + self._zp = zp + self._g_idx = g_idx + self._g_idx_sort_indices = g_idx_sort_indices + self._workspace = workspace + self._in_features = int(in_features) + self._out_features = int(out_features) + self._group_size = int(group_size) + self._weight_bits = int(weight_bits) + self._tp_dim = tp_dim + self._bias = bias + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self._strategy.linear_forward( + x, + None, + self._bias, + quant_kind=self._quant_kind, + qweight=self._qweight, + scales=self._scales, + zp=self._zp, + g_idx=self._g_idx, + g_idx_sort_indices=self._g_idx_sort_indices, + workspace=self._workspace, + in_features=self._in_features, + out_features=self._out_features, + group_size=self._group_size, + weight_bits=self._weight_bits, + tp_dim=self._tp_dim, + ) + + +class _OfflineAWQMarlinPlan(_ForwardPlanBase): + def __init__( + self, + *, + sig: _ForwardPlanSig, + strategy, + quant_kind: str, + qweight: torch.Tensor, + scales: torch.Tensor, + zp: torch.Tensor, + workspace: torch.Tensor, + in_features: int, + out_features: int, + group_size: int, + tp_dim: Optional[int], + bias: Optional[torch.Tensor], + ) -> None: + self.sig = sig + self._strategy = strategy + self._quant_kind = (quant_kind or "other").strip().lower() or "other" + self._qweight = qweight + self._scales = scales + self._zp = zp + self._workspace = workspace + self._in_features = int(in_features) + self._out_features = int(out_features) + self._group_size = int(group_size) + self._tp_dim = tp_dim + self._bias = bias + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self._strategy.linear_forward( + x, + None, + self._bias, + quant_kind=self._quant_kind, + qweight=self._qweight, + scales=self._scales, + zp=self._zp, + workspace=self._workspace, + in_features=self._in_features, + out_features=self._out_features, + group_size=self._group_size, + tp_dim=self._tp_dim, + ) + + +class _DirectGPTQGemmPlan(_ForwardPlanBase): + """Direct GPTQ GEMM plan (bypass Python strategy glue). + + This calls `torch.ops._C.gptq_gemm` directly with pre-resolved static args. + """ + + def __init__( + self, + *, + sig: _ForwardPlanSig, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + g_idx: torch.Tensor, + weight_bits: int, + out_features: int, + bias: Optional[torch.Tensor], + use_exllama: bool = True, + use_v2_format: bool = False, + cast_back_to_x_dtype: bool = True, + ) -> None: + self.sig = sig + self._qweight = qweight + self._qzeros = qzeros + self._scales = scales + self._g_idx = g_idx + self._weight_bits = int(weight_bits) + self._out_features = int(out_features) + self._bias = bias + self._use_exllama = bool(use_exllama) + self._use_v2_format = bool(use_v2_format) + self._cast_back = bool(cast_back_to_x_dtype) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + # vLLM GPTQ kernels expect FP16 activations. + x_in = x if x.dtype == torch.float16 else x.to(dtype=torch.float16) + x2 = x_in.reshape(-1, x_in.shape[-1]) if x_in.dim() != 2 else x_in + if not x2.is_contiguous(): + x2 = x2.contiguous() + + out = torch.ops._C.gptq_gemm( + x2, + self._qweight, + self._qzeros, + self._scales, + self._g_idx, + self._use_exllama, + self._use_v2_format, + self._weight_bits, + ) + if self._bias is not None: + out.add_(self._bias.to(dtype=out.dtype)) + out = out.reshape(x.shape[:-1] + (self._out_features,)) + if self._cast_back and out.dtype != x.dtype: + return out.to(dtype=x.dtype) + return out + + +class _DirectAWQGemmPlan(_ForwardPlanBase): + """Direct AWQ GEMM plan (bypass Python strategy glue).""" + + def __init__( + self, + *, + sig: _ForwardPlanSig, + awq_gemm, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + out_features: int, + bias: Optional[torch.Tensor], + split_k_iters: int = 1, + cast_back_to_x_dtype: bool = True, + ) -> None: + self.sig = sig + self._awq_gemm = awq_gemm + self._qweight = qweight + self._qzeros = qzeros + self._scales = scales + self._out_features = int(out_features) + self._bias = bias + self._split_k_iters = int(split_k_iters) + self._cast_back = bool(cast_back_to_x_dtype) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + # vLLM AWQ kernels expect FP16 activations. + x_in = x if x.dtype == torch.float16 else x.to(dtype=torch.float16) + reshaped_x = x_in.reshape(-1, x_in.shape[-1]) + if not reshaped_x.is_contiguous(): + reshaped_x = reshaped_x.contiguous() + + # vLLM AWQ GEMM entrypoints (C++ op and Triton fallback) use the same order: + # awq_gemm(input, qweight, scales, qzeros, split_k_iters) + out = self._awq_gemm(reshaped_x, self._qweight, self._scales, self._qzeros, self._split_k_iters) + if self._bias is not None: + out.add_(self._bias.to(dtype=out.dtype)) + out = out.reshape(x.shape[:-1] + (self._out_features,)) + if self._cast_back and out.dtype != x.dtype: + return out.to(dtype=x.dtype) + return out + + +class _DirectMarlinGemmPlan(_ForwardPlanBase): + """Direct Marlin GEMM plan (bypass Python strategy glue). + + This calls `torch.ops._C.gptq_marlin_gemm` directly with pre-resolved static args. + """ + + def __init__( + self, + *, + sig: _ForwardPlanSig, + qweight: torch.Tensor, + scales: torch.Tensor, + zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype_id: int, + n: int, + is_k_full: bool, + use_atomic_add: bool, + marlin_bias: Optional[torch.Tensor], + cast_back_to_x_dtype: bool = True, + ) -> None: + self.sig = sig + self._qweight = qweight + self._scales = scales + self._zp = zp + self._g_idx = g_idx + self._g_idx_sort_indices = g_idx_sort_indices + self._workspace = workspace + self._wtype_id = int(wtype_id) + self._n = int(n) + self._is_k_full = bool(is_k_full) + self._use_atomic_add = bool(use_atomic_add) + self._bias = marlin_bias + self._cast_back = bool(cast_back_to_x_dtype) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (int(self._n),) + m = int(reshaped_x.shape[0]) + k = int(reshaped_x.shape[1]) + out = torch.ops._C.gptq_marlin_gemm( + reshaped_x, + None, + self._qweight, + self._bias, + self._scales, + None, + None, + self._zp, + self._g_idx, + self._g_idx_sort_indices, + self._workspace, + self._wtype_id, + m, + int(self._n), + k, + self._is_k_full, + self._use_atomic_add, + True, # use_fp32_reduce + False, # is_zp_float + ) + out = out.reshape(out_shape) + if self._cast_back and out.dtype != x.dtype: + return out.to(dtype=x.dtype) + return out + + class LoRAMixin: """Mixin class to add LoRA support to existing linear layers.""" def __init_lora__(self, r: int = 0, lora_alpha: float = 1.0, lora_dropout: float = 0.0): @@ -42,9 +545,15 @@ def __init_lora__(self, r: int = 0, lora_alpha: float = 1.0, lora_dropout: float def merge_lora(self): """Merge LoRA weights into base weight.""" - if hasattr(self, 'r') and self.r > 0 and not self.merged: - self.weight.data += self.scaling * torch.mm(self.lora_B, self.lora_A) - self.merged = True + if not (hasattr(self, 'r') and self.r > 0 and not self.merged): + return + # If base weight is missing (e.g., quantized linear removed bf16 weight Parameter), + # we cannot merge in-place. Keep LoRA unmerged and apply via lora_forward. + weight = getattr(self, "weight", None) + if weight is None or not hasattr(weight, "data"): + return + self.weight.data += self.scaling * torch.mm(self.lora_B, self.lora_A) + self.merged = True def lora_forward(self, x: torch.Tensor, base_output: torch.Tensor) -> torch.Tensor: """Apply LoRA forward pass.""" @@ -63,13 +572,1421 @@ def __init__( input_size: int, output_size: int, tp_dim: int | None = None, + quant_kind: str = "other", ): super().__init__() self.input_size = input_size self.output_size = output_size + # Cache forward output features (avoid per-call inference). + # Subclasses with TP partitions should overwrite this after partition sizes are known. + self._forward_out_features: int = int(output_size) self.tp_dim = tp_dim + self.quant_kind = (quant_kind or "other").strip().lower() or "other" self.tp_rank = dist.get_rank() self.tp_size = dist.get_world_size() + # Quantized weight storage (W8A16 etc.). Empty by default. + # NOTE: We keep these as buffers so they move with the module and do not appear as Parameters. + self.register_buffer("quant_weight_int8", torch.empty(0, dtype=torch.int8), persistent=False) + self.register_buffer("quant_scales", torch.empty(0, dtype=torch.bfloat16), persistent=False) + # Cache a 1xN view of scales to avoid per-call view/shape handling on hot paths. + self.register_buffer("quant_scales_1xn", torch.empty(0, dtype=torch.bfloat16), persistent=False) + self.register_buffer("_weight_is_quantized", torch.tensor(False, dtype=torch.bool), persistent=False) + + # GPTQ/AWQ offline quantized weight storage (W4A16). + # NOTE(vLLM-format): + # - GPTQ: qweight int32 [K/pack, N], qzeros int32 [K/group, N/pack], + # scales fp16 [K/group, N], g_idx optional (usually empty when desc_act=False) + # - AWQ : qweight int32 [K, N/pack], qzeros int32 [K/group, N/pack], + # scales fp16 [K/group, N] + # + # Where pack = 32 / bits (bits=4 => pack=8), K=in_features, N=out_features. + self.register_buffer("gptq_qweight", torch.empty(0, dtype=torch.int32), persistent=False) + self.register_buffer("gptq_qzeros", torch.empty(0, dtype=torch.int32), persistent=False) + self.register_buffer("gptq_scales", torch.empty(0, dtype=torch.float16), persistent=False) + self.register_buffer("gptq_g_idx", torch.empty(0, dtype=torch.int32), persistent=False) + self.register_buffer("awq_qweight", torch.empty(0, dtype=torch.int32), persistent=False) + self.register_buffer("awq_qzeros", torch.empty(0, dtype=torch.int32), persistent=False) + self.register_buffer("awq_scales", torch.empty(0, dtype=torch.float16), persistent=False) + # Metadata for offline quantized weights + self.register_buffer("_offline_quant_format", torch.empty(0, dtype=torch.int8), persistent=False) # 0=none, 1=gptq, 2=awq + # Bits for offline GPTQ/AWQ weights (needed for marlin-exported layouts where + # we cannot infer bits from packed tensor shapes). + self.register_buffer("_offline_quant_bits", torch.tensor(0, dtype=torch.int32), persistent=False) + self.register_buffer("_offline_quant_group_size", torch.tensor(128, dtype=torch.int32), persistent=False) + self.register_buffer("_offline_quant_out_features", torch.tensor(0, dtype=torch.int32), persistent=False) + self.register_buffer("_offline_quant_in_features", torch.tensor(0, dtype=torch.int32), persistent=False) + # GPTQ runtime prep state (vLLM requires gptq_shuffle before first gemm). + self.register_buffer("_gptq_is_shuffled", torch.tensor(False, dtype=torch.bool), persistent=False) + + # ---- vLLM Marlin variants (GPTQ/AWQ) one-time repack cache ---- + # These buffers are populated lazily when a *_marlin strategy is selected. + self.register_buffer("_gptq_marlin_is_prepared", torch.tensor(False, dtype=torch.bool), persistent=False) + self.register_buffer("gptq_marlin_qweight", torch.empty(0, dtype=torch.int32), persistent=False) + self.register_buffer("gptq_marlin_scales", torch.empty(0, dtype=torch.float16), persistent=False) + self.register_buffer("gptq_marlin_zp", torch.empty(0, dtype=torch.int32), persistent=False) + self.register_buffer("gptq_marlin_g_idx", torch.empty(0, dtype=torch.int32), persistent=False) + self.register_buffer("gptq_marlin_g_idx_sort_indices", torch.empty(0, dtype=torch.int32), persistent=False) + self.register_buffer("gptq_marlin_workspace", torch.empty(0, dtype=torch.int32), persistent=False) + self.register_buffer("_awq_marlin_is_prepared", torch.tensor(False, dtype=torch.bool), persistent=False) + self.register_buffer("awq_marlin_qweight", torch.empty(0, dtype=torch.int32), persistent=False) + self.register_buffer("awq_marlin_scales", torch.empty(0, dtype=torch.float16), persistent=False) + self.register_buffer("awq_marlin_zp", torch.empty(0, dtype=torch.int32), persistent=False) + self.register_buffer("awq_marlin_workspace", torch.empty(0, dtype=torch.int32), persistent=False) + + # ---- Python-side meta cache (CUDA Graph friendly) ---- + # Avoid `.item()` on CUDA tensors in hot paths (it introduces GPU->CPU sync and breaks graph capture). + self._weight_is_quantized_py: bool = False + # 0=none, 1=gptq, 2=awq + self._offline_quant_format_py: int = 0 + self._offline_quant_bits_py: int = 0 + self._offline_quant_group_size_py: int = 128 + self._offline_quant_out_features_py: int = 0 + self._offline_quant_in_features_py: int = 0 + self._gptq_is_shuffled_py: bool = False + self._gptq_marlin_is_prepared_py: bool = False + self._awq_marlin_is_prepared_py: bool = False + + # ---- Forward plan cache (for static/graph-friendly dispatch) ---- + # When enabled, we build a per-layer callable plan that fixes the runtime + # dispatch decisions (bf16 vs quant vs offline, and which concrete kernel path). + # This removes heavy Python branching from the hot path and makes CUDA graph + # capture more stable. + self._forward_plan_enabled: bool = False + self._forward_plan: Optional[_ForwardPlanBase] = None + + def _invalidate_forward_plan(self) -> None: + self._forward_plan = None + + @staticmethod + def _device_index(device: torch.device) -> int: + if device.type == "cuda" and device.index is not None: + return int(device.index) + return -1 + + def enable_forward_plan(self, enabled: bool = True) -> None: + """Enable/disable cached forward plan dispatch for this layer.""" + self._forward_plan_enabled = bool(enabled) + if not self._forward_plan_enabled: + self._invalidate_forward_plan() + + def build_forward_plan_for_static(self, example_x: torch.Tensor, bias: Optional[torch.Tensor]) -> None: + """Build a cached forward plan for a fixed static decode-step shape. + + This should be called during warmup/capture. After building, `_forward_base` + can execute with minimal Python overhead by invoking the cached plan. + """ + strategy = self._get_linear_strategy() + # Ensure we don't keep bf16 and quant weights both resident. + self._maybe_promote_weight_to_quantized_at_runtime(example_x, strategy) + + device = example_x.device + dev_idx = self._device_index(device) + has_bias = bias is not None + strategy_name = getattr(strategy, "name", "") if strategy is not None else "" + + # Offline quantized weights have highest priority. + if self.has_offline_quantized_weight(): + if strategy is None: + raise RuntimeError("Offline quantized weight is present but no linear strategy is configured.") + weight_format = getattr(strategy, "linear_weight_format", None) + out_features, in_features, group_size = self._offline_meta() + sig = _ForwardPlanSig( + device_type=device.type, + device_index=dev_idx, + x_dtype=example_x.dtype, + x_shape=tuple(int(x) for x in example_x.shape), + has_bias=has_bias, + mode="offline", + strategy_name=strategy_name, + ) + + if weight_format == "gptq": + self._maybe_prepare_offline_gptq(example_x) + bits = self._infer_gptq_weight_bits(in_features=in_features) + # Use already-correct g_idx buffer (can be empty), moved once to the example device. + g_idx = self.gptq_g_idx + if g_idx.device != device: + g_idx = g_idx.to(device=device, dtype=torch.int) + + # Prefer direct torch.ops entry point to bypass Python strategy glue. + if hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "gptq_gemm"): + self._forward_plan = _DirectGPTQGemmPlan( + sig=sig, + qweight=self.gptq_qweight, + qzeros=self.gptq_qzeros, + scales=self.gptq_scales, + g_idx=g_idx, + weight_bits=bits, + out_features=out_features, + bias=bias, + use_exllama=True, + use_v2_format=False, + cast_back_to_x_dtype=True, + ) + else: + self._forward_plan = _OfflineGPTQPlan( + sig=sig, + strategy=strategy, + quant_kind=self.quant_kind, + qweight=self.gptq_qweight, + qzeros=self.gptq_qzeros, + scales=self.gptq_scales, + g_idx=g_idx, + weight_bits=bits, + out_features=out_features, + in_features=in_features, + group_size=group_size, + bias=bias, + ) + return + + if weight_format == "awq": + bits = int(self._offline_quant_bits_py) if int(self._offline_quant_bits_py) > 0 else 4 + pack_factor = 32 // max(1, bits) + # Prefer direct torch.ops entry point to bypass Python strategy glue. + awq_gemm = None + try: + if hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "awq_gemm"): + awq_gemm = torch.ops._C.awq_gemm + except Exception: + awq_gemm = None + + if awq_gemm is not None: + self._forward_plan = _DirectAWQGemmPlan( + sig=sig, + awq_gemm=awq_gemm, + qweight=self.awq_qweight, + qzeros=self.awq_qzeros, + scales=self.awq_scales, + out_features=out_features, + bias=bias, + split_k_iters=1, + cast_back_to_x_dtype=True, + ) + else: + self._forward_plan = _OfflineAWQPlan( + sig=sig, + strategy=strategy, + quant_kind=self.quant_kind, + qweight=self.awq_qweight, + qzeros=self.awq_qzeros, + scales=self.awq_scales, + pack_factor=pack_factor, + out_features=out_features, + in_features=in_features, + group_size=group_size, + bias=bias, + ) + return + + if weight_format == "gptq_marlin": + self._maybe_prepare_offline_gptq_marlin(example_x) + bits = self._infer_gptq_weight_bits(in_features=in_features) + # Prefer direct torch.ops entry point to bypass Python strategy glue. + if hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "gptq_marlin_gemm"): + try: + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( # type: ignore + marlin_is_k_full, + marlin_make_empty_g_idx, + should_use_atomic_add_reduce, + marlin_permute_bias, + ) + from vllm.scalar_type import scalar_types # type: ignore + except Exception: + marlin_is_k_full = None # type: ignore + marlin_make_empty_g_idx = None # type: ignore + should_use_atomic_add_reduce = None # type: ignore + marlin_permute_bias = None # type: ignore + scalar_types = None # type: ignore + + if scalar_types is None: + # Fall back to the strategy path if vLLM marlin utils are unavailable. + self._forward_plan = _OfflineGPTQMarlinPlan( + sig=sig, + strategy=strategy, + quant_kind=self.quant_kind, + qweight=self.gptq_marlin_qweight, + scales=self.gptq_marlin_scales, + zp=self.gptq_marlin_zp, + g_idx=self.gptq_marlin_g_idx, + g_idx_sort_indices=self.gptq_marlin_g_idx_sort_indices, + workspace=self.gptq_marlin_workspace, + in_features=in_features, + out_features=out_features, + group_size=group_size, + weight_bits=bits, + tp_dim=self.tp_dim, + bias=bias, + ) + return + + device = example_x.device + dev_key = self._device_index(device) + # Prefer already prepared tensors; if missing, use cached empties. + def _empty() -> torch.Tensor: + if marlin_make_empty_g_idx is not None: + return marlin_make_empty_g_idx(device) + return torch.empty((0,), device=device, dtype=torch.int32) + + g_idx = self.gptq_marlin_g_idx if self.gptq_marlin_g_idx.numel() > 0 else _empty() + g_idx_sort = ( + self.gptq_marlin_g_idx_sort_indices + if self.gptq_marlin_g_idx_sort_indices.numel() > 0 + else _empty() + ) + row_parallel = bool(self.tp_dim == 1) + has_g_idx = bool(g_idx.numel() > 0) + is_k_full = True if marlin_is_k_full is None else marlin_is_k_full(has_g_idx, row_parallel) + + marlin_bias = None + if bias is not None: + marlin_bias = marlin_permute_bias(bias) if marlin_permute_bias is not None else bias + + reshaped_x = example_x.reshape(-1, example_x.shape[-1]) + m = int(reshaped_x.shape[0]) + n = int(out_features) + k = int(reshaped_x.shape[1]) + use_atomic_add = False + if should_use_atomic_add_reduce is not None: + use_atomic_add = bool( + should_use_atomic_add_reduce(m=m, n=n, k=k, device=device, dtype=reshaped_x.dtype) + ) + + if bits == 4: + wtype = scalar_types.uint4b8 + elif bits == 8: + wtype = scalar_types.uint8b128 + else: + raise RuntimeError(f"gptq_marlin: unsupported weight_bits={bits} (expected 4 or 8)") + + self._forward_plan = _DirectMarlinGemmPlan( + sig=sig, + qweight=self.gptq_marlin_qweight, + scales=self.gptq_marlin_scales, + zp=self.gptq_marlin_zp, + g_idx=g_idx, + g_idx_sort_indices=g_idx_sort, + workspace=self.gptq_marlin_workspace, + wtype_id=wtype.id, + n=out_features, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + marlin_bias=marlin_bias, + cast_back_to_x_dtype=True, + ) + else: + self._forward_plan = _OfflineGPTQMarlinPlan( + sig=sig, + strategy=strategy, + quant_kind=self.quant_kind, + qweight=self.gptq_marlin_qweight, + scales=self.gptq_marlin_scales, + zp=self.gptq_marlin_zp, + g_idx=self.gptq_marlin_g_idx, + g_idx_sort_indices=self.gptq_marlin_g_idx_sort_indices, + workspace=self.gptq_marlin_workspace, + in_features=in_features, + out_features=out_features, + group_size=group_size, + weight_bits=bits, + tp_dim=self.tp_dim, + bias=bias, + ) + return + + if weight_format == "awq_marlin": + self._maybe_prepare_offline_awq_marlin(example_x) + if hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "gptq_marlin_gemm"): + try: + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( # type: ignore + marlin_make_empty_g_idx, + should_use_atomic_add_reduce, + marlin_permute_bias, + ) + from vllm.scalar_type import scalar_types # type: ignore + except Exception: + marlin_make_empty_g_idx = None # type: ignore + should_use_atomic_add_reduce = None # type: ignore + marlin_permute_bias = None # type: ignore + scalar_types = None # type: ignore + + if scalar_types is None: + self._forward_plan = _OfflineAWQMarlinPlan( + sig=sig, + strategy=strategy, + quant_kind=self.quant_kind, + qweight=self.awq_marlin_qweight, + scales=self.awq_marlin_scales, + zp=self.awq_marlin_zp, + workspace=self.awq_marlin_workspace, + in_features=in_features, + out_features=out_features, + group_size=group_size, + tp_dim=self.tp_dim, + bias=bias, + ) + return + + device = example_x.device + empty = ( + marlin_make_empty_g_idx(device) + if marlin_make_empty_g_idx is not None + else torch.empty((0,), device=device, dtype=torch.int32) + ) + marlin_bias = None + if bias is not None: + marlin_bias = marlin_permute_bias(bias) if marlin_permute_bias is not None else bias + + reshaped_x = example_x.reshape(-1, example_x.shape[-1]) + m = int(reshaped_x.shape[0]) + n = int(out_features) + k = int(reshaped_x.shape[1]) + use_atomic_add = False + if should_use_atomic_add_reduce is not None: + use_atomic_add = bool( + should_use_atomic_add_reduce(m=m, n=n, k=k, device=device, dtype=reshaped_x.dtype) + ) + + self._forward_plan = _DirectMarlinGemmPlan( + sig=sig, + qweight=self.awq_marlin_qweight, + scales=self.awq_marlin_scales, + zp=self.awq_marlin_zp, + g_idx=empty, + g_idx_sort_indices=empty, + workspace=self.awq_marlin_workspace, + wtype_id=scalar_types.uint4.id, + n=out_features, + is_k_full=True, + use_atomic_add=use_atomic_add, + marlin_bias=marlin_bias, + cast_back_to_x_dtype=True, + ) + else: + self._forward_plan = _OfflineAWQMarlinPlan( + sig=sig, + strategy=strategy, + quant_kind=self.quant_kind, + qweight=self.awq_marlin_qweight, + scales=self.awq_marlin_scales, + zp=self.awq_marlin_zp, + workspace=self.awq_marlin_workspace, + in_features=in_features, + out_features=out_features, + group_size=group_size, + tp_dim=self.tp_dim, + bias=bias, + ) + return + + # If a new offline strategy is added, fall back to the generic runtime dispatcher. + raise RuntimeError( + f"Offline quantized weight is present but strategy weight_format={weight_format!r} is not supported by forward plan." + ) + + # Online/load-time quantized weights. + if self.has_quantized_weight(): + if strategy is None: + raise RuntimeError("Quantized weight is present but no linear strategy is configured.") + sig = _ForwardPlanSig( + device_type=device.type, + device_index=dev_idx, + x_dtype=example_x.dtype, + x_shape=tuple(int(x) for x in example_x.shape), + has_bias=has_bias, + mode="quant", + strategy_name=strategy_name, + ) + if getattr(strategy, "name", "") == "linear_int8_w8a16": + self._forward_plan = _QuantInt8W8A16Plan( + sig=sig, + strategy=strategy, + quant_kind=self.quant_kind, + qweight=self.quant_weight_int8, + scales_1xn=self.quant_scales_1xn, + out_features=self._forward_out_features, + bias=bias, + ) + return + if getattr(strategy, "name", "") == "linear_int8_w8a8": + self._forward_plan = _QuantInt8W8A8Plan( + sig=sig, + strategy=strategy, + quant_kind=self.quant_kind, + qweight=self.quant_weight_int8, + scales_1xn=self.quant_scales_1xn, + out_features=self._forward_out_features, + bias=bias, + ) + return + self._forward_plan = _QuantGenericPlan( + sig=sig, + strategy=strategy, + quant_kind=self.quant_kind, + weight=self.quant_weight_int8, + scales=self.quant_scales, + bias=bias, + ) + return + + # BF16 weights (no quant). + weight = getattr(self, "weight", None) + if weight is None: + raise RuntimeError("No quantized/offline weights are present but bf16 weight is missing.") + sig = _ForwardPlanSig( + device_type=device.type, + device_index=dev_idx, + x_dtype=example_x.dtype, + x_shape=tuple(int(x) for x in example_x.shape), + has_bias=has_bias, + mode="bf16", + strategy_name=strategy_name, + ) + self._forward_plan = _BF16Plan(sig=sig, weight=weight, bias=bias) + + def has_quantized_weight(self) -> bool: + return self._weight_is_quantized_py and self.quant_weight_int8.numel() > 0 and self.quant_scales.numel() > 0 + + def has_offline_quantized_weight(self) -> bool: + """Check if offline quantized weights (GPTQ/AWQ) are present.""" + if self._offline_quant_format_py == 1: # GPTQ + return ( + self.gptq_qweight.numel() > 0 + and self.gptq_qzeros.numel() > 0 + and self.gptq_scales.numel() > 0 + ) + elif self._offline_quant_format_py == 2: # AWQ + return ( + self.awq_qweight.numel() > 0 + and self.awq_qzeros.numel() > 0 + and self.awq_scales.numel() > 0 + ) + return False + + def set_offline_quantized_weight( + self, + format: str, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + *, + out_features: int, + in_features: int, + group_size: int = 128, + g_idx: Optional[torch.Tensor] = None, + ) -> None: + """Set offline quantized weights (GPTQ or AWQ format). + + Args: + format: "gptq" or "awq" + qweight/qzeros/scales: vLLM standard tensors (see notes above). + out_features: Output features (N) + in_features: Input features (K) + group_size: Group size for quantization (default: 128) + g_idx: Optional int32 tensor [in_features] for act-order (GPTQ only; usually empty) + """ + # NOTE: Offline quantized weights are typically loaded from safetensors on CPU. + # In Diffulex, the engine may move modules to CUDA before calling this method, + # so we must ensure tensors are moved to the module device here. + def _infer_module_device() -> torch.device: + w = getattr(self, "weight", None) + if isinstance(w, torch.Tensor): + return w.device + for p in self.parameters(recurse=False): + return p.device + for b in self.buffers(recurse=False): + return b.device + return torch.device("cpu") + + module_device = _infer_module_device() + + format = format.strip().lower() + if format not in ("gptq", "awq"): + raise ValueError(f"Unsupported offline quant format: {format}. Supported: 'gptq', 'awq'") + + # Infer bits/pack_factor from packed tensor shapes to support GPTQ W2/W4/W8. + # vLLM packing convention: + # - GPTQ: qweight [K/pack, N], qzeros [K/group, N/pack] + # - AWQ: qweight [K, N/pack], qzeros [K/group, N/pack] + # where pack = 32 / bits and bits must divide 32. + if format == "gptq": + if int(qweight.shape[0]) <= 0 or in_features % int(qweight.shape[0]) != 0: + raise ValueError( + "Cannot infer GPTQ pack_factor from qweight shape: " + f"in_features={in_features}, qweight.shape={tuple(qweight.shape)}" + ) + pack_factor = in_features // int(qweight.shape[0]) + else: # awq + if int(qweight.shape[1]) <= 0 or out_features % int(qweight.shape[1]) != 0: + raise ValueError( + "Cannot infer AWQ pack_factor from qweight shape: " + f"out_features={out_features}, qweight.shape={tuple(qweight.shape)}" + ) + pack_factor = out_features // int(qweight.shape[1]) + if 32 % pack_factor != 0: + raise ValueError( + f"Unsupported pack_factor={pack_factor} (requires 32%pack_factor==0) " + f"for offline format={format}. " + f"in_features={in_features}, out_features={out_features}, " + f"qweight.shape={tuple(qweight.shape)}, qzeros.shape={tuple(qzeros.shape)}, scales.shape={tuple(scales.shape)}" + ) + bits = 32 // pack_factor + if format == "awq" and bits != 4: + raise ValueError(f"AWQ 目前仅支持 4-bit(pack_factor=8),当前推断 bits={bits} (pack_factor={pack_factor})") + # Cache meta as Python primitives (graph-friendly). + self._offline_quant_bits_py = int(bits) + # Record bits for downstream kernels (esp. marlin path). + self._offline_quant_bits = torch.tensor(bits, dtype=torch.int32, device=module_device) + + if qweight.dtype != torch.int32: + raise TypeError(f"qweight must be int32 (vLLM format), got {qweight.dtype}") + if qzeros.dtype != torch.int32: + raise TypeError(f"qzeros must be int32 (vLLM format), got {qzeros.dtype}") + if scales.dtype not in (torch.float16, torch.bfloat16, torch.float32): + raise TypeError( + f"scales must be float16/bfloat16/float32 (vLLM format), got {scales.dtype}" + ) + if scales.dtype != torch.float16: + scales = scales.to(dtype=torch.float16) + + # Move to module device before validation/assignment. + if qweight.device != module_device: + qweight = qweight.to(device=module_device) + if qzeros.device != module_device: + qzeros = qzeros.to(device=module_device) + if scales.device != module_device: + scales = scales.to(device=module_device) + if g_idx is not None and g_idx.device != module_device: + g_idx = g_idx.to(device=module_device) + + # Make packed tensors contiguous once at load-time (avoid per-call checks/copies). + qweight = qweight.contiguous() + qzeros = qzeros.contiguous() + scales = scales.contiguous() + if g_idx is not None: + g_idx = g_idx.contiguous() + + # group_size == -1 means channelwise in some ecosystems; vLLM normalizes -1 to K. + group_size_norm = in_features if group_size == -1 else group_size + if group_size_norm <= 0 or (in_features % group_size_norm != 0): + raise ValueError( + f"Invalid group_size={group_size} for in_features={in_features}. " + "Expected group_size == -1 or a positive divisor of in_features." + ) + num_groups = in_features // group_size_norm + + if format == "gptq": + expected_qweight_shape = (in_features // pack_factor, out_features) + expected_qzeros_shape = (num_groups, out_features // pack_factor) + expected_scales_shape = (num_groups, out_features) + else: # awq + expected_qweight_shape = (in_features, out_features // pack_factor) + expected_qzeros_shape = (num_groups, out_features // pack_factor) + expected_scales_shape = (num_groups, out_features) + + if qweight.shape != expected_qweight_shape: + raise ValueError( + f"qweight shape mismatch: got {tuple(qweight.shape)}, expected {expected_qweight_shape}" + ) + if qzeros.shape != expected_qzeros_shape: + raise ValueError( + f"qzeros shape mismatch: got {tuple(qzeros.shape)}, expected {expected_qzeros_shape}" + ) + if scales.shape != expected_scales_shape: + raise ValueError( + f"scales shape mismatch: got {tuple(scales.shape)}, expected {expected_scales_shape}" + ) + + if format == "gptq": + self.gptq_qweight = qweight + self.gptq_qzeros = qzeros + self.gptq_scales = scales + if g_idx is not None and getattr(g_idx, "numel", lambda: 1)() == 0: + g_idx = None + if g_idx is not None: + if g_idx.shape != (in_features,): + raise ValueError( + f"g_idx shape mismatch: got {g_idx.shape}, expected ({in_features},)" + ) + if g_idx.dtype != torch.int32: + g_idx = g_idx.to(dtype=torch.int32) + self.gptq_g_idx = g_idx + else: + # Clear g_idx if not provided + self.gptq_g_idx = torch.empty(0, dtype=torch.int32, device=module_device) + self._offline_quant_format = torch.tensor(1, dtype=torch.int8, device=module_device) + self._gptq_is_shuffled = torch.tensor(False, dtype=torch.bool, device=module_device) + self._offline_quant_format_py = 1 + self._gptq_is_shuffled_py = False + else: # AWQ + self.awq_qweight = qweight + self.awq_qzeros = qzeros + self.awq_scales = scales + # AWQ doesn't use g_idx, clear it + self.gptq_qweight = torch.empty(0, dtype=torch.int32, device=module_device) + self.gptq_qzeros = torch.empty(0, dtype=torch.int32, device=module_device) + self.gptq_scales = torch.empty(0, dtype=torch.float16, device=module_device) + self.gptq_g_idx = torch.empty(0, dtype=torch.int32, device=module_device) + self._offline_quant_format = torch.tensor(2, dtype=torch.int8, device=module_device) + self._gptq_is_shuffled = torch.tensor(False, dtype=torch.bool, device=module_device) + self._offline_quant_format_py = 2 + self._gptq_is_shuffled_py = False + + # Reset marlin-prep caches (weights may have changed / moved). + self._gptq_marlin_is_prepared = torch.tensor(False, dtype=torch.bool, device=module_device) + self.gptq_marlin_qweight = torch.empty(0, dtype=torch.int32, device=module_device) + self.gptq_marlin_scales = torch.empty(0, dtype=torch.float16, device=module_device) + self.gptq_marlin_zp = torch.empty(0, dtype=torch.int32, device=module_device) + self.gptq_marlin_g_idx = torch.empty(0, dtype=torch.int32, device=module_device) + self.gptq_marlin_g_idx_sort_indices = torch.empty(0, dtype=torch.int32, device=module_device) + self.gptq_marlin_workspace = torch.empty(0, dtype=torch.int32, device=module_device) + self._awq_marlin_is_prepared = torch.tensor(False, dtype=torch.bool, device=module_device) + self.awq_marlin_qweight = torch.empty(0, dtype=torch.int32, device=module_device) + self.awq_marlin_scales = torch.empty(0, dtype=torch.float16, device=module_device) + self.awq_marlin_zp = torch.empty(0, dtype=torch.int32, device=module_device) + self.awq_marlin_workspace = torch.empty(0, dtype=torch.int32, device=module_device) + + self._offline_quant_group_size = torch.tensor(group_size, dtype=torch.int32, device=module_device) + self._offline_quant_out_features = torch.tensor(out_features, dtype=torch.int32, device=module_device) + self._offline_quant_in_features = torch.tensor(in_features, dtype=torch.int32, device=module_device) + # Python meta mirrors. + self._offline_quant_group_size_py = int(group_size) + self._offline_quant_out_features_py = int(out_features) + self._offline_quant_in_features_py = int(in_features) + self._gptq_marlin_is_prepared_py = False + self._awq_marlin_is_prepared_py = False + + # Drop bf16 weight Parameter if present (to free memory) + if "weight" in self._parameters: + self._parameters.pop("weight", None) + setattr(self, "weight", None) + + # Offline weights changed; cached forward plan is no longer valid. + self._invalidate_forward_plan() + + def _maybe_prepare_offline_gptq(self, x: torch.Tensor) -> None: + """Prepare vLLM GPTQ weights on first use (required gptq_shuffle).""" + if self._offline_quant_format_py != 1: + return + if self.gptq_qweight.numel() == 0: + return + if self._gptq_is_shuffled_py: + return + + # Lazy import to avoid pulling vLLM unless GPTQ offline weights are used. + try: + from vllm import _custom_ops as ops # type: ignore + except Exception as e: # pragma: no cover + raise RuntimeError( + "GPTQ offline 权重已加载,但无法导入 vLLM CUDA custom ops(vllm._custom_ops)。" + ) from e + + # vLLM uses torch.int for g_idx (can be empty when desc_act=False). + if self.gptq_g_idx.numel() == 0: + g_idx = torch.empty((0,), device=x.device, dtype=torch.int) + else: + g_idx = self.gptq_g_idx.to(device=x.device, dtype=torch.int) + + if self.gptq_qweight.device != x.device: + raise RuntimeError( + f"GPTQ qweight device mismatch: qweight on {self.gptq_qweight.device}, x on {x.device}. " + "请确保模型与输入在同一设备。" + ) + + # Infer weight_bits from packed qweight shape to support GPTQ W2/W4/W8. + # qweight: [K/pack_factor, N], where pack_factor = 32 / weight_bits. + in_features = int(self._offline_quant_in_features_py) + if in_features is None or in_features <= 0: + raise RuntimeError("GPTQ offline 权重已加载,但无法推断 in_features 以计算 weight_bits。") + if self.gptq_qweight.shape[0] <= 0 or in_features % int(self.gptq_qweight.shape[0]) != 0: + raise RuntimeError( + f"GPTQ qweight shape 不合法,无法推断 weight_bits: " + f"in_features={in_features}, qweight.shape={tuple(self.gptq_qweight.shape)}" + ) + pack_factor = in_features // int(self.gptq_qweight.shape[0]) + if 32 % pack_factor != 0: + raise RuntimeError( + f"GPTQ pack_factor={pack_factor} 不支持(需要 32 % pack_factor == 0)," + f"in_features={in_features}, qweight.shape={tuple(self.gptq_qweight.shape)}" + ) + weight_bits = 32 // pack_factor + ops.gptq_shuffle(self.gptq_qweight, g_idx, weight_bits) + # Do NOT create new tensors on hot paths; update in-place + python mirror. + self._gptq_is_shuffled.fill_(True) + self._gptq_is_shuffled_py = True + + def _maybe_prepare_offline_gptq_marlin(self, x: torch.Tensor) -> None: + """Prepare vLLM GPTQ Marlin weights on first use (repack + permute scales/zp). + + IMPORTANT: This path must NOT call `gptq_shuffle` (that is specific to gptq_gemm/exllama). + """ + if self._offline_quant_format_py != 1: + return + if self.gptq_qweight.numel() == 0: + return + if self._gptq_marlin_is_prepared_py: + return + + try: + from vllm import _custom_ops as ops # type: ignore + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( # type: ignore + marlin_make_empty_g_idx, + marlin_make_workspace_new, + marlin_permute_scales, + marlin_sort_g_idx, + ) + except Exception as e: # pragma: no cover + raise RuntimeError( + "GPTQ Marlin 需要 vLLM CUDA custom ops + marlin_utils,但当前环境不可用。" + ) from e + + device = x.device + if self.gptq_qweight.device != device: + raise RuntimeError( + f"GPTQ qweight device mismatch: qweight on {self.gptq_qweight.device}, x on {device}. " + "请确保模型与输入在同一设备。" + ) + + in_features = int(self._offline_quant_in_features_py) + out_features = int(self._offline_quant_out_features_py) + group_size = int(self._offline_quant_group_size_py) + if in_features <= 0 or out_features <= 0: + raise RuntimeError( + f"GPTQ Marlin: invalid feature sizes: in_features={in_features}, out_features={out_features}" + ) + + # Determine weight_bits. + # - Standard GPTQ layout: infer from qweight K packing. + # - Marlin-exported layout: bits cannot be inferred from qweight shape; use recorded bits. + weight_bits = int(self._offline_quant_bits_py) + if weight_bits <= 0: + if self.gptq_qweight.shape[0] <= 0 or in_features % int(self.gptq_qweight.shape[0]) != 0: + raise RuntimeError( + "GPTQ Marlin: cannot infer pack_factor from qweight shape: " + f"in_features={in_features}, qweight.shape={tuple(self.gptq_qweight.shape)}" + ) + pack_factor = in_features // int(self.gptq_qweight.shape[0]) + if 32 % pack_factor != 0: + raise RuntimeError( + f"GPTQ Marlin: unsupported pack_factor={pack_factor} (requires 32%pack_factor==0)" + ) + weight_bits = 32 // pack_factor + if weight_bits not in (4, 8): + raise RuntimeError( + f"GPTQ Marlin: only 4/8-bit are supported in this integration, got bits={weight_bits}" + ) + + # If loader already provided marlin-ready weights/scales (exported offline), + # skip repack/permute but still create workspace / g_idx metadata. + already_marlin_ready = ( + self.gptq_marlin_qweight.numel() > 0 + and self.gptq_marlin_scales.numel() > 0 + ) + if already_marlin_ready: + if self.gptq_marlin_qweight.device != device or self.gptq_marlin_scales.device != device: + raise RuntimeError( + "GPTQ Marlin: prepacked marlin tensors device mismatch: " + f"qweight on {self.gptq_marlin_qweight.device}, scales on {self.gptq_marlin_scales.device}, x on {device}." + ) + + # g_idx (act-order) handling: marlin expects sorted g_idx + sort indices; otherwise empty. + if self.gptq_g_idx.numel() > 0: + g_idx_sorted, g_idx_sort_indices = marlin_sort_g_idx(self.gptq_g_idx.to(device=device, dtype=torch.int32)) + self.gptq_marlin_g_idx = g_idx_sorted.contiguous() + self.gptq_marlin_g_idx_sort_indices = g_idx_sort_indices.contiguous() + else: + self.gptq_marlin_g_idx = marlin_make_empty_g_idx(device) + self.gptq_marlin_g_idx_sort_indices = marlin_make_empty_g_idx(device) + + # Workspace (internal locking mechanism). + self.gptq_marlin_workspace = marlin_make_workspace_new(device) + + if not already_marlin_ready: + # Repack qweight to marlin format. + self.gptq_marlin_qweight = ops.gptq_marlin_repack( + self.gptq_qweight.contiguous(), + perm=self.gptq_marlin_g_idx_sort_indices, + size_k=in_features, + size_n=out_features, + num_bits=weight_bits, + is_a_8bit=False, + ).contiguous() + + # Permute scales to marlin format. + self.gptq_marlin_scales = marlin_permute_scales( + self.gptq_scales.contiguous(), + size_k=in_features, + size_n=out_features, + group_size=group_size, + is_a_8bit=False, + ).contiguous() + + # GPTQ Marlin only supports symmetric weights (no runtime zero-points). + # Use empty zp to keep has_zp=False in the kernel. + self.gptq_marlin_zp = marlin_make_empty_g_idx(device) + + self._gptq_marlin_is_prepared.fill_(True) + self._gptq_marlin_is_prepared_py = True + + def _maybe_prepare_offline_awq_marlin(self, x: torch.Tensor) -> None: + """Prepare vLLM AWQ Marlin weights on first use (repack + permute scales/zp).""" + if self._offline_quant_format_py != 2: + return + if self.awq_qweight.numel() == 0: + return + if self._awq_marlin_is_prepared_py: + return + + try: + from vllm import _custom_ops as ops # type: ignore + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( # type: ignore + awq_to_marlin_zero_points, + marlin_make_workspace_new, + marlin_permute_scales, + ) + except Exception as e: # pragma: no cover + raise RuntimeError( + "AWQ Marlin 需要 vLLM CUDA custom ops + marlin_utils,但当前环境不可用。" + ) from e + + device = x.device + if self.awq_qweight.device != device: + raise RuntimeError( + f"AWQ qweight device mismatch: qweight on {self.awq_qweight.device}, x on {device}. " + "请确保模型与输入在同一设备。" + ) + + in_features = int(self._offline_quant_in_features_py) + out_features = int(self._offline_quant_out_features_py) + group_size = int(self._offline_quant_group_size_py) + if in_features <= 0 or out_features <= 0: + raise RuntimeError( + f"AWQ Marlin: invalid feature sizes: in_features={in_features}, out_features={out_features}" + ) + + # AWQ is 4-bit only. + pack_factor = out_features // int(self.awq_qweight.shape[1]) + if pack_factor != 8: + raise RuntimeError(f"AWQ Marlin: expected pack_factor=8 (W4), got pack_factor={pack_factor}") + weight_bits = 4 + num_groups = (in_features // (in_features if group_size == -1 else group_size)) + + self.awq_marlin_workspace = marlin_make_workspace_new(device) + + # Repack qweight to marlin format. + self.awq_marlin_qweight = ops.awq_marlin_repack( + self.awq_qweight.contiguous(), + size_k=in_features, + size_n=out_features, + num_bits=weight_bits, + is_a_8bit=False, + ).contiguous() + + # Permute scales to marlin format. + self.awq_marlin_scales = marlin_permute_scales( + self.awq_scales.contiguous(), + size_k=in_features, + size_n=out_features, + group_size=group_size, + is_a_8bit=False, + ).contiguous() + + # Convert zero-points to marlin format. + self.awq_marlin_zp = awq_to_marlin_zero_points( + self.awq_qzeros.contiguous(), + size_k=num_groups, + size_n=out_features, + num_bits=weight_bits, + is_a_8bit=False, + ).contiguous() + + self._awq_marlin_is_prepared.fill_(True) + self._awq_marlin_is_prepared_py = True + + def set_quantized_weight(self, quant_weight_int8: torch.Tensor, quant_scales: torch.Tensor) -> None: + # Support: + # - int8: int8/int4 weight-only quantization + # - float8: FP8 weight-only quantization (vLLM-aligned) + # - uint8: legacy FP8 storage (kept for backward compatibility) + fp8_dtypes: tuple[torch.dtype, ...] = tuple( + d + for d in ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e4m3fnuz", None), + getattr(torch, "float8_e5m2", None), + getattr(torch, "float8_e5m2fnuz", None), + ) + if d is not None + ) + if quant_weight_int8.dtype not in (torch.int8, torch.uint8, *fp8_dtypes): + raise TypeError( + f"quant_weight_int8 must be int8/uint8/float8, got {quant_weight_int8.dtype}" + ) + # Store scales dtype depends on strategy: + # - W8A16/W4A16 kernels currently take bf16 scales. + # - W8A8/W4A8 paths are more sensitive to scale precision; keep scales at fp16. + # - FP8 W8A16 uses float32 scales. + # - FP8 W8A8 uses float16 scales. + try: + strategy = get_linear_strategy(self.quant_kind) + except Exception: + strategy = None + scale_dtype = torch.bfloat16 + force_weight_contig = True + if strategy is not None: + weight_format = getattr(strategy, "linear_weight_format", None) + act_format = getattr(strategy, "linear_act_format", None) + # FP8 W8A16 uses float32 scales + if weight_format in ("fp8_e4m3", "fp8_e5m2") and act_format == "bf16": + scale_dtype = torch.float32 + # Keep KxN transpose-view layout (do NOT force contiguous) for vLLM FP8 kernels. + force_weight_contig = False + # W8A8 int8 uses float32 [1, N] weight scales in vLLM cutlass_scaled_mm path. + elif weight_format == "int8" and act_format == "int8": + scale_dtype = torch.float32 + # vLLM CUTLASS scaled_mm expects int8 weight in KxN with stride(0)==1, + # which is typically produced as a transpose-view (non-contiguous). + # Do NOT force contiguous here; just avoid per-call conversions. + force_weight_contig = False + # FP8 W8A8 keeps float32 scales; also keep KxN transpose-view layout. + elif act_format in ("fp8_e4m3", "fp8_e5m2"): + scale_dtype = torch.float32 + force_weight_contig = False + # Other int8/int4 mixed paths use float16 scales by default. + elif act_format == "int8": + scale_dtype = torch.float16 + if quant_scales.dtype != scale_dtype: + quant_scales = quant_scales.to(dtype=scale_dtype) + # Make sure scales are contiguous once at load-time. + # NOTE: Some kernels require specific non-contiguous weight layouts (e.g., W8A8 KxN with stride(0)==1). + # We avoid per-call `is_contiguous/contiguous` checks while preserving required layouts. + if force_weight_contig: + quant_weight_int8 = quant_weight_int8.contiguous() + quant_scales = quant_scales.contiguous() + self.quant_weight_int8 = quant_weight_int8 + self.quant_scales = quant_scales + # 1xN view for fused kernels expecting 2D scales. + self.quant_scales_1xn = quant_scales if quant_scales.dim() == 2 else quant_scales.view(1, -1) + self._weight_is_quantized.fill_(True) + self._weight_is_quantized_py = True + # Quant buffers changed; cached forward plan is no longer valid. + self._invalidate_forward_plan() + + def _maybe_promote_weight_to_quantized_at_runtime( + self, + x: torch.Tensor, + strategy, + *, + expected_weight_formats: tuple[str, ...] = ("int8", "int4", "fp8_e4m3", "fp8_e5m2"), + ) -> None: + """Runtime safety net: if a Linear is configured for quantization but the bf16/fp16 + weight Parameter was not quantized+removed at load-time (e.g., due to sharded load + ordering), quantize once on first forward and drop the bf16 weight Parameter. + + This avoids keeping both bf16 weights and quantized weights resident on GPU. + """ + if strategy is None: + return + if self.has_offline_quantized_weight() or self.has_quantized_weight(): + return + weight_param = self._parameters.get("weight", None) + if weight_param is None: + return + weight_format = getattr(strategy, "linear_weight_format", None) + if weight_format not in expected_weight_formats: + return + if getattr(strategy, "name", "").startswith("linear_stub"): + return + w = getattr(self, "weight", None) + if w is None or getattr(w, "dtype", None) not in (torch.bfloat16, torch.float16): + return + try: + qweight, scales = strategy.quantize_weight_for_kernel(w.data, device=w.data.device) + except Exception: + return + self.set_quantized_weight(qweight, scales) + # Drop bf16 weight Parameter to free GPU memory. + self._parameters.pop("weight", None) + setattr(self, "weight", None) + + def _maybe_quantize_loaded_weight_param( + self, + param: nn.Parameter, + *, + loaded_shard_id: object = None, + expected_shard_ids: set[object] | None = None, + ) -> None: + """If current Linear is configured for quantization, quantize the loaded bf16 weight and drop the bf16 Parameter. + + This is called at the end of weight_loader(), after the shard copy is done. + Supports int8 (W8A16/W8A8), int4 (W4A16/W4A8), and FP8 (FP8 W8A16/FP8 W8A8) quantization. + """ + # Only process the real weight Parameter (ignore bias). + current_weight = self._parameters.get("weight", None) + if current_weight is None or current_weight is not param: + return + + # Some modules load the same weight parameter in multiple shards (e.g., QKV / merged linears). + # In that case, we must wait until all shards are loaded before quantizing/removing the bf16 Parameter, + # otherwise subsequent shard loads would fail (model.get_parameter can't find it). + if expected_shard_ids is not None: + if not hasattr(self, "_loaded_weight_shard_ids"): + self._loaded_weight_shard_ids: set[object] = set() + self._loaded_weight_shard_ids.add(loaded_shard_id) + if self._loaded_weight_shard_ids != expected_shard_ids: + return + + # Get strategy for this kind; default bf16 strategy should not trigger quantization. + strategy = get_linear_strategy(self.quant_kind) + if strategy is None: + return + weight_format = getattr(strategy, "linear_weight_format", None) + # NOTE: We intentionally do NOT require act_format == "bf16" here. + # For W8A8/W4A8/FP8 W8A8 we still want to quantize+drop the bf16 weight Parameter at load-time. + # But we must avoid doing this for the generic stub strategy (unsupported combos), + # otherwise we'd drop weights and then raise NotImplementedError at runtime. + if getattr(strategy, "name", "").startswith("linear_stub"): + return + + # Support int8/int4/FP8 weight formats (W8A16/W8A8, W4A16/W4A8, FP8 W8A16/FP8 W8A8). + if weight_format not in ("int8", "int4", "fp8_e4m3", "fp8_e5m2"): + return + + # Quantize on the same device as the loaded param (typically CUDA). + qweight, scales = strategy.quantize_weight_for_kernel(param.data, device=param.data.device) + self.set_quantized_weight(qweight, scales) + + # Drop bf16 weight Parameter to free GPU memory. + self._parameters.pop("weight", None) + # Keep attribute for compatibility, but ensure forward uses quant buffers. + setattr(self, "weight", None) + + def _get_linear_strategy(self): + """Return strategy for current `quant_kind` (or None). + + NOTE: do not swallow TypeError here; a wrong strategy type should fail fast. + """ + return get_linear_strategy(self.quant_kind) + + def _offline_meta(self) -> tuple[int, int, int]: + """Return (out_features, in_features, group_size) for offline GPTQ/AWQ.""" + return ( + int(self._offline_quant_out_features_py), + int(self._offline_quant_in_features_py), + int(self._offline_quant_group_size_py), + ) + + def _infer_gptq_weight_bits(self, *, in_features: int) -> int: + """Infer/return GPTQ weight bits for downstream kernels. + + Priority: + - use recorded bits (e.g., marlin-exported layouts), + - otherwise infer from qweight packing. + """ + bits = int(self._offline_quant_bits_py) + if bits > 0: + return bits + if self.gptq_qweight.numel() == 0: + raise RuntimeError("GPTQ bits 推断失败:gptq_qweight 为空。") + if self.gptq_qweight.shape[0] <= 0 or in_features % int(self.gptq_qweight.shape[0]) != 0: + raise RuntimeError( + f"GPTQ bits 推断失败:in_features={in_features}, qweight.shape={tuple(self.gptq_qweight.shape)}" + ) + pack_factor = in_features // int(self.gptq_qweight.shape[0]) + if 32 % pack_factor != 0: + raise RuntimeError(f"GPTQ bits 推断失败:pack_factor={pack_factor} 不满足 32%pack_factor==0") + return 32 // pack_factor + + def _maybe_int4_original_in_features_kwargs(self, strategy, x: torch.Tensor) -> Optional[dict]: + """Some int4 kernels need original K (before packing).""" + if strategy is None: + return None + if getattr(strategy, "linear_weight_format", None) == "int4": + return {"original_in_features": x.shape[1]} + return None + + def _build_offline_forward_kwargs(self, x: torch.Tensor, strategy) -> dict: + """Build kwargs for offline GPTQ/AWQ (including Marlin variants).""" + if strategy is None: + raise RuntimeError("Offline quantized weight is present but no linear strategy is configured.") + + format_val = int(self._offline_quant_format_py) + weight_format = getattr(strategy, "linear_weight_format", None) + out_features, in_features, group_size = self._offline_meta() + + meta = { + "out_features": out_features, + "in_features": in_features, + "group_size": group_size, + } + + if format_val == 1: # GPTQ + # IMPORTANT: only gptq_gemm needs gptq_shuffle; marlin variants require the original format. + if weight_format == "gptq": + self._maybe_prepare_offline_gptq(x) + return { + **meta, + "gptq_qweight": self.gptq_qweight, + "gptq_qzeros": self.gptq_qzeros, + "gptq_scales": self.gptq_scales, + "gptq_group_size": group_size, + # Always pass g_idx (can be empty). vLLM expects it for GPTQ kernels. + "gptq_g_idx": self.gptq_g_idx, + } + + if weight_format == "gptq_marlin": + self._maybe_prepare_offline_gptq_marlin(x) + bits = self._infer_gptq_weight_bits(in_features=in_features) + return { + **meta, + "gptq_weight_bits": bits, + "gptq_marlin_qweight": self.gptq_marlin_qweight, + "gptq_marlin_scales": self.gptq_marlin_scales, + "gptq_marlin_zp": self.gptq_marlin_zp, + "gptq_marlin_g_idx": self.gptq_marlin_g_idx, + "gptq_marlin_g_idx_sort_indices": self.gptq_marlin_g_idx_sort_indices, + "gptq_marlin_workspace": self.gptq_marlin_workspace, + } + + raise RuntimeError( + f"Offline GPTQ weights are present, but current strategy weight_format={weight_format!r} is not compatible." + ) + + if format_val == 2: # AWQ + if weight_format == "awq": + return { + **meta, + "awq_qweight": self.awq_qweight, + "awq_qzeros": self.awq_qzeros, + "awq_scales": self.awq_scales, + "awq_group_size": group_size, + } + + if weight_format == "awq_marlin": + self._maybe_prepare_offline_awq_marlin(x) + return { + **meta, + "awq_marlin_qweight": self.awq_marlin_qweight, + "awq_marlin_scales": self.awq_marlin_scales, + "awq_marlin_zp": self.awq_marlin_zp, + "awq_marlin_workspace": self.awq_marlin_workspace, + "awq_weight_bits": 4, + } + + raise RuntimeError( + f"Offline AWQ weights are present, but current strategy weight_format={weight_format!r} is not compatible." + ) + + raise RuntimeError(f"Unknown offline quant format: {format_val}") + + def _forward_base(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: + """Unified forward dispatcher for bf16 / online quant / offline GPTQ/AWQ.""" + if getattr(self, "_forward_plan_enabled", False): + plan = getattr(self, "_forward_plan", None) + if plan is None: + self.build_forward_plan_for_static(x, bias) + plan = getattr(self, "_forward_plan", None) + if plan is not None: + sig = plan.sig + dev = x.device + dev_idx = self._device_index(dev) + if ( + sig.device_type == dev.type + and sig.device_index == dev_idx + and sig.x_dtype == x.dtype + and sig.x_shape == tuple(int(v) for v in x.shape) + and sig.has_bias == (bias is not None) + ): + return plan(x) + # Static mode but shape/dtype changed: rebuild once and retry. + self.build_forward_plan_for_static(x, bias) + plan = getattr(self, "_forward_plan", None) + if plan is not None: + sig = plan.sig + if ( + sig.device_type == dev.type + and sig.device_index == dev_idx + and sig.x_dtype == x.dtype + and sig.x_shape == tuple(int(v) for v in x.shape) + and sig.has_bias == (bias is not None) + ): + return plan(x) + + strategy = self._get_linear_strategy() + # Runtime safety net: ensure we don't keep bf16+quant weights both resident. + self._maybe_promote_weight_to_quantized_at_runtime(x, strategy) + + # Offline quantized weights (GPTQ/AWQ) have higher priority. + if self.has_offline_quantized_weight(): + if strategy is None: + raise RuntimeError("Offline quantized weight is present but no linear strategy is configured.") + weight_format = getattr(strategy, "linear_weight_format", None) + out_features, in_features, group_size = self._offline_meta() + + # Avoid per-call kwargs dict construction on hot paths. + if weight_format == "gptq": + self._maybe_prepare_offline_gptq(x) + bits = self._infer_gptq_weight_bits(in_features=in_features) + return strategy.linear_forward( + x, + None, # weight not used for offline quantized weights + bias, + quant_kind=self.quant_kind, + gptq_qweight=self.gptq_qweight, + gptq_qzeros=self.gptq_qzeros, + gptq_scales=self.gptq_scales, + gptq_g_idx=self.gptq_g_idx, + weight_bits=bits, + use_v2_format=False, + out_features=out_features, + in_features=in_features, + group_size=group_size, + ) + + if weight_format == "awq": + # AWQ is 4-bit only in vLLM; bits stored in _offline_quant_bits. + bits = int(self._offline_quant_bits_py) if int(self._offline_quant_bits_py) > 0 else 4 + pack_factor = 32 // max(1, bits) + return strategy.linear_forward( + x, + None, + bias, + quant_kind=self.quant_kind, + awq_qweight=self.awq_qweight, + awq_qzeros=self.awq_qzeros, + awq_scales=self.awq_scales, + pack_factor=pack_factor, + out_features=out_features, + in_features=in_features, + group_size=group_size, + ) + + if weight_format == "gptq_marlin": + self._maybe_prepare_offline_gptq_marlin(x) + bits = self._infer_gptq_weight_bits(in_features=in_features) + return strategy.linear_forward( + x, + None, + bias, + quant_kind=self.quant_kind, + qweight=self.gptq_marlin_qweight, + scales=self.gptq_marlin_scales, + zp=self.gptq_marlin_zp, + g_idx=self.gptq_marlin_g_idx, + g_idx_sort_indices=self.gptq_marlin_g_idx_sort_indices, + workspace=self.gptq_marlin_workspace, + in_features=in_features, + out_features=out_features, + group_size=group_size, + weight_bits=bits, + tp_dim=self.tp_dim, + ) + + if weight_format == "awq_marlin": + self._maybe_prepare_offline_awq_marlin(x) + return strategy.linear_forward( + x, + None, + bias, + quant_kind=self.quant_kind, + qweight=self.awq_marlin_qweight, + scales=self.awq_marlin_scales, + zp=self.awq_marlin_zp, + workspace=self.awq_marlin_workspace, + in_features=in_features, + out_features=out_features, + group_size=group_size, + tp_dim=self.tp_dim, + ) + + # Fallback: compatibility for any remaining strategies. + kwargs = self._build_offline_forward_kwargs(x, strategy) + return strategy.linear_forward( + x, + None, + bias, + quant_kind=self.quant_kind, + **kwargs, + ) + + if self.has_quantized_weight(): + if strategy is None: + raise RuntimeError("Quantized weight is present but no linear strategy is configured.") + # Hot path: avoid per-call dict construction when possible. + extra_kwargs = self._maybe_int4_original_in_features_kwargs(strategy, x) + # W8A16(AllSpark) expects scales in 1xN layout and needs explicit N. + if getattr(strategy, "name", "") == "linear_int8_w8a16": + if extra_kwargs: + return strategy.linear_forward( + x, + self.quant_weight_int8, + bias, + quant_kind=self.quant_kind, + quant_scales=self.quant_scales_1xn, + out_features=self._forward_out_features, + **extra_kwargs, + ) + return strategy.linear_forward( + x, + self.quant_weight_int8, + bias, + quant_kind=self.quant_kind, + quant_scales=self.quant_scales_1xn, + out_features=self._forward_out_features, + ) + + # W8A8 expects scales in 1xN layout and is sensitive to weight layout (KxN stride0==1). + if getattr(strategy, "name", "") == "linear_int8_w8a8": + if extra_kwargs: + return strategy.linear_forward( + x, + self.quant_weight_int8, + bias, + quant_kind=self.quant_kind, + quant_scales=self.quant_scales_1xn, + out_features=self._forward_out_features, + **extra_kwargs, + ) + return strategy.linear_forward( + x, + self.quant_weight_int8, + bias, + quant_kind=self.quant_kind, + quant_scales=self.quant_scales_1xn, + out_features=self._forward_out_features, + ) + + if extra_kwargs: + return strategy.linear_forward( + x, + self.quant_weight_int8, + bias, + quant_kind=self.quant_kind, + quant_scales=self.quant_scales, + **extra_kwargs, + ) + return strategy.linear_forward( + x, + self.quant_weight_int8, + bias, + quant_kind=self.quant_kind, + quant_scales=self.quant_scales, + ) + + if strategy is None: + weight = getattr(self, "weight", None) + if weight is None: + raise RuntimeError("No strategy is configured and bf16 weight is missing.") + return F.linear(x, weight, bias) + + weight = getattr(self, "weight", None) + # NOTE: For offline-quantized strategies (e.g. GPTQ/AWQ/Marlin), the original + # bf16 weight may be intentionally removed after loading to save memory. + # In that case, the quantization strategy must be able to run without it. + kwargs = self._maybe_int4_original_in_features_kwargs(strategy, x) + if kwargs: + return strategy.linear_forward(x, weight, bias, quant_kind=self.quant_kind, **kwargs) + return strategy.linear_forward(x, weight, bias, quant_kind=self.quant_kind) def forward(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -85,8 +2002,9 @@ def __init__( r: int = 0, lora_alpha: float = 1.0, lora_dropout: float = 0.0, + quant_kind: str = "other", ): - LinearBase.__init__(self, input_size, output_size) + LinearBase.__init__(self, input_size, output_size, None, quant_kind) self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size)) self.weight.weight_loader = self.weight_loader if bias: @@ -99,9 +2017,10 @@ def __init__( def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param.data.copy_(loaded_weight) + self._maybe_quantize_loaded_weight_param(param, loaded_shard_id=None, expected_shard_ids={None}) def forward(self, x: torch.Tensor) -> torch.Tensor: - base_out = F.linear(x, self.weight, self.bias) + base_out = self._forward_base(x, self.bias) return self.lora_forward(x, base_out) @@ -115,10 +2034,12 @@ def __init__( r: int = 0, lora_alpha: float = 1.0, lora_dropout: float = 0.0, + quant_kind: str = "other", ): - LinearBase.__init__(self, input_size, output_size, 0) + LinearBase.__init__(self, input_size, output_size, 0, quant_kind) self.input_size_per_partition = input_size self.output_size_per_partition = divide(output_size, self.tp_size) + self._forward_out_features = int(self.output_size_per_partition) self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size)) self.weight.weight_loader = self.weight_loader @@ -136,9 +2057,10 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) param_data.copy_(loaded_weight) + self._maybe_quantize_loaded_weight_param(param, loaded_shard_id=None, expected_shard_ids={None}) def forward(self, x: torch.Tensor) -> torch.Tensor: - base_out = F.linear(x, self.weight, self.bias) + base_out = self._forward_base(x, self.bias) return self.lora_forward(x, base_out) @@ -152,9 +2074,18 @@ def __init__( r: int = 0, lora_alpha: float = 1.0, lora_dropout: float = 0.0, + quant_kind: str = "other", ): self.output_sizes = output_sizes - super().__init__(input_size, sum(output_sizes), bias=bias, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) + super().__init__( + input_size, + sum(output_sizes), + bias=bias, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + quant_kind=quant_kind, + ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int): param_data = param.data @@ -163,6 +2094,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank] param_data.copy_(loaded_weight) + expected = set(range(len(self.output_sizes))) + self._maybe_quantize_loaded_weight_param(param, loaded_shard_id=loaded_shard_id, expected_shard_ids=expected) class QKVParallelLinear(ColumnParallelLinear): @@ -177,6 +2110,7 @@ def __init__( r: int = 0, lora_alpha: float = 1.0, lora_dropout: float = 0.0, + quant_kind: str = "attn", ): self.head_size = head_size self.total_num_heads = total_num_heads @@ -186,7 +2120,7 @@ def __init__( self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) input_size = hidden_size output_size = (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_size - super().__init__(input_size, output_size, bias, r, lora_alpha, lora_dropout) + super().__init__(input_size, output_size, bias, r, lora_alpha, lora_dropout, quant_kind=quant_kind) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str): param_data = param.data @@ -203,6 +2137,7 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank] param_data.copy_(loaded_weight) + self._maybe_quantize_loaded_weight_param(param, loaded_shard_id=loaded_shard_id, expected_shard_ids={"q", "k", "v"}) class RowParallelLinear(LinearBase, LoRAMixin): @@ -215,8 +2150,9 @@ def __init__( r: int = 0, lora_alpha: float = 1.0, lora_dropout: float = 0.0, + quant_kind: str = "other", ): - LinearBase.__init__(self, input_size, output_size, 1) + LinearBase.__init__(self, input_size, output_size, 1, quant_kind) self.input_size_per_partition = divide(input_size, self.tp_size) self.output_size_per_partition = output_size @@ -236,9 +2172,11 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) param_data.copy_(loaded_weight) + self._maybe_quantize_loaded_weight_param(param, loaded_shard_id=None, expected_shard_ids={None}) def forward(self, x: torch.Tensor) -> torch.Tensor: - y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None) + bias = self.bias if self.tp_rank == 0 else None + y = self._forward_base(x, bias) if self.tp_size > 1: dist.all_reduce(y) return self.lora_forward(x, y) diff --git a/diffulex/logger.py b/diffulex/logger.py new file mode 100644 index 0000000..821feac --- /dev/null +++ b/diffulex/logger.py @@ -0,0 +1,176 @@ +""" +Professional logging setup with colored output for Diffulex +""" + +import logging +import sys +from pathlib import Path +from typing import Optional + +try: + from rich.console import Console + from rich.logging import RichHandler + from rich.traceback import install as install_rich_traceback + from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn + RICH_AVAILABLE = True +except ImportError: + RICH_AVAILABLE = False + +try: + import colorama + from colorama import Fore, Style, init as init_colorama + COLORAMA_AVAILABLE = True + init_colorama(autoreset=True) +except ImportError: + COLORAMA_AVAILABLE = False + + +class ColoredFormatter(logging.Formatter): + """Custom formatter with color support""" + + if COLORAMA_AVAILABLE: + COLORS = { + 'DEBUG': Fore.CYAN, + 'INFO': Fore.GREEN, + 'WARNING': Fore.YELLOW, + 'ERROR': Fore.RED, + 'CRITICAL': Fore.RED + Style.BRIGHT, + } + else: + COLORS = {} + + RESET = Style.RESET_ALL if COLORAMA_AVAILABLE else '' + + def format(self, record): + log_color = self.COLORS.get(record.levelname, '') + record.levelname = f"{log_color}{record.levelname}{self.RESET}" + return super().format(record) + + +def setup_logger( + name: str = "diffulex", + level: int = logging.INFO, + log_file: Optional[str] = None, + use_rich: bool = True, +) -> logging.Logger: + """ + Setup a professional logger with colored output + + Args: + name: Logger name + level: Logging level + log_file: Optional log file path + use_rich: Whether to use rich library for better formatting + + Returns: + Configured logger + """ + logger = logging.getLogger(name) + logger.setLevel(level) + logger.handlers.clear() + logger.propagate = False # Prevent propagation to root logger to avoid duplicate output + + # Use Rich if available and requested + if use_rich and RICH_AVAILABLE: + console = Console(stderr=True) + handler = RichHandler( + console=console, + show_time=True, + show_path=False, + rich_tracebacks=True, + markup=True, + ) + handler.setFormatter(logging.Formatter( + "%(message)s", + datefmt="[%X]" + )) + logger.addHandler(handler) + + # Install rich traceback + install_rich_traceback(show_locals=True) + else: + # Fallback to colored console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(level) + + if COLORAMA_AVAILABLE: + formatter = ColoredFormatter( + '%(asctime)s | %(levelname)-8s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + else: + formatter = logging.Formatter( + '%(asctime)s | %(levelname)-8s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # Add file handler if specified + if log_file: + log_path = Path(log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + + file_handler = logging.FileHandler(log_file, encoding='utf-8') + file_handler.setLevel(level) + file_formatter = logging.Formatter( + '%(asctime)s | %(levelname)-8s | %(name)s | %(funcName)s:%(lineno)d | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + file_handler.setFormatter(file_formatter) + logger.addHandler(file_handler) + + return logger + + +def get_logger(name: str = "diffulex") -> logging.Logger: + """ + Get or create a logger + + Args: + name: Logger name + + Returns: + Logger instance + """ + logger = logging.getLogger(name) + if not logger.handlers: + # Setup default logger if not already configured + setup_logger(name) + # Ensure propagate is False to avoid duplicate output + logger.propagate = False + return logger + + +class LoggerMixin: + """Mixin class to add logger property to classes""" + + @property + def logger(self) -> logging.Logger: + """Get logger for this class""" + return get_logger(self.__class__.__module__) + + +# Add success method to logger +def _add_success_method(): + """Add success method to logging.Logger class""" + if RICH_AVAILABLE: + def success(self, message: str, *args, **kwargs): + """Log success message with rich formatting""" + self.info(f"[green]✓[/green] {message}", *args, **kwargs) + else: + def success(self, message: str, *args, **kwargs): + """Log success message""" + if COLORAMA_AVAILABLE: + self.info(f"{Fore.GREEN}✓{Style.RESET_ALL} {message}", *args, **kwargs) + else: + self.info(f"✓ {message}", *args, **kwargs) + + if not hasattr(logging.Logger, 'success'): + logging.Logger.success = success + + +# Initialize success method +_add_success_method() + diff --git a/diffulex/model/__init__.py b/diffulex/model/__init__.py index 12581e2..afb32cc 100644 --- a/diffulex/model/__init__.py +++ b/diffulex/model/__init__.py @@ -18,7 +18,8 @@ except Exception as e: # Skip modules that fail to import import warnings - warnings.warn(f"Failed to import {module_name}: {e}", ImportWarning) + # ImportWarning is ignored by default, which can hide real registration problems. + warnings.warn(f"Failed to import {module_name}: {e!r}", RuntimeWarning) __all__ = _model_modules.copy() diff --git a/diffulex/model/auto_model.py b/diffulex/model/auto_model.py index bb77064..744496e 100755 --- a/diffulex/model/auto_model.py +++ b/diffulex/model/auto_model.py @@ -14,6 +14,47 @@ class AutoModelForDiffusionLM: MODEL_MAPPING: dict[str, RegistryEntry] = {} + @classmethod + def _ensure_registry_populated(cls) -> None: + """Best-effort import of built-in models to populate the registry. + + This is intentionally defensive for multi-process / spawn execution where + import side-effects can be sensitive to initialization order. + """ + if cls.MODEL_MAPPING: + return + + try: + import importlib + import pkgutil + import warnings + + import diffulex.model as model_pkg + + excluded = {"auto_model", "__init__"} + for mod in pkgutil.iter_modules(model_pkg.__path__): + name = mod.name + if name in excluded or mod.ispkg: + continue + try: + importlib.import_module(f"diffulex.model.{name}") + except Exception as e: + warnings.warn( + f"Failed to import diffulex.model.{name} during registry auto-discovery: {e!r}", + RuntimeWarning, + ) + except Exception as e: + # Don't fail hard here; the caller will raise with available models. + try: + import warnings + + warnings.warn( + f"Model registry auto-discovery failed: {e!r}", + RuntimeWarning, + ) + except Exception: + pass + @classmethod def register( cls, @@ -77,10 +118,16 @@ def from_config(cls, config: Config): try: factory, use_full_config = cls.MODEL_MAPPING[config.model_name] except KeyError as err: - available = ", ".join(cls.available_models()) or "" - raise ValueError( - f"Model '{config.model_name}' is not registered. Available models: {available}." - ) from err + # Spawn/multi-process execution can hit this before side-effect imports + # have populated the registry. Try a best-effort discovery once. + cls._ensure_registry_populated() + if config.model_name in cls.MODEL_MAPPING: + factory, use_full_config = cls.MODEL_MAPPING[config.model_name] + else: + available = ", ".join(cls.available_models()) or "" + raise ValueError( + f"Model '{config.model_name}' is not registered. Available models: {available}." + ) from err if factory is None: raise ValueError(f"Model '{config.model_name}' is reserved but not implemented yet.") diff --git a/diffulex/model/config/dream/configuration_dream.py b/diffulex/model/config/dream/configuration_dream.py index 6a8c49d..ec83795 100755 --- a/diffulex/model/config/dream/configuration_dream.py +++ b/diffulex/model/config/dream/configuration_dream.py @@ -17,10 +17,10 @@ from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation -from transformers.utils import logging +from diffulex.logger import get_logger -logger = logging.get_logger(__name__) +logger = get_logger(__name__) class DreamConfig(PretrainedConfig): diff --git a/diffulex/model/config/fast_dllm_v2/configuration_fast_dllm_v2.py b/diffulex/model/config/fast_dllm_v2/configuration_fast_dllm_v2.py index ab484c6..0b373ac 100755 --- a/diffulex/model/config/fast_dllm_v2/configuration_fast_dllm_v2.py +++ b/diffulex/model/config/fast_dllm_v2/configuration_fast_dllm_v2.py @@ -17,10 +17,10 @@ from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation -from transformers.utils import logging +from diffulex.logger import get_logger -logger = logging.get_logger(__name__) +logger = get_logger(__name__) class FastdLLMV2Config(PretrainedConfig): diff --git a/diffulex/model/config/sdar/configuration_sdar.py b/diffulex/model/config/sdar/configuration_sdar.py index f201418..fed2675 100644 --- a/diffulex/model/config/sdar/configuration_sdar.py +++ b/diffulex/model/config/sdar/configuration_sdar.py @@ -3,10 +3,10 @@ from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation -from transformers.utils import logging +from diffulex.logger import get_logger -logger = logging.get_logger(__name__) +logger = get_logger(__name__) class SDARConfig(PretrainedConfig): diff --git a/diffulex/model/dream.py b/diffulex/model/dream.py index c7e3ac5..8398b0b 100755 --- a/diffulex/model/dream.py +++ b/diffulex/model/dream.py @@ -55,21 +55,25 @@ def __init__( hidden_size, self.total_num_heads * self.head_dim, bias=qkv_bias, + quant_kind="attn", ) self.k_proj = ColumnParallelLinear( hidden_size, self.total_num_kv_heads * self.head_dim, bias=qkv_bias, + quant_kind="attn", ) self.v_proj = ColumnParallelLinear( hidden_size, self.total_num_kv_heads * self.head_dim, bias=qkv_bias, + quant_kind="attn", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, + quant_kind="attn", ) self.rotary_emb = get_rope( self.head_dim, @@ -114,16 +118,19 @@ def __init__( hidden_size, intermediate_size, bias=False, + quant_kind="mlp", ) self.up_proj = ColumnParallelLinear( hidden_size, intermediate_size, bias=False, + quant_kind="mlp", ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, + quant_kind="mlp", ) assert hidden_act == "silu" self.act_fn = SiluAndMul() diff --git a/diffulex/model/fast_dllm_v2.py b/diffulex/model/fast_dllm_v2.py index 126705b..e56db79 100755 --- a/diffulex/model/fast_dllm_v2.py +++ b/diffulex/model/fast_dllm_v2.py @@ -55,21 +55,25 @@ def __init__( hidden_size, self.total_num_heads * self.head_dim, bias=qkv_bias, + quant_kind="attn", ) self.k_proj = ColumnParallelLinear( hidden_size, self.total_num_kv_heads * self.head_dim, bias=qkv_bias, + quant_kind="attn", ) self.v_proj = ColumnParallelLinear( hidden_size, self.total_num_kv_heads * self.head_dim, bias=qkv_bias, + quant_kind="attn", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, + quant_kind="attn", ) self.rotary_emb = get_rope( self.head_dim, @@ -114,16 +118,19 @@ def __init__( hidden_size, intermediate_size, bias=False, + quant_kind="mlp", ) self.up_proj = ColumnParallelLinear( hidden_size, intermediate_size, bias=False, + quant_kind="mlp", ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, + quant_kind="mlp", ) assert hidden_act == "silu" self.act_fn = SiluAndMul() diff --git a/diffulex/model/llada.py b/diffulex/model/llada.py index c3a5243..af29757 100755 --- a/diffulex/model/llada.py +++ b/diffulex/model/llada.py @@ -55,21 +55,25 @@ def __init__( hidden_size, self.total_num_heads * self.head_dim, bias=qkv_bias, + quant_kind="attn", ) self.k_proj = ColumnParallelLinear( hidden_size, self.total_num_kv_heads * self.head_dim, bias=qkv_bias, + quant_kind="attn", ) self.v_proj = ColumnParallelLinear( hidden_size, self.total_num_kv_heads * self.head_dim, bias=qkv_bias, + quant_kind="attn", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, + quant_kind="attn", ) self.rotary_emb = get_rope( self.head_dim, @@ -115,16 +119,19 @@ def __init__( hidden_size, intermediate_size, bias=False, + quant_kind="mlp", ) self.up_proj = ColumnParallelLinear( hidden_size, intermediate_size, bias=False, + quant_kind="mlp", ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, + quant_kind="mlp", ) assert hidden_act == "silu" self.act_fn = SiluAndMul() diff --git a/diffulex/model/sdar.py b/diffulex/model/sdar.py index a733c45..d16750b 100644 --- a/diffulex/model/sdar.py +++ b/diffulex/model/sdar.py @@ -50,21 +50,25 @@ def __init__(self, config: SDARConfig) -> None: config.hidden_size, self.total_num_heads * self.head_dim, bias=bias, + quant_kind="attn", ) self.k_proj = ColumnParallelLinear( config.hidden_size, self.total_num_kv_heads * self.head_dim, bias=bias, + quant_kind="attn", ) self.v_proj = ColumnParallelLinear( config.hidden_size, self.total_num_kv_heads * self.head_dim, bias=bias, + quant_kind="attn", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=bias, + quant_kind="attn", ) # SDAR uses q/k per-head RMSNorm. @@ -116,9 +120,24 @@ class SDARMLP(nn.Module): def __init__(self, config: SDARConfig) -> None: super().__init__() - self.gate_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False) - self.up_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False) - self.down_proj = RowParallelLinear(config.intermediate_size, config.hidden_size, bias=False) + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + quant_kind="mlp", + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + quant_kind="mlp", + ) + self.down_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=False, + quant_kind="mlp", + ) assert getattr(config, "hidden_act", "silu") == "silu" self.act_fn = SiluAndMul() diff --git a/diffulex/sampler/base.py b/diffulex/sampler/base.py index 56734ab..86d8b4a 100644 --- a/diffulex/sampler/base.py +++ b/diffulex/sampler/base.py @@ -7,6 +7,9 @@ from easydict import EasyDict as edict from diffulex.engine.sequence import SequenceBase +from diffulex.logger import get_logger + +logger = get_logger(__name__) class SamplerBase(nn.Module): @@ -103,7 +106,7 @@ def _fetch_last_logits(self, logits: torch.Tensor, seq: SequenceBase) -> torch.T def _shift_logits(self, logits, last_logit=None): if logits.shape[1] == 0: - print("Warning: logits sequence length is 0, returning empty logits") + logger.warning("Logits sequence length is 0, returning empty logits") raise Exception("logits sequence length is 0") shifted_logits = torch.zeros_like(logits) diff --git a/diffulex/sampler/dream.py b/diffulex/sampler/dream.py index 9f06340..1ff85c6 100644 --- a/diffulex/sampler/dream.py +++ b/diffulex/sampler/dream.py @@ -56,9 +56,11 @@ def forward(self, logits: torch.Tensor, temperatures: torch.Tensor, high_conf_indices = torch.where(initial_confidence > block.accept_threshold)[0] accepted_ids = high_conf_indices - true_local_ids_sub_map[str(block_id)] = [block.local_mask_token_ids[accepted_id] for accepted_id in accepted_ids.tolist()] - accepted_ids_sub_map[str(block_id)] = accepted_ids.tolist() - sampled_tokens_sub_map[str(block_id)] = sampled_tokens + # Avoid calling `.tolist()` on CUDA tensors directly (can trigger many per-element DtoH syncs). + accepted_ids_list = accepted_ids.to(device="cpu").tolist() + true_local_ids_sub_map[str(block_id)] = [block.local_mask_token_ids[i] for i in accepted_ids_list] + accepted_ids_sub_map[str(block_id)] = accepted_ids_list + sampled_tokens_sub_map[str(block_id)] = sampled_tokens.to(device="cpu").tolist() seq_idx = str(seq.seq_id) true_local_ids_map[seq_idx] = true_local_ids_sub_map diff --git a/diffulex/sampler/fast_dllm_v2.py b/diffulex/sampler/fast_dllm_v2.py index ec323b5..5726655 100644 --- a/diffulex/sampler/fast_dllm_v2.py +++ b/diffulex/sampler/fast_dllm_v2.py @@ -59,19 +59,15 @@ def forward(self, seqs: list[SequenceBase], logits: torch.Tensor, temperatures: if len(high_conf_indices) == 0: max_prob_idx = initial_confidence.argmax() - accepted_ids = torch.tensor([max_prob_idx], device=sampled_tokens.device, dtype=torch.long) + accepted_ids = max_prob_idx.view(1) else: max_prob_idx = initial_confidence.argmax() - accepted_ids = torch.unique(torch.cat([ - high_conf_indices, - torch.tensor([max_prob_idx], device=sampled_tokens.device, dtype=torch.long) - ])) + accepted_ids = torch.unique(torch.cat([high_conf_indices, max_prob_idx.view(1)])) - true_local_ids_sub_map[str(block_id)] = [ - block.local_mask_token_ids[accepted_id] for accepted_id in accepted_ids.tolist() - ] - accepted_ids_sub_map[str(block_id)] = accepted_ids.tolist() - sampled_tokens_sub_map[str(block_id)] = sampled_tokens + accepted_ids_list = accepted_ids.to(device="cpu").tolist() + true_local_ids_sub_map[str(block_id)] = [block.local_mask_token_ids[i] for i in accepted_ids_list] + accepted_ids_sub_map[str(block_id)] = accepted_ids_list + sampled_tokens_sub_map[str(block_id)] = sampled_tokens.to(device="cpu").tolist() seq_idx = str(seq.seq_id) true_local_ids_map[seq_idx] = true_local_ids_sub_map diff --git a/diffulex/sampler/llada.py b/diffulex/sampler/llada.py index 5202fa1..fd11f44 100644 --- a/diffulex/sampler/llada.py +++ b/diffulex/sampler/llada.py @@ -52,9 +52,10 @@ def forward(self, logits: torch.Tensor, temperatures: torch.Tensor, high_conf_indices = torch.where(initial_confidence > block.accept_threshold)[0] accepted_ids = high_conf_indices - true_local_ids_sub_map[str(block_id)] = [block.local_mask_token_ids[accepted_id] for accepted_id in accepted_ids.tolist()] - accepted_ids_sub_map[str(block_id)] = accepted_ids.tolist() - sampled_tokens_sub_map[str(block_id)] = sampled_tokens + accepted_ids_list = accepted_ids.to(device="cpu").tolist() + true_local_ids_sub_map[str(block_id)] = [block.local_mask_token_ids[i] for i in accepted_ids_list] + accepted_ids_sub_map[str(block_id)] = accepted_ids_list + sampled_tokens_sub_map[str(block_id)] = sampled_tokens.to(device="cpu").tolist() seq_idx = str(seq.seq_id) true_local_ids_map[seq_idx] = true_local_ids_sub_map diff --git a/diffulex/sampler/sdar.py b/diffulex/sampler/sdar.py new file mode 100644 index 0000000..8fc3896 --- /dev/null +++ b/diffulex/sampler/sdar.py @@ -0,0 +1,81 @@ +import torch + +from dataclasses import dataclass + +from diffulex.sampler.auto_sampler import AutoSampler +from diffulex.sampler.base import SamplerShiftLogits, SampleOutputBase +from diffulex.engine.sequence import SequenceBase + + +@dataclass +class SDARSampleOutputForDiffusionLM(SampleOutputBase): + pass + + +@AutoSampler.register("sdar") +class SDARSamplerForDiffusionLM(SamplerShiftLogits): + def forward(self, seqs: list[SequenceBase], logits: torch.Tensor, temperatures: torch.Tensor, + top_p=None, top_k=None, margin_confidence=False, neg_entropy=False, threshold=0.95): + attn_metadata = self.fetch_attn_metadata() + split_logits = torch.split( + logits, [len(seq) for seq in seqs] if attn_metadata.is_prefill + else [attn_metadata.diffusion_block_size] * len(seqs), dim=0 + ) + + accepted_ids_map = {} + sampled_tokens_map = {} + true_local_ids_map = {} + for temperature, seq, seq_logits in zip(temperatures, seqs, split_logits): + true_local_ids_sub_map = {} + accepted_ids_sub_map = {} + sampled_tokens_sub_map = {} + + last_logits = self._fetch_last_logits(seq_logits, seq) + + shifted_logits = self._shift_logits(seq_logits, last_logits) + + for block_id, block in enumerate(seq.diffusion_blocks): + if not block.is_active or sum(block.local_mask_tokens) == 0: + continue + + if len(block.global_mask_token_ids) == 0: + continue + + if attn_metadata.is_prefill: + mask_token_logits = shifted_logits[block.global_mask_token_ids, ...] + else: + mask_token_logits = shifted_logits[block.local_mask_token_ids, ...] + + confidence, sampled_tokens, initial_confidence = self.sample_tokens( + mask_token_logits, + temperature, + top_p=top_p, + top_k=top_k, + neg_entropy=(neg_entropy == "neg_entropy"), + margin_confidence=(margin_confidence == "margin_confidence") + ) + + high_conf_indices = torch.where(initial_confidence > threshold)[0] + + if len(high_conf_indices) == 0: + max_prob_idx = initial_confidence.argmax() + accepted_ids = max_prob_idx.view(1) + else: + max_prob_idx = initial_confidence.argmax() + accepted_ids = torch.unique(torch.cat([high_conf_indices, max_prob_idx.view(1)])) + + accepted_ids_list = accepted_ids.to(device="cpu").tolist() + true_local_ids_sub_map[str(block_id)] = [block.local_mask_token_ids[i] for i in accepted_ids_list] + accepted_ids_sub_map[str(block_id)] = accepted_ids_list + sampled_tokens_sub_map[str(block_id)] = sampled_tokens.to(device="cpu").tolist() + + seq_idx = str(seq.seq_id) + true_local_ids_map[seq_idx] = true_local_ids_sub_map + accepted_ids_map[seq_idx] = accepted_ids_sub_map + sampled_tokens_map[seq_idx] = sampled_tokens_sub_map + + return SDARSampleOutputForDiffusionLM( + true_local_ids_map=true_local_ids_map, + accepted_ids_map=accepted_ids_map, + sampled_tokens_map=sampled_tokens_map + ) \ No newline at end of file diff --git a/diffulex/strategy/block_diffusion/engine/model_runner.py b/diffulex/strategy/block_diffusion/engine/model_runner.py index d363ba4..61a4f99 100644 --- a/diffulex/strategy/block_diffusion/engine/model_runner.py +++ b/diffulex/strategy/block_diffusion/engine/model_runner.py @@ -5,6 +5,7 @@ from multiprocessing.synchronize import Event import torch +from tqdm import tqdm from diffulex.config import Config from diffulex.engine.sequence import SequenceBase @@ -23,24 +24,6 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event]): self.mask_token_id = config.mask_token_id super().__init__(config, rank, event) - - def warmup_model(self): - print("Warming up model...") - set_warming_up(True) - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - max_num_batched_tokens, max_model_len = ( - self.config.max_num_batched_tokens, - self.config.max_model_len, - ) - num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) - test_input_ids = [0] * max_model_len - seqs = [BDSequence(test_input_ids, config=self.config) for _ in range(num_seqs)] - self.run(seqs, True) - for seq in seqs: - seq.post_process() - torch.cuda.empty_cache() - reset_warming_up() def prepare_prefill(self, seqs: list[BDSequence]): input_ids: list[int] = [] @@ -87,7 +70,6 @@ def prepare_prefill(self, seqs: list[BDSequence]): slot_mapping.extend([-1] * self.block_size) block_tables = self.prepare_block_tables(seqs) - input_ids_tensor = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions_tensor = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) context_lens_tensor = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) @@ -145,7 +127,7 @@ def prepare_decode(self, seqs: list[BDSequence]): num_pages_storing = seq.num_page_blocks_in_active_diffusion_block total_num_pages = len(seq.block_table) for i in range(0, num_pages_storing): - start = seq.block_table[total_num_pages - num_pages_storing + i] * self.block_size + start = seq.block_table[(total_num_pages - 1) - num_pages_storing + i] * self.block_size end = start + self.block_size slot_mapping.extend(range(start, end)) @@ -174,24 +156,24 @@ def prepare_decode(self, seqs: list[BDSequence]): @torch.inference_mode() def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool): - if is_prefill or self.enforce_eager or input_ids.size(0) > 512: + if is_prefill or self.enforce_eager or input_ids.size(0) > 512 * self.diffusion_block_size: return self.model.compute_logits(self.model(input_ids, positions)) num_tokens = input_ids.size(0) - context = fetch_bd_attn_metadata() + attn_metadata = fetch_bd_attn_metadata() graph = self.graphs[next(x for x in self.graph_bs if x >= num_tokens)] graph_vars = self.graph_vars for key, value in graph_vars.items(): if key != "outputs": value.zero_() - num_seqs = len(context.context_lens) + num_seqs = len(attn_metadata.context_lens) graph_vars["input_ids"][:num_tokens] = input_ids graph_vars["positions"][:num_tokens] = positions - graph_vars["slot_mapping"][:num_tokens] = context.slot_mapping - graph_vars["context_lens"][:num_seqs] = context.context_lens - graph_vars["cu_seqlens_q"][:num_seqs + 1] = context.cu_seqlens_q - graph_vars["cu_seqlens_k"][:num_seqs + 1] = context.cu_seqlens_k - graph_vars["block_tables"][:num_seqs, : context.block_tables.size(1)] = context.block_tables + graph_vars["slot_mapping"][:num_tokens] = attn_metadata.slot_mapping + graph_vars["context_lens"][:num_seqs] = attn_metadata.context_lens + graph_vars["cu_seqlens_q"][:num_seqs + 1] = attn_metadata.cu_seqlens_q + graph_vars["cu_seqlens_k"][:num_seqs + 1] = attn_metadata.cu_seqlens_k + graph_vars["block_tables"][:num_seqs, : attn_metadata.block_tables.size(1)] = attn_metadata.block_tables graph.replay() return self.model.compute_logits(graph_vars["outputs"][:num_tokens]) @@ -205,6 +187,17 @@ def run(self, seqs: list[SequenceBase], is_prefill: bool) -> list[int]: @torch.inference_mode() def capture_cudagraph(self): + # Enable per-layer forward-plan dispatch to stabilize capture and minimize + # Python branching inside the captured region. + try: + from diffulex.layer.linear import LinearBase + for m in self.model.modules(): + if isinstance(m, LinearBase): + m.enable_forward_plan(True) + except Exception: + pass + + set_warming_up(True) config = self.config hf_config = config.hf_config max_num_seqs = min(self.config.max_num_seqs, 512) @@ -234,8 +227,8 @@ def capture_cudagraph(self): self.graph_bs.append(num_seqs * diffusion_block_size) self.graphs = {} self.graph_pool = None - - for num_tokens in reversed(self.graph_bs): + + for num_tokens in tqdm(reversed(self.graph_bs), desc="Capturing CUDA graphs"): num_seqs = num_tokens // diffusion_block_size graph = torch.cuda.CUDAGraph() @@ -272,3 +265,4 @@ def capture_cudagraph(self): block_tables=block_tables, outputs=outputs, ) + reset_warming_up() \ No newline at end of file diff --git a/diffulex/strategy/block_diffusion/engine/sequence.py b/diffulex/strategy/block_diffusion/engine/sequence.py index 936b242..f2c85a6 100644 --- a/diffulex/strategy/block_diffusion/engine/sequence.py +++ b/diffulex/strategy/block_diffusion/engine/sequence.py @@ -196,7 +196,7 @@ def extend_mask_tokens(self, extend_len: int) -> None: self.token_ids.extend([self.mask_token_id] * extend_len) def init_diffusion_blocks(self) -> None: - """Initialize diffusion blocks: prefix blocks are TO_CACHE, last block with mask tokens is ACTIVE.""" + """Initialize diffusion blocks: prefix blocks are `TO_CACHE`, last block with mask tokens is `ACTIVE`.""" self.prefix_len = len(self.token_ids) block_size = self.diffusion_block_size diff --git a/diffulex/strategy/d2f/engine/kvcache_manager.py b/diffulex/strategy/d2f/engine/kvcache_manager.py index f3eeb73..27591c6 100644 --- a/diffulex/strategy/d2f/engine/kvcache_manager.py +++ b/diffulex/strategy/d2f/engine/kvcache_manager.py @@ -14,17 +14,38 @@ class D2FKVCacheManager(KVCacheManagerBase): def __init__(self, config: Config): super().__init__(config) + def _required_kv_blocks(self, seq: "D2FSequence") -> int: + """How many KV-cache blocks this sequence needs *now* for cached+to-cache tokens. + + NOTE: In diffusion decoding, a single decode step may move multiple tokens into + "to_cache", which can cross multiple KV blocks. So we must ensure block_table + is large enough for all cached_or_caching tokens, not just append one block. + """ + n = seq.cached_or_caching_num_tokens + if n <= 0: + return 0 + # Need enough blocks to cover token indices [0, n-1]. + return (n + self.block_size - 1) // self.block_size + def can_append(self, seq: "D2FSequence") -> bool: - return len(self.free_block_ids) >= (seq.cached_or_caching_num_tokens % self.block_size == 1) + # We may need to allocate multiple blocks in one step (cached_or_caching can jump). + required = self._required_kv_blocks(seq) + missing = max(0, required - len(seq.block_table)) + return len(self.free_block_ids) >= missing def may_append(self, seq: "D2FSequence") -> None: if seq.cached_or_caching_num_tokens == 0: return block_table = seq.block_table if not block_table: + # Defensive: allocate() should have populated it for prefill/prompt, but don't crash here. return - last_block = self.blocks[block_table[-1]] - if seq.cached_or_caching_num_tokens // self.block_size == len(seq.block_table): + + required = self._required_kv_blocks(seq) + # Allocate enough KV blocks to cover all cached_or_caching tokens. + while len(block_table) < required: + last_block = self.blocks[block_table[-1]] + # Preserve the existing "finalize previous block hash" behavior before moving on. if last_block.hash == -1: prev_end_token = seq.cached_or_caching_num_tokens - seq.caching_num_tokens - 1 prev_block_idx = prev_end_token // self.block_size @@ -34,6 +55,15 @@ def may_append(self, seq: "D2FSequence") -> None: h = self.compute_hash(token_ids, prefix) last_block.update(h, token_ids) self.hash_to_block_id[h] = last_block.block_id + + if not self.free_block_ids: + raise RuntimeError( + "D2FKVCacheManager: insufficient free KV cache blocks to append: " + f"required={required}, current_len={len(block_table)}, " + f"cached_or_caching_num_tokens={seq.cached_or_caching_num_tokens}, " + f"block_size={self.block_size}." + ) + block_id = self.free_block_ids[0] self._allocate_block(block_id) block_table.append(block_id) \ No newline at end of file diff --git a/diffulex/strategy/d2f/engine/model_runner.py b/diffulex/strategy/d2f/engine/model_runner.py index 7d736ab..9a020a9 100644 --- a/diffulex/strategy/d2f/engine/model_runner.py +++ b/diffulex/strategy/d2f/engine/model_runner.py @@ -25,23 +25,26 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event]): super().__init__(config, rank, event) - def warmup_model(self): - print("Warming up model...") - set_warming_up(True) - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - max_num_batched_tokens, max_model_len = ( - self.config.max_num_batched_tokens, - self.config.max_model_len, - ) - num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) - test_input_ids = [0] * max_model_len - seqs = [D2FSequence(test_input_ids, config=self.config) for _ in range(num_seqs)] - self.run(seqs, True) - for seq in seqs: - seq.post_process() - torch.cuda.empty_cache() - reset_warming_up() + def _get_decode_mode(self) -> str: + """ + 统一选择 decode_mode 的逻辑: + 1. 如果 config.decode_mode 已设置,优先使用 config 的值 + 2. 否则,如果 kv_cache_dtype 是 FP8,自动切换到 "static" + 3. 否则,默认使用 "varlen" + """ + if self.config.decode_mode is not None: + return self.config.decode_mode + + # Auto-select based on kv_cache_dtype + decode_mode = "varlen" + try: + from diffulex.utils.kv_cache_dtype import parse_kv_cache_dtype + if parse_kv_cache_dtype(getattr(self.config, "kv_cache_dtype", "bf16")).is_fp8: + decode_mode = "static" + except Exception: + decode_mode = "varlen" + + return decode_mode def prepare_prefill(self, seqs: list[D2FSequence]): input_ids: list[int] = [] @@ -115,6 +118,7 @@ def prepare_prefill(self, seqs: list[D2FSequence]): ) ) + decode_mode = self._get_decode_mode() set_d2f_attn_metadata( True, cu_seqlens_q=cu_seqlens_q_tensor, @@ -129,7 +133,7 @@ def prepare_prefill(self, seqs: list[D2FSequence]): seq_lens=seq_lens, seq_lens_ts=seq_lens_ts, diffusion_block_size=self.diffusion_block_size, - decode_mode="varlen", + decode_mode=decode_mode, attn_type="full_attention", ) return input_ids_tensor, positions_tensor @@ -198,6 +202,21 @@ def get_step(diff_blk, begin_idx): cur_diffusion_block_start = 0 cur_diffusion_block_end = step start_idx += step + # IMPORTANT: + # We must have a KV-cache block allocated for this mem_block_idx. + # If not, this is almost always due to insufficient KV cache blocks + # (e.g. higher model/weight memory footprint leaves too few blocks). + if mem_block_idx >= len(seq.block_table): + raise RuntimeError( + "KV cache block allocation is insufficient during decode: " + f"mem_block_idx={mem_block_idx} requires block_table length >= {mem_block_idx + 1}, " + f"but got len(block_table)={len(seq.block_table)} (seq.num_blocks={seq.num_blocks}). " + "This usually means GPU memory utilization is too low to allocate enough KV cache " + f"blocks for this run (num_kvcache_blocks={getattr(self.config, 'num_kvcache_blocks', None)}, " + f"gpu_memory_utilization={getattr(self.config, 'gpu_memory_utilization', None)}). " + "Try increasing gpu_memory_utilization, reducing max_model_len/max_tokens/max_num_seqs, " + "or using a lower-memory weight quantization (e.g. int4)." + ) mem_block_start = ( seq.block_table[mem_block_idx] * self.block_size + context_len % seq.block_size @@ -241,6 +260,14 @@ def get_step(diff_blk, begin_idx): slot_mapping_tensor = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_lens_tensor = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) block_tables = self.prepare_block_tables(seqs) + # NOTE: + # - d2f decode supports "varlen" and "static" modes (see config.decode_mode). + # - For FP8 KV, the (varlen/distinct-layout) path uses `load_kvcache` which is expected to + # handle FP8 dequantization / scale application inside the fused operator (no Python-level dequant). + # - Performance can still differ between modes/kernels; when FP8 KV is enabled, prefer the + # best-supported kernel path on your stack (often "static"/unified-layout) and validate with profiling. + # - Allow manual override via config.decode_mode if specified. + decode_mode = self._get_decode_mode() set_d2f_attn_metadata( False, slot_mapping=slot_mapping_tensor, @@ -256,7 +283,7 @@ def get_step(diff_blk, begin_idx): kv_cache_layout=self.config.kv_cache_layout, need_kv_cache_store=need_kv_cache_store, diffusion_block_size=self.diffusion_block_size, - decode_mode="varlen", + decode_mode=decode_mode, attn_type="full_attention", ) return input_ids_tensor, positions_tensor @@ -265,20 +292,45 @@ def get_step(diff_blk, begin_idx): def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool): if is_prefill or self.enforce_eager or input_ids.size(0) > 512: return self.model.compute_logits(self.model(input_ids, positions)) - bs = input_ids.size(0) + num_tokens = input_ids.size(0) context = fetch_d2f_attn_metadata() - graph = self.graphs[next(x for x in self.graph_bs if x >= bs)] + candidates = [x for x in self.graph_bs if x >= num_tokens] + if not candidates: + # Safety: fall back if capture didn't include a large-enough bucket. + return self.model.compute_logits(self.model(input_ids, positions)) + bucket_tokens = candidates[0] + graph = self.graphs[bucket_tokens] graph_vars = self.graph_vars - for key, value in graph_vars.items(): - if key != "outputs": - value.zero_() - graph_vars["input_ids"][:bs] = input_ids - graph_vars["positions"][:bs] = positions - graph_vars["slot_mapping"][:bs] = context.slot_mapping - graph_vars["context_lens"][:bs] = context.context_lens - graph_vars["block_tables"][:bs, : context.block_tables.size(1)] = context.block_tables + # Safety: fall back if runtime batch exceeds captured metadata capacity. + num_seqs = int(context.context_lens.numel()) + max_num_seqs_for_graph = int(graph_vars["context_lens"].numel()) + if num_seqs > max_num_seqs_for_graph: + return self.model.compute_logits(self.model(input_ids, positions)) + + # Reset buffers to safe defaults (avoid "0" being interpreted as a valid index). + graph_vars["input_ids"].zero_() + graph_vars["positions"].zero_() + graph_vars["slot_mapping"].fill_(-1) + graph_vars["context_lens"].zero_() + graph_vars["block_tables"].fill_(-1) + graph_vars["input_ids"][:num_tokens] = input_ids + graph_vars["positions"][:num_tokens] = positions + graph_vars["slot_mapping"][:num_tokens] = context.slot_mapping + graph_vars["context_lens"][:num_seqs] = context.context_lens + # cu_seqlens are required by unified paged-attn decode kernels. + if getattr(context, "cu_seqlens_q", None) is not None: + # Pad to captured length so "extra" sequences become 0-length. + graph_vars["cu_seqlens_q"].fill_(int(num_tokens)) + graph_vars["cu_seqlens_q"][: num_seqs + 1] = context.cu_seqlens_q + if getattr(context, "cu_seqlens_k", None) is not None: + last_k = int(context.cu_seqlens_k[num_seqs].item()) + graph_vars["cu_seqlens_k"].fill_(last_k) + graph_vars["cu_seqlens_k"][: num_seqs + 1] = context.cu_seqlens_k + + bt_cols = min(int(graph_vars["block_tables"].size(1)), int(context.block_tables.size(1))) + graph_vars["block_tables"][:num_seqs, :bt_cols] = context.block_tables[:, :bt_cols] graph.replay() - return self.model.compute_logits(graph_vars["outputs"][:bs]) + return self.model.compute_logits(graph_vars["outputs"][:num_tokens]) def run(self, seqs: list[SequenceBase], is_prefill: bool) -> list[int]: input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) @@ -290,8 +342,118 @@ def run(self, seqs: list[SequenceBase], is_prefill: bool) -> list[int]: @torch.inference_mode() def capture_cudagraph(self): - """ - TODO: Varlen decoding does not support CUDA graph capture yet. - Can be implemented, but requires drastically high overhead. - """ - raise NotImplementedError("CUDA graph capture for DiffusionLM is not implemented yet.") + # Static-mode CUDA graph capture for D2F decode. + # + # NOTE: + # - This matches `run_model()`'s replay protocol: we only overwrite + # input_ids/positions/slot_mapping/context_lens/block_tables per step. + # - Varlen mode is intentionally not supported here (assume static flow). + from tqdm import tqdm + + # Enable per-layer forward-plan dispatch to stabilize capture and minimize + # Python branching inside the captured region. + try: + from diffulex.layer.linear import LinearBase + for m in self.model.modules(): + if isinstance(m, LinearBase): + m.enable_forward_plan(True) + except Exception: + pass + + set_warming_up(True) + config = self.config + hf_config = config.hf_config + diffusion_block_size = int(self.diffusion_block_size) + max_num_seqs = int(self.config.max_num_seqs) + # Graph path is only used when num_tokens <= 512. + # + # IMPORTANT: + # In D2F decode, `num_tokens` (sum of per-seq seqlen_q) is NOT guaranteed to equal + # `num_seqs * diffusion_block_size`. A single seq can contribute multiple diffusion blocks, + # so we must bucket by `num_tokens` directly and keep metadata tensors sized by + # `max_num_seqs_for_graph` (padding unused seqs to 0-length via cu_seqlens). + max_num_seqs_for_graph = max(1, min(max_num_seqs, 512)) + max_num_tokens = 512 + max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size + + # Allocate graph buffers on the same device/dtype as the model. + try: + p0 = next(self.model.parameters()) + graph_device = p0.device + graph_dtype = p0.dtype + except StopIteration: + graph_device = torch.device("cuda") + graph_dtype = torch.float16 + + # Allocate max-size graph buffers. + input_ids = torch.zeros(max_num_tokens, dtype=torch.int64, device=graph_device) + positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=graph_device) + slot_mapping = torch.full((max_num_tokens,), -1, dtype=torch.int32, device=graph_device) + context_lens = torch.zeros(max_num_seqs_for_graph, dtype=torch.int32, device=graph_device) + block_tables = torch.full((max_num_seqs_for_graph, max_num_blocks), -1, dtype=torch.int32, device=graph_device) + outputs = torch.zeros(max_num_tokens, hf_config.hidden_size, dtype=graph_dtype, device=graph_device) + cu_seqlens_q = torch.zeros(max_num_seqs_for_graph + 1, dtype=torch.int32, device=graph_device) + cu_seqlens_k = torch.zeros(max_num_seqs_for_graph + 1, dtype=torch.int32, device=graph_device) + + # Capture bucketed graphs by total num_tokens. + self.graph_bs = [] + # Keep buckets aligned to diffusion_block_size for stable kernel shapes. + for t in range(diffusion_block_size, max_num_tokens + 1, diffusion_block_size): + self.graph_bs.append(int(t)) + self.graphs = {} + self.graph_pool = None + + for num_tokens in tqdm(reversed(self.graph_bs), desc="Capturing CUDA graphs"): + num_seqs = int(max_num_seqs_for_graph) + graph = torch.cuda.CUDAGraph() + # Fill placeholder metadata with valid monotonic cu_seqlens to satisfy kernel assertions. + # IMPORTANT: cu_seqlens_q must be non-decreasing and end at `num_tokens` + # (it is used to index into Q/slot_mapping which are length `num_tokens`). + # Use a simple placeholder: put all Q tokens into the first seq and make + # the remaining seqs 0-length. + cu_seqlens_q[: num_seqs + 1].fill_(int(num_tokens)) + cu_seqlens_q[0] = 0 + # Use a conservative max-seqlen for K to keep shapes stable; values are overwritten before replay. + cu_seqlens_k[: num_seqs + 1] = ( + torch.arange(num_seqs + 1, dtype=torch.int32, device=graph_device) * int(config.max_model_len) + ) + context_lens[:num_seqs].fill_(int(config.max_model_len)) + # Use a benign placeholder block table for the first seq. + block_tables[:1].zero_() + # For static decode, use placeholder metadata tensors; per-step values are copied + # into `graph_vars` before replay. + set_d2f_attn_metadata( + False, + slot_mapping=slot_mapping[:num_tokens], + context_lens=context_lens[:num_seqs], + cu_seqlens_q=cu_seqlens_q[: num_seqs + 1], + cu_seqlens_k=cu_seqlens_k[: num_seqs + 1], + max_seqlen_q=int(num_tokens), + max_seqlen_k=int(config.max_model_len), + block_tables=block_tables[:num_seqs], + kv_cache_layout=self.config.kv_cache_layout, + need_kv_cache_store=True, + diffusion_block_size=self.diffusion_block_size, + decode_mode="static", + attn_type="full_attention", + ) + outputs[:num_tokens] = self.model(input_ids[:num_tokens], positions[:num_tokens]) # warmup + with torch.cuda.graph(graph, self.graph_pool): + outputs[:num_tokens] = self.model(input_ids[:num_tokens], positions[:num_tokens]) # capture + if self.graph_pool is None: + self.graph_pool = graph.pool() + self.graphs[num_tokens] = graph + torch.cuda.synchronize() + reset_d2f_attn_metadata() + + self.graph_vars = dict( + input_ids=input_ids, + positions=positions, + slot_mapping=slot_mapping, + context_lens=context_lens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + block_tables=block_tables, + outputs=outputs, + ) + reset_warming_up() diff --git a/diffulex/strategy/d2f/engine/scheduler.py b/diffulex/strategy/d2f/engine/scheduler.py index a4b8f29..d362dda 100644 --- a/diffulex/strategy/d2f/engine/scheduler.py +++ b/diffulex/strategy/d2f/engine/scheduler.py @@ -5,6 +5,8 @@ from diffulex.engine.sequence import SequenceStatus from .sequence import D2FSequence +import torch + @AutoScheduler.register("d2f", is_default=True) class D2FScheduler(SchedulerBase): @@ -104,12 +106,17 @@ def postprocess( continue diffusion_block = seq.diffusion_blocks[int(block_id)] sampled_tokens = sampled_tokens_map.get(block_id, []) + # `sampled_tokens` may be a CUDA Tensor (legacy behavior) or list[int]. + # Converting per-token via `.item()` causes massive DtoH sync overhead. + # Convert once per block. + if isinstance(sampled_tokens, torch.Tensor): + sampled_tokens = sampled_tokens.tolist() true_local_ids = true_ids_map.get(block_id, []) for true_local_id, accepted_id in zip(true_local_ids, accepted_ids): - token = sampled_tokens[accepted_id] + token = int(sampled_tokens[accepted_id]) diffusion_block.modify_token(true_local_id, token) if ( - (not seq.ignore_eos and token.item() == self.eos) + (not seq.ignore_eos and token == self.eos) or seq.num_completion_tokens >= seq.max_tokens ): seq.meet_eos = True diff --git a/diffulex/strategy/d2f/engine/sequence.py b/diffulex/strategy/d2f/engine/sequence.py index db22bc8..7532ea8 100644 --- a/diffulex/strategy/d2f/engine/sequence.py +++ b/diffulex/strategy/d2f/engine/sequence.py @@ -117,7 +117,9 @@ def modify_token(self, local_token_id: int, modified_to: int) -> None: raise RuntimeError("Diffusion block is not attached to a sequence.") target_id = local_token_id + self.global_start_id assert self.seq.token_ids[target_id] == self.mask_token_id - self.seq.token_ids[target_id] = modified_to.item() # type: ignore[assignment] + # Hot path: avoid per-token CUDA -> CPU sync via Tensor.item(). + # `modified_to` should be a python int (or at least int-castable). + self.seq.token_ids[target_id] = int(modified_to) # type: ignore[assignment] self.seq.new_tokens += 1 diff --git a/diffulex/strategy/fast_dllm_v2/__init__.py b/diffulex/strategy/fast_dllm_v2/__init__.py new file mode 100644 index 0000000..845afa2 --- /dev/null +++ b/diffulex/strategy/fast_dllm_v2/__init__.py @@ -0,0 +1,14 @@ +"""Block Diffusion strategy component exports.""" +from __future__ import annotations + +from .engine.kvcache_manager import BDKVCacheManager +from .engine.model_runner import BDModelRunner +from .engine.scheduler import BDScheduler +from .engine.sequence import BDSequence + +__all__ = [ + "BDKVCacheManager", + "BDModelRunner", + "BDScheduler", + "BDSequence", +] diff --git a/diffulex/strategy/fast_dllm_v2/attention/metadata.py b/diffulex/strategy/fast_dllm_v2/attention/metadata.py new file mode 100644 index 0000000..7ae64b2 --- /dev/null +++ b/diffulex/strategy/fast_dllm_v2/attention/metadata.py @@ -0,0 +1,62 @@ +import torch + +from typing import List +from dataclasses import dataclass + +from diffulex.attention.metadata import AttnMetaDataBase +from diffulex.strategy.fast_dllm_v2.engine.sequence import FastDLLMV2Sequence + + +@dataclass +class FDV2AttnMetaData(AttnMetaDataBase): + seqs: List[FastDLLMV2Sequence] = None + kv_cache_layout: str = "unified" + need_kv_cache_store: bool = True + + def __post_init__(self): + if self.context_lens is not None and sum(self.context_lens) > 0: + self.total_lens = self.diffusion_block_size + self.context_lens + + +FDV2_ATTN_METADATA = FDV2AttnMetaData() + +def fetch_fdv2_attn_metadata() -> FDV2AttnMetaData: + return FDV2_ATTN_METADATA + +def set_fdv2_attn_metadata( + is_prefill: bool = False, + cu_seqlens_q: torch.Tensor | None = None, + cu_seqlens_k: torch.Tensor | None = None, + max_seqlen_q: int = 0, + max_seqlen_k: int = 0, + slot_mapping: torch.Tensor | None = None, + context_lens: torch.Tensor | None = None, + block_tables: torch.Tensor | None = None, + page_block_size: int = 32, + diffusion_block_size: int = 32, + decode_mode: str = "static", + attn_type: str = "full_attention", + kv_cache_layout: str = "unified", + need_kv_cache_store: bool = True, +) -> None: + global FDV2_ATTN_METADATA + FDV2_ATTN_METADATA = FDV2AttnMetaData( + is_prefill=is_prefill, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + slot_mapping=slot_mapping, + context_lens=context_lens, + block_tables=block_tables, + page_block_size=page_block_size, + diffusion_block_size=diffusion_block_size, + kv_cache_layout=kv_cache_layout, + need_kv_cache_store=need_kv_cache_store, + decode_mode=decode_mode, + attn_type=attn_type, + ) + +def reset_fdv2_attn_metadata() -> None: + global FDV2_ATTN_METADATA + FDV2_ATTN_METADATA = FDV2AttnMetaData() \ No newline at end of file diff --git a/diffulex/strategy/fast_dllm_v2/engine/kvcache_manager.py b/diffulex/strategy/fast_dllm_v2/engine/kvcache_manager.py new file mode 100644 index 0000000..94aeab6 --- /dev/null +++ b/diffulex/strategy/fast_dllm_v2/engine/kvcache_manager.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from diffulex.config import Config +from diffulex.engine.kvcache_manager import AutoKVCacheManager, KVCacheManagerBase + +if TYPE_CHECKING: + from .sequence import FDV2Sequence + + +@AutoKVCacheManager.register("fast_dllm_v2", is_default=True) +class FastDLLMV2KVCacheManager(KVCacheManagerBase): + def __init__(self, config: Config): + super().__init__(config) + + def can_append(self, seq: "FDV2Sequence") -> bool: + return len(self.free_block_ids) >= (seq.cached_or_caching_num_tokens % self.block_size == 1) + + def may_append(self, seq: "FDV2Sequence") -> None: + if seq.cached_or_caching_num_tokens == 0: + return + block_table = seq.block_table + if not block_table: + return + last_block = self.blocks[block_table[-1]] + if seq.cached_or_caching_num_tokens // self.block_size == len(seq.block_table): + if last_block.hash == -1: + prev_end_token = seq.cached_or_caching_num_tokens - seq.caching_num_tokens - 1 + prev_block_idx = prev_end_token // self.block_size + if prev_block_idx < seq.num_blocks: + token_ids: list[int] = seq.block(prev_block_idx) + prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 + h = self.compute_hash(token_ids, prefix) + last_block.update(h, token_ids) + self.hash_to_block_id[h] = last_block.block_id + block_id = self.free_block_ids[0] + self._allocate_block(block_id) + block_table.append(block_id) \ No newline at end of file diff --git a/diffulex/strategy/fast_dllm_v2/engine/model_runner.py b/diffulex/strategy/fast_dllm_v2/engine/model_runner.py new file mode 100644 index 0000000..1f5f6c4 --- /dev/null +++ b/diffulex/strategy/fast_dllm_v2/engine/model_runner.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import time + +from multiprocessing.synchronize import Event + +import torch +from tqdm import tqdm + +from diffulex.config import Config +from diffulex.engine.sequence import SequenceBase +from diffulex.strategy.fast_dllm_v2.engine.sequence import FDV2Sequence +from diffulex.attention.metadata import set_fetch_fn_for_attn_metadata, set_warming_up, reset_warming_up +from diffulex.engine.model_runner import AutoModelRunner, ModelRunnerBase +from diffulex.strategy.fast_dllm_v2.attention.metadata import fetch_fdv2_attn_metadata, set_fdv2_attn_metadata, reset_fdv2_attn_metadata + + +@AutoModelRunner.register("fast_dllm_v2", is_default=True) +class FastDLLMV2ModelRunner(ModelRunnerBase): + """Reference implementation of Block Diffusion decoding strategy.""" + def __init__(self, config: Config, rank: int, event: Event | list[Event]): + set_fetch_fn_for_attn_metadata(fetch_fdv2_attn_metadata) + self.diffusion_block_size = config.diffusion_block_size + self.mask_token_id = config.mask_token_id + + super().__init__(config, rank, event) + + def prepare_prefill(self, seqs: list[FDV2Sequence]): + input_ids: list[int] = [] + positions: list[int] = [] + cu_seqlens_q = [0] + cu_seqlens_k = [0] + max_seqlen_q = 0 + max_seqlen_k = 0 + slot_mapping: list[int] = [] + block_tables = None + context_lens: list[int] = [] + + for seq in seqs: + seq.init_diffusion_blocks() + + total_seqlen = len(seq) + input_ids.extend(seq[seq.cached_num_tokens:]) + positions.extend(range(seq.cached_num_tokens, total_seqlen)) + context_lens.append(0) + + seqlen_q = total_seqlen - seq.cached_num_tokens + seqlen_k = total_seqlen + cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) + cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) + + max_seqlen_q = max(seqlen_q, max_seqlen_q) + max_seqlen_k = max(seqlen_k, max_seqlen_k) + + if not seq.block_table: + continue + has_padding_mask = seq.pad_prefix_len > 0 + for i in range(0, seq.num_prefix_blocks): + if seq.block_cache_missed[i]: + if has_padding_mask and i == seq.num_prefix_blocks - 1: + slot_mapping.extend([-1] * self.block_size) + else: + start = seq.block_table[i] * self.block_size + if i != seq.num_prefix_blocks - 1: + end = start + self.block_size + else: + end = start + seq.prefix_last_block_num_tokens + slot_mapping.extend(range(start, end)) + else: + slot_mapping.extend([-1] * self.block_size) + + block_tables = self.prepare_block_tables(seqs) + input_ids_tensor = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + positions_tensor = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + context_lens_tensor = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + cu_seqlens_q_tensor = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + cu_seqlens_k_tensor = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + slot_mapping_tensor = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + + set_fdv2_attn_metadata( + True, + cu_seqlens_q=cu_seqlens_q_tensor, + cu_seqlens_k=cu_seqlens_k_tensor, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + slot_mapping=slot_mapping_tensor, + context_lens=context_lens_tensor, + block_tables=block_tables, + diffusion_block_size=self.diffusion_block_size, + kv_cache_layout=self.config.kv_cache_layout, + attn_type="block_attention", + decode_mode="static", + ) + return input_ids_tensor, positions_tensor + + def prepare_decode(self, seqs: list[FDV2Sequence]): + input_ids: list[int] = [] + positions: list[int] = [] + cu_seqlens_q = [0] + cu_seqlens_k = [0] + slot_mapping: list[int] = [] + context_lens: list[int] = [] + need_kv_cache_store = False + max_seqlen_q = 0 + max_seqlen_k = 0 + + for seq in seqs: + seq.next_diffusion_step() + + cur_input_ids, cur_positions, cur_context_len = seq.diffusion_decoding_inputs() + + input_ids.extend(cur_input_ids) + positions.extend(cur_positions) + context_lens.append(cur_context_len) + + seqlen_q = self.diffusion_block_size + seqlen_k = self.diffusion_block_size + max_seqlen_q = max(seqlen_q, max_seqlen_q) + max_seqlen_k = max(seqlen_k, max_seqlen_k) + cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) + cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) + + if seq.diffusion_blocks[-1].is_active: + slot_mapping.extend([-1] * self.diffusion_block_size) + elif seq.diffusion_blocks[-1].is_to_cache: + need_kv_cache_store = True + num_pages_storing = seq.num_page_blocks_in_active_diffusion_block + total_num_pages = len(seq.block_table) + for i in range(0, num_pages_storing): + start = seq.block_table[(total_num_pages - 1) - num_pages_storing + i] * self.block_size + end = start + self.block_size + slot_mapping.extend(range(start, end)) + + input_ids_tensor = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + positions_tensor = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + cu_seqlens_q_tensor = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + cu_seqlens_k_tensor = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + slot_mapping_tensor = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + context_lens_tensor = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + block_tables = self.prepare_block_tables(seqs) + set_fdv2_attn_metadata( + False, + slot_mapping=slot_mapping_tensor, + context_lens=context_lens_tensor, + cu_seqlens_q=cu_seqlens_q_tensor, + cu_seqlens_k=cu_seqlens_k_tensor, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + block_tables=block_tables, + page_block_size=self.config.kvcache_block_size, + diffusion_block_size=self.diffusion_block_size, + kv_cache_layout=self.config.kv_cache_layout, + need_kv_cache_store=need_kv_cache_store, + ) + return input_ids_tensor, positions_tensor + + @torch.inference_mode() + def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool): + if is_prefill or self.enforce_eager or input_ids.size(0) > 512 * self.diffusion_block_size: + return self.model.compute_logits(self.model(input_ids, positions)) + num_tokens = input_ids.size(0) + attn_metadata = fetch_fdv2_attn_metadata() + graph = self.graphs[next(x for x in self.graph_bs if x >= num_tokens)] + graph_vars = self.graph_vars + for key, value in graph_vars.items(): + if key != "outputs": + value.zero_() + + num_seqs = len(attn_metadata.context_lens) + graph_vars["input_ids"][:num_tokens] = input_ids + graph_vars["positions"][:num_tokens] = positions + graph_vars["slot_mapping"][:num_tokens] = attn_metadata.slot_mapping + graph_vars["context_lens"][:num_seqs] = attn_metadata.context_lens + graph_vars["cu_seqlens_q"][:num_seqs + 1] = attn_metadata.cu_seqlens_q + graph_vars["cu_seqlens_k"][:num_seqs + 1] = attn_metadata.cu_seqlens_k + graph_vars["block_tables"][:num_seqs, : attn_metadata.block_tables.size(1)] = attn_metadata.block_tables + graph.replay() + return self.model.compute_logits(graph_vars["outputs"][:num_tokens]) + + def run(self, seqs: list[SequenceBase], is_prefill: bool) -> list[int]: + input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) + temperatures = self.prepare_sample(seqs) if self.rank == 0 else None + logits = self.run_model(input_ids, positions, is_prefill) + sample_output = self.sampler(seqs, logits, temperatures) if self.rank == 0 else None + reset_fdv2_attn_metadata() + return sample_output + + @torch.inference_mode() + def capture_cudagraph(self): + # Enable per-layer forward-plan dispatch to stabilize capture and minimize + # Python branching inside the captured region. + try: + from diffulex.layer.linear import LinearBase + for m in self.model.modules(): + if isinstance(m, LinearBase): + m.enable_forward_plan(True) + except Exception: + pass + + set_warming_up(True) + config = self.config + hf_config = config.hf_config + max_num_seqs = min(self.config.max_num_seqs, 512) + max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size + diffusion_block_size = self.diffusion_block_size + + max_num_tokens = max_num_seqs * diffusion_block_size + + input_ids = torch.zeros(max_num_tokens, dtype=torch.int64) + positions = torch.zeros(max_num_tokens, dtype=torch.int64) + slot_mapping = torch.zeros(max_num_tokens, dtype=torch.int32) + context_lens = torch.zeros(max_num_seqs, dtype=torch.int32) + block_tables = torch.zeros(max_num_seqs, max_num_blocks, dtype=torch.int32) + outputs = torch.zeros(max_num_tokens, hf_config.hidden_size) + + cu_seqlens_q = torch.zeros(max_num_seqs + 1, dtype=torch.int32) + for i in range(max_num_seqs + 1): + cu_seqlens_q[i] = i * diffusion_block_size + + cu_seqlens_k = torch.zeros(max_num_seqs + 1, dtype=torch.int32) + for i in range(max_num_seqs + 1): + cu_seqlens_k[i] = i * config.max_model_len + + self.graph_bs = [] + seq_bs_list = [1, 2, 4, 8] + list(range(16, max_num_seqs + 1, 16)) + for num_seqs in seq_bs_list: + self.graph_bs.append(num_seqs * diffusion_block_size) + self.graphs = {} + self.graph_pool = None + + for num_tokens in tqdm(reversed(self.graph_bs), desc="Capturing CUDA graphs"): + num_seqs = num_tokens // diffusion_block_size + graph = torch.cuda.CUDAGraph() + + set_fdv2_attn_metadata( + False, + slot_mapping=slot_mapping[:num_tokens], + context_lens=context_lens[:num_seqs], + cu_seqlens_q=cu_seqlens_q[:num_seqs + 1], + cu_seqlens_k=cu_seqlens_k[:num_seqs + 1], + max_seqlen_q=diffusion_block_size, + max_seqlen_k=config.max_model_len, + block_tables=block_tables[:num_seqs], + diffusion_block_size=diffusion_block_size, + kv_cache_layout=self.config.kv_cache_layout, + need_kv_cache_store=True, + ) + + outputs[:num_tokens] = self.model(input_ids[:num_tokens], positions[:num_tokens]) # warmup + with torch.cuda.graph(graph, self.graph_pool): + outputs[:num_tokens] = self.model(input_ids[:num_tokens], positions[:num_tokens]) # capture + if self.graph_pool is None: + self.graph_pool = graph.pool() + self.graphs[num_tokens] = graph + torch.cuda.synchronize() + reset_fdv2_attn_metadata() + + self.graph_vars = dict( + input_ids=input_ids, + positions=positions, + slot_mapping=slot_mapping, + context_lens=context_lens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + block_tables=block_tables, + outputs=outputs, + ) + reset_warming_up() \ No newline at end of file diff --git a/diffulex/strategy/fast_dllm_v2/engine/scheduler.py b/diffulex/strategy/fast_dllm_v2/engine/scheduler.py new file mode 100644 index 0000000..bbfec89 --- /dev/null +++ b/diffulex/strategy/fast_dllm_v2/engine/scheduler.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from diffulex.config import Config +from diffulex.engine.scheduler import AutoScheduler, SchedulerBase +from diffulex.engine.sequence import SequenceStatus +from .sequence import FDV2Sequence + + +@AutoScheduler.register("fast_dllm_v2", is_default=True) +class FastDLLMV2Scheduler(SchedulerBase): + def __init__(self, config: Config): + super().__init__(config) + self.diffusion_block_size = config.diffusion_block_size + + def is_finished(self) -> bool: + return not self.waiting and not self.running + + def add(self, seq: FDV2Sequence) -> None: + self.waiting.append(seq) + + def schedule(self) -> tuple[list[FDV2Sequence], bool]: + scheduled: list[FDV2Sequence] = [] + num_seqs = 0 + num_batched_tokens = 0 + while self.waiting and num_seqs < self.max_num_seqs: + seq = self.waiting[0] + projected = len(seq) + seq.diffusion_block_size + if ( + num_batched_tokens + projected > self.max_num_batched_tokens + or not self.block_manager.can_allocate(seq) + ): + break + num_seqs += 1 + self.block_manager.allocate(seq) + num_batched_tokens += projected - seq.num_cached_tokens + seq.status = SequenceStatus.RUNNING + self.waiting.popleft() + self.running.append(seq) + scheduled.append(seq) + if scheduled: + return scheduled, True + + while self.running and num_seqs < self.max_num_seqs: + seq = self.running.popleft() + while not self.block_manager.can_append(seq): + if self.running: + self.preempt(self.running.pop()) + else: + self.preempt(seq) + break + else: + num_seqs += 1 + self.block_manager.may_append(seq) + scheduled.append(seq) + if not scheduled: + diag = { + "phase": "decode", + "waiting": len(self.waiting), + "running": len(self.running), + "max_num_seqs": self.max_num_seqs, + "max_num_batched_tokens": self.max_num_batched_tokens, + "diffusion_block_size": self.diffusion_block_size, + } + candidates = list(self.running)[:3] + list(self.waiting)[:2] + details = [] + for idx, candidate in enumerate(candidates): + try: + can_append = self.block_manager.can_append(candidate) + except Exception: + can_append = "error" + details.append( + f"[{idx}] status={candidate.status.name}, len={len(candidate)}, " + f"diff_block={getattr(candidate, 'diffusion_block_size', '?')}, " + f"new_tokens={getattr(candidate, 'new_tokens', '?')}, " + f"cached={getattr(candidate, 'num_cached_tokens', '?')}, " + f"can_append={can_append}" + ) + raise RuntimeError( + "BDScheduler: unable to schedule any sequence in decode; " + f"state={diag}; details={' | '.join(details)}" + ) + self.running.extendleft(reversed(scheduled)) + return scheduled, False + + def preempt(self, seq: FDV2Sequence) -> None: + seq.status = SequenceStatus.WAITING + self.block_manager.free(seq) + self.waiting.appendleft(seq) + + def postprocess( + self, + seqs: list[FDV2Sequence], + sample_output, + ) -> dict[int, int]: + n_diff_steps: dict[int, int] = {} + for seq in seqs: + seq.reset_new_tokens() + seq_id = str(seq.seq_id) + true_ids_map = sample_output.true_local_ids_map.get(seq_id, {}) + accepted_ids_map = sample_output.accepted_ids_map.get(seq_id, {}) + sampled_tokens_map = sample_output.sampled_tokens_map.get(seq_id, {}) + for block_id, accepted_ids in accepted_ids_map.items(): + if not accepted_ids: + continue + diffusion_block = seq.diffusion_blocks[int(block_id)] + sampled_tokens = sampled_tokens_map.get(block_id, []) + true_local_ids = true_ids_map.get(block_id, []) + for true_local_id, accepted_id in zip(true_local_ids, accepted_ids): + token = sampled_tokens[accepted_id] + diffusion_block.modify_token(true_local_id, token) + if ( + (not seq.ignore_eos and token.item() == self.eos) + or seq.num_completion_tokens >= seq.max_tokens + ): + seq.meet_eos = True + if seq.meet_eos and seq.diffusion_blocks[-1].available_to_cache: + seq.status = SequenceStatus.FINISHED + self.block_manager.free(seq) + if seq in self.running: + self.running.remove(seq) + n_diff_steps[seq.seq_id] = seq.n_steps + seq.post_process() + return n_diff_steps diff --git a/diffulex/strategy/fast_dllm_v2/engine/sequence.py b/diffulex/strategy/fast_dllm_v2/engine/sequence.py new file mode 100644 index 0000000..d105a55 --- /dev/null +++ b/diffulex/strategy/fast_dllm_v2/engine/sequence.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +from enum import Enum, auto +from dataclasses import dataclass + +from diffulex.config import Config +from diffulex.sampling_params import SamplingParams +from diffulex.engine.sequence import AutoSequence, SequenceBase + + +class FDV2BlockStatus(Enum): + ACTIVE = auto() + TO_CACHE = auto() + IN_CACHE = auto() + + +class FDV2SubBlockStatus(Enum): + ACTIVE = auto() + TO_DUAL_CACHE = auto() + IN_DUAL_CACHE = auto() + +@dataclass +class FDV2SubBlock: + sub_block_id: int = 0 + status: FDV2SubBlockStatus = FDV2SubBlockStatus.ACTIVE + +@dataclass +class FDV2Block: + block_id: int = 0 + status: FDV2BlockStatus = FDV2BlockStatus.ACTIVE + + global_start_id: int = 0 + global_end_id: int | None = None + cursor: int = 0 + + mask_token_id: int = 151666 + size: int = 32 + is_prompt: bool = False + + seq: "FDV2Sequence" | None = None + + def __post_init__(self) -> None: + self.global_end_id = self.global_start_id + self.size + + def __getitem__(self, key: int) -> int: + return self.seq[self.global_start_id + key] # type: ignore[index] + + def __len__(self) -> int: + return self.size + + def to_cache(self) -> None: + if self.available_to_cache and not self.is_in_cache: + self.status = FDV2BlockStatus.TO_CACHE + + def in_cache(self) -> None: + if self.is_to_cache: + self.status = FDV2BlockStatus.IN_CACHE + + def modify_token(self, local_token_id: int, modified_to: int) -> None: + if self.seq is None: + raise RuntimeError("Diffusion block is not attached to a sequence.") + target_id = local_token_id + self.global_start_id + assert self.seq.token_ids[target_id] == self.mask_token_id + self.seq.token_ids[target_id] = modified_to.item() # type: ignore[assignment] + self.seq.new_tokens += 1 + + @property + def token_ids(self) -> list[int]: + return self.seq.token_ids[self.global_start_id: self.global_end_id] + + @property + def has_mask_token(self) -> bool: + return any(token == self.mask_token_id for token in self.token_ids) + + @property + def is_active(self) -> bool: + return self.status == FDV2BlockStatus.ACTIVE + + @property + def is_to_cache(self) -> bool: + return self.status == FDV2BlockStatus.TO_CACHE + + @property + def is_in_cache(self) -> bool: + return self.status == FDV2BlockStatus.IN_CACHE + + @property + def available_to_cache(self) -> bool: + return not self.has_mask_token and self.is_active + + @property + def available_in_cache(self) -> bool: + return self.is_to_cache + + @property + def available_to_add_new_block(self) -> bool: + return self.is_in_cache + + @property + def local_mask_tokens(self) -> list[bool]: + return [token_id == self.mask_token_id for token_id in self.token_ids] + + @property + def local_mask_token_ids(self) -> list[int]: + return [idx for idx, is_mask in enumerate(self.local_mask_tokens) if is_mask] + + @property + def global_mask_token_ids(self) -> list[int]: + if self.seq is None: + return [] + offset = self.global_start_id - self.size * sum(block.is_to_cache for block in self.seq.diffusion_blocks) + return [mask_id + offset for mask_id in self.local_mask_token_ids] + + +@AutoSequence.register("fast_dllm_v2", is_default=True) +class FDV2Sequence(SequenceBase): + """Sequence implementation tailored for diffusion-based decoding.""" + + def __init__( + self, + token_ids: list[int], + sampling_params: SamplingParams = SamplingParams(), + config: Config | None = None, + ): + super().__init__(token_ids, sampling_params) + if config is None: + raise ValueError("BDSequence requires a Config instance.") + + self.config = config + self.diffusion_blocks: list[FDV2Block] = [] + self.diffusion_block_size = config.diffusion_block_size + self.mask_token_id = config.mask_token_id + self.n_steps = 0 + + @property + def completion_token_ids(self) -> list[int]: + return self.token_ids[self.prefix_len : ] + + @property + def prefix_len_with_padding(self) -> int: + return self.prefix_len + self.pad_prefix_len + + @property + def diffusion_block_status(self) -> list[FDV2BlockStatus]: + return [block.status for block in self.diffusion_blocks] + + @property + def num_prefix_blocks(self) -> int: + return (self.prefix_len + self.block_size - 1) // self.block_size + + @property + def prefix_last_block_num_tokens(self) -> int: + return self.prefix_len - (self.num_prefix_blocks - 1) * self.block_size + + @property + def active_block_token_ids(self) -> list[int]: + return self.diffusion_blocks[-1].token_ids + + @property + def num_page_blocks_in_active_diffusion_block(self) -> int: + return self.diffusion_block_size // self.block_size + + @property + def cached_num_tokens(self) -> int: + return sum(block.size for block in self.diffusion_blocks if block.is_in_cache) + + @property + def caching_num_tokens(self) -> int: + return sum(block.size for block in self.diffusion_blocks if block.is_to_cache) + + @property + def cached_or_caching_last_token_id(self) -> int: + return max(sum(block.size for block in self.diffusion_blocks if block.is_to_cache or block.is_in_cache) - 1, 0) + + @property + def cached_or_caching_num_tokens(self) -> int: + return self.cached_or_caching_last_token_id + 1 + + @property + def has_to_cache_block(self) -> bool: + return any(block.is_to_cache for block in self.diffusion_blocks) + + @property + def to_cache_last_token_id(self) -> int: + to_cache_num_tokens = 0 + for block in self.diffusion_blocks: + if block.is_to_cache: + to_cache_num_tokens += block.size + return to_cache_num_tokens - 1 + + @property + def num_completion_tokens(self) -> int: + return self.num_tokens - self.num_prompt_tokens + + def reset_new_tokens(self) -> None: + self.new_tokens = 0 + + def diffusion_decoding_inputs(self) -> tuple[list[int], list[int], int]: + return ( + self.active_block_token_ids, + list(range(self.num_tokens - self.diffusion_block_size, self.num_tokens)), + self.num_tokens - self.diffusion_block_size, + ) + + def extend_mask_tokens(self, extend_len: int) -> None: + self.token_ids.extend([self.mask_token_id] * extend_len) + + def init_diffusion_blocks(self) -> None: + """Initialize diffusion blocks: prefix blocks are `TO_CACHE`, last block with mask tokens is `ACTIVE`.""" + self.prefix_len = len(self.token_ids) + block_size = self.diffusion_block_size + + # Calculate prefix blocks and padding + num_prefix_blocks = self.prefix_len // block_size + self.pad_prefix_len = 0 if self.prefix_len % block_size == 0 else block_size - (self.prefix_len % block_size) + + # Add mask tokens for the last prefix block + self.extend_mask_tokens(self.pad_prefix_len) + + # Calculate total blocks needed + total_num_blocks = num_prefix_blocks if self.pad_prefix_len == 0 else num_prefix_blocks + 1 + + # Create all blocks + current_pos = 0 + for block_id in range(total_num_blocks): + # Determine block status + block_tokens = self.token_ids[current_pos:current_pos + block_size] + has_mask_token = any(token == self.mask_token_id for token in block_tokens) + is_last_prefix_block = (block_id == num_prefix_blocks) + + if block_id < num_prefix_blocks: + status = FDV2BlockStatus.TO_CACHE + elif is_last_prefix_block: + status = FDV2BlockStatus.ACTIVE if has_mask_token else FDV2BlockStatus.TO_CACHE + else: + status = FDV2BlockStatus.TO_CACHE + + block = FDV2Block( + block_id=block_id, + status=status, + global_start_id=current_pos, + size=block_size, + mask_token_id=self.mask_token_id, + is_prompt=(block_id <= num_prefix_blocks), + seq=self, + ) + self.diffusion_blocks.append(block) + current_pos += block_size + self.n_steps += 1 + + def next_diffusion_step(self) -> None: + """Append new diffusion block if needed.""" + if self.diffusion_blocks[-1].available_to_add_new_block: + self.extend_mask_tokens(self.diffusion_block_size) + self.diffusion_blocks.append( + FDV2Block( + block_id=len(self.diffusion_blocks), + status=FDV2BlockStatus.ACTIVE, + global_start_id=self.num_tokens - self.diffusion_block_size, + size=self.diffusion_block_size, + mask_token_id=self.mask_token_id, + is_prompt=False, + seq=self, + ) + ) + self.n_steps += 1 + + def post_process(self) -> None: + for block in self.diffusion_blocks: + block.cursor = 0 + if block.is_in_cache: + continue + if block.is_to_cache: + block.in_cache() + elif block.is_active: + if block.available_to_cache: + block.to_cache() + else: + break \ No newline at end of file diff --git a/diffulex/utils/checker.py b/diffulex/utils/checker.py deleted file mode 100755 index e933806..0000000 --- a/diffulex/utils/checker.py +++ /dev/null @@ -1,28 +0,0 @@ -def CHECK_SLOT_MAPPING(seqs, slot_mapping): - # check slot mapping layout - start_idx = 0 - for seq in seqs: - cur_ref_slot_mapping = [] - for idx in range(seq.num_diffusion_blocks): - if seq.active_blocks[idx]: - padding_num_tokens = (seq.num_diffusion_blocks - idx) * seq.diffusion_block_size - cur_ref_slot_mapping.extend([-1] * padding_num_tokens) - break - elif seq.to_cache_blocks[idx]: - cur_ref_slot_mapping.extend([0] * seq.diffusion_block_size) - cur_slot_mapping = slot_mapping[start_idx:start_idx + len(cur_ref_slot_mapping)] - for slot, ref_slot in zip(cur_slot_mapping, cur_ref_slot_mapping): - try: - if ref_slot == -1: - assert slot == -1 - elif ref_slot == 0: - assert slot != -1 - elif ref_slot is not None: - assert slot is not None - except AssertionError: - raise ValueError(f"Slot mapping mismatch: {slot} != {ref_slot}. " - f"Check the implementation of prepare_decode.\n" - f"slot_mapping: {cur_slot_mapping}\n" - f"ref_slot_mapping: {cur_ref_slot_mapping}\n" - f"diff: {[s - r for s, r in zip(cur_slot_mapping, cur_ref_slot_mapping)]}") - start_idx += len(cur_ref_slot_mapping) \ No newline at end of file diff --git a/diffulex/utils/kv_cache_dtype.py b/diffulex/utils/kv_cache_dtype.py new file mode 100644 index 0000000..0ebe312 --- /dev/null +++ b/diffulex/utils/kv_cache_dtype.py @@ -0,0 +1,26 @@ +""" +KV Cache dtype utilities. + +This module has been moved to diffulex.utils.quantization.kv_cache_dtype. +This file is kept for backward compatibility and re-exports from the new location. +""" + +# Re-export everything from the new location for backward compatibility +from diffulex.utils.quantization.kv_cache_dtype import ( + KvCacheDType, + KvCacheDTypeSpec, + parse_kv_cache_dtype, + ensure_scale_tensor, + view_fp8_cache, + _normalize_kv_cache_dtype, + _get_fp8_e4m3_dtype, + _get_fp8_e5m2_dtype, +) + +__all__ = [ + 'KvCacheDType', + 'KvCacheDTypeSpec', + 'parse_kv_cache_dtype', + 'ensure_scale_tensor', + 'view_fp8_cache', +] diff --git a/diffulex/utils/loader.py b/diffulex/utils/loader.py index b2e7cbe..b78f788 100755 --- a/diffulex/utils/loader.py +++ b/diffulex/utils/loader.py @@ -8,6 +8,170 @@ from functools import partial from safetensors import safe_open from diffulex.config import Config +from diffulex.logger import get_logger + +logger = get_logger(__name__) + +def _read_quantize_config(model_dir: str) -> dict: + """Read vLLM-style quantization metadata if present. + + We use this to detect checkpoint formats like `gptq_marlin` which reuse the same + tensor keys (qweight/qzeros/scales[/g_idx]) but have different semantics. + """ + cfg_path = os.path.join(model_dir, "quantize_config.json") + if not os.path.exists(cfg_path): + return {} + try: + with open(cfg_path, "r") as f: + data = json.load(f) + return data if isinstance(data, dict) else {} + except Exception: + return {} + + +def _make_packed_qzeros_constant( + *, + num_groups: int, + out_features: int, + bits: int, + device: torch.device | str, +) -> torch.Tensor: + """Create a GPTQ-style packed qzeros tensor filled with a constant. + + For vLLM GPTQ v1 checkpoints, zeros are stored as (zeros - 1) and then bit-packed + along the output dimension (N). For symmetric quantization, zeros is typically + bias=2^(bits-1), thus stored constant becomes (2^(bits-1) - 1). + + This is primarily used as a *shape-compatible dummy* when loading gptq_marlin + checkpoints where runtime zero-points are intentionally unused (qzeros may be empty). + """ + if bits not in (2, 4, 8): + raise ValueError(f"Unsupported bits={bits} for packed qzeros (expected 2/4/8)") + pack_factor = 32 // bits + if out_features % pack_factor != 0: + raise ValueError( + f"out_features={out_features} not divisible by pack_factor={pack_factor} for bits={bits}" + ) + out_packed = out_features // pack_factor + + # Stored constant for GPTQ v1: bias - 1, where bias = 2^(bits-1). + z = (1 << (bits - 1)) - 1 + packed_val = 0 + for i in range(pack_factor): + packed_val |= (z & ((1 << bits) - 1)) << (bits * i) + + return torch.full( + (int(num_groups), int(out_packed)), + int(packed_val), + dtype=torch.int32, + device=device, + ) + + +def _infer_module_device(module: nn.Module) -> torch.device: + w = getattr(module, "weight", None) + if isinstance(w, torch.Tensor): + return w.device + for p in module.parameters(recurse=False): + return p.device + for b in module.buffers(recurse=False): + return b.device + return torch.device("cpu") + + +def _set_offline_gptq_marlin_weight( + module: nn.Module, + *, + qweight: torch.Tensor, + scales: torch.Tensor, + out_features: int, + in_features: int, + group_size: int, + bits: int, + g_idx: torch.Tensor | None, +) -> None: + """Directly set GPTQ-Marlin-ready offline weights into a Diffulex Linear module. + + This bypasses `set_offline_quantized_weight` because marlin-exported `scales` + use a different layout (e.g. (2*num_groups, out_features/2)) and would fail + the standard GPTQ shape validation. + + We still populate minimal GPTQ metadata/buffers so Diffulex forward chooses + the offline path, and then `LinearBase._maybe_prepare_offline_gptq_marlin` + will only allocate workspace / g_idx metadata (and not repack/permute again). + """ + module_device = _infer_module_device(module) + if qweight.device != module_device: + qweight = qweight.to(device=module_device) + if scales.device != module_device: + scales = scales.to(device=module_device) + if g_idx is not None and g_idx.device != module_device: + g_idx = g_idx.to(device=module_device) + + pack_factor = 32 // int(bits) + group_size_norm = in_features if group_size == -1 else group_size + if group_size_norm <= 0 or in_features % group_size_norm != 0: + raise ValueError(f"Invalid group_size={group_size} for in_features={in_features}") + num_groups = in_features // group_size_norm + + # Minimal qzeros to satisfy offline presence checks. (Marlin GPTQ symmetric doesn't use runtime zp.) + qzeros = _make_packed_qzeros_constant( + num_groups=num_groups, + out_features=out_features, + bits=int(bits), + device=module_device, + ) + + # Populate GPTQ buffers (note: scales here are marlin layout; gptq kernels should not be used). + module.gptq_qweight = qweight + module.gptq_qzeros = qzeros + module.gptq_scales = scales.to(dtype=torch.float16) + if g_idx is None: + module.gptq_g_idx = torch.empty(0, dtype=torch.int32, device=module_device) + else: + if getattr(g_idx, "numel", lambda: 1)() == 0: + module.gptq_g_idx = torch.empty(0, dtype=torch.int32, device=module_device) + else: + module.gptq_g_idx = g_idx.to(dtype=torch.int32) + + # Also mark as marlin-ready so LinearBase won't repack/permute again. + module.gptq_marlin_qweight = qweight + module.gptq_marlin_scales = module.gptq_scales + + module._offline_quant_format = torch.tensor(1, dtype=torch.int8, device=module_device) + module._offline_quant_bits = torch.tensor(int(bits), dtype=torch.int32, device=module_device) + module._offline_quant_group_size = torch.tensor(group_size, dtype=torch.int32, device=module_device) + module._offline_quant_out_features = torch.tensor(out_features, dtype=torch.int32, device=module_device) + module._offline_quant_in_features = torch.tensor(in_features, dtype=torch.int32, device=module_device) + module._gptq_is_shuffled = torch.tensor(False, dtype=torch.bool, device=module_device) + # Keep Python-side mirrors in sync; runtime fast paths rely on these and + # must not `.item()` from CUDA tensors (graph capture / perf). + if hasattr(module, "_offline_quant_format_py"): + module._offline_quant_format_py = 1 + if hasattr(module, "_offline_quant_bits_py"): + module._offline_quant_bits_py = int(bits) + if hasattr(module, "_offline_quant_group_size_py"): + module._offline_quant_group_size_py = int(group_size) + if hasattr(module, "_offline_quant_out_features_py"): + module._offline_quant_out_features_py = int(out_features) + if hasattr(module, "_offline_quant_in_features_py"): + module._offline_quant_in_features_py = int(in_features) + if hasattr(module, "_gptq_is_shuffled_py"): + module._gptq_is_shuffled_py = False + if hasattr(module, "_gptq_marlin_is_prepared_py"): + module._gptq_marlin_is_prepared_py = False + + # Reset marlin-prep caches (workspace/zp/g_idx meta will be created on first forward). + module._gptq_marlin_is_prepared = torch.tensor(False, dtype=torch.bool, device=module_device) + module.gptq_marlin_zp = torch.empty(0, dtype=torch.int32, device=module_device) + module.gptq_marlin_g_idx = torch.empty(0, dtype=torch.int32, device=module_device) + module.gptq_marlin_g_idx_sort_indices = torch.empty(0, dtype=torch.int32, device=module_device) + module.gptq_marlin_workspace = torch.empty(0, dtype=torch.int32, device=module_device) + + # Drop bf16 weight Parameter if present (to free memory and avoid accidental fallback). + if hasattr(module, "_parameters") and "weight" in module._parameters: + module._parameters.pop("weight", None) + setattr(module, "weight", None) def load_lora_config(lora_path: str) -> dict: @@ -41,24 +205,489 @@ def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor): param.data.copy_(loaded_weight) +def _load_gptq_awq_weights(model: nn.Module, config: Config): + """Load GPTQ/AWQ offline quantized weights from checkpoint. + + Args: + model: Model module + config: Config with model path + + Returns: + Tuple of (loaded_gptq_count, loaded_awq_count, skipped_count) + """ + loaded_gptq = 0 + loaded_awq = 0 + skipped = 0 + + # Check if model is configured for GPTQ or AWQ + weight_attn_dtype = getattr(config, "linear_attn_weight_dtype", "bf16") or "bf16" + weight_mlp_dtype = getattr(config, "linear_mlp_weight_dtype", "bf16") or "bf16" + quantize_cfg = _read_quantize_config(getattr(config, "model", "")) + checkpoint_format = (quantize_cfg.get("checkpoint_format") or "").strip().lower() + ckpt_bits = int(quantize_cfg.get("bits", 0) or 0) + ckpt_group_size = int(quantize_cfg.get("group_size", 0) or 0) + + # NOTE: marlin variants reuse the same offline GPTQ/AWQ checkpoint keys + # (qweight/qzeros/scales[/g_idx]) and are repacked lazily in `LinearBase` + # on first forward. + gptq_dtypes = {"gptq", "gptq_marlin"} + awq_dtypes = {"awq", "awq_marlin"} + use_gptq = (weight_attn_dtype or "").lower() in gptq_dtypes or (weight_mlp_dtype or "").lower() in gptq_dtypes + use_awq = (weight_attn_dtype or "").lower() in awq_dtypes or (weight_mlp_dtype or "").lower() in awq_dtypes + want_gptq_marlin = (weight_attn_dtype or "").lower() == "gptq_marlin" or (weight_mlp_dtype or "").lower() == "gptq_marlin" + want_awq_marlin = (weight_attn_dtype or "").lower() == "awq_marlin" or (weight_mlp_dtype or "").lower() == "awq_marlin" + is_gptq_marlin_ckpt = checkpoint_format == "gptq_marlin" + is_awq_marlin_ckpt = checkpoint_format == "awq_marlin" + + if not (use_gptq or use_awq): + return loaded_gptq, loaded_awq, skipped + + all_files = list(glob(os.path.join(config.model, "*.safetensors"))) + + # Scan keys once and remember which file contains each key. + # This avoids the O(num_modules * num_files) "search every file for every module" pattern below. + key_to_file: dict[str, str] = {} + module_keys: dict[str, dict[str, str]] = {} + offline_suffixes = (".qweight", ".qzeros", ".scales", ".g_idx") + for file in all_files: + with safe_open(file, "pt", "cpu") as f: + for key in f.keys(): + if not key.endswith(offline_suffixes): + continue + key_to_file[key] = file + # Group by module prefix: {prefix}.qweight, {prefix}.qzeros, {prefix}.scales, {prefix}.g_idx (GPTQ only) + if key.endswith(".qweight"): + prefix = key[:-8] + module_keys.setdefault(prefix, {})["qweight"] = key + elif key.endswith(".qzeros"): + prefix = key[:-7] + module_keys.setdefault(prefix, {})["qzeros"] = key + elif key.endswith(".scales"): + prefix = key[:-7] + module_keys.setdefault(prefix, {})["scales"] = key + else: # .g_idx + prefix = key[:-6] + module_keys.setdefault(prefix, {})["g_idx"] = key + + # Cache modules lookup to avoid rebuilding dict(model.named_modules()) repeatedly. + named_modules = dict(model.named_modules()) + offline_capable_modules: dict[str, nn.Module] = { + name: m for name, m in named_modules.items() if hasattr(m, "set_offline_quantized_weight") + } + + def _find_offline_capable_module(module_name: str) -> nn.Module | None: + """Best-effort resolve module_name to a module with offline quant support.""" + m = offline_capable_modules.get(module_name) + if m is not None: + return m + + # Try a few naming fallbacks (keep behavior compatible with the previous implementation). + leaf = module_name.split(".")[-1] if module_name else module_name + for name, cand in offline_capable_modules.items(): + if ( + name == module_name + or name.endswith("." + module_name) + or module_name.endswith("." + name) + or (name.split(".")[-1] == leaf) + ): + return cand + return None + + def _load_tensors_for_prefix(key_dict: dict[str, str], *, want_g_idx: bool) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + """Load qweight/qzeros/scales/(g_idx) from the minimal set of safetensors files.""" + qweight = qzeros = scales = g_idx = None + keys = [key_dict.get("qweight"), key_dict.get("qzeros"), key_dict.get("scales")] + if want_g_idx: + keys.append(key_dict.get("g_idx")) + files_needed = {key_to_file.get(k) for k in keys if k} + files_needed.discard(None) + + for file in files_needed: + with safe_open(file, "pt", "cpu") as f: + if qweight is None and (key_dict.get("qweight") in f.keys()): + qweight = f.get_tensor(key_dict["qweight"]) + if qzeros is None and (key_dict.get("qzeros") in f.keys()): + qzeros = f.get_tensor(key_dict["qzeros"]) + if scales is None and (key_dict.get("scales") in f.keys()): + scales = f.get_tensor(key_dict["scales"]) + if want_g_idx and g_idx is None and ("g_idx" in key_dict) and (key_dict["g_idx"] in f.keys()): + g_idx = f.get_tensor(key_dict["g_idx"]) + return qweight, qzeros, scales, g_idx + + # Load GPTQ/AWQ weights for each module + packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) + + for prefix, key_dict in module_keys.items(): + if "qweight" not in key_dict or "qzeros" not in key_dict or "scales" not in key_dict: + continue # Skip incomplete sets + + # Map prefix to module name + module_name = prefix + for k, (v, _) in packed_modules_mapping.items(): + if k in prefix: + module_name = prefix.replace(k, v) + break + + try: + module = _find_offline_capable_module(module_name) + if module is None: + skipped += 1 + continue + + # Determine format: check if g_idx exists (GPTQ) or not (AWQ) + has_g_idx = "g_idx" in key_dict + is_gptq_keyset = has_g_idx or is_gptq_marlin_ckpt + if is_gptq_keyset and use_gptq: + format = "gptq" + elif (not is_gptq_keyset) and use_awq: + format = "awq" + else: + # Prefer GPTQ if both are enabled and g_idx exists + format = "gptq" if (use_gptq and is_gptq_keyset) else ("awq" if use_awq else None) + + if format is None: + skipped += 1 + continue + + # Load tensors from the minimal set of safetensors files. + qweight, qzeros, scales, g_idx = _load_tensors_for_prefix( + key_dict, want_g_idx=(format == "gptq") + ) + + if qweight is None or qzeros is None or scales is None: + skipped += 1 + continue + + # Infer dimensions from tensor shapes (vLLM standard format) WITHOUT + # assuming bits=4. This enables GPTQ W2/W4/W8 checkpoints. + if format == "gptq": + if is_gptq_marlin_ckpt: + # gptq_marlin export uses Marlin repacked qweight/scales layouts. + # Empirically (vLLM marlin): qweight is packed on K in tiles of 16, + # so qweight.shape[0] == in_features / 16; and scales carries original N. + out_features = int(scales.shape[1]) if scales.ndim == 2 else int(qweight.shape[1]) + in_features = int(qweight.shape[0]) * 16 + if ckpt_bits not in (4, 8): + logger.warning( + f"gptq_marlin requires bits=4/8, got bits={ckpt_bits} for {module_name}. Skipping." + ) + skipped += 1 + continue + # Keep pack_factor for dummy qzeros creation later. + pack_factor = 32 // int(ckpt_bits) + else: + # Standard GPTQ: qweight [K/pack, N] + out_features = int(qweight.shape[1]) + # qzeros: [K/group, N/pack] (may be empty for some checkpoints) + if getattr(qzeros, "numel", lambda: 1)() == 0: + if ckpt_bits not in (2, 4, 8): + logger.warning( + f"qzeros is empty and cannot infer bits for {module_name}. " + "Please ensure quantize_config.json contains bits (2/4/8). Skipping." + ) + skipped += 1 + continue + pack_factor = 32 // int(ckpt_bits) + else: + if int(qzeros.shape[1]) <= 0 or out_features % int(qzeros.shape[1]) != 0: + logger.warning( + f"Cannot infer GPTQ pack_factor from qzeros for {module_name}: " + f"qzeros.shape={tuple(qzeros.shape)}, qweight.shape={tuple(qweight.shape)}. Skipping." + ) + skipped += 1 + continue + pack_factor = out_features // int(qzeros.shape[1]) # 32 / bits + in_features = int(qweight.shape[0]) * pack_factor + else: + # awq: qweight: [K, N/pack], scales: [K/group, N] + out_features = int(scales.shape[1]) if scales.ndim == 2 else int(qweight.shape[1]) + if int(qweight.shape[1]) <= 0 or out_features % int(qweight.shape[1]) != 0: + logger.warning( + f"Cannot infer AWQ pack_factor from scales/qweight for {module_name}: " + f"scales.shape={tuple(scales.shape)}, qweight.shape={tuple(qweight.shape)}. Skipping." + ) + skipped += 1 + continue + pack_factor = out_features // int(qweight.shape[1]) # 32 / bits (expected 8 for AWQ 4-bit) + in_features = int(qweight.shape[0]) + + # Infer group_size from qzeros/scales. + # qzeros/scales are groupwise on K (in_features). + group_size = 128 + if ckpt_group_size not in (0, None): + # quantize_config.json stores actual group_size (may be -1) + group_size = int(ckpt_group_size) + else: + if is_gptq_marlin_ckpt and len(scales.shape) == 2 and int(scales.shape[0]) > 0: + # vLLM marlin_permute_scales keeps shape [num_groups, N] for most cases. + # Some older/alternate layouts may use [2*num_groups, N/2]. + num_groups = int(scales.shape[0]) + if num_groups > 0 and in_features % num_groups == 0: + group_size = in_features // num_groups + elif num_groups % 2 == 0 and (in_features % (num_groups // 2)) == 0: + # Fallback for legacy 2*num_groups layouts. + group_size = in_features // (num_groups // 2) + else: + num_groups = int(qzeros.shape[0]) if getattr(qzeros, "numel", lambda: 1)() > 0 else 0 + if num_groups > 0 and in_features % num_groups == 0: + group_size = in_features // num_groups + elif len(scales.shape) == 2 and int(scales.shape[0]) > 0 and in_features % int(scales.shape[0]) == 0: + group_size = in_features // int(scales.shape[0]) + + # For gptq_marlin checkpoints qzeros may be empty; create a shape-compatible dummy + # packed qzeros so LinearBase considers offline weights present. + if ( + format == "gptq" + and getattr(qzeros, "numel", lambda: 1)() == 0 + and (want_gptq_marlin or is_gptq_marlin_ckpt) + and ckpt_bits in (2, 4, 8) + ): + group_size_norm = in_features if group_size == -1 else group_size + if group_size_norm <= 0 or (in_features % group_size_norm) != 0: + logger.warning( + f"Invalid group_size={group_size} for {module_name} with in_features={in_features}. Skipping." + ) + skipped += 1 + continue + num_groups = in_features // group_size_norm + try: + qzeros = _make_packed_qzeros_constant( + num_groups=num_groups, + out_features=out_features, + bits=int(ckpt_bits), + device=qweight.device, + ) + except Exception as e: + logger.warning(f"Failed to create dummy qzeros for {module_name}: {e}. Skipping.") + skipped += 1 + continue + + # Handle tensor parallel sharding (TP>1). + # ColumnParallelLinear: tp_dim=0 (shard N/out_features) + # RowParallelLinear : tp_dim=1 (shard K/in_features) + tp_size = int(getattr(module, "tp_size", 1) or 1) + tp_rank = int(getattr(module, "tp_rank", 0) or 0) + tp_dim = getattr(module, "tp_dim", None) + if tp_size > 1: + if tp_dim not in (0, 1): + logger.warning( + f"Unsupported tp_dim={tp_dim} for offline quantized weights. Skipping {module_name}." + ) + skipped += 1 + continue + + # Shard along output features (N) for column-parallel modules. + if tp_dim == 0: + if out_features % tp_size != 0: + logger.warning( + f"out_features={out_features} not divisible by TP={tp_size} for {module_name}. " + "Skipping offline quant weights for this module." + ) + skipped += 1 + continue + out_per = out_features // tp_size + out_start = tp_rank * out_per + out_end = out_start + out_per + if out_per % pack_factor != 0: + logger.warning( + f"out_features_per_partition={out_per} not divisible by pack_factor={pack_factor} " + f"for {module_name}. Skipping." + ) + skipped += 1 + continue + out_packed_per = out_per // pack_factor + out_packed_start = out_start // pack_factor + out_packed_end = out_packed_start + out_packed_per + + if format == "gptq": + if is_gptq_marlin_ckpt: + # Marlin qweight packs N by a factor (bits/2): N_packed = N * (bits/2) + n_factor = int(ckpt_bits) // 2 + if n_factor <= 0: + logger.warning( + f"invalid gptq_marlin n_factor for bits={ckpt_bits} ({module_name}). Skipping." + ) + skipped += 1 + continue + qweight = qweight[:, (out_start * n_factor):(out_end * n_factor)] + # scales keep original N + scales = scales[:, out_start:out_end] + # qzeros stays dummy/empty; g_idx stays on K. + out_features = out_per + else: + # qweight: [K/pack, N] + qweight = qweight[:, out_start:out_end] + # qzeros: [K/group, N/pack] + qzeros = qzeros[:, out_packed_start:out_packed_end] + # scales: [K/group, N] + scales = scales[:, out_start:out_end] + out_features = out_per + else: + # awq qweight: [K, N/pack] + qweight = qweight[:, out_packed_start:out_packed_end] + qzeros = qzeros[:, out_packed_start:out_packed_end] + scales = scales[:, out_start:out_end] + out_features = out_per + + # Shard along input features (K) for row-parallel modules. + elif tp_dim == 1: + if in_features % tp_size != 0: + logger.warning( + f"in_features={in_features} not divisible by TP={tp_size} for {module_name}. " + "Skipping offline quant weights for this module." + ) + skipped += 1 + continue + in_per = in_features // tp_size + in_start = tp_rank * in_per + in_end = in_start + in_per + if group_size <= 0 or (in_per % group_size) != 0 or (in_start % group_size) != 0: + logger.warning( + f"group_size={group_size} incompatible with TP sharding for {module_name} " + f"(in_per={in_per}, in_start={in_start}). Skipping." + ) + skipped += 1 + continue + g_start = in_start // group_size + g_end = in_end // group_size + + if format == "gptq": + if is_gptq_marlin_ckpt: + # Marlin qweight packs K in tiles of 16: K_packed = K / 16 + if in_start % 16 != 0: + logger.warning( + f"gptq_marlin requires in_start divisible by 16, got in_start={in_start} " + f"for {module_name}. Skipping." + ) + skipped += 1 + continue + q_start = in_start // 16 + q_end = in_end // 16 + qweight = qweight[q_start:q_end, :] + # Shard scales on group dimension (K/group). + # vLLM marlin_permute_scales typically returns [num_groups, N]. + group_size_norm = in_features if group_size == -1 else group_size + expected_num_groups = in_features // group_size_norm if group_size_norm > 0 else 0 + if expected_num_groups <= 0: + logger.warning( + f"invalid expected_num_groups={expected_num_groups} for {module_name}. Skipping." + ) + skipped += 1 + continue + if int(scales.shape[0]) == expected_num_groups: + scales = scales[g_start:g_end, :] + elif int(scales.shape[0]) == 2 * expected_num_groups: + # Legacy/alternate layout: [2*num_groups, N/2] + scales = scales[(2 * g_start):(2 * g_end), :] + else: + logger.warning( + f"unexpected gptq_marlin scales.shape[0]={int(scales.shape[0])} " + f"(expected {expected_num_groups} or {2*expected_num_groups}) for {module_name}. Skipping." + ) + skipped += 1 + continue + if g_idx is not None and getattr(g_idx, "numel", lambda: 1)() > 0: + g_idx = g_idx[in_start:in_end] + in_features = in_per + else: + # qweight: [K/pack, N] (packed on K) + if in_start % pack_factor != 0: + logger.warning( + f"in_start={in_start} not divisible by pack_factor={pack_factor} " + f"for {module_name}. Skipping." + ) + skipped += 1 + continue + q_start = in_start // pack_factor + q_end = in_end // pack_factor + qweight = qweight[q_start:q_end, :] + qzeros = qzeros[g_start:g_end, :] + scales = scales[g_start:g_end, :] + if g_idx is not None and getattr(g_idx, "numel", lambda: 1)() > 0: + g_idx = g_idx[in_start:in_end] + in_features = in_per + else: + # awq qweight: [K, N/pack] + qweight = qweight[in_start:in_end, :] + qzeros = qzeros[g_start:g_end, :] + scales = scales[g_start:g_end, :] + in_features = in_per + + # Treat empty g_idx as "not provided" for GPTQ (desc_act=False checkpoints often store empty). + if g_idx is not None and getattr(g_idx, "numel", lambda: 1)() == 0: + g_idx = None + + # Set offline quantized weight + try: + if format == "gptq" and is_gptq_marlin_ckpt: + if ckpt_bits not in (4, 8): + raise ValueError(f"gptq_marlin checkpoint requires bits=4/8, got bits={ckpt_bits}") + _set_offline_gptq_marlin_weight( + module, + qweight=qweight, + scales=scales, + out_features=out_features, + in_features=in_features, + group_size=group_size, + bits=int(ckpt_bits), + g_idx=g_idx, + ) + else: + module.set_offline_quantized_weight( + format=format, + qweight=qweight, + qzeros=qzeros, + scales=scales, + out_features=out_features, + in_features=in_features, + group_size=group_size, + g_idx=g_idx, + ) + if format == "gptq": + loaded_gptq += 1 + else: + loaded_awq += 1 + except Exception as e: + logger.exception(f"Failed to load offline quantized weights for {module_name}: {e}") + skipped += 1 + + except Exception as e: + logger.exception(f"Error loading offline quantized weights for {prefix}: {e}") + skipped += 1 + + return loaded_gptq, loaded_awq, skipped + + def load_model(model: nn.Module, config: Config): """Load model weights and optionally LoRA weights.""" # Enable LoRA for linear layers if LoRA is enabled if config.use_lora and config.lora_path: lora_config = load_lora_config(config.lora_path) if lora_config: - print(f"LoRA Config Loaded: {lora_config}") + logger.info(f"LoRA Config Loaded: {lora_config}") model = enable_lora_for_model(model, lora_config) else: - print("No adapter_config.json found, using default LoRA parameters") + logger.info("No adapter_config.json found, using default LoRA parameters") default_config = {'r': 16, 'lora_alpha': 32.0, 'lora_dropout': 0.0} model = enable_lora_for_model(model, default_config) - # Load base model weights + # First, try to load offline quantized weights (GPTQ/AWQ) + loaded_gptq, loaded_awq, skipped_offline = _load_gptq_awq_weights(model, config) + if loaded_gptq > 0 or loaded_awq > 0: + print(f"Loaded offline quantized weights: GPTQ={loaded_gptq}, AWQ={loaded_awq}, skipped={skipped_offline}") + + # Load base model weights (only for non-offline-quantized layers) packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) for file in tqdm(glob(os.path.join(config.model, "*.safetensors")), desc="Loading base model"): with safe_open(file, "pt", "cpu") as f: for weight_name in f.keys(): + # Skip GPTQ/AWQ keys (already loaded) + if any( + weight_name.endswith(suffix) + for suffix in [".qweight", ".qzeros", ".scales", ".g_idx"] + ): + continue + for k in packed_modules_mapping: if k in weight_name: @@ -73,31 +702,52 @@ def load_model(model: nn.Module, config: Config): param_name = weight_name.replace(k, v) if "layernorm" in param_name: - param = model.get_parameter(param_name) - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, f.get_tensor(weight_name)) + try: + param = model.get_parameter(param_name) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, f.get_tensor(weight_name)) + except (AttributeError, KeyError): + # Try buffer fallback for non-parameter weights + try: + buffer = model.get_buffer(param_name) + buffer.copy_(f.get_tensor(weight_name)) + except (AttributeError, KeyError): + pass else: - param = model.get_parameter(param_name) - weight_loader = partial(getattr(param, "weight_loader"), param, f.get_tensor(weight_name)) - if shard_id is None: - weight_loader() - else: - weight_loader(shard_id) + try: + param = model.get_parameter(param_name) + weight_loader = partial(getattr(param, "weight_loader"), param, f.get_tensor(weight_name)) + if shard_id is None: + weight_loader() + else: + weight_loader(shard_id) + except (AttributeError, KeyError): + # Parameter might not exist if offline quantized weights were loaded + # Skip it silently + pass break else: - param = model.get_parameter(weight_name) - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, f.get_tensor(weight_name)) + try: + param = model.get_parameter(weight_name) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, f.get_tensor(weight_name)) + except (AttributeError, KeyError): + # Try buffer fallback for non-parameter weights + try: + buffer = model.get_buffer(weight_name) + buffer.copy_(f.get_tensor(weight_name)) + except (AttributeError, KeyError): + pass # Load LoRA weights if enabled if config.use_lora and config.lora_path: if os.path.exists(config.lora_path): - print(f"Loading LoRA weights from {config.lora_path}") + logger.info(f"Loading LoRA weights from {config.lora_path}") load_lora_weights_fn = partial(load_lora_weights, model, config.lora_path) packed_modules_mapping = packed_modules_mapping if config.model_name == "llada" else None model = load_lora_weights_fn(packed_modules_mapping=packed_modules_mapping) else: - print(f"Warning: LoRA path {config.lora_path} does not exist, skipping LoRA loading") + logger.warning(f"LoRA path {config.lora_path} does not exist, skipping LoRA loading") return model @@ -189,16 +839,16 @@ def load_lora_weights(model: nn.Module, lora_path: str, packed_modules_mapping: module.lora_B.data.copy_(found_b) applied_count += 1 except Exception as e: - print(f"Failed to load LoRA weights for {name}: {e}") + logger.warning(f"Failed to load LoRA weights for {name}: {e}") for module in model.modules(): if hasattr(module, 'merge_lora'): module.merge_lora() - print(f"LoRA weights applied to {applied_count} layers and merged") + logger.info(f"LoRA weights applied to {applied_count} layers and merged") except Exception as e: - print(f"Error loading LoRA weights: {e}") - print("Continuing with base model only") + logger.error(f"Error loading LoRA weights: {e}") + logger.warning("Continuing with base model only") return model diff --git a/diffulex/utils/quantization/__init__.py b/diffulex/utils/quantization/__init__.py new file mode 100644 index 0000000..78f8013 --- /dev/null +++ b/diffulex/utils/quantization/__init__.py @@ -0,0 +1,69 @@ +""" +Quantization module for diffulex. + +This module provides a flexible, extensible quantization architecture that supports: +- KV Cache quantization +- Future: Weight quantization, Activation quantization, etc. + +The architecture uses strategy pattern with context management to minimize coupling. +""" + +from diffulex.utils.quantization.context import ( + QuantizationContext, + get_quantization_context, + set_kv_cache_strategy, + get_kv_cache_strategy, +) +from diffulex.utils.quantization.factory import QuantizationStrategyFactory +from diffulex.utils.quantization.config import ( + QuantizationConfig, + KVCacheQuantConfig, + WeightQuantConfig, + ActivationQuantConfig, +) +from diffulex.utils.quantization.registry import ( + create_kv_cache_strategy, + registered_kv_cache_dtypes, +) +from diffulex.utils.quantization.strategy import ( + QuantizationStrategy, + KVCacheQuantizationStrategy, + WeightQuantizationStrategy, +) +# Re-export kv_cache_dtype utilities for backward compatibility +from diffulex.utils.quantization.kv_cache_dtype import ( + KvCacheDType, + KvCacheDTypeSpec, + parse_kv_cache_dtype, + ensure_scale_tensor, + view_fp8_cache, +) + +__all__ = [ + # Context + 'QuantizationContext', + 'get_quantization_context', + 'set_kv_cache_strategy', + 'get_kv_cache_strategy', + # Factory + 'QuantizationStrategyFactory', + # Config + 'QuantizationConfig', + 'KVCacheQuantConfig', + 'WeightQuantConfig', + 'ActivationQuantConfig', + # Registry + 'create_kv_cache_strategy', + 'registered_kv_cache_dtypes', + # Strategy interfaces + 'QuantizationStrategy', + 'KVCacheQuantizationStrategy', + 'WeightQuantizationStrategy', + # KV Cache dtype utilities (for backward compatibility) + 'KvCacheDType', + 'KvCacheDTypeSpec', + 'parse_kv_cache_dtype', + 'ensure_scale_tensor', + 'view_fp8_cache', +] + diff --git a/diffulex/utils/quantization/config.py b/diffulex/utils/quantization/config.py new file mode 100644 index 0000000..5e30ef9 --- /dev/null +++ b/diffulex/utils/quantization/config.py @@ -0,0 +1,71 @@ +""" +Quantization configuration objects for Diffulex. + +Diffulex currently exposes a single user-facing knob: `Config.kv_cache_dtype`. +This module introduces explicit config dataclasses so we can extend to +weights/activations quantization without growing ad-hoc fields everywhere. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class KVCacheQuantConfig: + """KV-cache quantization configuration.""" + + dtype: str = "bf16" + # Future extension points: + # - scale_mode: "running_max" | "static" | ... + # - per_channel: bool + + +@dataclass(frozen=True) +class WeightQuantConfig: + """Weight quantization configuration (placeholder).""" + + method: str = "none" + # Linear (by kind) + linear_attn_dtype: str = "bf16" + linear_mlp_dtype: str = "bf16" + + +@dataclass(frozen=True) +class ActivationQuantConfig: + """Activation quantization configuration (placeholder).""" + + # Linear activations (by kind) + linear_attn_dtype: str = "bf16" + linear_mlp_dtype: str = "bf16" + + +@dataclass(frozen=True) +class QuantizationConfig: + """Top-level quantization configuration for Diffulex.""" + + kv_cache: KVCacheQuantConfig = KVCacheQuantConfig() + weights: WeightQuantConfig = WeightQuantConfig() + activations: ActivationQuantConfig = ActivationQuantConfig() + + @classmethod + def from_diffulex_config(cls, config) -> "QuantizationConfig": + # Keep this tolerant: Diffulex's Config is a simple dataclass and may evolve. + kv_cache_dtype = getattr(config, "kv_cache_dtype", "bf16") or "bf16" + linear_attn_weight_dtype = getattr(config, "linear_attn_weight_dtype", "bf16") or "bf16" + linear_mlp_weight_dtype = getattr(config, "linear_mlp_weight_dtype", "bf16") or "bf16" + linear_attn_act_dtype = getattr(config, "linear_attn_act_dtype", "bf16") or "bf16" + linear_mlp_act_dtype = getattr(config, "linear_mlp_act_dtype", "bf16") or "bf16" + return cls( + kv_cache=KVCacheQuantConfig(dtype=kv_cache_dtype), + weights=WeightQuantConfig( + linear_attn_dtype=linear_attn_weight_dtype, + linear_mlp_dtype=linear_mlp_weight_dtype, + ), + activations=ActivationQuantConfig( + linear_attn_dtype=linear_attn_act_dtype, + linear_mlp_dtype=linear_mlp_act_dtype, + ), + ) + + diff --git a/diffulex/utils/quantization/context.py b/diffulex/utils/quantization/context.py new file mode 100644 index 0000000..183319a --- /dev/null +++ b/diffulex/utils/quantization/context.py @@ -0,0 +1,183 @@ +""" +Quantization context manager. + +This module provides a thread-local context for managing quantization strategies, +avoiding the need to pass quantization parameters through multiple layers. +""" + +from typing import Dict, Optional +from threading import local + +from diffulex.utils.quantization.strategy import ( + QuantizationStrategy, + KVCacheQuantizationStrategy, + WeightQuantizationStrategy, + LinearQuantizationStrategy, +) + + +class QuantizationContext: + """ + Quantization context manager. + + Uses thread-local storage to support multi-threaded/multi-process environments. + Quantization strategies are registered and retrieved through context, avoiding parameter passing. + """ + + _thread_local = local() + + def __init__(self): + self._strategies: Dict[str, QuantizationStrategy] = {} + # Step-local cache for activation quantization (e.g., W8A8 per-row quant). + # Keyed by tensor identity+layout to allow reuse within a single forward/step. + self._act_quant_cache: Dict[tuple, tuple] = {} + + @classmethod + def current(cls) -> 'QuantizationContext': + """Get current thread's quantization context.""" + if not hasattr(cls._thread_local, 'context'): + cls._thread_local.context = QuantizationContext() + return cls._thread_local.context + + def set_strategy(self, key: str, strategy: QuantizationStrategy): + """Set a quantization strategy.""" + self._strategies[key] = strategy + + def get_strategy(self, key: str, default: Optional[QuantizationStrategy] = None) -> Optional[QuantizationStrategy]: + """Get a quantization strategy.""" + return self._strategies.get(key, default) + + def get_kv_cache_strategy(self) -> Optional[KVCacheQuantizationStrategy]: + """Get KV Cache quantization strategy.""" + strategy = self._strategies.get('kv_cache') + if strategy is None: + return None + if isinstance(strategy, KVCacheQuantizationStrategy): + return strategy + raise TypeError( + f"KV cache strategy must be KVCacheQuantizationStrategy, got {type(strategy)}" + ) + + def get_weight_strategy(self) -> Optional[WeightQuantizationStrategy]: + """Get weight quantization strategy.""" + strategy = self._strategies.get('weight') + if strategy is None: + return None + if isinstance(strategy, WeightQuantizationStrategy): + return strategy + raise TypeError( + f"Weight strategy must be WeightQuantizationStrategy, got {type(strategy)}" + ) + + def set_linear_strategy(self, kind: str, strategy: LinearQuantizationStrategy) -> None: + """Set Linear quantization strategy for a kind ("attn"/"mlp"/"other").""" + key = f"linear_{(kind or 'other').strip().lower() or 'other'}" + self._strategies[key] = strategy + + def get_linear_strategy(self, kind: str) -> Optional[LinearQuantizationStrategy]: + """Get Linear quantization strategy for a kind ("attn"/"mlp"/"other").""" + key = f"linear_{(kind or 'other').strip().lower() or 'other'}" + strategy = self._strategies.get(key) + if strategy is None: + return None + if isinstance(strategy, LinearQuantizationStrategy): + return strategy + raise TypeError( + f"{key} strategy must be LinearQuantizationStrategy, got {type(strategy)}" + ) + + def clear(self): + """Clear all strategies.""" + self._strategies.clear() + self._act_quant_cache.clear() + + # ---- Activation quantization cache helpers (step-local) ---- + def clear_act_quant_cache(self) -> None: + self._act_quant_cache.clear() + + def _act_quant_cache_key(self, x) -> tuple: + # Include version to avoid reusing after in-place mutation. + # data_ptr() is stable for the tensor storage; combine with shape/stride/dtype/device. + try: + version = getattr(x, "_version", None) + except Exception: + version = None + return ( + int(x.data_ptr()), + tuple(x.shape), + tuple(x.stride()), + str(x.dtype), + str(x.device), + int(version) if version is not None else -1, + ) + + def get_cached_act_quant(self, x): + return self._act_quant_cache.get(self._act_quant_cache_key(x)) + + def set_cached_act_quant(self, x, x_q, x_scales) -> None: + self._act_quant_cache[self._act_quant_cache_key(x)] = (x_q, x_scales) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Optionally clear context on exit, or keep it for reuse + pass + + +# Global convenience functions +def get_quantization_context() -> QuantizationContext: + """Get current quantization context.""" + return QuantizationContext.current() + + +def set_kv_cache_strategy(strategy: KVCacheQuantizationStrategy): + """Set KV Cache quantization strategy.""" + ctx = QuantizationContext.current() + ctx.set_strategy('kv_cache', strategy) + + +def get_kv_cache_strategy() -> Optional[KVCacheQuantizationStrategy]: + """Get KV Cache quantization strategy.""" + ctx = QuantizationContext.current() + return ctx.get_kv_cache_strategy() + + +def set_weight_strategy(strategy: WeightQuantizationStrategy): + """Set weight quantization strategy.""" + ctx = QuantizationContext.current() + ctx.set_strategy('weight', strategy) + + +def get_weight_strategy() -> Optional[WeightQuantizationStrategy]: + """Get weight quantization strategy.""" + ctx = QuantizationContext.current() + return ctx.get_weight_strategy() + + +def set_linear_strategy(kind: str, strategy: LinearQuantizationStrategy) -> None: + """Set Linear quantization strategy for a kind ("attn"/"mlp"/"other").""" + ctx = QuantizationContext.current() + ctx.set_linear_strategy(kind, strategy) + + +def get_linear_strategy(kind: str) -> Optional[LinearQuantizationStrategy]: + """Get Linear quantization strategy for a kind ("attn"/"mlp"/"other").""" + ctx = QuantizationContext.current() + return ctx.get_linear_strategy(kind) + + +def clear_act_quant_cache() -> None: + """Clear step-local activation quant cache for the current thread.""" + QuantizationContext.current().clear_act_quant_cache() + + +def get_cached_act_quant(x): + """Get cached (x_q, x_scales) for activation quantization, or None.""" + return QuantizationContext.current().get_cached_act_quant(x) + + +def set_cached_act_quant(x, x_q, x_scales) -> None: + """Set cached (x_q, x_scales) for activation quantization.""" + QuantizationContext.current().set_cached_act_quant(x, x_q, x_scales) + diff --git a/diffulex/utils/quantization/factory.py b/diffulex/utils/quantization/factory.py new file mode 100644 index 0000000..ee7e3b6 --- /dev/null +++ b/diffulex/utils/quantization/factory.py @@ -0,0 +1,88 @@ +""" +Quantization strategy factory. + +This module provides factory functions to create quantization strategies from configuration. +""" + +from typing import Optional + +from diffulex.utils.quantization.context import QuantizationContext +from diffulex.utils.quantization.config import QuantizationConfig +from diffulex.utils.quantization.registry import create_kv_cache_strategy as _create_kv_cache_strategy +from diffulex.utils.quantization.registry import create_linear_strategy as _create_linear_strategy +from diffulex.utils.quantization.strategy import KVCacheQuantizationStrategy + +# Ensure built-in strategies are imported so they can register themselves. +from diffulex.utils.quantization import strategies as _builtin_strategies # noqa: F401 + + +class QuantizationStrategyFactory: + """Quantization strategy factory.""" + + @staticmethod + def create_kv_cache_strategy(dtype: Optional[str] = None) -> KVCacheQuantizationStrategy: + """ + Create KV Cache quantization strategy. + + Args: + dtype: KV cache dtype string: + - None or "bf16": BF16 (no quantization) + - "fp16": FP16 (no quantization, future support) + - "fp32": FP32 (no quantization, future support) + - "fp8" or "fp8_e4m3": FP8 E4M3 with running max + - "fp8_e5m2": FP8 E5M2 with running max + + Returns: + KV Cache quantization strategy instance + + Raises: + ValueError: If dtype is not supported + """ + # NOTE: dtype normalization + compatibility handling lives in the registry. + return _create_kv_cache_strategy(dtype or "bf16") + + @staticmethod + def create_from_config(config) -> QuantizationContext: + """ + Create and configure quantization context from config object. + + Args: + config: Configuration object that may contain quantization-related fields: + - kv_cache_dtype: KV cache dtype string + - weight_dtype: Weight dtype string (future) + + Returns: + Configured quantization context + """ + ctx = QuantizationContext.current() + + quant_cfg = QuantizationConfig.from_diffulex_config(config) + + # KV Cache strategy + strategy = QuantizationStrategyFactory.create_kv_cache_strategy(quant_cfg.kv_cache.dtype) + strategy.configure(diffulex_config=config) + ctx.set_strategy('kv_cache', strategy) + + # Linear strategies (weights + activations) by kind + linear_attn = _create_linear_strategy( + weight_dtype=quant_cfg.weights.linear_attn_dtype, + act_dtype=quant_cfg.activations.linear_attn_dtype, + ) + linear_attn.configure(diffulex_config=config) + ctx.set_linear_strategy("attn", linear_attn) + + linear_mlp = _create_linear_strategy( + weight_dtype=quant_cfg.weights.linear_mlp_dtype, + act_dtype=quant_cfg.activations.linear_mlp_dtype, + ) + linear_mlp.configure(diffulex_config=config) + ctx.set_linear_strategy("mlp", linear_mlp) + + # Future: Weight strategy + # weight_dtype = getattr(config, 'weight_dtype', None) + # if weight_dtype: + # strategy = QuantizationStrategyFactory.create_weight_strategy(weight_dtype) + # ctx.set_strategy('weight', strategy) + + return ctx + diff --git a/diffulex/utils/quantization/kv_cache_dtype.py b/diffulex/utils/quantization/kv_cache_dtype.py new file mode 100644 index 0000000..34190c6 --- /dev/null +++ b/diffulex/utils/quantization/kv_cache_dtype.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import IntEnum +from typing import Any + +import torch + +try: + # vLLM provides a platform-specific fp8 dtype (can be e4m3fn / e4m3fnuz, etc.) + from vllm.platforms import current_platform # type: ignore +except Exception: # pragma: no cover + current_platform = None + + +class KvCacheDType(IntEnum): + BF16 = 0 + FP16 = 1 + FP32 = 2 + FP8_E4M3 = 3 + FP8_E5M2 = 4 + + +@dataclass(frozen=True) +class KvCacheDTypeSpec: + enum: KvCacheDType + is_fp8: bool + fp8_view_dtype: torch.dtype | None + fp8_min: float | None + fp8_max: float | None + + +def _normalize_kv_cache_dtype(kv_cache_dtype: str) -> str: + s = (kv_cache_dtype or "").strip().lower() + aliases = { + "bf16": "bf16", + "bfloat16": "bf16", + "fp16": "fp16", + "float16": "fp16", + "fp32": "fp32", + "float32": "fp32", + "fp8": "fp8_e4m3", + "fp8_e4m3": "fp8_e4m3", + "e4m3": "fp8_e4m3", + "fp8_e5m2": "fp8_e5m2", + "e5m2": "fp8_e5m2", + } + if s not in aliases: + raise ValueError( + f"Unsupported kv_cache_dtype={kv_cache_dtype!r}. " + "Supported: bf16/fp16/fp32/fp8/fp8_e4m3/fp8_e5m2" + ) + return aliases[s] + + +def _get_fp8_e4m3_dtype() -> torch.dtype: + if current_platform is None: + if hasattr(torch, "float8_e4m3fn"): + return torch.float8_e4m3fn # type: ignore[attr-defined] + raise RuntimeError("FP8 requested but vLLM current_platform is unavailable.") + return current_platform.fp8_dtype() + + +def _get_fp8_e5m2_dtype() -> torch.dtype: + if hasattr(torch, "float8_e5m2"): + return torch.float8_e5m2 # type: ignore[attr-defined] + if hasattr(torch, "float8_e5m2fnuz"): + return torch.float8_e5m2fnuz # type: ignore[attr-defined] + raise RuntimeError( + "FP8 E5M2 requested but this torch build does not expose float8_e5m2 dtype." + ) + + +def parse_kv_cache_dtype(kv_cache_dtype: str) -> KvCacheDTypeSpec: + norm = _normalize_kv_cache_dtype(kv_cache_dtype) + if norm == "bf16": + return KvCacheDTypeSpec(KvCacheDType.BF16, False, None, None, None) + if norm == "fp16": + return KvCacheDTypeSpec(KvCacheDType.FP16, False, None, None, None) + if norm == "fp32": + return KvCacheDTypeSpec(KvCacheDType.FP32, False, None, None, None) + + if norm == "fp8_e4m3": + fp8 = _get_fp8_e4m3_dtype() + enum = KvCacheDType.FP8_E4M3 + elif norm == "fp8_e5m2": + fp8 = _get_fp8_e5m2_dtype() + enum = KvCacheDType.FP8_E5M2 + else: # pragma: no cover + raise AssertionError(norm) + + info = torch.finfo(fp8) + return KvCacheDTypeSpec( + enum=enum, + is_fp8=True, + fp8_view_dtype=fp8, + fp8_min=float(info.min), + fp8_max=float(info.max), + ) + + +def ensure_scale_tensor( + scale: Any, + *, + num_kv_heads: int, + device: torch.device, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Returns a CUDA tensor suitable for Triton: + - shape [num_kv_heads] (per-head) + - dtype float32 by default + """ + if scale is None: + return torch.ones((num_kv_heads,), device=device, dtype=dtype) + if isinstance(scale, (float, int)): + return torch.full((num_kv_heads,), float(scale), device=device, dtype=dtype) + if isinstance(scale, torch.Tensor): + if scale.numel() == 1: + return torch.full((num_kv_heads,), float(scale.item()), device=device, dtype=dtype) + if scale.numel() != num_kv_heads: + raise ValueError( + f"scale must be scalar or shape [num_kv_heads]={num_kv_heads}, got {tuple(scale.shape)}" + ) + return scale.to(device=device, dtype=dtype).contiguous() + raise TypeError(f"Unsupported scale type: {type(scale)}") + + +def view_fp8_cache(cache: torch.Tensor, kv_cache_dtype: str) -> torch.Tensor: + """ + FP8 KV cache uses uint8 as storage for compatibility. This returns a view tensor + whose dtype is fp8, so Triton will see the correct element type. + """ + spec = parse_kv_cache_dtype(kv_cache_dtype) + if not spec.is_fp8: + return cache + assert spec.fp8_view_dtype is not None + if cache.dtype == torch.uint8: + return cache.view(spec.fp8_view_dtype) + if cache.dtype == spec.fp8_view_dtype: + return cache + raise AssertionError( + f"FP8 cache must be torch.uint8 (storage) or {spec.fp8_view_dtype}, got {cache.dtype}" + ) + + diff --git a/diffulex/utils/quantization/quantize_model.py b/diffulex/utils/quantization/quantize_model.py new file mode 100644 index 0000000..4c004c5 --- /dev/null +++ b/diffulex/utils/quantization/quantize_model.py @@ -0,0 +1,806 @@ +#!/usr/bin/env python3 +"""离线量化脚本:将模型权重量化为 vLLM 标准 GPTQ/AWQ 格式 + +支持两种量化格式(对齐 vLLM 权重格式): +- GPTQ: qweight/qzeros 为 int32 packed,scales 为 fp16,g_idx 可选(常见 desc_act=False 时为空) +- GPTQ_MARLIN: 导出 Marlin-ready 的 GPTQ 权重布局(qweight 已 repack,scales 已 permute,zp 为空) +- AWQ : qweight/qzeros 为 int32 packed,scales 为 fp16 + +使用方法: + python -m diffulex.utils.quantization.quantize_model \ + --model-path /path/to/model \ + --output-path /path/to/output \ + --quant-format gptq_marlin \ + --group-size 128 \ + --bits 4 \ + --quant-method auto \ + --calib-text-file /path/to/calib.txt \ + --calib-num-samples 128 \ + --calib-seq-len 512 + +说明: +- `quant-method=simple`:沿用当前“直接分组量化/舍入”的旧实现(不需要校准数据,不是真 GPTQ/AWQ)。 +- `quant-method=auto`:使用 `auto-gptq` / `awq(autoawq)` 做真正的校准量化,然后导出为 vLLM/Diffulex 可加载的权重格式。 +""" + +from __future__ import annotations + +import argparse +import os +import json +import random +from pathlib import Path +from typing import Optional + +import torch +import torch.nn as nn +from tqdm import tqdm +from safetensors.torch import save_file + +# Import model loading utilities +import sys +from pathlib import Path as PathLib + +# Add project root to path +_REPO_ROOT = PathLib(__file__).resolve().parents[3] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from safetensors import safe_open +from glob import glob + + +def _require_vllm(): + try: + from vllm.scalar_type import scalar_types # type: ignore + from vllm.model_executor.layers.quantization.utils.quant_utils import ( # type: ignore + awq_pack, + gptq_pack, + pack_cols, + quantize_weights, + ) + except Exception as e: # pragma: no cover + raise RuntimeError( + "离线 GPTQ/AWQ 打包已切换到 vLLM 标准格式,需要可 import 的 vLLM。" + ) from e + return scalar_types, quantize_weights, gptq_pack, awq_pack, pack_cols + + +def _require_vllm_marlin(): + # Marlin 预处理依赖 CUDA custom ops + try: + from vllm import _custom_ops as ops # type: ignore + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( # type: ignore + marlin_permute_scales, + ) + except Exception as e: # pragma: no cover + raise RuntimeError( + "导出 gptq_marlin 格式需要可 import 的 vLLM Marlin(含 CUDA custom ops)。" + ) from e + return ops, marlin_permute_scales + + +def _require_auto_gptq(): + try: + from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig # type: ignore + except Exception as e: # pragma: no cover + raise RuntimeError( + "未能导入 auto-gptq。请确认已在当前 .venv 安装(例如:BUILD_CUDA_EXT=0 pip install auto-gptq)。" + ) from e + return AutoGPTQForCausalLM, BaseQuantizeConfig + + +def _require_awq(): + try: + from awq import AutoAWQForCausalLM # type: ignore + except Exception as e: # pragma: no cover + raise RuntimeError( + "未能导入 awq(autoawq 的导入名是 `awq`)。" + ) from e + return AutoAWQForCausalLM + + +def _load_calib_texts( + calib_text_file: str, *, num_samples: int, seed: int +) -> list[str]: + p = Path(calib_text_file) + if not p.exists(): + raise FileNotFoundError(f"calib_text_file 不存在: {calib_text_file}") + lines = [ln.strip() for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines()] + lines = [ln for ln in lines if ln] + if not lines: + raise ValueError(f"calib_text_file 为空: {calib_text_file}") + if num_samples <= 0: + raise ValueError(f"calib_num_samples 必须 > 0, got {num_samples}") + if len(lines) <= num_samples: + return lines[:num_samples] + rng = random.Random(seed) + return rng.sample(lines, k=num_samples) + + +def _build_autogptq_examples( + tokenizer, texts: list[str], *, seq_len: int +) -> list[dict[str, torch.Tensor]]: + if seq_len <= 0: + raise ValueError(f"calib_seq_len 必须 > 0, got {seq_len}") + + # AutoGPTQ 会自行 collate/pad;这里用 fixed max_length 保持输入一致。 + examples: list[dict[str, torch.Tensor]] = [] + for t in texts: + enc = tokenizer( + t, + return_tensors="pt", + truncation=True, + max_length=seq_len, + padding="max_length", + ) + examples.append( + { + "input_ids": enc["input_ids"], + "attention_mask": enc.get("attention_mask", torch.ones_like(enc["input_ids"])), + } + ) + return examples + + +def _quantize_to_vllm_gptq( + weight: torch.Tensor, *, group_size: int, bits: int, use_v2_format: bool = False +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Quantize and pack weights into vLLM GPTQ checkpoint format. + + Input: + weight: fp32 [N, K] (PyTorch Linear weight) + Output (vLLM format): + qweight: int32 [K/pack, N] + qzeros : int32 [K/group, N/pack] (GPTQ v1 stores (zeros - 1); v2 stores zeros) + scales : fp16 [K/group, N] + g_idx : int32 empty tensor (desc_act=False) + """ + scalar_types, quantize_weights, gptq_pack, _, pack_cols = _require_vllm() + # vLLM GPTQConfig mentions 2/3/4/8, but the standard vLLM int32 packing + # used by `gptq_pack/pack_cols` requires 32 % bits == 0. + # So we support 2/4/8 here; 3-bit would need a different packing scheme. + if bits not in (2, 4, 8): + raise ValueError( + f"GPTQ bits 仅支持 2/4/8(vLLM 标准 int32 pack 要求 32%bits==0),当前 bits={bits}" + ) + + # vLLM operates on (K, N) + w = weight.T.contiguous() + size_k, size_n = w.shape + group_size_norm = size_k if group_size == -1 else group_size + if group_size_norm <= 0 or size_k % group_size_norm != 0: + raise ValueError(f"Invalid group_size={group_size} for in_features={size_k}") + + if bits == 2: + quant_type = scalar_types.uint2b2 + elif bits == 4: + quant_type = scalar_types.uint4b8 + else: # bits == 8 + quant_type = scalar_types.uint8b128 + + _, w_q, w_s, _ = quantize_weights(w, quant_type, group_size_norm, zero_points=False) + + pack_factor = 32 // bits + qweight = gptq_pack(w_q, bits, size_k, size_n).contiguous() # [K/pack, N] + + num_groups = size_k // group_size_norm + zeros = torch.full( + (num_groups, size_n), + int(getattr(quant_type, "bias", 0)), + dtype=torch.int32, + device=w.device, + ) + # GPTQ v1 stores zeros-1 in the checkpoint. + zeros_to_store = zeros if use_v2_format else (zeros - 1) + qzeros = pack_cols(zeros_to_store, bits, num_groups, size_n).contiguous() # [K/group, N/pack] + + scales = w_s.to(torch.float16).contiguous() # [K/group, N] + g_idx = torch.empty((0,), dtype=torch.int32, device=w.device) + return qweight, qzeros, scales, g_idx + + +def _quantize_to_vllm_gptq_marlin( + weight: torch.Tensor, *, group_size: int, bits: int +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Quantize weights and export marlin-ready GPTQ layout. + + 该导出格式对齐 vLLM `MarlinLinearKernel.process_weights_after_loading` 的结果: + - qweight: 已执行 `gptq_marlin_repack` + - scales : 已执行 `marlin_permute_scales` + - qzeros : 置空(Marlin GPTQ symmetric 路径不使用 runtime zp) + - g_idx : 空(desc_act=False) + + 注意:需要在 CUDA 上执行(`gptq_marlin_repack` 为 CUDA op)。 + """ + if weight.device.type != "cuda": + raise ValueError("gptq_marlin 导出需要 device=cuda(Marlin repack 为 CUDA op)") + + ops, marlin_permute_scales = _require_vllm_marlin() + + # 先按 vLLM 标准 GPTQ(symmetric, zero_points=False)量化并打包 + qweight, _qzeros, scales, g_idx = _quantize_to_vllm_gptq( + weight, group_size=group_size, bits=bits, use_v2_format=False + ) + + # vLLM GPTQ packing 的 shape 基于 w=(K,N);这里 size_k=in_features, size_n=out_features + size_k = weight.shape[1] + size_n = weight.shape[0] + group_size_norm = size_k if group_size == -1 else group_size + + # desc_act=False 时 perm 为空 + empty_perm = torch.empty((0,), dtype=torch.int32, device=weight.device) + + marlin_qweight = ops.gptq_marlin_repack( + qweight.contiguous(), + perm=empty_perm, + size_k=size_k, + size_n=size_n, + num_bits=bits, + is_a_8bit=False, + ).contiguous() + + marlin_scales = marlin_permute_scales( + scales.contiguous(), + size_k=size_k, + size_n=size_n, + group_size=group_size_norm, + is_a_8bit=False, + ).contiguous() + + # Marlin GPTQ symmetric 不使用 runtime zero points,导出空 qzeros 保持一致性 + marlin_qzeros = torch.empty((0,), dtype=torch.int32, device=weight.device) + marlin_g_idx = g_idx # already empty + + return marlin_qweight, marlin_qzeros, marlin_scales, marlin_g_idx + + +def _quantize_to_vllm_awq( + weight: torch.Tensor, *, group_size: int, bits: int +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Quantize and pack weights into vLLM AWQ checkpoint format. + + Input: + weight: fp32 [N, K] + Output (vLLM format): + qweight: int32 [K, N/pack] + qzeros : int32 [K/group, N/pack] + scales : fp16 [K/group, N] + """ + scalar_types, quantize_weights, _, awq_pack, _ = _require_vllm() + if bits != 4: + raise ValueError(f"AWQ 目前仅支持 4-bit,当前 bits={bits}") + + w = weight.T.contiguous() + size_k, size_n = w.shape + group_size_norm = size_k if group_size == -1 else group_size + if group_size_norm <= 0 or size_k % group_size_norm != 0: + raise ValueError(f"Invalid group_size={group_size} for in_features={size_k}") + + quant_type = scalar_types.uint4 + _, w_q, w_s, w_zp = quantize_weights(w, quant_type, group_size_norm, zero_points=True) + if w_zp is None: + raise RuntimeError("AWQ zero_points=True 但未生成 zero points,vLLM 量化返回异常。") + + qweight = awq_pack(w_q, bits, size_k, size_n).contiguous() # [K, N/pack] + num_groups = size_k // group_size_norm + qzeros = awq_pack(w_zp.to(torch.int32), bits, num_groups, size_n).contiguous() # [K/group, N/pack] + scales = w_s.to(torch.float16).contiguous() # [K/group, N] + return qweight, qzeros, scales + + +@torch.inference_mode() +def _export_autogptq_to_vllm_weights( + *, + gptq_base_model: nn.Module, + quant_format: str, + target_modules: Optional[list[str]], + desc_act: bool, + bits: int, + group_size: int, +) -> dict[str, torch.Tensor]: + """ + 从 auto-gptq 的量化后模型中抽取 qweight/qzeros/scales/g_idx,并按 vLLM/Diffulex 的命名导出。 + - quant_format == "gptq": 直接导出 QuantLinear 的 buffers。 + - quant_format == "gptq_marlin": 在导出前使用 vLLM Marlin 的 repack/permute,且导出空 qzeros/g_idx。 + """ + quantized_weights: dict[str, torch.Tensor] = {} + + if quant_format not in ("gptq", "gptq_marlin"): + raise ValueError(f"Unexpected quant_format for auto-gptq export: {quant_format}") + + if quant_format == "gptq_marlin": + if not torch.cuda.is_available(): + raise RuntimeError("导出 gptq_marlin 需要 CUDA(vLLM Marlin repack 为 CUDA op)。") + ops, marlin_permute_scales = _require_vllm_marlin() + + for module_name, module in gptq_base_model.named_modules(): + # AutoGPTQ 的 QuantLinear(triton/cuda)会有这些 buffer + if not (hasattr(module, "qweight") and hasattr(module, "qzeros") and hasattr(module, "scales")): + continue + + # 过滤:保持和旧脚本一致,默认不量化 lm_head + if "lm_head" in module_name: + continue + if target_modules and not any(t in module_name for t in target_modules): + continue + + qweight = getattr(module, "qweight") + qzeros = getattr(module, "qzeros") + scales = getattr(module, "scales") + g_idx = getattr(module, "g_idx", None) + + if not isinstance(qweight, torch.Tensor) or not isinstance(qzeros, torch.Tensor) or not isinstance(scales, torch.Tensor): + continue + + if quant_format == "gptq": + quantized_weights[f"{module_name}.qweight"] = qweight.detach().cpu().contiguous() + quantized_weights[f"{module_name}.qzeros"] = qzeros.detach().cpu().contiguous() + quantized_weights[f"{module_name}.scales"] = scales.detach().cpu().contiguous() + if desc_act and isinstance(g_idx, torch.Tensor) and g_idx.numel() > 0: + quantized_weights[f"{module_name}.g_idx"] = g_idx.detach().to(dtype=torch.int32).cpu().contiguous() + else: + quantized_weights[f"{module_name}.g_idx"] = torch.empty((0,), dtype=torch.int32) + continue + + # gptq_marlin 导出:用 vLLM 的 repack/permute 变成 Marlin-ready layout + in_features = int(getattr(module, "infeatures", 0)) + out_features = int(getattr(module, "outfeatures", 0)) + if in_features <= 0 or out_features <= 0: + # fallback:从张量形状推断(qweight shape: [K/pack, N]) + out_features = int(qweight.shape[1]) + pack = 32 // bits + in_features = int(qweight.shape[0] * pack) + + group_size_norm = in_features if group_size == -1 else group_size + empty_perm = torch.empty((0,), dtype=torch.int32, device="cuda") + + qweight_cuda = qweight.contiguous().to(device="cuda") + scales_cuda = scales.contiguous().to(device="cuda", dtype=torch.float16) + + marlin_qweight = ops.gptq_marlin_repack( + qweight_cuda, + perm=empty_perm, + size_k=in_features, + size_n=out_features, + num_bits=bits, + is_a_8bit=(bits == 8), + ).contiguous() + marlin_scales = marlin_permute_scales( + scales_cuda, + size_k=in_features, + size_n=out_features, + group_size=group_size_norm, + is_a_8bit=(bits == 8), + ).contiguous() + + quantized_weights[f"{module_name}.qweight"] = marlin_qweight.detach().cpu().contiguous() + quantized_weights[f"{module_name}.qzeros"] = torch.empty((0,), dtype=torch.int32) + quantized_weights[f"{module_name}.scales"] = marlin_scales.detach().cpu().contiguous() + quantized_weights[f"{module_name}.g_idx"] = torch.empty((0,), dtype=torch.int32) + + return quantized_weights + + +@torch.inference_mode() +def _export_awq_to_vllm_weights( + *, + awq_base_model: nn.Module, + target_modules: Optional[list[str]], +) -> dict[str, torch.Tensor]: + """ + 从 awq(pack 后)模型中抽取 qweight/qzeros/scales,并按 vLLM/Diffulex 的命名导出。 + """ + quantized_weights: dict[str, torch.Tensor] = {} + for module_name, module in awq_base_model.named_modules(): + if not (hasattr(module, "qweight") and hasattr(module, "qzeros") and hasattr(module, "scales")): + continue + if "lm_head" in module_name: + continue + if target_modules and not any(t in module_name for t in target_modules): + continue + + qweight = getattr(module, "qweight") + qzeros = getattr(module, "qzeros") + scales = getattr(module, "scales") + if not isinstance(qweight, torch.Tensor) or not isinstance(qzeros, torch.Tensor) or not isinstance(scales, torch.Tensor): + continue + + quantized_weights[f"{module_name}.qweight"] = qweight.detach().cpu().contiguous() + quantized_weights[f"{module_name}.qzeros"] = qzeros.detach().cpu().contiguous() + quantized_weights[f"{module_name}.scales"] = scales.detach().cpu().contiguous() + return quantized_weights + + +def quantize_model( + model_path: str, + output_path: str, + quant_format: str = "gptq", + group_size: int = 128, + bits: int = 4, + target_modules: Optional[list[str]] = None, + device: str = "cpu", + quant_method: str = "auto", + calib_text_file: Optional[str] = None, + calib_num_samples: int = 128, + calib_seq_len: int = 512, + calib_batch_size: int = 1, + calib_seed: int = 0, + # GPTQ config + desc_act: bool = False, + sym: bool = True, + damp_percent: float = 0.01, + true_sequential: bool = True, + use_triton: bool = True, +) -> None: + """Quantize model weights to GPTQ/AWQ format. + + Args: + model_path: Path to input model directory (containing safetensors files) + output_path: Path to output directory (will create if not exists) + quant_format: "gptq" or "awq" + group_size: Group size for quantization (default: 128) + bits: Number of bits per weight (default: 4) + target_modules: List of module name patterns to quantize (e.g., ["q_proj", "k_proj"]). + If None, quantizes all linear layers. + device: Device to use for quantization ("cpu" or "cuda") + quant_method: "auto"(真 GPTQ/AWQ,需校准数据)或 "simple"(旧实现,无校准) + calib_text_file: 校准文本文件(每行一条样本) + """ + if quant_format not in ["gptq", "gptq_marlin", "awq"]: + raise ValueError( + f"Unsupported quant_format: {quant_format}. Must be 'gptq', 'gptq_marlin' or 'awq'" + ) + if quant_method not in ["auto", "simple"]: + raise ValueError("quant_method must be 'auto' or 'simple'") + + # Marlin GPTQ 强约束:对称量化 + 不使用 act-order + if quant_format == "gptq_marlin": + desc_act = False + sym = True + + output_path = Path(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + # Load model config (for tokenizer special tokens, etc.) + _ = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + quantized_weights: dict[str, torch.Tensor] = {} + metadata = { + "quant_format": quant_format, + "quant_method": quant_method, + "group_size": group_size, + "bits": bits, + "quantized_modules": [], + } + + # ---------------------------- + # 真 GPTQ/AWQ(需要校准数据) + # ---------------------------- + if quant_method == "auto": + if calib_text_file is None: + raise ValueError("quant_method=auto 需要提供 --calib-text-file") + + texts = _load_calib_texts(calib_text_file, num_samples=calib_num_samples, seed=calib_seed) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + if quant_format in ("gptq", "gptq_marlin"): + if quant_format == "gptq_marlin" and device != "cuda": + raise ValueError("导出 gptq_marlin 需要 --device cuda") + + AutoGPTQForCausalLM, BaseQuantizeConfig = _require_auto_gptq() + examples = _build_autogptq_examples(tokenizer, texts, seq_len=calib_seq_len) + + qcfg = BaseQuantizeConfig( + bits=int(bits), + group_size=int(group_size), + damp_percent=float(damp_percent), + desc_act=bool(desc_act), + sym=bool(sym), + true_sequential=bool(true_sequential), + ) + + model_init_kwargs = { + "trust_remote_code": True, + } + # 让 AutoGPTQ 自己用 accelerate 做 device_map;CPU 模式下走默认加载。 + if device == "cuda": + model_init_kwargs["device_map"] = "auto" + model_init_kwargs["torch_dtype"] = torch.float16 + + gptq_model = AutoGPTQForCausalLM.from_pretrained( + model_path, + qcfg, + **model_init_kwargs, + ) + gptq_model.quantize( + examples, + batch_size=int(calib_batch_size), + use_triton=bool(use_triton), + cache_examples_on_gpu=(device == "cuda"), + ) + + quantized_weights = _export_autogptq_to_vllm_weights( + gptq_base_model=gptq_model.model, + quant_format=quant_format, + target_modules=target_modules, + desc_act=bool(desc_act), + bits=int(bits), + group_size=int(group_size), + ) + + else: # awq + if bits != 4: + raise ValueError(f"AWQ 目前仅支持 4-bit,当前 bits={bits}") + AutoAWQForCausalLM = _require_awq() + + awq_model = AutoAWQForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + safetensors=True, + device_map="auto" if device == "cuda" else None, + torch_dtype="auto", + ) + + awq_model.quantize( + tokenizer=tokenizer, + quant_config={ + "zero_point": True, + "q_group_size": int(group_size), + "w_bit": int(bits), + "version": "GEMM", + }, + calib_data=texts, + max_calib_samples=int(calib_num_samples), + max_calib_seq_len=int(calib_seq_len), + ) + awq_model.pack() + + quantized_weights = _export_awq_to_vllm_weights( + awq_base_model=awq_model.model, + target_modules=target_modules, + ) + + # ---------------------------- + # 旧实现(无校准,不是真 GPTQ/AWQ) + # ---------------------------- + else: + safetensors_files = list(glob(os.path.join(model_path, "*.safetensors"))) + if not safetensors_files: + raise ValueError(f"No safetensors files found in {model_path}") + + print(f"Found {len(safetensors_files)} safetensors files") + + all_weight_keys: list[str] = [] + for file in safetensors_files: + with safe_open(file, "pt", device) as f: + all_weight_keys.extend(f.keys()) + + linear_weight_keys: list[str] = [] + for key in all_weight_keys: + if any(skip in key for skip in [".bias", ".norm", ".embed", ".lm_head"]): + continue + if not key.endswith(".weight"): + continue + if target_modules and not any(target in key for target in target_modules): + continue + linear_weight_keys.append(key) + + print(f"Found {len(linear_weight_keys)} linear layer weights to quantize") + + for key in tqdm(linear_weight_keys, desc="Quantizing weights (simple)"): + weight = None + for file in safetensors_files: + with safe_open(file, "pt", device) as f: + if key in f.keys(): + weight = f.get_tensor(key) + break + + if weight is None: + print(f"Warning: Could not load weight for {key}") + continue + if weight.dim() != 2: + print(f"Skipping {key}: not a 2D weight (shape: {weight.shape})") + continue + + out_features, in_features = weight.shape + weight_fp32 = weight.to(torch.float32).to(device) + prefix = key[:-7] # Remove ".weight" + + if quant_format == "gptq": + qweight, qzeros, scales, g_idx = _quantize_to_vllm_gptq( + weight_fp32, group_size=group_size, bits=bits, use_v2_format=False + ) + quantized_weights[f"{prefix}.qweight"] = qweight.cpu() + quantized_weights[f"{prefix}.qzeros"] = qzeros.cpu() + quantized_weights[f"{prefix}.scales"] = scales.cpu() + quantized_weights[f"{prefix}.g_idx"] = g_idx.cpu() + + elif quant_format == "gptq_marlin": + qweight, qzeros, scales, g_idx = _quantize_to_vllm_gptq_marlin( + weight_fp32, group_size=group_size, bits=bits + ) + quantized_weights[f"{prefix}.qweight"] = qweight.cpu() + quantized_weights[f"{prefix}.qzeros"] = qzeros.cpu() + quantized_weights[f"{prefix}.scales"] = scales.cpu() + quantized_weights[f"{prefix}.g_idx"] = g_idx.cpu() + + else: # awq + qweight, qzeros, scales = _quantize_to_vllm_awq( + weight_fp32, group_size=group_size, bits=bits + ) + quantized_weights[f"{prefix}.qweight"] = qweight.cpu() + quantized_weights[f"{prefix}.qzeros"] = qzeros.cpu() + quantized_weights[f"{prefix}.scales"] = scales.cpu() + + metadata["quantized_modules"].append( + { + "name": prefix, + "out_features": int(out_features), + "in_features": int(in_features), + "group_size": group_size, + "bits": bits, + } + ) + + if device == "cuda": + torch.cuda.empty_cache() + + # Copy all model files (config, tokenizer, etc.) to output directory + import shutil + print(f"\nCopying model files to {output_path}...") + model_path_obj = Path(model_path) + + # First, copy original safetensors files (for non-quantized layers like lm_head, embeddings, etc.) + print(" Copying original safetensors files (for non-quantized layers)...") + for file in model_path_obj.glob("*.safetensors"): + dest_file = output_path / file.name + shutil.copy2(file, dest_file) + print(f" Copied {file.name}") + + # Copy other non-safetensors files + for file in model_path_obj.iterdir(): + if file.is_file() and not file.name.endswith('.safetensors'): + dest_file = output_path / file.name + shutil.copy2(file, dest_file) + print(f" Copied {file.name}") + + # Save quantized weights to safetensors (this will add quantized weights to the directory) + output_file = output_path / f"model_quantized_{quant_format}.safetensors" + print(f"\nSaving quantized weights to {output_file}...") + save_file(quantized_weights, output_file) + + # Save metadata + metadata_file = output_path / f"quantization_metadata_{quant_format}.json" + with open(metadata_file, "w") as f: + json.dump(metadata, f, indent=2) + + # vLLM/Diffulex 会读取 quantize_config.json 识别量化类型与超参 + if quant_format in ("gptq", "gptq_marlin", "awq"): + if quant_format == "gptq_marlin": + cfg_desc_act = False + cfg_sym = True + cfg_ckpt = "gptq_marlin" + elif quant_format == "gptq": + cfg_desc_act = bool(desc_act) + cfg_sym = bool(sym) + cfg_ckpt = "gptq" + else: # awq + cfg_desc_act = False + cfg_sym = False + cfg_ckpt = "awq" + + quantize_cfg = { + "bits": int(bits), + "group_size": int(group_size), + "desc_act": bool(cfg_desc_act), + "sym": bool(cfg_sym), + "lm_head": False, + "checkpoint_format": cfg_ckpt, + } + with open(output_path / "quantize_config.json", "w", encoding="utf-8") as f: + json.dump(quantize_cfg, f, indent=2) + + print(f"\n✓ Quantization complete!") + print(f" - Quant method: {quant_method}") + print(f" - Output directory: {output_path}") + print(f" - Quantized weights file: {output_file}") + print(f" - Metadata file: {metadata_file}") + print(f"\n You can now use this directory directly as model path:") + print(f" --model-path {output_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="离线量化模型权重为 GPTQ/AWQ 格式", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--model-path", type=str, required=True, help="输入模型路径") + parser.add_argument("--output-path", type=str, required=True, help="输出路径") + parser.add_argument( + "--quant-format", + type=str, + choices=["gptq", "gptq_marlin", "awq"], + default="gptq", + help="量化格式: gptq / gptq_marlin / awq", + ) + parser.add_argument("--group-size", type=int, default=128, help="量化组大小 (默认: 128)") + parser.add_argument("--bits", type=int, default=4, help="每个权重的位数 (默认: 4)") + parser.add_argument("--target-modules", type=str, help="要量化的模块名称模式(逗号分隔),例如: q_proj,k_proj,v_proj") + parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cpu", help="量化设备 (默认: cpu)") + parser.add_argument( + "--quant-method", + type=str, + choices=["auto", "simple"], + default="auto", + help="量化方法: auto(真 GPTQ/AWQ, 需要校准数据) / simple(旧实现, 无校准)", + ) + parser.add_argument("--calib-text-file", type=str, default=None, help="校准文本文件(每行一条样本)") + parser.add_argument("--calib-num-samples", type=int, default=128, help="校准样本数 (默认: 128)") + parser.add_argument("--calib-seq-len", type=int, default=512, help="校准序列长度 (默认: 512)") + parser.add_argument("--calib-batch-size", type=int, default=1, help="校准 batch size (默认: 1)") + parser.add_argument("--calib-seed", type=int, default=0, help="校准采样随机种子 (默认: 0)") + parser.add_argument("--desc-act", action="store_true", help="GPTQ act-order(desc_act) (默认: False)") + parser.add_argument("--sym", dest="sym", action="store_true", default=True, help="GPTQ symmetric quant (默认: True)") + parser.add_argument("--no-sym", dest="sym", action="store_false", help="关闭 GPTQ symmetric quant") + parser.add_argument("--damp-percent", type=float, default=0.01, help="GPTQ damp_percent (默认: 0.01)") + parser.add_argument( + "--true-sequential", + dest="true_sequential", + action="store_true", + default=True, + help="GPTQ true_sequential (默认: True)", + ) + parser.add_argument( + "--no-true-sequential", + dest="true_sequential", + action="store_false", + help="关闭 GPTQ true_sequential", + ) + parser.add_argument( + "--use-triton", + dest="use_triton", + action="store_true", + default=True, + help="AutoGPTQ 使用 Triton backend (默认: True)", + ) + parser.add_argument( + "--no-triton", + dest="use_triton", + action="store_false", + help="关闭 AutoGPTQ Triton backend(可能回退到 CUDA extension)", + ) + + args = parser.parse_args() + + target_modules = None + if args.target_modules: + target_modules = [m.strip() for m in args.target_modules.split(",")] + + quantize_model( + model_path=args.model_path, + output_path=args.output_path, + quant_format=args.quant_format, + group_size=args.group_size, + bits=args.bits, + target_modules=target_modules, + device=args.device, + quant_method=args.quant_method, + calib_text_file=args.calib_text_file, + calib_num_samples=args.calib_num_samples, + calib_seq_len=args.calib_seq_len, + calib_batch_size=args.calib_batch_size, + calib_seed=args.calib_seed, + desc_act=bool(args.desc_act), + sym=bool(args.sym), + damp_percent=float(args.damp_percent), + true_sequential=bool(args.true_sequential), + use_triton=bool(args.use_triton), + ) + + +if __name__ == "__main__": + main() diff --git a/diffulex/utils/quantization/registry.py b/diffulex/utils/quantization/registry.py new file mode 100644 index 0000000..0b7be63 --- /dev/null +++ b/diffulex/utils/quantization/registry.py @@ -0,0 +1,172 @@ +""" +Quantization registries for Diffulex. + +This module provides lightweight registries similar in spirit to vLLM's +quantization directory: core runtime code should not hard-code concrete +strategy classes. Instead, it should dispatch by strategy capabilities +and let factories/registries decide which strategy to instantiate. +""" + +from __future__ import annotations + +from typing import Callable, Dict + +from diffulex.utils.quantization.kv_cache_dtype import _normalize_kv_cache_dtype +from diffulex.utils.quantization.strategy import ( + KVCacheQuantizationStrategy, + LinearQuantizationStrategy, +) + +# A builder returns a fully constructed strategy instance. +KVCacheStrategyBuilder = Callable[[], KVCacheQuantizationStrategy] + +_KV_CACHE_BUILDERS: Dict[str, KVCacheStrategyBuilder] = {} + + +def register_kv_cache_strategy(*dtype_aliases: str) -> Callable[[KVCacheStrategyBuilder], KVCacheStrategyBuilder]: + """Register a KV-cache strategy builder for one or more dtype aliases.""" + + def _decorator(builder: KVCacheStrategyBuilder) -> KVCacheStrategyBuilder: + for alias in dtype_aliases: + key = _normalize_kv_cache_dtype(alias) + _KV_CACHE_BUILDERS[key] = builder + return builder + + return _decorator + + +def create_kv_cache_strategy(kv_cache_dtype: str) -> KVCacheQuantizationStrategy: + """Create a KV-cache quantization strategy from a dtype string.""" + key = _normalize_kv_cache_dtype(kv_cache_dtype) + builder = _KV_CACHE_BUILDERS.get(key) + if builder is None: + raise ValueError( + f"Unsupported kv_cache_dtype={kv_cache_dtype!r} (normalized={key!r}). " + f"Registered: {sorted(_KV_CACHE_BUILDERS.keys())}" + ) + return builder() + + +def registered_kv_cache_dtypes() -> list[str]: + return sorted(_KV_CACHE_BUILDERS.keys()) + + +# ---- Linear (weights + activations) registry ---- +LinearStrategyBuilder = Callable[[], LinearQuantizationStrategy] +_LINEAR_BUILDERS: Dict[tuple[str, str], LinearStrategyBuilder] = {} + + +def _normalize_linear_dtype(dtype: str) -> str: + """Normalize Linear quantization dtype/method strings. + + We intentionally keep this lightweight: the concrete semantics (weight-only, + activation-only, etc.) live in the strategy implementations. + """ + s = (dtype or "").strip().lower() + # Reserved internal sentinel for generic fallback strategy registration. + if s in {"__stub__", "__fallback__"}: + return "__stub__" + aliases = { + "": "bf16", + "none": "bf16", + "bf16": "bf16", + "bfloat16": "bf16", + # Integer + "int8": "int8", + "i8": "int8", + "int4": "int4", + "i4": "int4", + # FP8 + "fp8": "fp8_e4m3", + "fp8_e4m3": "fp8_e4m3", + "e4m3": "fp8_e4m3", + "fp8_e5m2": "fp8_e5m2", + "e5m2": "fp8_e5m2", + # Weight-only methods (placeholders) + "gptq": "gptq", + "gptq_marlin": "gptq_marlin", + "gptq_marlin_24": "gptq_marlin_24", + "awq": "awq", + "awq_marlin": "awq_marlin", + "gptq_awq": "gptq_awq", + # vLLM-style fused W8A16 path (AllSpark): keep user-facing alias "marlin" + # for backward compatibility. Normalized key is "marlin_int8" to avoid + # conflating with other quant methods. + "marlin": "marlin_int8", + "marlin_int8": "marlin_int8", + } + if s not in aliases: + raise ValueError( + f"Unsupported linear quant dtype={dtype!r}. " + "Supported: bf16/int8/int4/fp8/fp8_e4m3/fp8_e5m2/gptq/awq/marlin" + ) + return aliases[s] + + +def register_linear_strategy( + *, + weight_dtype: str, + act_dtype: str, +) -> Callable[[LinearStrategyBuilder], LinearStrategyBuilder]: + """Register a Linear strategy builder for a (weight_dtype, act_dtype) pair.""" + + w = _normalize_linear_dtype(weight_dtype) + a = _normalize_linear_dtype(act_dtype) + + def _decorator(builder: LinearStrategyBuilder) -> LinearStrategyBuilder: + _LINEAR_BUILDERS[(w, a)] = builder + return builder + + return _decorator + + +def create_linear_strategy(*, weight_dtype: str, act_dtype: str) -> LinearQuantizationStrategy: + """Create a Linear quantization strategy from weight/activation dtype strings. + + If an exact pair is not registered, we fall back to: + - bf16/bf16: a built-in BF16 strategy (registered by default) + - otherwise: a generic stub strategy that raises NotImplementedError at runtime + (registered by default). + """ + w = _normalize_linear_dtype(weight_dtype) + a = _normalize_linear_dtype(act_dtype) + builder = _LINEAR_BUILDERS.get((w, a)) + if builder is not None: + return builder() + + # Fall back to generic stub builder if present. + stub = _LINEAR_BUILDERS.get(("__stub__", "__stub__")) + if stub is None: + raise ValueError( + f"Unsupported linear strategy pair (weight_dtype={weight_dtype!r}, act_dtype={act_dtype!r}) " + f"(normalized={(w, a)!r}). Registered pairs: {sorted(_LINEAR_BUILDERS.keys())}" + ) + s = stub() + # Attach requested formats for better error messages / future dispatch. + try: + setattr(s, "weight_dtype", w) + setattr(s, "act_dtype", a) + except Exception: + pass + return s + + +def registered_linear_dtypes() -> list[str]: + """Return the normalized dtype/method names accepted by `_normalize_linear_dtype`.""" + # Keep this list stable for CLI/help messages. + return [ + "bf16", + "int8", + "int4", + "fp8_e4m3", + "fp8_e5m2", + "gptq", + "gptq_marlin", + "gptq_marlin_24", + "awq", + "awq_marlin", + "gptq_awq", + "marlin_int8", + ] + + diff --git a/diffulex/utils/quantization/strategies/__init__.py b/diffulex/utils/quantization/strategies/__init__.py new file mode 100644 index 0000000..1fcc216 --- /dev/null +++ b/diffulex/utils/quantization/strategies/__init__.py @@ -0,0 +1,38 @@ +""" +Quantization strategy implementations. +""" + +from diffulex.utils.quantization.strategies.no_quantization import NoQuantizationStrategy +from diffulex.utils.quantization.strategies.kv_cache_bf16 import KVCacheBF16Strategy +from diffulex.utils.quantization.strategies.kv_cache_fp8_running_max import KVCacheFP8RunningMaxStrategy +from diffulex.utils.quantization.strategies.linear_bf16 import LinearBF16Strategy +from diffulex.utils.quantization.strategies.linear_stub import LinearStubStrategy +from diffulex.utils.quantization.strategies.linear_int8_w8a16 import LinearInt8W8A16Strategy # noqa: F401 +from diffulex.utils.quantization.strategies.linear_int4_w4a16 import LinearInt4W4A16Strategy # noqa: F401 +from diffulex.utils.quantization.strategies.linear_int8_w8a8 import LinearInt8W8A8Strategy # noqa: F401 +from diffulex.utils.quantization.strategies.linear_int4_w4a8 import LinearInt4W4A8Strategy # noqa: F401 +from diffulex.utils.quantization.strategies.linear_fp8_w8a16 import LinearFP8W8A16Strategy # noqa: F401 +from diffulex.utils.quantization.strategies.linear_fp8_w8a8 import LinearFP8W8A8Strategy # noqa: F401 +from diffulex.utils.quantization.strategies.linear_gptq_w4a16 import LinearGPTQW4A16Strategy # noqa: F401 +from diffulex.utils.quantization.strategies.linear_gptq_marlin_w4a16 import LinearGPTQMarlinW4A16Strategy # noqa: F401 +from diffulex.utils.quantization.strategies.linear_awq_w4a16 import LinearAWQW4A16Strategy # noqa: F401 +from diffulex.utils.quantization.strategies.linear_awq_marlin_w4a16 import LinearAWQMarlinW4A16Strategy # noqa: F401 + +__all__ = [ + 'NoQuantizationStrategy', + 'KVCacheBF16Strategy', + 'KVCacheFP8RunningMaxStrategy', + 'LinearBF16Strategy', + 'LinearStubStrategy', + 'LinearInt8W8A16Strategy', + 'LinearInt4W4A16Strategy', + 'LinearInt8W8A8Strategy', + 'LinearInt4W4A8Strategy', + 'LinearFP8W8A16Strategy', + 'LinearFP8W8A8Strategy', + 'LinearGPTQW4A16Strategy', + 'LinearGPTQMarlinW4A16Strategy', + 'LinearAWQW4A16Strategy', + 'LinearAWQMarlinW4A16Strategy', +] + diff --git a/diffulex/utils/quantization/strategies/kv_cache_bf16.py b/diffulex/utils/quantization/strategies/kv_cache_bf16.py new file mode 100644 index 0000000..5710d48 --- /dev/null +++ b/diffulex/utils/quantization/strategies/kv_cache_bf16.py @@ -0,0 +1,61 @@ +""" +BF16 KV Cache quantization strategy (no actual quantization, just storage format). +""" + +import torch +from typing import Optional +from diffulex.utils.quantization.strategy import KVCacheQuantizationStrategy +from diffulex.utils.quantization.registry import register_kv_cache_strategy + + +class KVCacheBF16Strategy(KVCacheQuantizationStrategy): + """BF16 KV Cache strategy (no quantization, just storage format).""" + + @property + def name(self) -> str: + return "kv_cache_bf16" + + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + """Returns BF16 storage dtype.""" + return torch.bfloat16, 2 + + def compute_scales(self, k: torch.Tensor, v: torch.Tensor, + num_kv_heads: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + """No scales needed for BF16 (no quantization).""" + # Return dummy scales (will not be used) + k_scale = torch.ones((num_kv_heads,), device=device, dtype=torch.float32) + v_scale = torch.ones((num_kv_heads,), device=device, dtype=torch.float32) + return k_scale, v_scale + + def update_scales(self, k: torch.Tensor, v: torch.Tensor, + k_scale: Optional[torch.Tensor], v_scale: Optional[torch.Tensor], + num_kv_heads: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + """No scales needed for BF16 (no quantization).""" + if k_scale is None: + k_scale = torch.ones((num_kv_heads,), device=device, dtype=torch.float32) + if v_scale is None: + v_scale = torch.ones((num_kv_heads,), device=device, dtype=torch.float32) + return k_scale, v_scale + + def quantize(self, tensor: torch.Tensor, scale: Optional[torch.Tensor] = None, **kwargs) -> tuple[torch.Tensor, None]: + """No quantization, just convert to BF16 if needed.""" + if tensor.dtype != torch.bfloat16: + tensor = tensor.to(torch.bfloat16) + return tensor, None + + def dequantize(self, quantized: torch.Tensor, scale_or_metadata: None, **kwargs) -> torch.Tensor: + """No dequantization needed.""" + return quantized + + def get_scale_shape(self, original_shape: tuple[int, ...], num_kv_heads: int, **kwargs) -> tuple[int, ...]: + """No scale needed for BF16.""" + return (0,) # Empty shape + + +# NOTE: fp16/fp32 are currently routed to the BF16 kernels in Diffulex. +# Keeping them registered avoids breaking older configs while we add +# true fp16/fp32 KV-cache kernels in the future. +@register_kv_cache_strategy("bf16", "bfloat16", "fp16", "float16", "fp32", "float32") +def _build_kv_cache_bf16() -> KVCacheBF16Strategy: + return KVCacheBF16Strategy() + diff --git a/diffulex/utils/quantization/strategies/kv_cache_fp8_running_max.py b/diffulex/utils/quantization/strategies/kv_cache_fp8_running_max.py new file mode 100644 index 0000000..6e8a76e --- /dev/null +++ b/diffulex/utils/quantization/strategies/kv_cache_fp8_running_max.py @@ -0,0 +1,252 @@ +""" +FP8 KV Cache quantization strategy using running max for scale management. +""" + +import torch +from typing import Optional + +from diffulex.utils.quantization.strategy import KVCacheQuantizationStrategy +from diffulex.utils.quantization.kv_cache_dtype import parse_kv_cache_dtype, view_fp8_cache +from diffulex.utils.quantization.registry import register_kv_cache_strategy + + +class KVCacheFP8RunningMaxStrategy(KVCacheQuantizationStrategy): + """FP8 KV Cache quantization strategy using running max for scale management.""" + + def __init__(self, dtype: str = "fp8_e4m3"): + """ + Initialize FP8 KV Cache strategy. + + Args: + dtype: FP8 dtype string ("fp8_e4m3" or "fp8_e5m2") + """ + self.dtype_str = dtype + self.spec = parse_kv_cache_dtype(dtype) + if not self.spec.is_fp8: + raise ValueError(f"Expected FP8 dtype, got {dtype}") + + @property + def name(self) -> str: + return f"kv_cache_fp8_running_max_{self.dtype_str}" + + @property + def kv_cache_format(self) -> str: + return "fp8" + + @property + def requires_runtime_scales(self) -> bool: + return True + + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + """Returns uint8 as storage dtype for FP8 (FP8 values are stored as uint8).""" + return torch.uint8, 1 + + def view_kv_cache_for_kernels(self, cache: torch.Tensor) -> torch.Tensor: + # For kernels expecting float8 dtype (e.g. TileLang FP8 decode), keep uint8 + # storage but return a float8 view tensor. + return view_fp8_cache(cache, self.dtype_str) + + def quantize_kv_for_store( + self, + k: torch.Tensor, + v: torch.Tensor, + *, + k_scale: Optional[torch.Tensor], + v_scale: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """Vectorized quantization for KV cache store. + + Args: + k/v: [N, H, D] BF16/FP16/FP32 + k_scale/v_scale: [H] float32 + Returns: + k_q/v_q: [N, H, D] uint8 + """ + if k_scale is None or v_scale is None: + raise ValueError("FP8 quantization requires k_scale and v_scale") + k_q, _ = self.quantize(k, scale=k_scale) + v_q, _ = self.quantize(v, scale=v_scale) + return k_q, v_q + + def compute_scales(self, k: torch.Tensor, v: torch.Tensor, + num_kv_heads: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute initial quantization scales for K and V. + + Args: + k: Key tensor [seq_len, num_kv_heads, head_dim] + v: Value tensor [seq_len, num_kv_heads, head_dim] + num_kv_heads: Number of KV heads + device: Target device + + Returns: + (k_scale, v_scale): Scales with shape [num_kv_heads] + """ + eps = 1e-8 + fp8_max = float(self.spec.fp8_max) + + # Compute per-head absmax: [num_kv_heads] + k_absmax = k.to(torch.float32).abs().amax(dim=(0, 2)) + v_absmax = v.to(torch.float32).abs().amax(dim=(0, 2)) + + # Compute scales + k_scale = (k_absmax / fp8_max).clamp_min(eps) + v_scale = (v_absmax / fp8_max).clamp_min(eps) + + return k_scale.to(device, dtype=torch.float32), v_scale.to(device, dtype=torch.float32) + + def update_scales(self, k: torch.Tensor, v: torch.Tensor, + k_scale: Optional[torch.Tensor], v_scale: Optional[torch.Tensor], + num_kv_heads: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + """ + Update quantization scales using running max strategy. + + This method follows vLLM's RunningMax approach: + 1. Compute current batch's per-head absmax + 2. Update running max (max of current running max and current absmax) + 3. Convert running max to scale (absmax / fp8_max) + + Args: + k: Key tensor [seq_len, num_kv_heads, head_dim] + v: Value tensor [seq_len, num_kv_heads, head_dim] + k_scale: Current K scale (None if first time) - shape [num_kv_heads] + v_scale: Current V scale (None if first time) - shape [num_kv_heads] + num_kv_heads: Number of KV heads + device: Target device + + Returns: + (updated_k_scale, updated_v_scale): Updated scales with shape [num_kv_heads] + """ + eps = 1e-8 + fp8_max = float(self.spec.fp8_max) + + # Compute current per-head absmax: [num_kv_heads] + k_absmax = k.to(torch.float32).abs().amax(dim=(0, 2)) + v_absmax = v.to(torch.float32).abs().amax(dim=(0, 2)) + + # Update running max + # Note: k_scale/v_scale are scales (already divided by fp8_max), so we need to + # convert them back to absmax before comparing with current absmax + if k_scale is None: + k_absmax_running = k_absmax.clone().detach() + else: + # Convert scale back to absmax for comparison + k_absmax_running = k_scale * fp8_max + # Update running max: take max of current running max and current batch absmax + k_absmax_running = torch.maximum(k_absmax_running, k_absmax) + + if v_scale is None: + v_absmax_running = v_absmax.clone().detach() + else: + # Convert scale back to absmax for comparison + v_absmax_running = v_scale * fp8_max + # Update running max: take max of current running max and current batch absmax + v_absmax_running = torch.maximum(v_absmax_running, v_absmax) + + # Compute scales from running max (absmax / fp8_max) + k_scale = (k_absmax_running / fp8_max).clamp_min(eps) + v_scale = (v_absmax_running / fp8_max).clamp_min(eps) + + return k_scale.to(device, dtype=torch.float32), v_scale.to(device, dtype=torch.float32) + + def quantize(self, tensor: torch.Tensor, scale: torch.Tensor, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize a tensor using FP8. + + Args: + tensor: Tensor to quantize [seq_len, head_dim] or [seq_len, num_heads, head_dim] + scale: Quantization scale [1] or [num_heads] + **kwargs: Additional arguments + + Returns: + (quantized_tensor, scale): Tuple of quantized tensor (uint8) and scale + """ + assert scale is not None, "FP8 quantization requires scale" + assert self.spec.fp8_view_dtype is not None + assert self.spec.fp8_min is not None and self.spec.fp8_max is not None + + # Handle both [seq_len, head_dim] and [seq_len, num_heads, head_dim] input shapes + if tensor.dim() == 2: + # [seq_len, head_dim] case: scale should be [1] + descale = (1.0 / scale).view(-1, 1) # [1, 1] for broadcasting to [seq_len, head_dim] + elif tensor.dim() == 3: + # [seq_len, num_heads, head_dim] case: scale should be [num_heads] + descale = (1.0 / scale).view(1, -1, 1) # [1, num_heads, 1] for broadcasting + else: + raise ValueError(f"Expected 2D or 3D tensor, got {tensor.dim()}D tensor with shape {tensor.shape}") + + # Quantize: value / scale, then clamp to FP8 range + quantized = (tensor.float() * descale).clamp( + min=float(self.spec.fp8_min), + max=float(self.spec.fp8_max) + ) + + # Convert to FP8 view dtype, then view as uint8 for storage + quantized_fp8 = quantized.to(self.spec.fp8_view_dtype) + quantized_uint8 = quantized_fp8.view(torch.uint8) + + return quantized_uint8, scale + + def dequantize(self, quantized: torch.Tensor, scale: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Dequantize a tensor from FP8. + + Args: + quantized: Quantized tensor (uint8 storage) [seq_len, num_heads, head_dim] + scale: Quantization scale [num_heads] + **kwargs: Additional arguments + + Returns: + Dequantized tensor + """ + assert scale is not None, "FP8 dequantization requires scale" + assert self.spec.fp8_view_dtype is not None + + # View uint8 as FP8 dtype + fp8_tensor = quantized.view(self.spec.fp8_view_dtype).float() + + # Reshape scale to broadcast: [num_heads] -> [1, num_heads, 1] + scale_view = scale.view(1, -1, 1) + + # Dequantize: value * scale + return fp8_tensor * scale_view + + def get_scale_shape(self, original_shape: tuple[int, ...], num_kv_heads: int, **kwargs) -> tuple[int, ...]: + """ + Returns the shape of scale tensor. + + Args: + original_shape: Original tensor shape (not used for KV cache) + num_kv_heads: Number of KV heads + + Returns: + Scale shape: [num_kv_heads] + """ + return (num_kv_heads,) + + def init_scales(self, num_kv_heads: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + """ + Initialize quantization scales for K and V. + + Args: + num_kv_heads: Number of KV heads + device: Target device + + Returns: + (k_scale, v_scale): Initial scales with shape [num_kv_heads], initialized to 1.0 + """ + # Initialize scales to 1.0 (will be updated on first update_scales call) + k_scale = torch.ones((num_kv_heads,), device=device, dtype=torch.float32) + v_scale = torch.ones((num_kv_heads,), device=device, dtype=torch.float32) + return k_scale, v_scale + + +@register_kv_cache_strategy("fp8", "fp8_e4m3", "e4m3") +def _build_kv_cache_fp8_e4m3() -> KVCacheFP8RunningMaxStrategy: + return KVCacheFP8RunningMaxStrategy("fp8_e4m3") + + +@register_kv_cache_strategy("fp8_e5m2", "e5m2") +def _build_kv_cache_fp8_e5m2() -> KVCacheFP8RunningMaxStrategy: + return KVCacheFP8RunningMaxStrategy("fp8_e5m2") + diff --git a/diffulex/utils/quantization/strategies/linear_awq_marlin_w4a16.py b/diffulex/utils/quantization/strategies/linear_awq_marlin_w4a16.py new file mode 100644 index 0000000..bb19518 --- /dev/null +++ b/diffulex/utils/quantization/strategies/linear_awq_marlin_w4a16.py @@ -0,0 +1,166 @@ +""" +AWQ Marlin (W4, A16) Linear strategy using vLLM Marlin CUDA kernels. + +- Input activations: bf16 (cast to fp16 for vLLM marlin kernel) +- Weights: offline AWQ vLLM standard format (qweight/qzeros/scales) +- One-time repack/permutation is performed by Diffulex `LinearBase` and passed in via kwargs: + - awq_marlin_qweight / awq_marlin_scales / awq_marlin_zp + - awq_marlin_workspace +""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch + +from diffulex.utils.quantization.registry import register_linear_strategy +from diffulex.utils.quantization.strategy import LinearQuantizationStrategy + +try: + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( # type: ignore + apply_awq_marlin_linear, + marlin_make_empty_g_idx, + should_use_atomic_add_reduce, + marlin_permute_bias, + ) + from vllm.scalar_type import scalar_types # type: ignore +except Exception: # pragma: no cover + apply_awq_marlin_linear = None # type: ignore + marlin_make_empty_g_idx = None # type: ignore + should_use_atomic_add_reduce = None # type: ignore + marlin_permute_bias = None # type: ignore + scalar_types = None # type: ignore + + +@register_linear_strategy(weight_dtype="awq_marlin", act_dtype="bf16") +def _build_linear_awq_marlin_w4a16() -> LinearQuantizationStrategy: + return LinearAWQMarlinW4A16Strategy() + + +class LinearAWQMarlinW4A16Strategy(LinearQuantizationStrategy): + def __init__(self) -> None: + super().__init__() + self._available: bool = bool(apply_awq_marlin_linear is not None and scalar_types is not None) + self._empty_cache: dict[int, torch.Tensor] = {} + self._bias_cache: dict[tuple[int, int], torch.Tensor] = {} + self._atomic_add_cache: dict[tuple[int, int, int, int, int], bool] = {} + + @property + def name(self) -> str: + return "linear_awq_marlin_w4a16" + + @property + def linear_weight_format(self) -> str: + return "awq_marlin" + + @property + def linear_act_format(self) -> str: + return "bf16" + + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + return torch.int32, 4 + + def get_scale_shape(self, original_shape: tuple[int, ...], **kwargs) -> tuple[int, ...]: + # Same as AWQ: [K/group, N] + if len(original_shape) != 2: + raise ValueError(f"Expected 2D weight shape, got {original_shape}") + out_features, in_features = original_shape + group_size = int(kwargs.get("group_size", 128)) + group_size = in_features if group_size == -1 else group_size + if group_size <= 0 or in_features % group_size != 0: + raise ValueError(f"Invalid group_size={group_size} for in_features={in_features}") + num_groups = in_features // group_size + return (num_groups, out_features) + + def quantize(self, tensor: torch.Tensor, **kwargs) -> tuple[torch.Tensor, Any]: + return tensor, {} + + def dequantize(self, quantized: torch.Tensor, scale_or_metadata: Any, **kwargs) -> torch.Tensor: + if quantized.is_floating_point(): + return quantized + raise NotImplementedError("AWQ Marlin 不提供 Python dequantize;请使用 vLLM Marlin CUDA kernel。") + + def linear_forward( + self, + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + *, + quant_kind: str, + qweight: torch.Tensor, + scales: torch.Tensor, + zp: torch.Tensor, + workspace: Optional[torch.Tensor] = None, + in_features: int = 0, + out_features: int = 0, + group_size: int = 128, + tp_dim: Optional[int] = None, + ) -> torch.Tensor: + _ = quant_kind, weight, group_size, tp_dim + if not self._available or workspace is None: + raise RuntimeError("awq_marlin 需要 vLLM (marlin_utils + scalar_types);当前环境不可用。") + if in_features <= 0 or out_features <= 0: + raise RuntimeError("awq_marlin: missing in_features/out_features.") + + device = x.device + dev_key = int(device.index) if device.type == "cuda" and device.index is not None else -1 + + # AWQ marlin does not use g_idx/perm; pass empty tensors (cached). + empty = self._empty_cache.get(dev_key) + if empty is None: + empty = marlin_make_empty_g_idx(device) if marlin_make_empty_g_idx is not None else torch.empty((0,), device=device, dtype=torch.int32) + self._empty_cache[dev_key] = empty + + # Cache permuted bias. + marlin_bias = None + if bias is not None: + bkey = (dev_key, int(bias.data_ptr())) + marlin_bias = self._bias_cache.get(bkey) + if marlin_bias is None: + marlin_bias = marlin_permute_bias(bias) if marlin_permute_bias is not None else bias + self._bias_cache[bkey] = marlin_bias + + reshaped_x = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (int(out_features),) + m = int(reshaped_x.shape[0]) + n = int(out_features) + k = int(reshaped_x.shape[1]) + dtype_id = 1 if reshaped_x.dtype == torch.bfloat16 else (2 if reshaped_x.dtype == torch.float16 else 0) + use_atomic_add = False + if should_use_atomic_add_reduce is not None: + akey = (dev_key, dtype_id, m, n, k) + cached = self._atomic_add_cache.get(akey) + if cached is None: + cached = bool( + should_use_atomic_add_reduce( + m=m, n=n, k=k, device=device, dtype=reshaped_x.dtype + ) + ) + self._atomic_add_cache[akey] = cached + use_atomic_add = cached + + out = torch.ops._C.gptq_marlin_gemm( + reshaped_x, + None, + qweight, + marlin_bias, + scales, + None, + None, + zp, + empty, + empty, + workspace, + scalar_types.uint4.id, + m, + n, + k, + True, # is_k_full + use_atomic_add, + True, # use_fp32_reduce + False, # is_zp_float + ) + out = out.reshape(out_shape) + return out.to(dtype=x.dtype) if out.dtype != x.dtype else out + diff --git a/diffulex/utils/quantization/strategies/linear_awq_w4a16.py b/diffulex/utils/quantization/strategies/linear_awq_w4a16.py new file mode 100644 index 0000000..ea6675d --- /dev/null +++ b/diffulex/utils/quantization/strategies/linear_awq_w4a16.py @@ -0,0 +1,157 @@ +""" +AWQ W4A16 Linear quantization strategy (vLLM standard format). + +- Weight format: vLLM AWQ (packed int32 qweight/qzeros + fp16 scales) +- Activation: bf16 (no activation quantization) +- Forward: vLLM custom op `awq_gemm` (with the same heuristic as vLLM) + +No TileLang dependency. +""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch +import torch.nn.functional as F + +from diffulex.utils.quantization.registry import register_linear_strategy +from diffulex.utils.quantization.strategy import LinearQuantizationStrategy + +try: + from vllm import _custom_ops as ops # type: ignore +except Exception: # pragma: no cover + ops = None # type: ignore + +try: + # Triton fallback path for AWQ GEMM (works even when C++/CUDA ops are not built). + from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton # type: ignore +except Exception: # pragma: no cover + awq_gemm_triton = None # type: ignore + + +@register_linear_strategy(weight_dtype="awq", act_dtype="bf16") +def _build_linear_awq_w4a16() -> LinearQuantizationStrategy: + return LinearAWQW4A16Strategy() + + +class LinearAWQW4A16Strategy(LinearQuantizationStrategy): + def __init__(self) -> None: + super().__init__() + # Resolve the concrete kernel entry points once (avoid per-call dispatch). + self._awq_gemm_cpp = None + self._awq_gemm_triton = awq_gemm_triton + + awq_gemm = None + try: + if hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "awq_gemm"): + awq_gemm = torch.ops._C.awq_gemm + except Exception: + awq_gemm = None + # Prefer the real C++ op if present; otherwise keep `None` and fall back to Triton. + self._awq_gemm_cpp = awq_gemm + # Keep the python wrapper as a last resort (it may route to Triton or to torch.ops._C). + self._awq_gemm_py = ops.awq_gemm if (ops is not None and hasattr(ops, "awq_gemm")) else None + self._ops_available: bool = bool( + self._awq_gemm_cpp is not None or self._awq_gemm_triton is not None or self._awq_gemm_py is not None + ) + + @property + def name(self) -> str: + return "linear_awq_w4a16" + + @property + def linear_weight_format(self) -> str: + return "awq" + + @property + def linear_act_format(self) -> str: + return "bf16" + + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + # vLLM AWQ stores packed weights in int32. + return torch.int32, 4 + + def get_scale_shape(self, original_shape: tuple[int, ...], **kwargs) -> tuple[int, ...]: + # vLLM AWQ scales: [K/group, N], where Linear weight is (N, K). + if len(original_shape) != 2: + raise ValueError(f"Expected 2D weight shape, got {original_shape}") + out_features, in_features = original_shape + group_size = int(kwargs.get("group_size", 128)) + group_size = in_features if group_size == -1 else group_size + if group_size <= 0 or in_features % group_size != 0: + raise ValueError(f"Invalid group_size={group_size} for in_features={in_features}") + num_groups = in_features // group_size + return (num_groups, out_features) + + def quantize(self, tensor: torch.Tensor, **kwargs) -> tuple[torch.Tensor, Any]: + # Offline AWQ is handled by `diffulex.utils.quantization.quantize_model`. + return tensor, {} + + def dequantize(self, quantized: torch.Tensor, scale_or_metadata: Any, **kwargs) -> torch.Tensor: + if quantized.is_floating_point(): + return quantized + raise NotImplementedError( + "AWQ dequantize is not implemented in Diffulex. " + "Use vLLM kernels via linear_forward." + ) + + def linear_forward( + self, + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + *, + quant_kind: str, + awq_qweight: Optional[torch.Tensor] = None, + awq_qzeros: Optional[torch.Tensor] = None, + awq_scales: Optional[torch.Tensor] = None, + pack_factor: int = 8, + out_features: Optional[int] = None, + in_features: Optional[int] = None, + group_size: int = 128, + ) -> torch.Tensor: + _ = quant_kind, weight, pack_factor, in_features, group_size + if not self._ops_available: + raise RuntimeError( + "vLLM is required for AWQ W4A16 (missing `vllm._custom_ops`). " + "Please install/build vLLM with CUDA ops." + ) + qweight = awq_qweight + qzeros = awq_qzeros + scales = awq_scales + if qweight is None or qzeros is None or scales is None: + if weight is None: + raise RuntimeError("AWQ offline weights missing packed tensors and bf16 weight is not present.") + return F.linear(x, weight, bias) + + # vLLM AWQ kernels expect FP16 activations. + x_in = x if x.dtype == torch.float16 else x.to(dtype=torch.float16) + + # Use known out_features if provided (avoid per-call inference). + n = int(out_features) if out_features is not None else int(scales.shape[1]) + out_shape = x.shape[:-1] + (n,) + reshaped_x = x_in.reshape(-1, x_in.shape[-1]) + + # Always use awq_gemm to avoid large temporary dequantized weight allocations. + # vLLM API: + # - C++ op: awq_gemm(input, qweight, scales, qzeros, split_k_iters) + # - Triton : awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters) + split_k_iters = 1 + if self._awq_gemm_triton is not None: + out = self._awq_gemm_triton(reshaped_x, qweight, scales, qzeros, split_k_iters) # type: ignore[misc] + elif self._awq_gemm_cpp is not None: + out = self._awq_gemm_cpp(reshaped_x, qweight, scales, qzeros, split_k_iters) # type: ignore[misc] + elif self._awq_gemm_py is not None: + out = self._awq_gemm_py(reshaped_x, qweight, scales, qzeros, split_k_iters) # type: ignore[misc] + else: + raise RuntimeError( + "vLLM is required for AWQ W4A16 but no available kernel entry point was found " + "(missing both Triton and C++ awq_gemm)." + ) + + if bias is not None: + out.add_(bias.to(dtype=out.dtype)) + out = out.reshape(out_shape) + return out.to(dtype=x.dtype) if out.dtype != x.dtype else out + diff --git a/diffulex/utils/quantization/strategies/linear_bf16.py b/diffulex/utils/quantization/strategies/linear_bf16.py new file mode 100644 index 0000000..82d12bf --- /dev/null +++ b/diffulex/utils/quantization/strategies/linear_bf16.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import torch + +from diffulex.utils.quantization.registry import register_linear_strategy +from diffulex.utils.quantization.strategy import LinearQuantizationStrategy + + +@register_linear_strategy(weight_dtype="bf16", act_dtype="bf16") +def _build_linear_bf16() -> LinearQuantizationStrategy: + return LinearBF16Strategy() + + +class LinearBF16Strategy(LinearQuantizationStrategy): + """Default Linear strategy: no quantization (bf16/bf16).""" + + @property + def name(self) -> str: + return "linear_bf16" + + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + # No special storage; keep as-is. + return torch.bfloat16, 2 + + def quantize(self, tensor: torch.Tensor, **kwargs): + _ = kwargs + return tensor, None + + def dequantize(self, quantized: torch.Tensor, scale_or_metadata, **kwargs) -> torch.Tensor: + _ = scale_or_metadata, kwargs + return quantized + + def get_scale_shape(self, original_shape: tuple[int, ...], **kwargs) -> tuple[int, ...]: + _ = original_shape, kwargs + return tuple() + + + + diff --git a/diffulex/utils/quantization/strategies/linear_fp8_w8a16.py b/diffulex/utils/quantization/strategies/linear_fp8_w8a16.py new file mode 100644 index 0000000..b25cf99 --- /dev/null +++ b/diffulex/utils/quantization/strategies/linear_fp8_w8a16.py @@ -0,0 +1,140 @@ +""" +FP8 W8A16 Linear quantization strategy (FP8 weight + bf16 activation), TileLang-free. + +vLLM-aligned implementation: +- Weight quantization: `vllm._custom_ops.scaled_fp8_quant` (FP8 weight + per-tensor scale). +- Forward: use vLLM's `Fp8LinearOp` (CUTLASS scaled_mm when available). + +Note: +- vLLM 的 FP8 linear 核心路径以 e4m3 为主(由 vLLM 当前平台决定的 fp8 dtype)。 +- 为了避免“静默走慢路径”,这里不再使用 `F.linear` 的反量化 GEMM。 +""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch + +from diffulex.utils.quantization.registry import register_linear_strategy +from diffulex.utils.quantization.strategy import LinearQuantizationStrategy + + +@register_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="bf16") +def _build_linear_fp8_e4m3_w8a16() -> LinearQuantizationStrategy: + return LinearFP8W8A16Strategy("fp8_e4m3") + + +@register_linear_strategy(weight_dtype="fp8_e5m2", act_dtype="bf16") +def _build_linear_fp8_e5m2_w8a16() -> LinearQuantizationStrategy: + return LinearFP8W8A16Strategy("fp8_e5m2") + + +class LinearFP8W8A16Strategy(LinearQuantizationStrategy): + def __init__(self, weight_dtype: str = "fp8_e4m3") -> None: + super().__init__() + self.weight_dtype_str = weight_dtype + # Cache: id(weight) -> (q_fp8_KN [K,N], scale_fp32 [1]) + self._weight_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} + + try: + from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( # type: ignore + Fp8LinearOp, + ) + except Exception as e: # pragma: no cover + raise RuntimeError("FP8 需要 vLLM(Fp8LinearOp / _custom_ops)。") from e + + # dynamic activation quantization to FP8 inside vLLM + self._fp8_linear = Fp8LinearOp(act_quant_static=False) + + @property + def name(self) -> str: + return f"linear_fp8_{self.weight_dtype_str}_w8a16" + + @property + def linear_weight_format(self) -> str: + return self.weight_dtype_str + + @property + def linear_act_format(self) -> str: + return "bf16" + + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + # vLLM stores fp8 weights as float8 dtype tensor + return torch.uint8, 1 + + def get_scale_shape(self, original_shape: tuple[int, ...], **kwargs: Any) -> tuple[int, ...]: + _ = kwargs + if len(original_shape) != 2: + raise ValueError(f"Expected 2D weight [N,K], got {original_shape}") + # per-tensor scale + return (1,) + + def quantize(self, tensor: torch.Tensor, **kwargs: Any) -> tuple[torch.Tensor, Any]: + _ = kwargs + if tensor.dim() != 2: + raise ValueError(f"Expected 2D weight [N,K], got shape={tuple(tensor.shape)}") + from vllm import _custom_ops as ops # type: ignore + from vllm.platforms import current_platform # type: ignore + + # vLLM: per-tensor scale, output dtype = current_platform.fp8_dtype() + q_fp8, scale = ops.scaled_fp8_quant(tensor.to(torch.float32).contiguous(), scale=None) + # Keep transpose-view for CUTLASS expectation (b.stride(0) == 1). + q_kn_fp8 = q_fp8.t() # [K,N] fp8 dtype, non-contiguous + scale = scale.to(torch.float32).reshape(1).contiguous() + return q_kn_fp8, {"scales": scale, "fp8_dtype": current_platform.fp8_dtype()} + + def quantize_weight_for_kernel( + self, + weight: torch.Tensor, + *, + device: torch.device | None = None, + **_: Any, + ) -> tuple[torch.Tensor, Any]: + q_fp8, meta = self.quantize(weight) + if device is not None: + q_fp8 = q_fp8.to(device=device) + meta["scales"] = meta["scales"].to(device=device) + return q_fp8, meta["scales"] + + def dequantize(self, quantized: torch.Tensor, scale_or_metadata: Any, **kwargs: Any) -> torch.Tensor: + _ = kwargs + raise RuntimeError("FP8 不提供 dequantize 路径(避免走慢的反量化 + F.linear)。") + + def linear_forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + *, + quant_kind: str, + quant_scales: Optional[torch.Tensor] = None, + out_features: Optional[int] = None, + ) -> torch.Tensor: + _ = quant_kind, out_features + if weight is not None and quant_scales is not None: + # Expected: weight is fp8 K×N tensor (transpose-view is fine). + q_kn = weight.to(device=x.device) + scales = quant_scales.to(device=x.device, dtype=torch.float32).reshape(1) + else: + wid = id(weight) + cached = self._weight_cache.get(wid) + if cached is None or cached[0].device != x.device: + q_fp8, meta = self.quantize(weight) + q_fp8 = q_fp8.to(device=x.device) + scales = meta["scales"].to(device=x.device, dtype=torch.float32).reshape(1) + q_kn = q_fp8 + self._weight_cache[wid] = (q_fp8, scales) + else: + q_kn, scales = cached + + # vLLM Fp8LinearOp expects weight as [K,N] fp8 tensor and per-tensor scale. + return self._fp8_linear.apply( + input=x, + weight=q_kn, + weight_scale=scales, + out_dtype=x.dtype if x.dtype in (torch.bfloat16, torch.float16) else torch.bfloat16, + input_scale=None, + bias=bias, + ) + diff --git a/diffulex/utils/quantization/strategies/linear_fp8_w8a8.py b/diffulex/utils/quantization/strategies/linear_fp8_w8a8.py new file mode 100644 index 0000000..d7f48c6 --- /dev/null +++ b/diffulex/utils/quantization/strategies/linear_fp8_w8a8.py @@ -0,0 +1,129 @@ +""" +FP8 W8A8 Linear quantization strategy (FP8 weight + FP8 activation), TileLang-free. + +vLLM-aligned implementation: +- Weight quantization: `vllm._custom_ops.scaled_fp8_quant` (per-tensor scale). +- Activation quantization + GEMM: vLLM `Fp8LinearOp` (CUTLASS scaled_mm when available). +""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch + +from diffulex.utils.quantization.registry import register_linear_strategy +from diffulex.utils.quantization.strategy import LinearQuantizationStrategy + + +def _require_fp8_linear_op(): + try: + from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( # type: ignore + Fp8LinearOp, + ) + except Exception as e: # pragma: no cover + raise RuntimeError("FP8 需要 vLLM(Fp8LinearOp / _custom_ops)。") from e + return Fp8LinearOp + + +@register_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="fp8_e4m3") +def _build_linear_fp8_e4m3_w8a8() -> LinearQuantizationStrategy: + return LinearFP8W8A8Strategy("fp8_e4m3", "fp8_e4m3") + + +@register_linear_strategy(weight_dtype="fp8_e5m2", act_dtype="fp8_e5m2") +def _build_linear_fp8_e5m2_w8a8() -> LinearQuantizationStrategy: + return LinearFP8W8A8Strategy("fp8_e5m2", "fp8_e5m2") + + +class LinearFP8W8A8Strategy(LinearQuantizationStrategy): + def __init__(self, weight_dtype: str = "fp8_e4m3", act_dtype: str = "fp8_e4m3") -> None: + super().__init__() + self.weight_dtype_str = weight_dtype + self.act_dtype_str = act_dtype + # Cache: id(weight) -> (q_fp8_KN [K,N], scale_fp32 [1]) + self._weight_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} + Fp8LinearOp = _require_fp8_linear_op() + self._fp8_linear = Fp8LinearOp(act_quant_static=False) + + @property + def name(self) -> str: + return f"linear_fp8_{self.weight_dtype_str}_w8a8" + + @property + def linear_weight_format(self) -> str: + return self.weight_dtype_str + + @property + def linear_act_format(self) -> str: + return self.act_dtype_str + + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + return torch.uint8, 1 + + def get_scale_shape(self, original_shape: tuple[int, ...], **kwargs: Any) -> tuple[int, ...]: + _ = kwargs + if len(original_shape) != 2: + raise ValueError(f"Expected 2D weight [N,K], got {original_shape}") + return (1,) + + def quantize(self, tensor: torch.Tensor, **kwargs: Any) -> tuple[torch.Tensor, Any]: + _ = kwargs + if tensor.dim() != 2: + raise ValueError(f"Expected 2D weight [N,K], got shape={tuple(tensor.shape)}") + from vllm import _custom_ops as ops # type: ignore + from vllm.platforms import current_platform # type: ignore + + q_fp8, scale = ops.scaled_fp8_quant(tensor.to(torch.float32).contiguous(), scale=None) + q_kn_fp8 = q_fp8.t() # [K,N], stride(0)==1 + scale = scale.to(torch.float32).reshape(1).contiguous() + return q_kn_fp8, {"scales": scale, "fp8_dtype": current_platform.fp8_dtype()} + + def quantize_weight_for_kernel( + self, + weight: torch.Tensor, + *, + device: torch.device | None = None, + **_: Any, + ) -> tuple[torch.Tensor, Any]: + q_fp8, meta = self.quantize(weight) + if device is not None: + q_fp8 = q_fp8.to(device=device) + meta["scales"] = meta["scales"].to(device=device) + return q_fp8, meta["scales"] + + def dequantize(self, quantized: torch.Tensor, scale_or_metadata: Any, **kwargs: Any) -> torch.Tensor: + _ = kwargs + raise RuntimeError("FP8 不提供 dequantize 路径(避免走慢的反量化 + F.linear)。") + + def linear_forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + *, + quant_kind: str, + **kwargs: Any, + ) -> torch.Tensor: + _ = quant_kind + wid = id(weight) + cached = self._weight_cache.get(wid) + if cached is None or cached[0].device != x.device: + q_fp8, meta = self.quantize(weight) + q_fp8 = q_fp8.to(device=x.device) + w_scale = meta["scales"].to(device=x.device, dtype=torch.float32).reshape(1) + self._weight_cache[wid] = (q_fp8, w_scale) + else: + q_fp8, w_scale = cached + + q_kn = q_fp8 + + return self._fp8_linear.apply( + input=x, + weight=q_kn, + weight_scale=w_scale, + out_dtype=x.dtype if x.dtype in (torch.bfloat16, torch.float16) else torch.bfloat16, + input_scale=None, + bias=bias, + ) + diff --git a/diffulex/utils/quantization/strategies/linear_gptq_marlin_w4a16.py b/diffulex/utils/quantization/strategies/linear_gptq_marlin_w4a16.py new file mode 100644 index 0000000..1425c85 --- /dev/null +++ b/diffulex/utils/quantization/strategies/linear_gptq_marlin_w4a16.py @@ -0,0 +1,204 @@ +""" +GPTQ Marlin (W4/W8, A16) Linear strategy using vLLM Marlin CUDA kernels. + +- Input activations: bf16 (cast to fp16 for vLLM marlin kernel) +- Weights: offline GPTQ vLLM standard format (qweight/qzeros/scales/g_idx) +- One-time repack/permutation is performed by Diffulex `LinearBase` and passed in via kwargs: + - gptq_marlin_qweight / gptq_marlin_scales / gptq_marlin_zp + - gptq_marlin_g_idx / gptq_marlin_g_idx_sort_indices + - gptq_marlin_workspace + +This strategy intentionally does NOT fall back to F.linear silently: if marlin tensors +are missing, it raises to avoid accidentally benchmarking a slow path. +""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch +import torch.nn.functional as F + +from diffulex.utils.quantization.registry import register_linear_strategy +from diffulex.utils.quantization.strategy import LinearQuantizationStrategy + +try: + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( # type: ignore + apply_gptq_marlin_linear, + marlin_is_k_full, + marlin_make_empty_g_idx, + should_use_atomic_add_reduce, + marlin_permute_bias, + ) + from vllm.scalar_type import scalar_types # type: ignore +except Exception: # pragma: no cover + apply_gptq_marlin_linear = None # type: ignore + marlin_is_k_full = None # type: ignore + marlin_make_empty_g_idx = None # type: ignore + should_use_atomic_add_reduce = None # type: ignore + marlin_permute_bias = None # type: ignore + scalar_types = None # type: ignore + + +@register_linear_strategy(weight_dtype="gptq_marlin", act_dtype="bf16") +def _build_linear_gptq_marlin_w4a16() -> LinearQuantizationStrategy: + return LinearGPTQMarlinW4A16Strategy() + + +class LinearGPTQMarlinW4A16Strategy(LinearQuantizationStrategy): + def __init__(self) -> None: + super().__init__() + self._available: bool = bool(apply_gptq_marlin_linear is not None and scalar_types is not None) + self._empty_cache: dict[int, torch.Tensor] = {} + self._bias_cache: dict[tuple[int, int], torch.Tensor] = {} + self._atomic_add_cache: dict[tuple[int, int, int, int, int], bool] = {} + + @property + def name(self) -> str: + return "linear_gptq_marlin_w4a16" + + @property + def linear_weight_format(self) -> str: + return "gptq_marlin" + + @property + def linear_act_format(self) -> str: + return "bf16" + + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + return torch.int32, 4 + + def get_scale_shape(self, original_shape: tuple[int, ...], **kwargs) -> tuple[int, ...]: + # Same as GPTQ: [K/group, N] + if len(original_shape) != 2: + raise ValueError(f"Expected 2D weight shape, got {original_shape}") + out_features, in_features = original_shape + group_size = int(kwargs.get("group_size", 128)) + group_size = in_features if group_size == -1 else group_size + if group_size <= 0 or in_features % group_size != 0: + raise ValueError(f"Invalid group_size={group_size} for in_features={in_features}") + num_groups = in_features // group_size + return (num_groups, out_features) + + def quantize(self, tensor: torch.Tensor, **kwargs) -> tuple[torch.Tensor, Any]: + return tensor, {} + + def dequantize(self, quantized: torch.Tensor, scale_or_metadata: Any, **kwargs) -> torch.Tensor: + if quantized.is_floating_point(): + return quantized + raise NotImplementedError("GPTQ Marlin 不提供 Python dequantize;请使用 vLLM Marlin CUDA kernel。") + + def linear_forward( + self, + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + *, + quant_kind: str, + qweight: torch.Tensor, + scales: torch.Tensor, + zp: torch.Tensor, + g_idx: Optional[torch.Tensor] = None, + g_idx_sort_indices: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + in_features: int = 0, + out_features: int = 0, + group_size: int = 128, + weight_bits: int = 0, + tp_dim: Optional[int] = None, + ) -> torch.Tensor: + _ = quant_kind, weight, group_size + if not self._available or workspace is None: + raise RuntimeError("gptq_marlin 需要 vLLM (marlin_utils + scalar_types);当前环境不可用。") + + if in_features <= 0 or out_features <= 0: + raise RuntimeError("gptq_marlin: missing in_features/out_features.") + + if weight_bits == 4: + wtype = scalar_types.uint4b8 + elif weight_bits == 8: + wtype = scalar_types.uint8b128 + else: + raise RuntimeError(f"gptq_marlin: unsupported weight_bits={weight_bits} (expected 4 or 8)") + + device = x.device + dev_key = int(device.index) if device.type == "cuda" and device.index is not None else -1 + + # g_idx can be empty (desc_act=False). Prefer already-correct tensors; avoid per-call to(). + if g_idx is None or g_idx.numel() == 0: + empty = self._empty_cache.get(dev_key) + if empty is None: + empty = marlin_make_empty_g_idx(device) if marlin_make_empty_g_idx is not None else torch.empty((0,), device=device, dtype=torch.int32) + self._empty_cache[dev_key] = empty + g_idx_t = empty + else: + g_idx_t = g_idx + if g_idx_sort_indices is None or g_idx_sort_indices.numel() == 0: + empty = self._empty_cache.get(dev_key) + if empty is None: + empty = marlin_make_empty_g_idx(device) if marlin_make_empty_g_idx is not None else torch.empty((0,), device=device, dtype=torch.int32) + self._empty_cache[dev_key] = empty + g_idx_sort_t = empty + else: + g_idx_sort_t = g_idx_sort_indices + + # Determine whether K is full (needed by marlin kernel). Row-parallel layers set tp_dim=1 in Diffulex. + row_parallel = bool(tp_dim == 1) + has_g_idx = bool(g_idx_t.numel() > 0) + is_k_full = True if marlin_is_k_full is None else marlin_is_k_full(has_g_idx, row_parallel) + + # Cache permuted bias (Marlin expects permuted bias order). + marlin_bias = None + if bias is not None: + bkey = (dev_key, int(bias.data_ptr())) + marlin_bias = self._bias_cache.get(bkey) + if marlin_bias is None: + marlin_bias = marlin_permute_bias(bias) if marlin_permute_bias is not None else bias + self._bias_cache[bkey] = marlin_bias + + # Flatten like F.linear: [*,K] -> [M,K] + reshaped_x = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (int(out_features),) + + # Cache heuristic for atomic-add reduction (depends on M/N/K, device, dtype). + m = int(reshaped_x.shape[0]) + n = int(out_features) + k = int(reshaped_x.shape[1]) + dtype_id = 1 if reshaped_x.dtype == torch.bfloat16 else (2 if reshaped_x.dtype == torch.float16 else 0) + use_atomic_add = False + if should_use_atomic_add_reduce is not None: + akey = (dev_key, dtype_id, m, n, k) + cached = self._atomic_add_cache.get(akey) + if cached is None: + cached = bool( + should_use_atomic_add_reduce( + m=m, n=n, k=k, device=device, dtype=reshaped_x.dtype + ) + ) + self._atomic_add_cache[akey] = cached + use_atomic_add = cached + + # Directly call the underlying CUDA op to minimize Python glue. + out = torch.ops._C.gptq_marlin_gemm( + reshaped_x, + None, + qweight, + marlin_bias, + scales, + None, + None, + zp, + g_idx_t, + g_idx_sort_t, + workspace, + wtype.id, + m, + n, + k, + is_k_full, + use_atomic_add, + True, # use_fp32_reduce + False, # is_zp_float + ) + return out.reshape(out_shape) + diff --git a/diffulex/utils/quantization/strategies/linear_gptq_w4a16.py b/diffulex/utils/quantization/strategies/linear_gptq_w4a16.py new file mode 100644 index 0000000..7adfd10 --- /dev/null +++ b/diffulex/utils/quantization/strategies/linear_gptq_w4a16.py @@ -0,0 +1,153 @@ +""" +GPTQ W4A16 Linear quantization strategy (vLLM standard format). + +- Weight format: vLLM GPTQ (packed int32 qweight/qzeros + fp16 scales) +- Activation: bf16 (no activation quantization) +- Forward: vLLM custom op `gptq_gemm` + +Design notes: +- Diffulex follows vLLM's fast path: run `gptq_shuffle` once (handled by + `LinearBase._maybe_prepare_offline_gptq`) and then call `gptq_gemm` with + `use_exllama=True`. +- No TileLang dependency. +""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch +import torch.nn.functional as F + +from diffulex.utils.quantization.registry import register_linear_strategy +from diffulex.utils.quantization.strategy import LinearQuantizationStrategy + +try: + from vllm import _custom_ops as ops # type: ignore +except Exception: # pragma: no cover + ops = None # type: ignore + + +@register_linear_strategy(weight_dtype="gptq", act_dtype="bf16") +def _build_linear_gptq_w4a16() -> LinearQuantizationStrategy: + return LinearGPTQW4A16Strategy() + + +class LinearGPTQW4A16Strategy(LinearQuantizationStrategy): + def __init__(self) -> None: + super().__init__() + self._ops_available: bool = bool(ops is not None and hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "gptq_gemm")) + # Cache empty g_idx tensor per device to avoid per-call allocations. + self._empty_cache: dict[int, torch.Tensor] = {} + + @property + def name(self) -> str: + return "linear_gptq_w4a16" + + @property + def linear_weight_format(self) -> str: + return "gptq" + + @property + def linear_act_format(self) -> str: + return "bf16" + + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + # vLLM GPTQ stores packed weights in int32. + return torch.int32, 4 + + def get_scale_shape(self, original_shape: tuple[int, ...], **kwargs) -> tuple[int, ...]: + # vLLM GPTQ scales: [K/group, N], where Linear weight is (N, K). + if len(original_shape) != 2: + raise ValueError(f"Expected 2D weight shape, got {original_shape}") + out_features, in_features = original_shape + group_size = int(kwargs.get("group_size", 128)) + group_size = in_features if group_size == -1 else group_size + if group_size <= 0 or in_features % group_size != 0: + raise ValueError(f"Invalid group_size={group_size} for in_features={in_features}") + num_groups = in_features // group_size + return (num_groups, out_features) + + def quantize(self, tensor: torch.Tensor, **kwargs) -> tuple[torch.Tensor, Any]: + # Offline GPTQ is handled by `diffulex.utils.quantization.quantize_model`. + return tensor, {} + + def dequantize(self, quantized: torch.Tensor, scale_or_metadata: Any, **kwargs) -> torch.Tensor: + if quantized.is_floating_point(): + return quantized + raise NotImplementedError( + "GPTQ dequantize is not implemented in Diffulex. " + "Use vLLM kernels via linear_forward." + ) + + def linear_forward( + self, + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + *, + quant_kind: str, + gptq_qweight: Optional[torch.Tensor] = None, + gptq_qzeros: Optional[torch.Tensor] = None, + gptq_scales: Optional[torch.Tensor] = None, + gptq_g_idx: Optional[torch.Tensor] = None, + weight_bits: int = 0, + use_v2_format: bool = False, + out_features: Optional[int] = None, + in_features: Optional[int] = None, + group_size: int = 128, + ) -> torch.Tensor: + _ = quant_kind, weight, in_features, group_size + if not self._ops_available: + raise RuntimeError( + "vLLM is required for GPTQ W4A16 (missing `vllm._custom_ops`). " + "Please install/build vLLM with CUDA ops." + ) + qweight = gptq_qweight + qzeros = gptq_qzeros + scales = gptq_scales + g_idx = gptq_g_idx + + if qweight is None or qzeros is None or scales is None: + # correctness fallback (should not happen for offline GPTQ weights) + if weight is None: + raise RuntimeError("GPTQ offline weights missing packed tensors and bf16 weight is not present.") + return F.linear(x, weight, bias) + + if weight_bits <= 0: + raise RuntimeError("GPTQ requires explicit weight_bits (>0) for the CUDA kernel path.") + + # vLLM GPTQ kernels expect FP16 activations. + x_in = x if x.dtype == torch.float16 else x.to(dtype=torch.float16) + x2 = x_in.reshape(-1, x_in.shape[-1]) if x_in.dim() != 2 else x_in + if not x2.is_contiguous(): + x2 = x2.contiguous() + + device = x.device + dev_key = int(device.index) if device.type == "cuda" and device.index is not None else -1 + if g_idx is None or g_idx.numel() == 0: + empty = self._empty_cache.get(dev_key) + if empty is None or empty.device != device: + empty = torch.empty((0,), device=device, dtype=torch.int) + self._empty_cache[dev_key] = empty + g_idx_t = empty + else: + g_idx_t = g_idx if (g_idx.device == device and g_idx.dtype == torch.int) else g_idx.to(device=device, dtype=torch.int) + + output = torch.ops._C.gptq_gemm( + x2, + qweight, + qzeros, + scales, + g_idx_t, + True, # use_exllama + bool(use_v2_format), + int(weight_bits), + ) + if bias is not None: + output.add_(bias.to(dtype=output.dtype)) + + out_shape = x.shape[:-1] + (int(out_features) if out_features is not None else int(qweight.shape[-1]),) + output = output.reshape(out_shape) + return output.to(dtype=x.dtype) if output.dtype != x.dtype else output + diff --git a/diffulex/utils/quantization/strategies/linear_int4_w4a16.py b/diffulex/utils/quantization/strategies/linear_int4_w4a16.py new file mode 100644 index 0000000..870a860 --- /dev/null +++ b/diffulex/utils/quantization/strategies/linear_int4_w4a16.py @@ -0,0 +1,157 @@ +""" +W4A16 Linear quantization strategy (int4 weight + bf16 activation), TileLang-free. + +vLLM-aligned behavior: +- vLLM 在 sm89(如 4090)上并没有“在线 int4 -> 快 GEMM”的通用路径; + 真正的 int4 加速通常依赖 GPTQ/AWQ 的 marlin/cutlass 以及对应的离线权重格式。 +- 为避免“看起来是 int4 但实际在跑 bf16 GEMM”,默认禁止静默走 `F.linear` 慢路径。 + +如需临时允许 correctness-first 慢 fallback,可设置环境变量: + `DIFFULEX_ALLOW_SLOW_QUANT_FALLBACK=1` +""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch +import torch.nn.functional as F + +from diffulex.utils.quantization.registry import register_linear_strategy +from diffulex.utils.quantization.strategy import LinearQuantizationStrategy + + +@register_linear_strategy(weight_dtype="int4", act_dtype="bf16") +def _build_linear_int4_w4a16() -> LinearQuantizationStrategy: + return LinearInt4W4A16Strategy() + + +class LinearInt4W4A16Strategy(LinearQuantizationStrategy): + def __init__(self) -> None: + super().__init__() + # Cache: id(weight) -> (packed_int8 [N, ceil(K/2)], scales_fp32 [N]) + self._weight_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} + + @property + def name(self) -> str: + return "linear_int4_w4a16" + + @property + def linear_weight_format(self) -> str: + return "int4" + + @property + def linear_act_format(self) -> str: + return "bf16" + + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + return torch.int8, 1 + + def get_scale_shape(self, original_shape: tuple[int, ...], **kwargs: Any) -> tuple[int, ...]: + _ = kwargs + if len(original_shape) != 2: + raise ValueError(f"Expected 2D weight [N,K], got {original_shape}") + return (original_shape[0],) + + @staticmethod + def _pack_int4_to_int8(int4_tensor: torch.Tensor) -> torch.Tensor: + # int4_tensor: int8 [N,K] values in [-8,7] + n, k = int4_tensor.shape + t = int4_tensor.clamp(-8, 7).to(torch.int16) + u = (t + 8).to(torch.uint8) # [0,15] + if k % 2 != 0: + u = torch.cat([u, torch.full((n, 1), 8, device=u.device, dtype=torch.uint8)], dim=1) + k = k + 1 + u2 = u.view(n, k // 2, 2) + packed = (u2[:, :, 0] | (u2[:, :, 1] << 4)).to(torch.int8) + return packed.contiguous() + + @staticmethod + def _unpack_int8_to_int4(packed: torch.Tensor, *, original_k: int) -> torch.Tensor: + # packed: int8 [N, ceil(K/2)] (two nibbles per byte) + p = packed.view(torch.uint8) + low = (p & 0x0F).to(torch.int16) - 8 + high = ((p >> 4) & 0x0F).to(torch.int16) - 8 + n, pk = packed.shape + out = torch.empty((n, pk * 2), device=packed.device, dtype=torch.int16) + out[:, 0::2] = low + out[:, 1::2] = high + return out[:, :original_k].to(torch.int8).contiguous() + + def quantize(self, tensor: torch.Tensor, **kwargs: Any) -> tuple[torch.Tensor, Any]: + _ = kwargs + if tensor.dim() != 2: + raise ValueError(f"Expected 2D weight [N,K], got shape={tuple(tensor.shape)}") + w = tensor.to(torch.bfloat16) + abs_max = w.abs().amax(dim=-1, keepdim=True) # [N,1] + scales = (abs_max.clamp(min=1e-8) / 7.0).to(torch.float32).squeeze(-1) # [N] + q = torch.round(w.to(torch.float32) / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8) + packed = self._pack_int4_to_int8(q) + return packed, {"scales": scales} + + def quantize_weight_for_kernel( + self, + weight: torch.Tensor, + *, + device: torch.device | None = None, + **_: Any, + ) -> tuple[torch.Tensor, Any]: + """Quantize+pack bf16 weight for kernel consumption. + + Returns: + (packed_int8 [N, ceil(K/2)], scales_fp32 [N]) + """ + packed, meta = self.quantize(weight) + if device is not None: + packed = packed.to(device=device) + meta["scales"] = meta["scales"].to(device=device) + return packed, meta["scales"] + + def dequantize(self, quantized: torch.Tensor, scale_or_metadata: Any, **kwargs: Any) -> torch.Tensor: + original_k = int(kwargs.get("original_in_features", 0)) + if original_k <= 0: + raise ValueError("original_in_features is required to dequantize int4 weights") + scales = scale_or_metadata.get("scales") if isinstance(scale_or_metadata, dict) else scale_or_metadata + if scales is None: + raise ValueError("scales required for dequantization") + q = self._unpack_int8_to_int4(quantized, original_k=original_k).to(torch.float32) + w = q * scales.to(torch.float32).unsqueeze(-1) + return w.to(torch.bfloat16) + + def linear_forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + *, + quant_kind: str, + **kwargs: Any, + ) -> torch.Tensor: + _ = quant_kind + if not bool(int(__import__("os").environ.get("DIFFULEX_ALLOW_SLOW_QUANT_FALLBACK", "0"))): + raise RuntimeError( + "当前平台/配置下 `int4` 在线量化没有可用的 vLLM 快 kernel(例如 4090/sm89 无 CUTLASS W4A8)。" + "为避免静默退化到 bf16 GEMM,已禁止 `F.linear` 慢 fallback。" + "请改用 `gptq/awq`(vLLM 标准打包格式)或设置 DIFFULEX_ALLOW_SLOW_QUANT_FALLBACK=1 临时开启。" + ) + original_k = int(kwargs.get("original_in_features", x.shape[-1])) + quant_scales = kwargs.get("quant_scales", None) + + if weight is not None and weight.dtype == torch.int8 and quant_scales is not None: + packed = weight.to(device=x.device) + scales = quant_scales.to(device=x.device, dtype=torch.float32) + else: + wid = id(weight) + cached = self._weight_cache.get(wid) + if cached is None or cached[0].device != x.device: + packed, meta = self.quantize(weight) + packed = packed.to(device=x.device) + scales = meta["scales"].to(device=x.device, dtype=torch.float32) + self._weight_cache[wid] = (packed, scales) + else: + packed, scales = cached + + # Slow fallback (explicitly opted-in). + w_deq = self.dequantize(packed, {"scales": scales}, original_in_features=original_k) + return F.linear(x, w_deq, bias) + diff --git a/diffulex/utils/quantization/strategies/linear_int4_w4a8.py b/diffulex/utils/quantization/strategies/linear_int4_w4a8.py new file mode 100644 index 0000000..decb19d --- /dev/null +++ b/diffulex/utils/quantization/strategies/linear_int4_w4a8.py @@ -0,0 +1,77 @@ +""" +W4A8 Linear quantization strategy (int4 weight + int8 activation), TileLang-free. + +vLLM-aligned behavior: +- vLLM 的 CUTLASS W4A8 kernel 需要 sm90(Hopper);在 sm89(如 4090)上不可用。 +- 为避免静默退化到 bf16 GEMM,默认禁止 `F.linear` 慢 fallback。 + +如需临时允许 correctness-first 慢 fallback,可设置: + `DIFFULEX_ALLOW_SLOW_QUANT_FALLBACK=1` +""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch +import torch.nn.functional as F + +from diffulex.utils.quantization.registry import register_linear_strategy +from diffulex.utils.quantization.strategy import LinearQuantizationStrategy + +from .linear_int4_w4a16 import LinearInt4W4A16Strategy + + +@register_linear_strategy(weight_dtype="int4", act_dtype="int8") +def _build_linear_int4_w4a8() -> LinearQuantizationStrategy: + return LinearInt4W4A8Strategy() + + +class LinearInt4W4A8Strategy(LinearQuantizationStrategy): + def __init__(self) -> None: + super().__init__() + self._w4a16 = LinearInt4W4A16Strategy() + + @property + def name(self) -> str: + return "linear_int4_w4a8" + + @property + def linear_weight_format(self) -> str: + return "int4" + + @property + def linear_act_format(self) -> str: + return "int8" + + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + return torch.int8, 1 + + def get_scale_shape(self, original_shape: tuple[int, ...], **kwargs: Any) -> tuple[int, ...]: + return self._w4a16.get_scale_shape(original_shape, **kwargs) + + def quantize(self, tensor: torch.Tensor, **kwargs: Any) -> tuple[torch.Tensor, Any]: + return self._w4a16.quantize(tensor, **kwargs) + + def dequantize(self, quantized: torch.Tensor, scale_or_metadata: Any, **kwargs: Any) -> torch.Tensor: + return self._w4a16.dequantize(quantized, scale_or_metadata, **kwargs) + + def linear_forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + *, + quant_kind: str, + **kwargs: Any, + ) -> torch.Tensor: + _ = quant_kind + if not bool(int(__import__("os").environ.get("DIFFULEX_ALLOW_SLOW_QUANT_FALLBACK", "0"))): + raise RuntimeError( + "当前平台/配置下 `int4` 在线量化没有可用的 vLLM 快 kernel(例如 4090/sm89 无 CUTLASS W4A8)。" + "为避免静默退化到 bf16 GEMM,已禁止 `F.linear` 慢 fallback。" + "请改用 `gptq/awq`(vLLM 标准打包格式)或设置 DIFFULEX_ALLOW_SLOW_QUANT_FALLBACK=1 临时开启。" + ) + # Correctness-first: reuse W4A16 implementation. + return self._w4a16.linear_forward(x, weight, bias, quant_kind="other", **kwargs) + diff --git a/diffulex/utils/quantization/strategies/linear_int8_w8a16.py b/diffulex/utils/quantization/strategies/linear_int8_w8a16.py new file mode 100644 index 0000000..67ab104 --- /dev/null +++ b/diffulex/utils/quantization/strategies/linear_int8_w8a16.py @@ -0,0 +1,29 @@ +""" +W8A16 Linear quantization strategy (int8 weight + bf16 activation). + +This path is now implemented by reusing Diffulex's marlin(AllSpark)-style W8A16 +strategy, which matches vLLM's effective fast path and avoids TileLang. +""" + +from __future__ import annotations + +from diffulex.utils.quantization.registry import register_linear_strategy +from diffulex.utils.quantization.strategy import LinearQuantizationStrategy + +from .linear_marlin_int8_w8a16 import LinearMarlinInt8W8A16Strategy + + +class LinearInt8W8A16Strategy(LinearMarlinInt8W8A16Strategy): + """ + Compatibility alias for the historical Diffulex strategy name. + + This keeps the registry and `strategies.__init__` imports stable while + reusing the vLLM-aligned marlin(AllSpark) W8A16 implementation. + """ + + +@register_linear_strategy(weight_dtype="int8", act_dtype="bf16") +def _build_linear_int8_w8a16() -> LinearQuantizationStrategy: + # Alias to marlin(AllSpark) W8A16 implementation. + return LinearInt8W8A16Strategy() + diff --git a/diffulex/utils/quantization/strategies/linear_int8_w8a8.py b/diffulex/utils/quantization/strategies/linear_int8_w8a8.py new file mode 100644 index 0000000..ba07440 --- /dev/null +++ b/diffulex/utils/quantization/strategies/linear_int8_w8a8.py @@ -0,0 +1,153 @@ +""" +W8A8 Linear quantization strategy (int8 weight + int8 activation), TileLang-free. + +Implementation (vLLM-aligned): +- Activation quantization: `vllm._custom_ops.scaled_int8_quant` (dynamic per-token). +- GEMM+dequant: `vllm._custom_ops.cutlass_scaled_mm` (CUTLASS, with internal + triton fallback depending on shape/platform) — no `F.linear` slow path. + +Notes: +- Weight is stored as int8 in **K×N** layout (transposed), matching vLLM CUTLASS + kernels. +- Weight scale is stored as **[1, N]** float32 for broadcasting. +""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch # type: ignore + +from diffulex.utils.quantization.registry import register_linear_strategy +from diffulex.utils.quantization.strategy import LinearQuantizationStrategy + + +try: + from vllm import _custom_ops as _vllm_ops # type: ignore +except Exception: # pragma: no cover + _vllm_ops = None # type: ignore + + +@register_linear_strategy(weight_dtype="int8", act_dtype="int8") +def _build_linear_int8_w8a8() -> LinearQuantizationStrategy: + return LinearInt8W8A8Strategy() + + +class LinearInt8W8A8Strategy(LinearQuantizationStrategy): + def __init__(self) -> None: + super().__init__() + # Cache: id(weight) -> (qweight_int8 [N,K], w_scales_fp32 [N]) + self._weight_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} + + @property + def name(self) -> str: + return "linear_int8_w8a8" + + @property + def linear_weight_format(self) -> str: + return "int8" + + @property + def linear_act_format(self) -> str: + return "int8" + + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + return torch.int8, 1 + + def get_scale_shape(self, original_shape: tuple[int, ...], **kwargs: Any) -> tuple[int, ...]: + _ = kwargs + if len(original_shape) != 2: + raise ValueError(f"Expected 2D weight [N,K], got {original_shape}") + return (original_shape[0],) + + def quantize(self, tensor: torch.Tensor, **kwargs: Any) -> tuple[torch.Tensor, Any]: + _ = kwargs + if tensor.dim() != 2: + raise ValueError(f"Expected 2D weight [N,K], got shape={tuple(tensor.shape)}") + # per-output-channel symmetric int8, store K×N for cutlass_scaled_mm + w = tensor.to(torch.float32) + abs_max = w.abs().amax(dim=-1, keepdim=False) # [N] + scales = (abs_max.clamp(min=1e-8) / 127.0).to(torch.float32) # [N] + q_nk = torch.round(w / scales.unsqueeze(-1)).clamp(-127, 127).to(torch.int8) # [N,K] + # NOTE: vLLM CUTLASS scaled_mm expects b.stride(0) == 1, which is true + # for a transpose-view (non-contiguous) but not for a contiguous K×N tensor. + q_kn = q_nk.t() # [K,N], stride(0)==1 + scale_b = scales.unsqueeze(0).contiguous() # [1,N] + return q_kn, {"scales": scale_b} + + def quantize_weight_for_kernel( + self, + weight: torch.Tensor, + *, + device: torch.device | None = None, + **_: Any, + ) -> tuple[torch.Tensor, Any]: + # Return int8 K×N weights + fp32 [1,N] scales for vLLM CUTLASS path. + q_kn, meta = self.quantize(weight) + if device is not None: + q_kn = q_kn.to(device=device) + meta["scales"] = meta["scales"].to(device=device) + return q_kn, meta["scales"] + + def dequantize(self, quantized: torch.Tensor, scale_or_metadata: Any, **kwargs: Any) -> torch.Tensor: + _ = kwargs + scales = scale_or_metadata.get("scales") if isinstance(scale_or_metadata, dict) else scale_or_metadata + if scales is None: + raise ValueError("scales required for dequantization") + raise RuntimeError( + "W8A8 不提供 dequantize 路径(避免走慢的 bf16 GEMM)。" + ) + + def linear_forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + *, + quant_kind: str, + quant_scales: Optional[torch.Tensor] = None, + out_features: Optional[int] = None, + ) -> torch.Tensor: + _ = quant_kind + if _vllm_ops is None: + raise RuntimeError("vLLM custom ops are required for W8A8 (scaled_int8_quant / cutlass_scaled_mm).") + + # Weight/scales: prefer load-time quantized buffers. + if weight is not None and weight.dtype == torch.int8 and quant_scales is not None: + qweight = weight + w_scales = quant_scales + else: + wid = id(weight) + cached = self._weight_cache.get(wid) + if cached is None or cached[0].device != x.device: + qweight, meta = self.quantize(weight) + qweight = qweight.to(device=x.device) + w_scales = meta["scales"].to(device=x.device, dtype=torch.float32) + self._weight_cache[wid] = (qweight, w_scales) + else: + qweight, w_scales = cached + + orig_shape = x.shape + x2 = x.reshape(-1, x.shape[-1]) if x.dim() != 2 else x + if x2.dtype not in (torch.bfloat16, torch.float16): + x2 = x2.to(torch.bfloat16) + if not x2.is_contiguous(): + x2 = x2.contiguous() + + # dynamic per-token int8 quant + fused GEMM+dequant + x_q, x_s, _ = _vllm_ops.scaled_int8_quant(x2, scale=None, azp=None, symmetric=True) + y = _vllm_ops.cutlass_scaled_mm( + x_q, + qweight, + scale_a=x_s, + scale_b=w_scales, + out_dtype=x2.dtype, + bias=bias.to(dtype=x2.dtype) if bias is not None else None, + ) + + if orig_shape == x2.shape: + return y + if x.dim() == 1: + return y.squeeze(0) + return y.reshape(*orig_shape[:-1], y.shape[-1]) + diff --git a/diffulex/utils/quantization/strategies/linear_marlin_int8_w8a16.py b/diffulex/utils/quantization/strategies/linear_marlin_int8_w8a16.py new file mode 100644 index 0000000..ceb3630 --- /dev/null +++ b/diffulex/utils/quantization/strategies/linear_marlin_int8_w8a16.py @@ -0,0 +1,332 @@ +"""W8A16 Linear quantization strategy using vLLM custom ops. + +This strategy uses vLLM's fused AllSpark W8A16 path via `vllm._custom_ops`: +- per-out-channel int8 quantization stored as uint8 (+128 bias) +- one-time N32K16 reorder (AllSpark repack) +- fused dequant + GEMM (AllSpark w8a16 gemm) + +Important: +- We intentionally do NOT vendor/compile a local AllSpark/Marlin extension in + Diffulex anymore. If `vllm._custom_ops` is unavailable, this strategy fails + fast (instead of silently compiling or falling back to a slow/oom-prone path). +""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch +import torch.nn.functional as F + +from diffulex.utils.quantization.registry import register_linear_strategy +from diffulex.utils.quantization.strategy import LinearQuantizationStrategy + +try: + import vllm._custom_ops as _vllm_ops +except Exception: + _vllm_ops = None + + +def _allspark_is_available() -> bool: + return bool( + _vllm_ops is not None + and hasattr(_vllm_ops, "allspark_w8a16_gemm") + and hasattr(_vllm_ops, "allspark_repack_weight") + ) + +def _allspark_repack_weight(b_qweight_kn: torch.Tensor, scales_1xn: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Repack KxN uint8 qweight + 1xN scales into (N_32,K) + (1,N_32) for AllSpark GEMM.""" + if _vllm_ops is None or not hasattr(_vllm_ops, "allspark_repack_weight"): + raise RuntimeError("vLLM custom ops are unavailable: missing `allspark_repack_weight`.") + q_reorder, s_reorder, _ = _vllm_ops.allspark_repack_weight( + b_qweight_kn, + scales_1xn, + None, + False, + ) + return q_reorder, s_reorder + + +@register_linear_strategy(weight_dtype="marlin_int8", act_dtype="bf16") +def _build_linear_marlin_int8_w8a16() -> LinearQuantizationStrategy: + return LinearMarlinInt8W8A16Strategy() + + +class LinearMarlinInt8W8A16Strategy(LinearQuantizationStrategy): + """W8A16 strategy using vLLM custom ops (AllSpark fused GEMM + repack).""" + + def __init__(self) -> None: + super().__init__() + # Cache for bf16 Parameters only (load-time quantized path bypasses this). + self._weight_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} + # Cache device info and thresholds to reduce per-call CPU overhead. + self._sm_info_cache: dict[int, tuple[int, int]] = {} + self._quant_block_n: int = 256 + self._cublas_m_thr: int = 256 + # One-time availability check (avoid calling `_allspark_is_available()` on every linear). + self._allspark_available: bool = _allspark_is_available() + + def configure(self, *, diffulex_config: Any | None = None) -> None: + # Prefer explicit config fields over environment-variable based tuning. + if diffulex_config is None: + return + try: + bn = int(getattr(diffulex_config, "linear_w8a16_quant_block_n", self._quant_block_n)) + self._quant_block_n = max(1, bn) + except Exception: + pass + try: + thr = int(getattr(diffulex_config, "linear_w8a16_allspark_cublas_m_threshold", self._cublas_m_thr)) + self._cublas_m_thr = max(1, thr) + except Exception: + pass + + @property + def name(self) -> str: + # NOTE: Keep strategy naming consistent with the public W8A16 INT8 path. + # The underlying implementation is a Marlin/AllSpark-style fused kernel, + # but the user-facing strategy name should not be tied to a particular kernel brand. + return "linear_int8_w8a16" + + @property + def linear_weight_format(self) -> str: + # Important: keep "int8" so LinearBase load-time quantization path triggers + # and drops bf16 weights to save memory. + return "int8" + + @property + def linear_act_format(self) -> str: + return "bf16" + + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + # We store qweight as uint8 (bias128 representation). + return torch.uint8, 1 + + # ---- Required abstract methods (for registry/factory instantiation) ---- + def quantize(self, tensor: torch.Tensor, **kwargs: Any) -> tuple[torch.Tensor, Any]: + """Reference per-output-channel symmetric int8 quantization. + + Returns: + quantized_int8: [N,K] int8 + scales: [N] bf16 + """ + _ = kwargs + if tensor.dim() != 2: + raise ValueError(f"Expected 2D weight [N,K], got shape={tuple(tensor.shape)}") + if tensor.dtype != torch.bfloat16: + tensor = tensor.to(dtype=torch.bfloat16) + abs_max = torch.abs(tensor).max(dim=-1, keepdim=True)[0] # [N,1] + scales = (abs_max.clamp(min=1e-8) / 127.0).to(dtype=torch.bfloat16) # [N,1] + q = torch.round(tensor.to(torch.float32) / scales.to(torch.float32)).clamp(-128, 127).to(torch.int8) + return q, scales.squeeze(-1) + + def dequantize(self, quantized: torch.Tensor, scale_or_metadata: Any, **kwargs: Any) -> torch.Tensor: + """Reference dequantization back to bf16.""" + _ = kwargs + scales = scale_or_metadata.get("scales") if isinstance(scale_or_metadata, dict) else scale_or_metadata + if scales is None: + raise ValueError("scales required for dequantization") + if scales.dim() == 1: + scales = scales.unsqueeze(-1) + return (quantized.to(torch.float32) * scales.to(torch.float32)).to(torch.bfloat16) + + def get_scale_shape(self, original_shape: tuple[int, ...], **kwargs: Any) -> tuple[int, ...]: + _ = kwargs + if len(original_shape) < 2: + raise ValueError(f"Expected weight shape with at least 2 dims, got {original_shape}") + return (original_shape[0],) + + def quantize_weight_for_kernel( + self, + weight: torch.Tensor, + *, + device: torch.device | None = None, + **kwargs: Any, + ) -> tuple[torch.Tensor, Any]: + """Quantize+repack bf16 weight for AllSpark fused kernel. + + Input: + weight: [N, K] bf16/fp16 + Output: + qweight_reorder: [N_32align, K] uint8 in N32K16 reorder layout + scales_reorder: [N_32align] bf16 scales (reordered/padded) + """ + _ = kwargs + if device is not None: + weight = weight.to(device=device) + + if weight.dim() != 2: + raise ValueError(f"Expected 2D weight [N,K], got shape={tuple(weight.shape)}") + + # Ensure bf16 for stable scales. + if weight.dtype != torch.bfloat16: + weight = weight.to(dtype=torch.bfloat16) + + n, k = weight.shape + n_32 = ((n + 31) // 32) * 32 + + # Per-output-channel symmetric scale. + abs_max = torch.abs(weight).max(dim=-1)[0] # [N] + scales = (abs_max.clamp(min=1e-8) / 127.0).to(dtype=torch.bfloat16) # [N] + + # IMPORTANT (OOM fix): + # Avoid allocating a full [N,K] fp32 copy (and an extra transpose buffer). + # Quantize in small row blocks and (when using AllSpark) write directly into + # the repack input layout B_kn=[K,N], so we never materialize q_u8 + transpose. + block_n = max(1, int(self._quant_block_n)) + + if self._allspark_available: + # AllSpark repack expects B in (K,N) contiguous layout. + b_kn = torch.empty((k, n), device=weight.device, dtype=torch.uint8) # [K,N] + for i in range(0, n, block_n): + j = min(i + block_n, n) + w_blk = weight[i:j, :] # [B,K] + s_blk = scales[i:j].unsqueeze(-1) # [B,1] + # Quantize to signed int in bf16 to minimize temporary memory. + q_i16 = torch.round(w_blk / s_blk).clamp(-128, 127).to(torch.int16) # [B,K] + q_u8_blk = (q_i16 + 128).to(torch.uint8) # [B,K] + # Write directly into [K,N] buffer. + b_kn[:, i:j] = q_u8_blk.transpose(0, 1) + else: + # Fallback storage (no reorder). Keep [N,K] and [N] (padded to N_32). + # Note: forward will detect unavailable allspark and fallback further. + q_pad = torch.full((n_32, k), 128, device=weight.device, dtype=torch.uint8) + for i in range(0, n, block_n): + j = min(i + block_n, n) + w_blk = weight[i:j, :] # [B,K] + s_blk = scales[i:j].unsqueeze(-1) # [B,1] + q_i16 = torch.round(w_blk / s_blk).clamp(-128, 127).to(torch.int16) # [B,K] + q_pad[i:j, :] = (q_i16 + 128).to(torch.uint8) + if n_32 != n: + s_pad = torch.zeros((n_32,), device=scales.device, dtype=torch.bfloat16) + s_pad[:n] = scales + return q_pad.contiguous(), s_pad.contiguous() + return q_pad[:n, :].contiguous(), scales.contiguous() + + # vLLM expects scales in [1, N] layout for repack. + q_reorder, s_reorder_1xn = _allspark_repack_weight( + b_kn.contiguous(), + scales.unsqueeze(0).contiguous(), + ) + + # Store scales as 1D for LinearBase buffers; linear_forward will reshape as needed. + s_1d = s_reorder_1xn.reshape(-1).to(dtype=torch.bfloat16) + return q_reorder.contiguous(), s_1d.contiguous() + + def quantize_act_for_kernel( + self, + x: torch.Tensor, + *, + device: torch.device | None = None, + **kwargs: Any, + ) -> tuple[torch.Tensor, Any]: + _ = kwargs + if device is not None: + x = x.to(device=device) + # No activation quantization for W8A16. + return x, None + + def _get_sm_info(self, device: torch.device) -> tuple[int, int]: + # get_device_properties is relatively expensive on hot paths; cache per device index. + try: + idx = int(device.index) if device.index is not None else int(torch.cuda.current_device()) + except Exception: + idx = -1 + cached = self._sm_info_cache.get(idx) + if cached is not None: + return cached + try: + props = torch.cuda.get_device_properties(device) + sm_count = int(getattr(props, "multi_processor_count", 0)) + sm_version = int(props.major) * 10 + int(props.minor) + self._sm_info_cache[idx] = (sm_count, sm_version) + return sm_count, sm_version + except Exception: + self._sm_info_cache[idx] = (0, 0) + return 0, 0 + + def linear_forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + *, + quant_kind: str, + quant_scales: Optional[torch.Tensor] = None, + out_features: Optional[int] = None, + ) -> torch.Tensor: + _ = quant_kind + if not self._allspark_available or _vllm_ops is None: + # correctness fallback only when bf16 weight exists + if weight is not None and getattr(weight, "dtype", None) in (torch.float16, torch.bfloat16): + return F.linear(x, weight, bias) + raise RuntimeError( + "vLLM AllSpark W8A16 fused kernel is unavailable. " + "Please ensure vLLM custom ops are installed and loadable (`import vllm._custom_ops`)." + ) + + orig_shape = x.shape + x2 = x.reshape(-1, x.shape[-1]) if x.dim() != 2 else x + if x2.device.type != "cuda": + if weight is not None and getattr(weight, "dtype", None) in (torch.float16, torch.bfloat16): + return F.linear(x, weight, bias) + raise RuntimeError("AllSpark W8A16 requires CUDA inputs.") + + if x2.dtype != torch.bfloat16: + x2 = x2.to(dtype=torch.bfloat16) + if not x2.is_contiguous(): + x2 = x2.contiguous() + + # Load-time quantized module path: weight is uint8/int8 buffer and scales provided. + if weight is not None and weight.dtype in (torch.uint8, torch.int8): + if quant_scales is None: + raise ValueError("quant_scales is required when weight is quantized") + qweight = weight + scales = quant_scales + else: + # Safety net for bf16 weights (should be rare in steady-state). + weight_id = id(weight) + cached = self._weight_cache.get(weight_id) + if cached is None or cached[0].device != x2.device: + qweight, scales = self.quantize_weight_for_kernel(weight, device=x2.device) + self._weight_cache[weight_id] = (qweight, scales) + else: + qweight, scales = cached + + m, k = x2.shape + n_32, k_w = qweight.shape + if k_w != k or (k & 15) != 0: + if weight is not None and getattr(weight, "dtype", None) in (torch.float16, torch.bfloat16): + y = F.linear(x, weight, bias) + return y + raise RuntimeError(f"AllSpark W8A16 requires K%16==0 and matching K. Got x.K={k}, w.K={k_w}.") + + n = int(out_features) if out_features is not None else (int(bias.numel()) if bias is not None else int(min(scales.numel(), n_32))) + n = n_32 if (n <= 0 or n > n_32) else n + scales_1xn = scales if scales.dim() == 2 else scales.view(1, -1) + + sm_count, sm_version = self._get_sm_info(x2.device) + y2 = _vllm_ops.allspark_w8a16_gemm( + x2, + qweight, + scales_1xn, + None, # b_qzeros + n, + -1, # group_size (only supports -1) + sm_count, + sm_version, + self._cublas_m_thr, + False, # has_zp + True, # n32k16_reorder + ) + if bias is not None: + y2 = y2 + bias + if orig_shape == x2.shape: + return y2 + if x.dim() == 1: + return y2.squeeze(0) + return y2.reshape(*orig_shape[:-1], y2.shape[-1]) + + # NOTE: We intentionally do not provide a generic dequantize+F.linear fallback for reordered weights. + # It materializes a full bf16 matrix and is prone to OOM on large models. + diff --git a/diffulex/utils/quantization/strategies/linear_stub.py b/diffulex/utils/quantization/strategies/linear_stub.py new file mode 100644 index 0000000..76d7d33 --- /dev/null +++ b/diffulex/utils/quantization/strategies/linear_stub.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional + +import torch + +from diffulex.utils.quantization.registry import register_linear_strategy +from diffulex.utils.quantization.strategy import LinearQuantizationStrategy + + +@register_linear_strategy(weight_dtype="__stub__", act_dtype="__stub__") +def _build_linear_stub() -> LinearQuantizationStrategy: + # Default fallback stub. Actual requested dtypes will be attached by the caller + # via attributes after creation if needed. + return LinearStubStrategy(weight_dtype="__stub__", act_dtype="__stub__") + + +@dataclass +class LinearStubStrategy(LinearQuantizationStrategy): + """Generic stub for any non-bf16 Linear quantization combination.""" + + weight_dtype: str + act_dtype: str + + @property + def name(self) -> str: + return f"linear_stub(w={self.weight_dtype},a={self.act_dtype})" + + @property + def linear_weight_format(self) -> str: + return self.weight_dtype + + @property + def linear_act_format(self) -> str: + return self.act_dtype + + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + # Placeholder; real implementations may store packed weights in int4/int8 etc. + return torch.uint8, 1 + + def quantize(self, tensor: torch.Tensor, **kwargs): + raise NotImplementedError(f"{self.name}: quantize is not implemented (stub). kwargs={list(kwargs.keys())}") + + def dequantize(self, quantized: torch.Tensor, scale_or_metadata: Any, **kwargs) -> torch.Tensor: + raise NotImplementedError(f"{self.name}: dequantize is not implemented (stub). kwargs={list(kwargs.keys())}") + + def get_scale_shape(self, original_shape: tuple[int, ...], **kwargs) -> tuple[int, ...]: + _ = original_shape, kwargs + return tuple() + + def linear_forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + *, + quant_kind: str, + **kwargs: Any, + ) -> torch.Tensor: + _ = x, weight, bias, kwargs + raise NotImplementedError( + "Linear quantization kernel is not implemented yet. " + f"kind={quant_kind!r}, weight_dtype={self.weight_dtype!r}, act_dtype={self.act_dtype!r}" + ) + + + + diff --git a/diffulex/utils/quantization/strategies/no_quantization.py b/diffulex/utils/quantization/strategies/no_quantization.py new file mode 100644 index 0000000..a2b75c1 --- /dev/null +++ b/diffulex/utils/quantization/strategies/no_quantization.py @@ -0,0 +1,31 @@ +""" +No quantization strategy (default, backward compatible). +""" + +import torch +from diffulex.utils.quantization.strategy import QuantizationStrategy + + +class NoQuantizationStrategy(QuantizationStrategy): + """No quantization strategy (default, backward compatible).""" + + @property + def name(self) -> str: + return "no_quantization" + + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + """Returns BF16 as default storage dtype.""" + return torch.bfloat16, 2 + + def quantize(self, tensor: torch.Tensor, **kwargs) -> tuple[torch.Tensor, None]: + """No quantization, return tensor as-is.""" + return tensor, None + + def dequantize(self, quantized: torch.Tensor, scale_or_metadata: None, **kwargs) -> torch.Tensor: + """No dequantization needed.""" + return quantized + + def get_scale_shape(self, original_shape: tuple[int, ...], **kwargs) -> tuple[int, ...]: + """No scale needed.""" + return (0,) # Empty shape + diff --git a/diffulex/utils/quantization/strategy.py b/diffulex/utils/quantization/strategy.py new file mode 100644 index 0000000..7c3b01a --- /dev/null +++ b/diffulex/utils/quantization/strategy.py @@ -0,0 +1,320 @@ +""" +Quantization strategy interfaces. + +This module defines abstract base classes for different types of quantization strategies. +""" + +from abc import ABC, abstractmethod +from typing import Any, Optional, Protocol + +import torch +import torch.nn.functional as F + + +class _AttnMetaDataLike(Protocol): + """A minimal protocol for attention metadata used by Diffulex runtime. + + We avoid importing `diffulex.attention.metadata` here to reduce the chance + of circular imports. + """ + + k_scale: Optional[torch.Tensor] + v_scale: Optional[torch.Tensor] + + +class QuantizationStrategy(ABC): + """Quantization strategy abstract base class.""" + + @property + @abstractmethod + def name(self) -> str: + """Strategy name.""" + pass + + @abstractmethod + def get_storage_dtype(self) -> tuple[torch.dtype, int]: + """ + Returns storage dtype and itemsize. + + Returns: + (storage_dtype, itemsize): Tuple of storage dtype and item size in bytes. + """ + pass + + @abstractmethod + def quantize(self, tensor: torch.Tensor, **kwargs) -> tuple[torch.Tensor, Any]: + """ + Quantize a tensor. + + Args: + tensor: Input tensor to quantize. + **kwargs: Additional arguments for quantization. + + Returns: + (quantized_tensor, scale_or_metadata): Tuple of quantized tensor and scale/metadata. + """ + pass + + @abstractmethod + def dequantize(self, quantized: torch.Tensor, scale_or_metadata: Any, **kwargs) -> torch.Tensor: + """ + Dequantize a tensor. + + Args: + quantized: Quantized tensor to dequantize. + scale_or_metadata: Scale or metadata needed for dequantization. + **kwargs: Additional arguments for dequantization. + + Returns: + Dequantized tensor. + """ + pass + + @abstractmethod + def get_scale_shape(self, original_shape: tuple[int, ...], **kwargs) -> tuple[int, ...]: + """ + Returns the shape of scale tensor. + + Args: + original_shape: Original tensor shape. + **kwargs: Additional arguments (e.g., num_kv_heads for KV cache). + + Returns: + Scale tensor shape. + """ + pass + + def configure(self, *, diffulex_config: Any | None = None) -> None: + """Optional hook to configure a strategy from Diffulex `Config`. + + We intentionally keep this a no-op by default to avoid forcing configuration + plumbing through every call site. Strategy-specific tuning knobs should be + surfaced via explicit fields on `diffulex.config.Config`, not environment variables. + """ + _ = diffulex_config + return + + # ---- Optional capability flags / helpers (non-abstract) ---- + # These helpers are used to avoid hard-coding isinstance(...) checks in the runtime. + @property + def requires_runtime_scales(self) -> bool: + """Whether this strategy requires runtime scale tensors to be allocated/updated.""" + return False + + +class KVCacheQuantizationStrategy(QuantizationStrategy): + """KV Cache quantization strategy interface (extended interface).""" + + # NOTE: We use a small string tag for dispatch instead of importing enums everywhere. + # Known values: + # - "bf16": no quantization, cache stored as bf16 + # - "fp8": FP8 cache stored as uint8 (storage) with float8 view for kernels + @property + def kv_cache_format(self) -> str: + return "bf16" + + @abstractmethod + def compute_scales(self, k: torch.Tensor, v: torch.Tensor, + num_kv_heads: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute quantization scales for K and V. + + Args: + k: Key tensor [seq_len, num_kv_heads, head_dim] + v: Value tensor [seq_len, num_kv_heads, head_dim] + num_kv_heads: Number of KV heads + device: Target device + + Returns: + (k_scale, v_scale): Tuple of K and V scales, shape [num_kv_heads] + """ + pass + + @abstractmethod + def update_scales(self, k: torch.Tensor, v: torch.Tensor, + k_scale: Optional[torch.Tensor], v_scale: Optional[torch.Tensor], + num_kv_heads: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + """ + Update quantization scales (e.g., using running max strategy). + + Args: + k: Key tensor [seq_len, num_kv_heads, head_dim] + v: Value tensor [seq_len, num_kv_heads, head_dim] + k_scale: Current K scale (None if first time) + v_scale: Current V scale (None if first time) + num_kv_heads: Number of KV heads + device: Target device + + Returns: + (updated_k_scale, updated_v_scale): Updated scales, shape [num_kv_heads] + """ + pass + + def init_scales(self, num_kv_heads: int, device: torch.device) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Initialize quantization scales for K and V. + + This method should be called once per layer to initialize scale tensors. + Strategies that don't require scales (e.g., BF16) should return (None, None). + + Args: + num_kv_heads: Number of KV heads + device: Target device + + Returns: + (k_scale, v_scale): Initial scales, shape [num_kv_heads], or (None, None) if not needed + """ + # Default implementation: return None (no scales needed) + return None, None + + # ---- Diffulex integration helpers (non-abstract) ---- + @property + def requires_kv_cache_scales(self) -> bool: + """Whether KV cache kernels / decode require per-head scales.""" + return self.requires_runtime_scales + + def maybe_set_attn_metadata_scales( + self, + attn_metadata: _AttnMetaDataLike, + *, + k_scale: Optional[torch.Tensor], + v_scale: Optional[torch.Tensor], + ) -> None: + """Populate `attn_metadata.k_scale/v_scale` when needed.""" + if not self.requires_kv_cache_scales: + return + if k_scale is None or v_scale is None: + raise ValueError( + f"{self.name} requires k_scale/v_scale but got " + f"k_scale={k_scale is not None}, v_scale={v_scale is not None}" + ) + attn_metadata.k_scale = k_scale + attn_metadata.v_scale = v_scale + + def view_kv_cache_for_kernels(self, cache: torch.Tensor) -> torch.Tensor: + """Return a view of cache suitable for kernel consumption. + + - BF16 strategies: return as-is + - FP8 strategies: subclasses may return a float8 view while keeping uint8 storage + """ + return cache + + def quantize_kv_for_store( + self, + k: torch.Tensor, + v: torch.Tensor, + *, + k_scale: Optional[torch.Tensor], + v_scale: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize K/V for KV cache store (optional helper). + + Returns: + (k_quantized, v_quantized): Usually uint8 tensors for FP8 strategies. + """ + raise NotImplementedError(f"{self.name} does not implement quantize_kv_for_store") + + +class WeightQuantizationStrategy(QuantizationStrategy): + """Weight quantization strategy interface (for future extension).""" + + @abstractmethod + def quantize_weight(self, weight: torch.Tensor, **kwargs) -> tuple[torch.Tensor, Any]: + """ + Quantize model weights. + + Args: + weight: Weight tensor to quantize. + **kwargs: Additional arguments for quantization. + + Returns: + (quantized_weight, scale_or_metadata): Tuple of quantized weight and scale/metadata. + """ + pass + + @abstractmethod + def dequantize_weight(self, quantized: torch.Tensor, scale_or_metadata: Any, **kwargs) -> torch.Tensor: + """ + Dequantize model weights. + + Args: + quantized: Quantized weight tensor. + scale_or_metadata: Scale or metadata needed for dequantization. + **kwargs: Additional arguments for dequantization. + + Returns: + Dequantized weight tensor. + """ + pass + + +class LinearQuantizationStrategy(QuantizationStrategy): + """Linear layer quantization strategy interface (weights + activations). + + This is an architecture hook: kernels/packed weights can be implemented later. + The runtime (Linear layers) should dispatch by `quant_kind` ("attn"/"mlp"/"other") + and use this strategy to compute the Linear output. + """ + + @property + def linear_weight_format(self) -> str: + """Small tag used for kernel dispatch for weights. + + Known values (initial set): + - "bf16": no weight quantization + - "int8"/"int4"/"fp8_e4m3"/"fp8_e5m2"/"gptq"/"awq": placeholders + """ + return "bf16" + + @property + def linear_act_format(self) -> str: + """Small tag used for kernel dispatch for activations.""" + return "bf16" + + def quantize_weight_for_kernel( + self, + weight: torch.Tensor, + *, + device: torch.device | None = None, + **_: Any, + ) -> tuple[torch.Tensor, Any]: + """Optionally quantize/pack weight for kernel consumption. + + Default behavior: no-op, returns (weight, None). + """ + if device is not None: + weight = weight.to(device=device) + return weight, None + + def quantize_act_for_kernel( + self, + x: torch.Tensor, + *, + device: torch.device | None = None, + **_: Any, + ) -> tuple[torch.Tensor, Any]: + """Optionally quantize activations for kernel consumption. + + Default behavior: no-op, returns (x, None). + """ + if device is not None: + x = x.to(device=device) + return x, None + + def linear_forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + *, + quant_kind: str, + **kwargs: Any, + ) -> torch.Tensor: + """Compute Linear output for a given kind. + + Default behavior: `F.linear(x, weight, bias)` (no quantization). + Quantized strategies may override this to call custom kernels. + """ + _ = quant_kind, kwargs + return F.linear(x, weight, bias) + diff --git a/diffulex_bench/README.md b/diffulex_bench/README.md new file mode 100644 index 0000000..158b266 --- /dev/null +++ b/diffulex_bench/README.md @@ -0,0 +1,323 @@ +# Diffulex Benchmark + +Benchmark framework for evaluating Diffulex inference engine using lm-evaluation-harness. + +## Features + +- ✅ **lm-evaluation-harness Integration**: Full support for 50+ evaluation tasks +- ✅ **YAML Configuration**: Clean and readable configuration files +- ✅ **Professional Logging**: Colored output with rich formatting +- ✅ **Flexible Configuration**: Support both config files and command-line arguments +- ✅ **Multiple Models**: Support for Dream, SDAR, Fast-dLLM-v2 models +- ✅ **Multiple Strategies**: D2F, Block Diffusion, Fast-dLLM decoding strategies + +## Quick Start + +### Installation + +```bash +# Install dependencies +pip install lm-eval rich colorama + +# Install diffulex (if not already installed) +pip install -e . +``` + +### Using Configuration File (Recommended) + +1. **Create or use existing config file**: + +```bash +# Copy example config +cp diffulex_bench/configs/example.yml my_config.yml + +# Edit the config file +vim my_config.yml +``` + +2. **Run benchmark**: + +```bash +python -m diffulex_bench.main --config my_config.yml +``` + +### Using Command Line Arguments + +```bash +python -m diffulex_bench.main \ + --model-path /path/to/model \ + --model-name dream \ + --decoding-strategy d2f \ + --dataset gsm8k \ + --dataset-limit 100 \ + --temperature 0.0 \ + --max-tokens 256 \ + --output-dir ./results +``` + +## Configuration Files + +Configuration files are located in `diffulex_bench/configs/` directory. We use YAML format for better readability. + +### Configuration Structure + +Configurations are organized into two sections: + +1. **`engine`**: Engine configuration (model weights, LoRA, model name, strategy, inference parameters) +2. **`eval`**: Evaluation configuration (dataset, tasks, sampling parameters, output settings) + +### Example Configuration + +See `diffulex_bench/configs/example.yml` for a complete example: + +```yaml +# Engine configuration - Parameters for Diffulex engine +engine: + # Model and weights + model_path: "/path/to/your/model" + model_name: "dream" + decoding_strategy: "d2f" + mask_token_id: 151666 + + # LoRA configuration + use_lora: false + lora_path: "" + + # Parallelism and memory + tensor_parallel_size: 1 + data_parallel_size: 1 + gpu_memory_utilization: 0.9 + max_model_len: 2048 + + # D2F-specific parameters + accept_threshold: 0.9 + complete_threshold: 0.95 + add_new_block_threshold: 0.1 + +# Evaluation configuration - Parameters for benchmark +eval: + # Task/Dataset + dataset_name: "gsm8k" + dataset_limit: 100 + + # Sampling + temperature: 0.0 + max_tokens: 256 + + # Output + output_dir: "benchmark_results" +``` + +### Pre-configured Examples + +- `configs/example.yml`: Complete example with all options +- `configs/dream_d2f_gsm8k.yml`: Dream model with D2F strategy on GSM8K + +## Supported Tasks + +The framework supports all tasks available in lm-evaluation-harness, including: + +- **GSM8K**: Math word problems +- **HumanEval**: Code generation +- **HellaSwag**: Commonsense reasoning +- **MMLU**: Massive multitask language understanding +- And 50+ more tasks... + +See [lm-evaluation-harness tasks](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/task_table.md) for the complete list. + +## Model Configuration + +### Model Types + +- `dream`: Dream model +- `sdar`: SDAR model +- `fast_dllm_v2`: Fast-dLLM-v2 model + +### Decoding Strategies + +- `d2f`: Discrete Diffusion Forcing +- `block_diffusion`: Block Diffusion +- `fast_dllm`: Fast-dLLM + +### Example: Dream with D2F + +```yaml +engine: + model_path: "/path/to/dream/model" + model_name: "dream" + decoding_strategy: "d2f" + mask_token_id: 151666 + accept_threshold: 0.9 + complete_threshold: 0.95 + add_new_block_threshold: 0.1 + +eval: + dataset_name: "gsm8k" + temperature: 0.0 + max_tokens: 256 +``` + +## Command Line Arguments + +### Basic Arguments + +```bash +--config PATH # Configuration file path (YAML or JSON) +--model-path PATH # Model path (required if no config) +--dataset TASK # Task name (e.g., gsm8k, humaneval) +--output-dir PATH # Output directory +``` + +### Model Arguments + +```bash +--model-name NAME # Model name: dream, sdar, fast_dllm_v2 +--decoding-strategy STR # Strategy: d2f, block_diffusion, fast_dllm +--mask-token-id ID # Mask token ID +``` + +### Inference Arguments + +```bash +--tensor-parallel-size N # Tensor parallel size +--data-parallel-size N # Data parallel size +--gpu-memory-utilization F # GPU memory utilization (0.0-1.0) +--max-model-len N # Maximum model length +``` + +### Sampling Arguments + +```bash +--temperature F # Sampling temperature +--max-tokens N # Maximum tokens to generate +``` + +### Logging Arguments + +```bash +--log-file PATH # Log file path (optional) +--log-level LEVEL # Log level: DEBUG, INFO, WARNING, ERROR +``` + +## Output + +Results are saved to the output directory (default: `benchmark_results/`) with: + +- Evaluation results in JSON format +- Detailed metrics and statistics +- Configuration used for the run +- Timestamp information + +## Examples + +### Example 1: GSM8K Evaluation + +```bash +python -m diffulex_bench.main \ + --config diffulex_bench/configs/dream_d2f_gsm8k.yml \ + --dataset-limit 100 +``` + +### Example 2: Custom Configuration + +```bash +python -m diffulex_bench.main \ + --model-path /path/to/model \ + --model-name dream \ + --decoding-strategy d2f \ + --dataset gsm8k \ + --temperature 0.0 \ + --max-tokens 512 \ + --output-dir ./my_results \ + --log-file ./benchmark.log +``` + +### Example 3: Using Default Config + +```bash +# If configs/example.yml exists, it will be used automatically +python -m diffulex_bench.main \ + --model-path /path/to/model \ + --dataset gsm8k +``` + +## Architecture + +``` +main.py (Entry Point) + ↓ +arg_parser.py (Argument Parsing) + ↓ +config.py (Configuration Management) + ↓ +run_benchmark() (Benchmark Execution) + ↓ +lm_eval.cli_evaluate() (Evaluation Framework) + ↓ +DiffulexLM (Model Interface) + ↓ +BenchmarkRunner (Engine Wrapper) + ↓ +Diffulex (Inference Engine) +``` + +## Advanced Usage + +### Custom Model Integration + +The framework uses `DiffulexLM` class which wraps `BenchmarkRunner`. You can extend it for custom models: + +```python +from diffulex_bench.lm_eval_model import DiffulexLM + +# DiffulexLM automatically registers with lm_eval +# Use it in lm_eval commands +``` + +### Programmatic Usage + +```python +from diffulex_bench.config import BenchmarkConfig, EngineConfig, EvalConfig +from diffulex_bench.main import run_benchmark + +# Load from YAML file +config = BenchmarkConfig.from_yaml("diffulex_bench/configs/example.yml") +run_benchmark(config) + +# Or create programmatically +engine = EngineConfig( + model_path="/path/to/model", + model_name="dream", + decoding_strategy="d2f", +) +eval_config = EvalConfig( + dataset_name="gsm8k", + temperature=0.0, + max_tokens=256, +) +config = BenchmarkConfig(engine=engine, eval=eval_config) +run_benchmark(config) +``` + +## Troubleshooting + +### Common Issues + +1. **lm-eval not found**: Install with `pip install lm-eval` +2. **Config file not found**: Check path or use absolute path +3. **Model loading fails**: Verify model path and model_name match +4. **Out of memory**: Reduce `gpu_memory_utilization` or `max_model_len` + +### Getting Help + +- Check logs with `--log-level DEBUG` +- Save logs to file with `--log-file benchmark.log` +- Verify configuration with `--config` option + +## Notes + +1. The framework uses **lm-evaluation-harness** for all evaluation logic +2. Configuration files use **YAML** format (JSON also supported) +3. All evaluation metrics are computed by lm-eval +4. Results follow lm-eval output format +5. GPU environment is recommended for best performance diff --git a/diffulex_bench/__init__.py b/diffulex_bench/__init__.py new file mode 100644 index 0000000..b9a730d --- /dev/null +++ b/diffulex_bench/__init__.py @@ -0,0 +1,36 @@ +""" +Diffulex Benchmark - Benchmark framework for evaluating Diffulex inference engine performance +""" + +from diffulex_bench.runner import BenchmarkRunner +from diffulex_bench.datasets import load_benchmark_dataset +from diffulex_bench.metrics import compute_metrics +from diffulex.logger import setup_logger, get_logger +from diffulex_bench.config import BenchmarkConfig, EngineConfig, EvalConfig + +# Import lm_eval model to register it +try: + from diffulex_bench.lm_eval_model import DiffulexLM + __all__ = [ + "BenchmarkRunner", + "load_benchmark_dataset", + "compute_metrics", + "setup_logger", + "get_logger", + "BenchmarkConfig", + "EngineConfig", + "EvalConfig", + "DiffulexLM", + ] +except ImportError: + __all__ = [ + "BenchmarkRunner", + "load_benchmark_dataset", + "compute_metrics", + "setup_logger", + "get_logger", + "BenchmarkConfig", + "EngineConfig", + "EvalConfig", + ] + diff --git a/diffulex_bench/arg_parser.py b/diffulex_bench/arg_parser.py new file mode 100644 index 0000000..d4f786c --- /dev/null +++ b/diffulex_bench/arg_parser.py @@ -0,0 +1,307 @@ +""" +Argument Parser - Command line argument parsing for benchmark +""" + +import argparse +from pathlib import Path + + +def create_argument_parser() -> argparse.ArgumentParser: + """ + Create and configure argument parser for benchmark + + Returns: + Configured ArgumentParser instance + """ + parser = argparse.ArgumentParser( + description="Diffulex Benchmark using lm-evaluation-harness", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Using configuration file (recommended) + python -m diffulex_bench.main --config diffulex_bench/configs/example.yml + + # Using command line arguments + python -m diffulex_bench.main \\ + --model-path /path/to/model \\ + --dataset gsm8k \\ + --dataset-limit 100 \\ + --output-dir ./results + + # With custom model settings + python -m diffulex_bench.main \\ + --model-path /path/to/model \\ + --model-name dream \\ + --decoding-strategy d2f \\ + --dataset gsm8k \\ + --temperature 0.0 \\ + --max-tokens 256 + """ + ) + + # Logging arguments + parser.add_argument( + "--log-file", + type=str, + default=None, + help="Log file path (optional)", + ) + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging level", + ) + + # Configuration file + parser.add_argument( + "--config", + type=str, + help="Configuration file path (YAML or JSON). Default: configs/example.yml", + ) + + # Model arguments + parser.add_argument( + "--model-path", + type=str, + help="Model path", + ) + parser.add_argument( + "--tokenizer-path", + type=str, + default=None, + help="Tokenizer path (defaults to model-path)", + ) + parser.add_argument( + "--model-name", + type=str, + default="dream", + choices=["dream", "sdar", "fast_dllm_v2"], + help="Model name", + ) + parser.add_argument( + "--decoding-strategy", + type=str, + default="d2f", + choices=["d2f", "block_diffusion", "fast_dllm"], + help="Decoding strategy", + ) + parser.add_argument( + "--mask-token-id", + type=int, + default=151666, + help="Mask token ID", + ) + + # Inference arguments + parser.add_argument( + "--tensor-parallel-size", + type=int, + default=1, + help="Tensor parallel size", + ) + parser.add_argument( + "--data-parallel-size", + type=int, + default=1, + help="Data parallel size", + ) + parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=0.9, + help="GPU memory utilization", + ) + parser.add_argument( + "--max-model-len", + type=int, + default=2048, + help="Maximum model length", + ) + parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=4096, + help="Maximum number of batched tokens", + ) + parser.add_argument( + "--max-num-seqs", + type=int, + default=128, + help="Maximum number of sequences", + ) + + # Sampling arguments + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Sampling temperature", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=256, + help="Maximum tokens to generate", + ) + parser.add_argument( + "--ignore-eos", + action="store_true", + help="Ignore EOS token", + ) + + # Dataset arguments + parser.add_argument( + "--dataset", + type=str, + default="gsm8k", + help="Dataset/task name (e.g., gsm8k, humaneval)", + ) + parser.add_argument( + "--dataset-split", + type=str, + default="test", + help="Dataset split", + ) + parser.add_argument( + "--dataset-limit", + type=int, + default=None, + help="Limit number of samples", + ) + + # Output arguments + parser.add_argument( + "--output-dir", + type=str, + default="benchmark_results", + help="Output directory", + ) + parser.add_argument( + "--save-results", + action="store_true", + default=True, + help="Save results to file", + ) + parser.add_argument( + "--no-save-results", + dest="save_results", + action="store_false", + help="Do not save results to file", + ) + + # LoRA arguments + parser.add_argument( + "--use-lora", + action="store_true", + help="Use LoRA", + ) + parser.add_argument( + "--lora-path", + type=str, + default="", + help="LoRA path", + ) + + # Engine arguments + parser.add_argument( + "--enforce-eager", + action="store_true", + help="Enforce eager mode (disable CUDA graphs)", + ) + parser.add_argument( + "--no-enforce-eager", + dest="enforce_eager", + action="store_false", + help="Disable eager mode (enable CUDA graphs when supported)", + ) + parser.set_defaults(enforce_eager=None) + parser.add_argument( + "--kv-cache-layout", + type=str, + default="unified", + choices=["unified", "distinct"], + help="KV cache layout", + ) + + # D2F-specific arguments + parser.add_argument( + "--accept-threshold", + type=float, + default=0.9, + help="Accept threshold for D2F", + ) + parser.add_argument( + "--complete-threshold", + type=float, + default=0.95, + help="Complete threshold for D2F", + ) + parser.add_argument( + "--add-new-block-threshold", + type=float, + default=0.1, + help="Add new block threshold for D2F", + ) + parser.add_argument( + "--diffusion-block-size", + type=int, + default=32, + help="Diffusion block size", + ) + + # Quantization arguments + parser.add_argument( + "--kv-cache-dtype", + type=str, + default=None, + choices=["bf16", "fp16", "fp32", "fp8_e4m3", "fp8_e5m2"], + help="KV cache data type", + ) + parser.add_argument( + "--decode-mode", + type=str, + default=None, + choices=["static", "varlen"], + help="Decode mode (static or varlen)", + ) + parser.add_argument( + "--linear-attn-weight-dtype", + type=str, + default=None, + help="Linear attention weight dtype", + ) + parser.add_argument( + "--linear-mlp-weight-dtype", + type=str, + default=None, + help="Linear MLP weight dtype", + ) + parser.add_argument( + "--linear-attn-act-dtype", + type=str, + default=None, + help="Linear attention activation dtype", + ) + parser.add_argument( + "--linear-mlp-act-dtype", + type=str, + default=None, + help="Linear MLP activation dtype", + ) + + return parser + + +def get_default_config_path() -> Path: + """ + Get default configuration file path + + Returns: + Path to default config file + """ + config_dir = Path(__file__).parent / "configs" + default_config = config_dir / "example.yml" + return default_config + diff --git a/diffulex_bench/config.py b/diffulex_bench/config.py new file mode 100644 index 0000000..2c9afab --- /dev/null +++ b/diffulex_bench/config.py @@ -0,0 +1,271 @@ +""" +Benchmark Configuration - Configuration management with separated engine and eval configs +""" + +from dataclasses import dataclass, field +from typing import Optional, Dict, Any +import json +import yaml + + +@dataclass +class EngineConfig: + """ + Engine configuration - Parameters for Diffulex engine initialization + """ + # Model and weights + model_path: str + tokenizer_path: Optional[str] = None + model_name: str = "dream" # Options: dream, sdar, fast_dllm_v2 + decoding_strategy: str = "d2f" # Options: d2f, block_diffusion, fast_dllm + mask_token_id: int = 151666 + + # LoRA configuration + use_lora: bool = False + lora_path: str = "" + + # Parallelism configuration + tensor_parallel_size: int = 1 + data_parallel_size: int = 1 + + # Memory and capacity configuration + gpu_memory_utilization: float = 0.9 + max_model_len: int = 2048 + max_num_batched_tokens: int = 4096 + max_num_seqs: int = 128 + + # Engine behavior configuration + enforce_eager: bool = False + kv_cache_layout: str = "unified" # Options: unified, distinct + + # D2F-specific configuration + accept_threshold: float = 0.9 + complete_threshold: float = 0.95 + add_new_block_threshold: float = 0.1 + diffusion_block_size: int = 32 + + # Quantization configuration + kv_cache_dtype: Optional[str] = None # "bf16", "fp16", "fp32", "fp8_e4m3", "fp8_e5m2" + decode_mode: Optional[str] = None # "static" or "varlen" + linear_attn_weight_dtype: Optional[str] = None # "bf16", "int8", "int4", "fp8_e4m3", etc. + linear_mlp_weight_dtype: Optional[str] = None + linear_attn_act_dtype: Optional[str] = None + linear_mlp_act_dtype: Optional[str] = None + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> "EngineConfig": + """Create engine configuration from dictionary""" + return cls(**config_dict) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return { + field.name: getattr(self, field.name) + for field in self.__dataclass_fields__.values() + } + + def get_diffulex_kwargs(self) -> Dict[str, Any]: + """Get arguments to pass to Diffulex engine""" + return { + 'model_name': self.model_name, + 'decoding_strategy': self.decoding_strategy, + 'mask_token_id': self.mask_token_id, + 'tensor_parallel_size': self.tensor_parallel_size, + 'data_parallel_size': self.data_parallel_size, + 'gpu_memory_utilization': self.gpu_memory_utilization, + 'max_model_len': self.max_model_len, + 'max_num_batched_tokens': self.max_num_batched_tokens, + 'max_num_seqs': self.max_num_seqs, + 'use_lora': self.use_lora, + 'lora_path': self.lora_path if self.use_lora else "", + 'enforce_eager': self.enforce_eager, + 'kv_cache_layout': self.kv_cache_layout, + 'accept_threshold': self.accept_threshold, + 'complete_threshold': self.complete_threshold, + 'add_new_block_threshold': self.add_new_block_threshold, + 'diffusion_block_size': self.diffusion_block_size, + } + + # Add quantization parameters if specified + if self.kv_cache_dtype is not None: + kwargs['kv_cache_dtype'] = self.kv_cache_dtype + if self.decode_mode is not None: + kwargs['decode_mode'] = self.decode_mode + if self.linear_attn_weight_dtype is not None: + kwargs['linear_attn_weight_dtype'] = self.linear_attn_weight_dtype + if self.linear_mlp_weight_dtype is not None: + kwargs['linear_mlp_weight_dtype'] = self.linear_mlp_weight_dtype + if self.linear_attn_act_dtype is not None: + kwargs['linear_attn_act_dtype'] = self.linear_attn_act_dtype + if self.linear_mlp_act_dtype is not None: + kwargs['linear_mlp_act_dtype'] = self.linear_mlp_act_dtype + + return kwargs + + +@dataclass +class EvalConfig: + """ + Evaluation configuration - Parameters for benchmark evaluation + """ + # Task/Dataset configuration + dataset_name: str = "gsm8k" # Task name (e.g., gsm8k, humaneval) + dataset_split: str = "test" + dataset_limit: Optional[int] = None + + # Sampling configuration + temperature: float = 0.0 + max_tokens: int = 256 + ignore_eos: bool = False + + # Output configuration + output_dir: str = "benchmark_results" + save_results: bool = True + use_tqdm: bool = True + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> "EvalConfig": + """Create evaluation configuration from dictionary""" + return cls(**config_dict) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return { + field.name: getattr(self, field.name) + for field in self.__dataclass_fields__.values() + } + + def get_sampling_params(self): + """Get sampling parameters""" + from diffulex import SamplingParams + return SamplingParams( + temperature=self.temperature, + max_tokens=self.max_tokens, + ignore_eos=self.ignore_eos, + ) + + +@dataclass +class BenchmarkConfig: + """ + Benchmark configuration - Combines engine and evaluation configurations + """ + engine: EngineConfig + eval: EvalConfig + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> "BenchmarkConfig": + """ + Create benchmark configuration from dictionary + + Supports both flat and nested dictionary structures for backward compatibility + """ + # Check if config_dict has nested structure + if 'engine' in config_dict and 'eval' in config_dict: + engine = EngineConfig.from_dict(config_dict['engine']) + eval_config = EvalConfig.from_dict(config_dict['eval']) + else: + # Flat structure - backward compatibility + # Split fields into engine and eval + engine_fields = { + 'model_path', 'tokenizer_path', 'model_name', 'decoding_strategy', + 'mask_token_id', 'use_lora', 'lora_path', 'tensor_parallel_size', + 'data_parallel_size', 'gpu_memory_utilization', 'max_model_len', + 'max_num_batched_tokens', 'max_num_seqs', 'enforce_eager', + 'kv_cache_layout', 'accept_threshold', 'complete_threshold', + 'add_new_block_threshold', 'diffusion_block_size', + 'kv_cache_dtype', 'decode_mode', 'linear_attn_weight_dtype', + 'linear_mlp_weight_dtype', 'linear_attn_act_dtype', 'linear_mlp_act_dtype' + } + + engine_dict = {k: v for k, v in config_dict.items() if k in engine_fields} + eval_dict = {k: v for k, v in config_dict.items() if k not in engine_fields} + + engine = EngineConfig.from_dict(engine_dict) + eval_config = EvalConfig.from_dict(eval_dict) + + return cls(engine=engine, eval=eval_config) + + @classmethod + def from_json(cls, json_path: str) -> "BenchmarkConfig": + """Load configuration from JSON file""" + with open(json_path, 'r', encoding='utf-8') as f: + config_dict = json.load(f) + return cls.from_dict(config_dict) + + @classmethod + def from_yaml(cls, yaml_path: str) -> "BenchmarkConfig": + """Load configuration from YAML file""" + with open(yaml_path, 'r', encoding='utf-8') as f: + config_dict = yaml.safe_load(f) + return cls.from_dict(config_dict) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary with nested structure""" + return { + 'engine': self.engine.to_dict(), + 'eval': self.eval.to_dict(), + } + + def save_json(self, json_path: str): + """Save to JSON file""" + with open(json_path, 'w', encoding='utf-8') as f: + json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) + + def save_yaml(self, yaml_path: str): + """Save to YAML file""" + with open(yaml_path, 'w', encoding='utf-8') as f: + yaml.dump(self.to_dict(), f, allow_unicode=True, default_flow_style=False) + + def get_diffulex_kwargs(self) -> Dict[str, Any]: + """Get arguments to pass to Diffulex engine""" + return self.engine.get_diffulex_kwargs() + + def get_sampling_params(self): + """Get sampling parameters""" + return self.eval.get_sampling_params() + + # Convenience properties for backward compatibility + @property + def model_path(self) -> str: + return self.engine.model_path + + @property + def tokenizer_path(self) -> Optional[str]: + return self.engine.tokenizer_path + + @property + def model_name(self) -> str: + return self.engine.model_name + + @property + def decoding_strategy(self) -> str: + return self.engine.decoding_strategy + + @property + def dataset_name(self) -> str: + return self.eval.dataset_name + + @property + def dataset_limit(self) -> Optional[int]: + return self.eval.dataset_limit + + @property + def output_dir(self) -> str: + return self.eval.output_dir + + @dataset_name.setter + def dataset_name(self, value: str): + self.eval.dataset_name = value + + @dataset_limit.setter + def dataset_limit(self, value: Optional[int]): + self.eval.dataset_limit = value + + @output_dir.setter + def output_dir(self, value: str): + self.eval.output_dir = value + + @model_path.setter + def model_path(self, value: str): + self.engine.model_path = value diff --git a/diffulex_bench/configs/__init__.py b/diffulex_bench/configs/__init__.py new file mode 100644 index 0000000..51b7ec8 --- /dev/null +++ b/diffulex_bench/configs/__init__.py @@ -0,0 +1,4 @@ +""" +Configuration files for Diffulex benchmarks +""" + diff --git a/diffulex_bench/configs/dream_d2f_gsm8k.yml b/diffulex_bench/configs/dream_d2f_gsm8k.yml new file mode 100644 index 0000000..74d1b07 --- /dev/null +++ b/diffulex_bench/configs/dream_d2f_gsm8k.yml @@ -0,0 +1,29 @@ +# Dream model with D2F strategy on GSM8K dataset +# Quick configuration example + +engine: + model_path: "/path/to/dream/model" + model_name: "dream" + decoding_strategy: "d2f" + mask_token_id: 151666 + + tensor_parallel_size: 1 + data_parallel_size: 1 + gpu_memory_utilization: 0.9 + max_model_len: 4096 + + use_lora: false + enforce_eager: false + + accept_threshold: 0.9 + complete_threshold: 0.95 + add_new_block_threshold: 0.1 + +eval: + dataset_name: "gsm8k" + dataset_limit: 100 + + temperature: 0.0 + max_tokens: 256 + + output_dir: "benchmark_results" diff --git a/diffulex_bench/configs/example.yml b/diffulex_bench/configs/example.yml new file mode 100644 index 0000000..bbdcbc5 --- /dev/null +++ b/diffulex_bench/configs/example.yml @@ -0,0 +1,60 @@ +# Diffulex Benchmark Configuration Example +# This configuration uses nested structure with engine and eval sections + +# Engine configuration - Parameters for Diffulex engine initialization +engine: + # Model and weights + model_path: "/data1/ckpts/Dream-org/Dream-v0-Base-7B" + tokenizer_path: null # Optional, defaults to model_path + model_name: "dream" # Options: dream, sdar, fast_dllm_v2 + decoding_strategy: "d2f" # Options: d2f, block_diffusion, fast_dllm + mask_token_id: 151666 + + # LoRA configuration + use_lora: false + lora_path: "" + + # Parallelism configuration + tensor_parallel_size: 1 + data_parallel_size: 1 + + # Memory and capacity configuration + gpu_memory_utilization: 0.7 + max_model_len: 4096 + max_num_batched_tokens: 4096 + max_num_seqs: 128 + + # Engine behavior configuration + enforce_eager: true # Set to true for varlen mode to avoid CUDA graph capture error + kv_cache_layout: "unified" # Options: unified, distinct + + # D2F-specific configuration + accept_threshold: 0.9 + complete_threshold: 0.95 + add_new_block_threshold: 0.1 + diffusion_block_size: 32 + + # Quantization configuration + kv_cache_dtype: null # Options: bf16, fp16, fp32, fp8_e4m3, fp8_e5m2 + decode_mode: "varlen" # Options: static, varlen + linear_attn_weight_dtype: null # Options: bf16, int8, int4, fp8_e4m3, fp8_e5m2, etc. + linear_mlp_weight_dtype: null + linear_attn_act_dtype: null + linear_mlp_act_dtype: null + +# Evaluation configuration - Parameters for benchmark evaluation +eval: + # Task/Dataset configuration + dataset_name: "gsm8k" # Options: gsm8k, humaneval, etc. + dataset_split: "test" + dataset_limit: 10 # Optional, limit number of samples (set to 10 for testing) + + # Sampling configuration + temperature: 0.0 + max_tokens: 512 + ignore_eos: false + + # Output configuration + output_dir: "benchmark_results" + save_results: true + use_tqdm: true diff --git a/diffulex_bench/datasets.py b/diffulex_bench/datasets.py new file mode 100644 index 0000000..3a882cf --- /dev/null +++ b/diffulex_bench/datasets.py @@ -0,0 +1,118 @@ +""" +Benchmark Datasets - Dataset loaders for benchmark evaluation +Supports common evaluation datasets such as GSM8K, HumanEval, etc. +""" + +from typing import List, Dict, Any, Optional, Callable +from datasets import load_dataset + + +def load_gsm8k( + split: str = "test", + limit: Optional[int] = None, + prompt_template: Optional[Callable[[str], str]] = None, +) -> List[Dict[str, Any]]: + """ + Load GSM8K dataset + + Args: + split: Dataset split, default "test" + limit: Limit number of samples, None means all + prompt_template: Prompt template function that takes question string and returns full prompt + + Returns: + List of dataset items, each containing 'prompt' and 'answer' fields + """ + dataset = load_dataset("gsm8k", "main", split=split) + + if limit: + dataset = dataset[:limit] + + results = [] + for item in dataset: + question = item["question"] + answer = item["answer"] + + if prompt_template: + prompt = prompt_template(question) + else: + # Default template + prompt = f"Question: {question}\nAnswer:" + + results.append({ + 'prompt': prompt, + 'answer': answer, + 'question': question, + }) + + return results + + +def load_humaneval( + limit: Optional[int] = None, + prompt_template: Optional[Callable[[str], str]] = None, +) -> List[Dict[str, Any]]: + """ + Load HumanEval dataset + + Args: + limit: Limit number of samples, None means all + prompt_template: Prompt template function that takes prompt string and returns full prompt + + Returns: + List of dataset items, each containing 'prompt', 'test', 'entry_point' fields + """ + dataset = load_dataset("openai/humaneval", split="test") + + if limit: + dataset = dataset[:limit] + + results = [] + for item in dataset: + prompt = item["prompt"] + test = item["test"] + entry_point = item["entry_point"] + + if prompt_template: + full_prompt = prompt_template(prompt) + else: + full_prompt = prompt + + results.append({ + 'prompt': full_prompt, + 'original_prompt': prompt, + 'test': test, + 'entry_point': entry_point, + 'task_id': item.get('task_id', ''), + }) + + return results + + +def load_benchmark_dataset( + dataset_name: str, + **kwargs +) -> List[Dict[str, Any]]: + """ + Unified dataset loading interface + + Args: + dataset_name: Dataset name, supports "gsm8k", "humaneval" + **kwargs: Arguments passed to the specific dataset loader + + Returns: + List of dataset items + """ + loaders = { + 'gsm8k': load_gsm8k, + 'humaneval': load_humaneval, + } + + if dataset_name not in loaders: + raise ValueError( + f"Unknown dataset: {dataset_name}. " + f"Supported datasets: {list(loaders.keys())}" + ) + + return loaders[dataset_name](**kwargs) + diff --git a/diffulex_bench/lm_eval_model.py b/diffulex_bench/lm_eval_model.py new file mode 100644 index 0000000..4d66882 --- /dev/null +++ b/diffulex_bench/lm_eval_model.py @@ -0,0 +1,329 @@ +""" +LM Eval Model - Diffulex integration with lm-evaluation-harness +""" + +import logging +import time +import json +from typing import List, Optional, Tuple, Type, TypeVar, Union +from pathlib import Path + +from lm_eval import utils +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM +from lm_eval.api.registry import register_model + +from diffulex import Diffulex, SamplingParams +from diffulex_bench.runner import BenchmarkRunner +from diffulex.logger import setup_logger, get_logger + +T = TypeVar("T", bound="LM") +eval_logger = logging.getLogger(__name__) + + +@register_model("diffulex") +class DiffulexLM(LM): + """ + Diffulex model integration for lm-evaluation-harness + """ + + def __init__( + self, + pretrained: str, + batch_size: Optional[Union[int, str]] = 1, + device: Optional[str] = "cuda", + dtype: Optional[Union[str, type]] = "auto", + max_new_tokens: Optional[int] = 256, + max_length: Optional[int] = 2048, + add_bos_token: Optional[bool] = False, + trust_remote_code: Optional[bool] = True, + temperature: Optional[float] = 0.0, + model_name: Optional[str] = "dream", + decoding_strategy: Optional[str] = "d2f", + mask_token_id: Optional[int] = 151666, + tensor_parallel_size: Optional[int] = 1, + data_parallel_size: Optional[int] = 1, + gpu_memory_utilization: Optional[float] = 0.9, + max_model_len: Optional[int] = 2048, + max_num_batched_tokens: Optional[int] = 4096, + max_num_seqs: Optional[int] = 128, + use_lora: Optional[bool] = False, + lora_path: Optional[str] = "", + enforce_eager: Optional[bool] = False, + kv_cache_layout: Optional[str] = "unified", + accept_threshold: Optional[float] = 0.9, + complete_threshold: Optional[float] = 0.95, + add_new_block_threshold: Optional[float] = 0.1, + diffusion_block_size: Optional[int] = 32, + save_dir: Optional[str] = None, + wait_ready: Optional[bool] = True, + kv_cache_dtype: Optional[str] = None, + decode_mode: Optional[str] = None, + linear_attn_weight_dtype: Optional[str] = None, + linear_mlp_weight_dtype: Optional[str] = None, + linear_attn_act_dtype: Optional[str] = None, + linear_mlp_act_dtype: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__() + + # Setup logger + self.logger = get_logger(__name__) + + assert isinstance(pretrained, str) + assert isinstance(batch_size, (int, str)) + + self.pretrained = pretrained + self.batch_size_per_gpu = batch_size + if isinstance(batch_size, str): + self.batch_size_per_gpu = int(batch_size) + + self.max_length = max_length + self.add_bos_token = add_bos_token + self.max_new_tokens = max_new_tokens + self.temperature = temperature + self.save_dir = save_dir + + # Diffulex-specific parameters + self.model_name = model_name + self.decoding_strategy = decoding_strategy + self.mask_token_id = mask_token_id + + # Statistics tracking + self.total_generated_tokens = 0 + self.total_nfe = 0 # Number of Forward Evaluations (diffusion steps) + self.total_generation_time = 0.0 + self.total_samples = 0 + self.all_generation_times = [] + self.all_nfe = [] + self.all_tokens = [] + + # Initialize Diffulex runner + self.runner = BenchmarkRunner( + model_path=pretrained, + tokenizer_path=pretrained, + wait_ready=wait_ready, + model_name=model_name, + decoding_strategy=decoding_strategy, + mask_token_id=mask_token_id, + tensor_parallel_size=tensor_parallel_size, + data_parallel_size=data_parallel_size, + gpu_memory_utilization=gpu_memory_utilization, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs, + use_lora=use_lora, + lora_path=lora_path if use_lora else "", + enforce_eager=enforce_eager, + kv_cache_layout=kv_cache_layout, + accept_threshold=accept_threshold, + complete_threshold=complete_threshold, + add_new_block_threshold=add_new_block_threshold, + diffusion_block_size=diffusion_block_size, + kv_cache_dtype=kv_cache_dtype, + decode_mode=decode_mode, + linear_attn_weight_dtype=linear_attn_weight_dtype, + linear_mlp_weight_dtype=linear_mlp_weight_dtype, + linear_attn_act_dtype=linear_attn_act_dtype, + linear_mlp_act_dtype=linear_mlp_act_dtype, + ) + + self.tokenizer = self.runner.tokenizer + + # Create sampling params + self.sampling_params = SamplingParams( + temperature=temperature, + max_tokens=max_new_tokens, + ) + + self.logger.success("Diffulex engine initialized successfully") + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return "cuda" # Diffulex manages device internally + + @property + def rank(self): + return 0 + + @property + def world_size(self): + return 1 + + def tok_decode(self, tokens, skip_special_tokens=True): + """Decode tokens to text""" + if isinstance(tokens, list) and len(tokens) > 0 and isinstance(tokens[0], list): + return [self.tokenizer.decode(t, skip_special_tokens=skip_special_tokens) for t in tokens] + return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def tok_encode(self, text, add_special_tokens=True): + """Encode text to tokens""" + return self.tokenizer( + text, return_tensors="pt", add_special_tokens=add_special_tokens + ).input_ids + + @classmethod + def create_from_arg_string( + cls: Type[T], arg_string: str, additional_config: Optional[dict] = None + ) -> T: + """ + Creates an instance of the LM class using the given argument string and additional config. + + Args: + arg_string: A string containing arguments in the format key1=value1,key2=value2 + additional_config: Optional dictionary containing additional configuration parameters + + Returns: + Instance of the LM class + """ + additional_config = {} if additional_config is None else additional_config + args = utils.simple_parse_args_string(arg_string) + args2 = {k: v for k, v in additional_config.items() if v is not None} + return cls(**args, **args2) + + def apply_chat_template( + self, chat_history, add_generation_prompt: bool = True + ) -> str: + """ + Apply a chat template to a list of chat history between user and model. + """ + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + return chat_templated + + @property + def tokenizer_name(self) -> str: + return self.tokenizer.name_or_path.replace("/", "__") + + def generate_until(self, requests: List[Instance], disable_tqdm: bool = False): + """ + Generate text until stopping conditions are met. + + Args: + requests: List of generation requests + disable_tqdm: Whether to disable progress bar + + Returns: + List of generated texts + """ + self.logger.info(f"Processing {len(requests)} generation requests...") + + # Prepare prompts + prompts = [] + gen_args = [] + + for req in requests: + prompt = req.arguments[0] + if self.add_bos_token and self.tokenizer.bos_token: + prompt = self.tokenizer.bos_token + prompt + prompts.append(prompt) + gen_args.append(req.arguments[1] if len(req.arguments) > 1 else {}) + + # Run generation + start_time = time.time() + outputs = self.runner.generate( + prompts, + self.sampling_params, + use_tqdm=not disable_tqdm, + ) + end_time = time.time() + + total_time = end_time - start_time + + # Extract results and accumulate statistics + results = [] + num_tokens = 0 + num_nfe = 0 + + for output in outputs: + text = output.get('text', '') + results.append(text) + + token_ids = output.get('token_ids', []) + n_diff_steps = output.get('n_diff_steps', 0) + + num_tokens += len(token_ids) + num_nfe += n_diff_steps + + self.all_generation_times.append(total_time / len(outputs) if outputs else 0) + self.all_nfe.append(n_diff_steps) + self.all_tokens.append(len(token_ids)) + + # Update statistics + self.total_samples += len(requests) + self.total_generated_tokens += num_tokens + self.total_nfe += num_nfe + self.total_generation_time += total_time + + # Log statistics + if self.total_samples > 0: + avg_tokens = self.total_generated_tokens / self.total_samples + avg_nfe = self.total_nfe / self.total_samples + avg_time = self.total_generation_time / self.total_samples + throughput = num_tokens / total_time if total_time > 0 else 0 + + self.logger.info( + f"Generated {len(results)} samples | " + f"Tokens: {num_tokens} | " + f"NFE: {num_nfe} | " + f"Time: {total_time:.2f}s | " + f"Throughput: {throughput:.2f} tok/s" + ) + + # Save statistics if save_dir is provided + if self.save_dir is not None: + self._save_statistics() + + return results + + def _save_statistics(self): + """Save statistics to file""" + import os + os.makedirs(self.save_dir, exist_ok=True) + + stats = { + 'total_samples': self.total_samples, + 'total_tokens': self.total_generated_tokens, + 'total_nfe': self.total_nfe, + 'total_time': self.total_generation_time, + 'avg_tokens_per_sample': self.total_generated_tokens / self.total_samples if self.total_samples > 0 else 0, + 'avg_nfe_per_sample': self.total_nfe / self.total_samples if self.total_samples > 0 else 0, + 'avg_time_per_sample': self.total_generation_time / self.total_samples if self.total_samples > 0 else 0, + 'throughput_tok_s': self.total_generated_tokens / self.total_generation_time if self.total_generation_time > 0 else 0, + 'nfe_per_token': self.total_nfe / self.total_generated_tokens if self.total_generated_tokens > 0 else 0, + 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), + } + + stats_path = os.path.join(self.save_dir, 'diffulex_stats.json') + with open(stats_path, 'w', encoding='utf-8') as f: + json.dump(stats, f, indent=2, ensure_ascii=False) + + self.logger.info(f"Statistics saved to {stats_path}") + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + """ + Compute log-likelihood of continuations given contexts. + + Note: This is a placeholder implementation. Full loglikelihood computation + for diffusion models requires special handling. + """ + self.logger.warning( + "loglikelihood computation for diffusion models is not fully implemented. " + "Returning placeholder values." + ) + return [(0.0, False) for _ in requests] + + def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: + """Compute log-likelihood of sequences.""" + raise NotImplementedError( + "loglikelihood_rolling is not implemented for diffusion models" + ) + diff --git a/diffulex_bench/logger.py b/diffulex_bench/logger.py new file mode 100644 index 0000000..444ee65 --- /dev/null +++ b/diffulex_bench/logger.py @@ -0,0 +1,16 @@ +""" +Logger module for diffulex_bench - Re-exports from diffulex.logger +""" + +# Re-export logger functionality from diffulex core package +from diffulex.logger import ( + setup_logger, + get_logger, + LoggerMixin, +) + +__all__ = [ + "setup_logger", + "get_logger", + "LoggerMixin", +] diff --git a/diffulex_bench/main.py b/diffulex_bench/main.py new file mode 100644 index 0000000..f6a7ae8 --- /dev/null +++ b/diffulex_bench/main.py @@ -0,0 +1,305 @@ +""" +Benchmark Main Entry - Main entry point for benchmark using lm-evaluation-harness +""" + +import sys +import logging +from pathlib import Path + +from diffulex_bench.config import BenchmarkConfig, EngineConfig, EvalConfig +from diffulex.logger import setup_logger, get_logger +from diffulex_bench.arg_parser import create_argument_parser, get_default_config_path + +try: + from lm_eval.__main__ import cli_evaluate +except ImportError: + cli_evaluate = None + + +def config_to_model_args(config: BenchmarkConfig) -> str: + """ + Convert BenchmarkConfig to lm_eval model_args string format + + Args: + config: Benchmark configuration + + Returns: + Model arguments string in key=value format + """ + engine = config.engine + eval_config = config.eval + + args_dict = { + 'pretrained': engine.model_path, + 'model_name': engine.model_name, + 'decoding_strategy': engine.decoding_strategy, + 'mask_token_id': engine.mask_token_id, + 'tensor_parallel_size': engine.tensor_parallel_size, + 'data_parallel_size': engine.data_parallel_size, + 'gpu_memory_utilization': engine.gpu_memory_utilization, + 'max_model_len': engine.max_model_len, + 'max_num_batched_tokens': engine.max_num_batched_tokens, + 'max_num_seqs': engine.max_num_seqs, + 'temperature': eval_config.temperature, + 'max_new_tokens': eval_config.max_tokens, + 'use_lora': engine.use_lora, + 'enforce_eager': engine.enforce_eager, + 'kv_cache_layout': engine.kv_cache_layout, + 'accept_threshold': engine.accept_threshold, + 'complete_threshold': engine.complete_threshold, + 'add_new_block_threshold': engine.add_new_block_threshold, + 'diffusion_block_size': engine.diffusion_block_size, + 'wait_ready': True, + } + + # Add quantization parameters if specified + if engine.kv_cache_dtype is not None: + args_dict['kv_cache_dtype'] = engine.kv_cache_dtype + if engine.decode_mode is not None: + args_dict['decode_mode'] = engine.decode_mode + if engine.linear_attn_weight_dtype is not None: + args_dict['linear_attn_weight_dtype'] = engine.linear_attn_weight_dtype + if engine.linear_mlp_weight_dtype is not None: + args_dict['linear_mlp_weight_dtype'] = engine.linear_mlp_weight_dtype + if engine.linear_attn_act_dtype is not None: + args_dict['linear_attn_act_dtype'] = engine.linear_attn_act_dtype + if engine.linear_mlp_act_dtype is not None: + args_dict['linear_mlp_act_dtype'] = engine.linear_mlp_act_dtype + + if engine.tokenizer_path: + args_dict['tokenizer_path'] = engine.tokenizer_path + + if engine.use_lora and engine.lora_path: + args_dict['lora_path'] = engine.lora_path + + # Convert to string format: key1=value1,key2=value2 + args_list = [f"{k}={v}" for k, v in args_dict.items()] + return ','.join(args_list) + + +def dataset_name_to_tasks(dataset_name: str) -> str: + """ + Convert dataset name to lm_eval task name + + Args: + dataset_name: Dataset name (e.g., "gsm8k", "humaneval") + + Returns: + lm_eval task name + """ + mapping = { + 'gsm8k': 'gsm8k', + 'humaneval': 'humaneval', + } + return mapping.get(dataset_name, dataset_name) + + +def run_benchmark(config: BenchmarkConfig) -> None: + """ + Run benchmark using lm-evaluation-harness + + Args: + config: Benchmark configuration + """ + logger = get_logger(__name__) + + if cli_evaluate is None: + logger.error( + "lm-evaluation-harness is not installed. " + "Please install it with: pip install lm-eval" + ) + sys.exit(1) + + benchmark_info = [ + '=' * 80, + 'Diffulex Benchmark (using lm-evaluation-harness)', + '=' * 80, + f'Model: {config.engine.model_path}', + f'Model Name: {config.engine.model_name}', + f'Decoding Strategy: {config.engine.decoding_strategy}', + f'Tasks: {config.eval.dataset_name}', + f'Output Directory: {config.eval.output_dir}', + '=' * 80, + ] + logger.info('\n'.join(benchmark_info)) + + # Convert config to lm_eval arguments + model_args = config_to_model_args(config) + tasks = dataset_name_to_tasks(config.eval.dataset_name) + + # Prepare sys.argv for lm_eval + original_argv = sys.argv.copy() + + try: + sys.argv = [ + "lm_eval", + "--model", "diffulex", + "--model_args", model_args, + "--tasks", tasks, + "--batch_size", "1", + "--output_path", config.eval.output_dir, + ] + + if config.eval.dataset_limit: + sys.argv.extend(["--limit", str(config.eval.dataset_limit)]) + + # Add any additional lm_eval arguments from config if needed + # For now, we use default batch_size=1 + + lm_eval_info = [ + '=' * 80, + 'Starting lm-evaluation-harness evaluation...', + '=' * 80, + f'Model args: {model_args}', + f'Tasks: {tasks}', + '=' * 80, + ] + logger.info('\n'.join(lm_eval_info)) + + # Run lm_eval + cli_evaluate() + + logger.success("Evaluation completed successfully") + + except Exception as e: + logger.error(f"Evaluation failed: {e}", exc_info=True) + sys.exit(1) + finally: + # Restore original argv + sys.argv = original_argv + + +def load_config_from_args(args) -> BenchmarkConfig: + """ + Load configuration from command line arguments + + Args: + args: Parsed command line arguments + + Returns: + BenchmarkConfig instance + """ + logger = get_logger(__name__) + + # Try to load from config file + if args.config: + config_path = Path(args.config) + else: + # Try default config path + default_config = get_default_config_path() + if default_config.exists(): + config_path = default_config + logger.info(f"Using default config: {config_path}") + else: + config_path = None + + if config_path and config_path.exists(): + if config_path.suffix in ['.yaml', '.yml']: + config = BenchmarkConfig.from_yaml(str(config_path)) + elif config_path.suffix == '.json': + config = BenchmarkConfig.from_json(str(config_path)) + else: + logger.error(f"Unsupported config file format: {config_path.suffix}") + sys.exit(1) + logger.info(f"Loaded configuration from: {config_path}") + + # Override with command line arguments if provided + if args.model_path: + config.engine.model_path = args.model_path + if getattr(args, "tokenizer_path", None): + config.engine.tokenizer_path = args.tokenizer_path + if args.dataset: + config.eval.dataset_name = args.dataset + if args.dataset_limit is not None: + config.eval.dataset_limit = args.dataset_limit + if getattr(args, "max_tokens", None) is not None: + config.eval.max_tokens = args.max_tokens + if getattr(args, "temperature", None) is not None: + config.eval.temperature = args.temperature + if args.output_dir: + config.eval.output_dir = args.output_dir + + # Engine overrides (make bench configs reusable for eager vs CUDA Graph comparisons) + if getattr(args, "enforce_eager", None) is not None: + config.engine.enforce_eager = bool(args.enforce_eager) + if getattr(args, "kv_cache_layout", None) is not None: + config.engine.kv_cache_layout = args.kv_cache_layout + if getattr(args, "decode_mode", None) is not None: + config.engine.decode_mode = args.decode_mode + if getattr(args, "kv_cache_dtype", None) is not None: + config.engine.kv_cache_dtype = args.kv_cache_dtype + if getattr(args, "max_model_len", None) is not None: + config.engine.max_model_len = args.max_model_len + if getattr(args, "max_num_seqs", None) is not None: + config.engine.max_num_seqs = args.max_num_seqs + if getattr(args, "max_num_batched_tokens", None) is not None: + config.engine.max_num_batched_tokens = args.max_num_batched_tokens + else: + if not args.model_path: + logger.error("Either --config or --model-path must be provided") + sys.exit(1) + + # Create config from command line arguments + engine = EngineConfig( + model_path=args.model_path, + tokenizer_path=args.tokenizer_path, + model_name=args.model_name, + decoding_strategy=args.decoding_strategy, + mask_token_id=args.mask_token_id, + tensor_parallel_size=args.tensor_parallel_size, + data_parallel_size=args.data_parallel_size, + gpu_memory_utilization=args.gpu_memory_utilization, + max_model_len=args.max_model_len, + max_num_batched_tokens=getattr(args, 'max_num_batched_tokens', 4096), + max_num_seqs=getattr(args, 'max_num_seqs', 128), + use_lora=args.use_lora, + lora_path=args.lora_path, + kv_cache_layout=getattr(args, 'kv_cache_layout', 'unified'), + accept_threshold=args.accept_threshold, + complete_threshold=args.complete_threshold, + add_new_block_threshold=args.add_new_block_threshold, + diffusion_block_size=args.diffusion_block_size, + kv_cache_dtype=getattr(args, 'kv_cache_dtype', None), + decode_mode=getattr(args, 'decode_mode', None), + # Force enforce_eager=True for varlen mode to avoid CUDA graph capture error + enforce_eager=True if getattr(args, 'decode_mode', None) == 'varlen' else (args.enforce_eager if hasattr(args, 'enforce_eager') else False), + linear_attn_weight_dtype=getattr(args, 'linear_attn_weight_dtype', None), + linear_mlp_weight_dtype=getattr(args, 'linear_mlp_weight_dtype', None), + linear_attn_act_dtype=getattr(args, 'linear_attn_act_dtype', None), + linear_mlp_act_dtype=getattr(args, 'linear_mlp_act_dtype', None), + ) + + eval_config = EvalConfig( + dataset_name=args.dataset, + dataset_split=getattr(args, 'dataset_split', 'test'), + dataset_limit=args.dataset_limit, + temperature=args.temperature, + max_tokens=args.max_tokens, + ignore_eos=getattr(args, 'ignore_eos', False), + output_dir=args.output_dir, + save_results=args.save_results, + ) + + config = BenchmarkConfig(engine=engine, eval=eval_config) + + return config + + +def main(): + """Main function""" + parser = create_argument_parser() + args = parser.parse_args() + + # Setup logger + log_level = getattr(logging, args.log_level.upper()) + setup_logger("diffulex_bench", level=log_level, log_file=args.log_file) + + # Load configuration + config = load_config_from_args(args) + + # Run benchmark using lm_eval + run_benchmark(config) + + +if __name__ == "__main__": + main() diff --git a/diffulex_bench/metrics.py b/diffulex_bench/metrics.py new file mode 100644 index 0000000..88e5a49 --- /dev/null +++ b/diffulex_bench/metrics.py @@ -0,0 +1,126 @@ +""" +Benchmark Metrics - Evaluation metrics computation +""" + +import re +from typing import List, Dict, Any, Optional +import json + + +def extract_number(text: str) -> Optional[float]: + """ + Extract number from text (for GSM8K and other math problems) + + Args: + text: Input text + + Returns: + Extracted number, or None if not found + """ + # Try to match #### number format (GSM8K standard format) + pattern = r'####\s*(-?\d+(?:\.\d+)?)' + match = re.search(pattern, text) + if match: + return float(match.group(1)) + + # Try to match the last number + numbers = re.findall(r'-?\d+(?:\.\d+)?', text) + if numbers: + try: + return float(numbers[-1]) + except ValueError: + pass + + return None + + +def gsm8k_accuracy( + predictions: List[str], + ground_truths: List[str], +) -> float: + """ + Calculate GSM8K accuracy + + Args: + predictions: List of predicted texts + ground_truths: List of ground truth answers (including full solution process) + + Returns: + Accuracy (0-1) + """ + if len(predictions) != len(ground_truths): + raise ValueError("Predictions and ground_truths must have the same length") + + correct = 0 + for pred, gt in zip(predictions, ground_truths): + pred_num = extract_number(pred) + gt_num = extract_number(gt) + + if pred_num is not None and gt_num is not None: + if abs(pred_num - gt_num) < 1e-6: + correct += 1 + + return correct / len(predictions) if predictions else 0.0 + + +def humaneval_pass_at_k( + results: List[Dict[str, Any]], + k: int = 1, +) -> float: + """ + Calculate HumanEval Pass@k metric + + Args: + results: List of results, each should contain 'output', 'test', 'entry_point' fields + k: k value, default 1 + + Returns: + Pass@k score + """ + # Note: Full HumanEval evaluation requires code execution, this is just a framework + # In practice, need to integrate code execution environment (e.g., Docker) + # Returns None, actual evaluation requires implementing code execution logic + return None + + +def compute_metrics( + outputs: List[Dict[str, Any]], + ground_truths: Optional[List[str]] = None, + dataset_name: str = "gsm8k", +) -> Dict[str, Any]: + """ + Compute evaluation metrics + + Args: + outputs: List of generation results + ground_truths: List of ground truth answers (optional) + dataset_name: Dataset name, used to select appropriate evaluation method + + Returns: + Dictionary of metrics + """ + metrics = {} + + # Basic statistics + total_tokens = sum(len(o.get('token_ids', [])) for o in outputs) + avg_diff_steps = sum(o.get('n_diff_steps', 0) for o in outputs) / len(outputs) if outputs else 0 + total_time = sum(o.get('generation_time', 0) for o in outputs) + + metrics['num_samples'] = len(outputs) + metrics['total_tokens'] = total_tokens + metrics['avg_tokens_per_sample'] = total_tokens / len(outputs) if outputs else 0 + metrics['avg_diff_steps'] = avg_diff_steps + metrics['total_time'] = total_time + metrics['throughput_tok_s'] = total_tokens / total_time if total_time > 0 else 0 + + # Dataset-specific metrics + if ground_truths and dataset_name == "gsm8k": + predictions = [o.get('text', '') for o in outputs] + metrics['accuracy'] = gsm8k_accuracy(predictions, ground_truths) + elif ground_truths and dataset_name == "humaneval": + # HumanEval requires code execution, this is just a framework + metrics['pass_at_1'] = None # Need to implement code execution logic + metrics['note'] = "HumanEval evaluation requires code execution environment" + + return metrics + diff --git a/diffulex_bench/report.py b/diffulex_bench/report.py new file mode 100644 index 0000000..c4c7622 --- /dev/null +++ b/diffulex_bench/report.py @@ -0,0 +1,113 @@ +""" +Benchmark Report - Report generation for benchmark results +""" + +import json +from pathlib import Path +from typing import Dict, Any, List, Optional +import pandas as pd + + +def generate_report(results_file: str, output_file: Optional[str] = None) -> str: + """ + Generate benchmark report + + Args: + results_file: Path to results JSON file + output_file: Path to output report file, if None prints to console + + Returns: + Report text + """ + with open(results_file, 'r', encoding='utf-8') as f: + results = json.load(f) + + config = results['config'] + metrics = results['metrics'] + + # Generate report + report_lines = [] + append_line = lambda line: report_lines.append(line) + append_line("=" * 80) + append_line("Diffulex Benchmark Report") + append_line("=" * 80) + append_line("") + append_line("Configuration:") + append_line(f" Model: {config.get('model_path', 'N/A')}") + append_line(f" Model Name: {config.get('model_name', 'N/A')}") + append_line(f" Decoding Strategy: {config.get('decoding_strategy', 'N/A')}") + append_line(f" Dataset: {config.get('dataset_name', 'N/A')}") + append_line(f" Tensor Parallel Size: {config.get('tensor_parallel_size', 'N/A')}") + append_line(f" Data Parallel Size: {config.get('data_parallel_size', 'N/A')}") + append_line("") + append_line("Metrics:") + append_line(f" Number of Samples: {metrics.get('num_samples', 'N/A')}") + append_line(f" Total Tokens: {metrics.get('total_tokens', 'N/A')}") + append_line(f" Average Tokens per Sample: {metrics.get('avg_tokens_per_sample', 0):.2f}") + append_line(f" Average Diffusion Steps: {metrics.get('avg_diff_steps', 0):.2f}") + append_line(f" Total Time: {metrics.get('total_time', 0):.2f} seconds") + append_line(f" Throughput: {metrics.get('throughput_tok_s', 0):.2f} tokens/s") + + if 'accuracy' in metrics and metrics['accuracy'] is not None: + report_lines.append(f" Accuracy: {metrics['accuracy']:.4f}") + + report_lines.append("") + report_lines.append(f"Timestamp: {results.get('timestamp', 'N/A')}") + report_lines.append("=" * 80) + + report_text = "\n".join(report_lines) + + # Save or output + if output_file: + with open(output_file, 'w', encoding='utf-8') as f: + f.write(report_text) + print(f"Report saved to: {output_file}") + else: + print(report_text) + + return report_text + + +def compare_results(result_files: List[str], output_file: Optional[str] = None) -> pd.DataFrame: + """ + Compare multiple benchmark results + + Args: + result_files: List of result file paths + output_file: Path to output CSV file, if None only returns DataFrame + + Returns: + DataFrame with comparison results + """ + rows = [] + + for result_file in result_files: + with open(result_file, 'r', encoding='utf-8') as f: + results = json.load(f) + + config = results['config'] + metrics = results['metrics'] + + row = { + 'model_path': config.get('model_path', 'N/A'), + 'model_name': config.get('model_name', 'N/A'), + 'decoding_strategy': config.get('decoding_strategy', 'N/A'), + 'dataset': config.get('dataset_name', 'N/A'), + 'num_samples': metrics.get('num_samples', 0), + 'total_tokens': metrics.get('total_tokens', 0), + 'avg_tokens_per_sample': metrics.get('avg_tokens_per_sample', 0), + 'avg_diff_steps': metrics.get('avg_diff_steps', 0), + 'throughput_tok_s': metrics.get('throughput_tok_s', 0), + 'accuracy': metrics.get('accuracy', None), + 'timestamp': results.get('timestamp', 'N/A'), + } + rows.append(row) + + df = pd.DataFrame(rows) + + if output_file: + df.to_csv(output_file, index=False, encoding='utf-8') + print(f"Comparison saved to: {output_file}") + + return df + diff --git a/diffulex_bench/runner.py b/diffulex_bench/runner.py new file mode 100644 index 0000000..9617bc4 --- /dev/null +++ b/diffulex_bench/runner.py @@ -0,0 +1,193 @@ +""" +Benchmark Runner - Benchmark runner that wraps Diffulex inference engine +Provides a unified interface for benchmarking +""" + +import time +from typing import List, Dict, Any, Optional + +from diffulex import Diffulex, SamplingParams +from transformers import AutoTokenizer +from diffulex.logger import get_logger + + +class BenchmarkRunner: + """ + Benchmark runner that wraps the Diffulex inference engine + """ + + def __init__( + self, + model_path: str, + tokenizer_path: Optional[str] = None, + wait_ready: bool = True, + **diffulex_kwargs + ): + """ + Initialize the benchmark runner + + Args: + model_path: Path to the model + tokenizer_path: Path to the tokenizer, if None uses model_path + wait_ready: Whether to wait for engine to be fully initialized before returning + **diffulex_kwargs: Additional arguments to pass to Diffulex + """ + self.model_path = model_path + self.tokenizer_path = tokenizer_path or model_path + self.logger = get_logger(__name__) + + # Initialize Diffulex engine + self.logger.info("Initializing Diffulex engine...") + self.llm = Diffulex(model_path, **diffulex_kwargs) + + # Wait for engine to be ready if requested + if wait_ready: + self._wait_for_ready() + + # Load tokenizer + self.logger.info("Loading tokenizer...") + self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_path, + trust_remote_code=True + ) + self.logger.success("Tokenizer loaded successfully") + + def _wait_for_ready(self, timeout: float = 300.0, check_interval: float = 0.5): + """ + Wait for the Diffulex engine to be fully initialized and ready + + Args: + timeout: Maximum time to wait in seconds + check_interval: Interval between readiness checks in seconds + """ + start_time = time.time() + + # Check if it's a DP worker (has _ask method) or TP worker + if hasattr(self.llm, '_ask'): + # DP worker: wait for all child processes to be ready + # by sending a lightweight command to each + dp_size = getattr(self.llm, 'dp_size', 1) + self.logger.info(f"[DiffulexDPWorker (DP={dp_size})]: Waiting for {dp_size} DiffulexTPWorker subprocesses to be ready...") + + while time.time() - start_time < timeout: + try: + # Try to send a lightweight command to check readiness + # Use is_finished as a lightweight check + for i in range(dp_size): + self.llm._ask(i, "is_finished") + self.logger.success("All DiffulexTPWorker subprocesses are ready") + return + except (EOFError, RuntimeError, AttributeError, ConnectionError) as e: + # Process not ready yet, wait and retry + elapsed = time.time() - start_time + if elapsed < timeout: + time.sleep(check_interval) + else: + raise RuntimeError( + f"Timeout waiting for DP workers to be ready after {elapsed:.1f}s: {e}" + ) from e + else: + # TP worker: wait for all subprocesses to be ready + # Check if subprocesses are alive and wait a bit for initialization + if hasattr(self.llm, 'ps') and self.llm.ps: + num_subprocesses = len(self.llm.ps) + self.logger.info(f"Waiting for {num_subprocesses} TP subprocess(es) to be ready...") + + while time.time() - start_time < timeout: + # Check if all subprocesses are alive + all_alive = all(p.is_alive() for p in self.llm.ps) + + if all_alive: + # Give subprocesses a bit more time to complete initialization + # The main process initialization is synchronous, but subprocesses + # may still be initializing (model loading, warmup, etc.) + # Subprocesses will synchronize via barrier in ModelRunnerBase.__init__ + # So we just need to wait a bit for them to complete initialization + time.sleep(2.0) # Wait a bit for subprocess initialization + self.logger.success("All TP subprocesses are ready") + return + else: + # Some process died, check which one + dead_processes = [ + i for i, p in enumerate(self.llm.ps) if not p.is_alive() + ] + exit_codes = [ + self.llm.ps[i].exitcode for i in dead_processes + ] + raise RuntimeError( + f"TP subprocess(es) {dead_processes} terminated during initialization. " + f"Exit code(s): {exit_codes}" + ) + + elapsed = time.time() - start_time + raise RuntimeError( + f"Timeout waiting for TP subprocesses to be ready after {elapsed:.1f}s" + ) + else: + # Single process TP worker, should be ready immediately + # Main process initialization is synchronous + self.logger.success("TP worker is ready") + return + + def generate( + self, + prompts: List[str], + sampling_params: SamplingParams, + use_tqdm: bool = True, + ) -> List[Dict[str, Any]]: + """ + Generate text + + Args: + prompts: List of input prompts + sampling_params: Sampling parameters + use_tqdm: Whether to show progress bar + + Returns: + List of generation results, each containing text, token_ids, n_diff_steps + """ + start_time = time.time() + outputs = self.llm.generate(prompts, sampling_params, use_tqdm=use_tqdm) + end_time = time.time() + + # Add timing information + total_time = end_time - start_time + for output in outputs: + output['generation_time'] = total_time / len(outputs) if outputs else 0 + + return outputs + + def evaluate_batch( + self, + prompts: List[str], + sampling_params: SamplingParams, + use_tqdm: bool = True, + ) -> Dict[str, Any]: + """ + Evaluate a batch of prompts + + Args: + prompts: List of input prompts + sampling_params: Sampling parameters + use_tqdm: Whether to show progress bar + + Returns: + Evaluation result dictionary containing generation results and statistics + """ + outputs = self.generate(prompts, sampling_params, use_tqdm=use_tqdm) + + # Calculate statistics + total_tokens = sum(len(o['token_ids']) for o in outputs) + total_time = sum(o.get('generation_time', 0) for o in outputs) + avg_diff_steps = sum(o.get('n_diff_steps', 0) for o in outputs) / len(outputs) if outputs else 0 + + return { + 'outputs': outputs, + 'num_samples': len(outputs), + 'total_tokens': total_tokens, + 'total_time': total_time, + 'avg_tokens_per_sample': total_tokens / len(outputs) if outputs else 0, + 'avg_diff_steps': avg_diff_steps, + 'throughput_tok_s': total_tokens / total_time if total_time > 0 else 0, + } + diff --git a/diffulex_kernel/__init__.py b/diffulex_kernel/__init__.py index 2369bb6..38ab37d 100644 --- a/diffulex_kernel/__init__.py +++ b/diffulex_kernel/__init__.py @@ -1,2 +1,54 @@ -from diffulex_kernel.python.dllm_flash_attn import dllm_flash_attn_decode, dllm_flash_attn_prefill -from diffulex_kernel.python.kv_cache_kernels import store_kvcache_distinct_layout, store_kvcache_unified_layout \ No newline at end of file +"""Diffulex CUDA kernel package. + +Keep this module lightweight: importing `diffulex_kernel` should not eagerly +import optional heavy deps (e.g. TileLang) unless the corresponding kernels are +actually used. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from diffulex_kernel.python.dllm_flash_attn_kernels import ( # noqa: F401 + dllm_flash_attn_decode as dllm_flash_attn_decode, + dllm_flash_attn_prefill as dllm_flash_attn_prefill, + ) + from diffulex_kernel.python.kv_cache_kernels import ( # noqa: F401 + load_kvcache as load_kvcache, + store_kvcache_distinct_layout as store_kvcache_distinct_layout, + store_kvcache_unified_layout as store_kvcache_unified_layout, + ) + + +def __getattr__(name: str): + if name == "dllm_flash_attn_decode": + from diffulex_kernel.python.dllm_flash_attn_kernels import dllm_flash_attn_decode + + return dllm_flash_attn_decode + if name == "dllm_flash_attn_prefill": + from diffulex_kernel.python.dllm_flash_attn_kernels import dllm_flash_attn_prefill + + return dllm_flash_attn_prefill + if name == "store_kvcache_distinct_layout": + from diffulex_kernel.python.kv_cache_kernels import store_kvcache_distinct_layout + + return store_kvcache_distinct_layout + if name == "store_kvcache_unified_layout": + from diffulex_kernel.python.kv_cache_kernels import store_kvcache_unified_layout + + return store_kvcache_unified_layout + if name == "load_kvcache": + from diffulex_kernel.python.kv_cache_kernels import load_kvcache + + return load_kvcache + raise AttributeError(name) + + +__all__ = [ + "dllm_flash_attn_decode", + "dllm_flash_attn_prefill", + "store_kvcache_distinct_layout", + "store_kvcache_unified_layout", + "load_kvcache", +] diff --git a/diffulex_kernel/python/auto_tuner.py b/diffulex_kernel/python/auto_tuner.py index f9b5ea0..72311b3 100644 --- a/diffulex_kernel/python/auto_tuner.py +++ b/diffulex_kernel/python/auto_tuner.py @@ -21,4 +21,40 @@ def build_configs(): "NUM_STAGES": c[2], "NUM_THREADS": c[3], } for c in CONFIGS + ] + + +def build_linear_configs(): + """Autotune configs for TileLang linear/GEMM-style kernels. + + Notes: + - Keys intentionally match the linear kernel function kwargs in `linear_kernels.py` + (lowercase: block_M/block_N/block_K/num_stages/threads). + - Keep the search space modest; these kernels are instantiated for many (M,N,K) shapes. + """ + BLOCK_M_LIST = [32, 64, 128] + BLOCK_N_LIST = [64, 128] + BLOCK_K_LIST = [64, 128] + NUM_STAGES_LIST = [2, 3] + THREADS_LIST = [128, 256] + + CONFIGS = list( + itertools.product( + BLOCK_M_LIST, + BLOCK_N_LIST, + BLOCK_K_LIST, + NUM_STAGES_LIST, + THREADS_LIST, + ) + ) + + return [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "threads": c[4], + } + for c in CONFIGS ] \ No newline at end of file diff --git a/diffulex_kernel/python/dllm_flash_attn.py b/diffulex_kernel/python/dllm_flash_attn.py deleted file mode 100644 index 9b2caa3..0000000 --- a/diffulex_kernel/python/dllm_flash_attn.py +++ /dev/null @@ -1,637 +0,0 @@ -import torch -import tilelang -import tilelang.language as T - -from flash_attn import flash_attn_varlen_func -from tilelang.autotuner import set_autotune_inputs - -from diffulex_kernel.python.auto_tuner import build_configs -from diffulex_kernel.python.kv_cache_kernels import load_kvcache -from diffulex.attention.metadata import AttnMetaDataBase, is_warming_up - -# from tilelang.engine.callback import register_cuda_postproc_callback -# @register_cuda_postproc_callback -# def tilelang_callback_cuda_postproc(code, _): -# code = "// tilelang_callback_cuda_postproc: generated CUDA code by TileLang\n" + code -# print(code) -# return code - - -kernel_config = None - - -@tilelang.autotune(configs=build_configs()) -@tilelang.jit( - out_idx=[-1], - pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,}, -) -def dllm_flash_attn_prefill_kernel( - NUM_SEQS: int, - NUM_GROUPS: int, - Q_LEN: int, - KV_LEN: int, - NUM_HEADS: int, - HEAD_DIM: int, - IS_BLOCK_ATTN: bool, - DIFFUSION_BLOCK_SIZE: int, - BLOCK_M: int = 64, - BLOCK_N: int = 64, - NUM_STAGES: int = 1, - NUM_THREADS: int = 128, -): - SCALE = (1.0 / HEAD_DIM)**0.5 * 1.44269504 # log2(e) - NUM_KV_HEADS = NUM_HEADS // NUM_GROUPS - Q_SHAPE = [Q_LEN, NUM_HEADS, HEAD_DIM] - KV_SHAPE = [KV_LEN, NUM_KV_HEADS, HEAD_DIM] - O_SHAPE = [Q_LEN, NUM_HEADS, HEAD_DIM] - DTYPE = "bfloat16" - ACCUM_DTYPE = "float" - - @T.prim_func - def kernel( - Q: T.Tensor(Q_SHAPE, DTYPE), - K: T.Tensor(KV_SHAPE, DTYPE), - V: T.Tensor(KV_SHAPE, DTYPE), - cu_seqlens_q: T.Tensor(NUM_SEQS + 1, "int32"), - cu_seqlens_k: T.Tensor(NUM_SEQS + 1, "int32"), - max_seqlen_q: T.int32, - O: T.Tensor(O_SHAPE, DTYPE), - ): - with T.Kernel(T.ceildiv(max_seqlen_q, BLOCK_M), NUM_HEADS, NUM_SEQS, threads=NUM_THREADS) as (bx, by, bz): - Q_shared = T.alloc_shared([BLOCK_M, HEAD_DIM], DTYPE) - K_shared = T.alloc_shared([BLOCK_N, HEAD_DIM], DTYPE) - V_shared = T.alloc_shared([BLOCK_N, HEAD_DIM], DTYPE) - O_shared = T.alloc_shared([BLOCK_M, HEAD_DIM], DTYPE) - - acc_score = T.alloc_fragment([BLOCK_M, BLOCK_N], ACCUM_DTYPE) - acc_score_cast = T.alloc_fragment([BLOCK_M, BLOCK_N], DTYPE) - acc_output = T.alloc_fragment([BLOCK_M, HEAD_DIM], ACCUM_DTYPE) - scores_max = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) - scores_max_prev = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) - scores_scale = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) - scores_sum = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) - log_sum = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) - - T.annotate_layout({ - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) - - q_block_idx = bx - seq_idx = bz - head_idx = by - kv_head_idx = head_idx // NUM_GROUPS - - q_start_idx = cu_seqlens_q[seq_idx] - kv_start_idx = cu_seqlens_k[seq_idx] - q_end_idx = cu_seqlens_q[seq_idx + 1] - kv_end_idx = cu_seqlens_k[seq_idx + 1] - - cur_q_seqlen = q_end_idx - q_start_idx - cur_kv_seqlen = kv_end_idx - kv_start_idx - - T.copy(Q[q_start_idx + q_block_idx * BLOCK_M : q_start_idx + (q_block_idx + 1) * BLOCK_M, head_idx, :], Q_shared) - - T.fill(acc_output, 0) - T.fill(acc_score, 0) - T.fill(log_sum, 0) - T.fill(scores_max, -T.infinity(ACCUM_DTYPE)) - - # The same boundary condition as naive causal mask - loop_range = ( - T.min(T.ceildiv(cur_q_seqlen + (q_block_idx + 1) * BLOCK_M, BLOCK_N), T.ceildiv(cur_kv_seqlen, BLOCK_N)) - if IS_BLOCK_ATTN else T.ceildiv(cur_kv_seqlen, BLOCK_N) - ) - for kv_block_idx in T.Pipelined(loop_range, num_stages=NUM_STAGES): - T.copy(K[kv_start_idx + kv_block_idx * BLOCK_N : kv_start_idx + (kv_block_idx + 1) * BLOCK_N, kv_head_idx, :], K_shared) - - # Initialize acc_score with mask - if IS_BLOCK_ATTN: - for i, j in T.Parallel(BLOCK_M, BLOCK_N): - num_diffusion_blocks = (q_block_idx * BLOCK_M + i) // DIFFUSION_BLOCK_SIZE + 1 - acc_score[i, j] = T.if_then_else( - (num_diffusion_blocks * DIFFUSION_BLOCK_SIZE <= kv_block_idx * BLOCK_N + j) or - (q_block_idx * BLOCK_M + i >= cur_q_seqlen or - kv_block_idx * BLOCK_N + j >= cur_kv_seqlen), -1e9, 0 - ) - else: - for i, j in T.Parallel(BLOCK_M, BLOCK_N): - acc_score[i, j] = T.if_then_else( - (q_block_idx * BLOCK_M + i >= cur_q_seqlen or - kv_block_idx * BLOCK_N + j >= cur_kv_seqlen), -1e9, 0 - ) - - # Compute attention scores - T.gemm(Q_shared, K_shared, acc_score, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - # Compute online softmax - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(ACCUM_DTYPE)) - T.reduce_max(acc_score, scores_max, dim=1, clear=False) # T.reduce_max(acc_score, scores_max, dim=1, clear=True) # TODO: check if this is correct - for i in T.Parallel(BLOCK_M): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - - for i in T.Parallel(BLOCK_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * SCALE - scores_max[i] * SCALE) - - for i, j in T.Parallel(BLOCK_M, BLOCK_N): - acc_score[i, j] = T.exp2(acc_score[i, j] * SCALE - scores_max[i] * SCALE) - - T.reduce_sum(acc_score, scores_sum, dim=1) - for i in T.Parallel(BLOCK_M): - log_sum[i] = log_sum[i] * scores_scale[i] + scores_sum[i] - - T.copy(acc_score, acc_score_cast) - for i, j in T.Parallel(BLOCK_M, HEAD_DIM): - acc_output[i, j] *= scores_scale[i] - - # Compute attention output - T.copy(V[kv_start_idx + kv_block_idx * BLOCK_N : kv_start_idx + (kv_block_idx + 1) * BLOCK_N, kv_head_idx, :], V_shared) - T.gemm(acc_score_cast, V_shared, acc_output, policy=T.GemmWarpPolicy.FullRow) - - for i, j in T.Parallel(BLOCK_M, HEAD_DIM): - acc_output[i, j] /= log_sum[i] - - T.copy(acc_output, O_shared) - for i, d_idx in T.Parallel(BLOCK_M, HEAD_DIM): - if i + q_block_idx * BLOCK_M < cur_q_seqlen: - O[i + q_start_idx + q_block_idx * BLOCK_M, head_idx, d_idx] = O_shared[i, d_idx] - - return kernel - - -@tilelang.jit( - out_idx=[-1], - pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,}, -) -def dllm_flash_attn_decode_kernel( - NUM_SEQS: int, - NUM_GROUPS: int, - NUM_PAGE_BLOCKS: int, - Q_LEN: int, - KV_LEN: int, - NUM_HEADS: int, - HEAD_DIM: int, - IS_BLOCK_ATTN: bool, - DIFFUSION_BLOCK_SIZE: int, - MAX_SEQ_NUM_BLOCKS: int, - PAGE_BLOCK_SIZE: int = 32, - BLOCK_M: int = 64, - BLOCK_N: int = 64, - NUM_STAGES: int = 1, - NUM_THREADS: int = 128, -): - SCALE = (1.0 / HEAD_DIM)**0.5 * 1.44269504 # log2(e) - NUM_KV_HEADS = NUM_HEADS // NUM_GROUPS - Q_SHAPE = [Q_LEN, NUM_HEADS, HEAD_DIM] - KV_SHAPE = [KV_LEN, NUM_KV_HEADS, HEAD_DIM] - O_SHAPE = [Q_LEN, NUM_HEADS, HEAD_DIM] - K_CACHE_SHAPE = [NUM_PAGE_BLOCKS, PAGE_BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM] - V_CACHE_SHAPE = [NUM_PAGE_BLOCKS, PAGE_BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM] - BLOCK_TABLE_SHAPE = [NUM_SEQS, MAX_SEQ_NUM_BLOCKS] - DTYPE = "bfloat16" - ACCUM_DTYPE = "float" - - @T.prim_func - def kernel( - Q: T.Tensor(Q_SHAPE, DTYPE), - K: T.Tensor(KV_SHAPE, DTYPE), - V: T.Tensor(KV_SHAPE, DTYPE), - K_Cache: T.Tensor(K_CACHE_SHAPE, DTYPE), - V_Cache: T.Tensor(V_CACHE_SHAPE, DTYPE), - block_tables: T.Tensor(BLOCK_TABLE_SHAPE, "int32"), - context_lens: T.Tensor(NUM_SEQS, "int32"), - cu_seqlens_q: T.Tensor(NUM_SEQS + 1, "int32"), - cu_seqlens_k: T.Tensor(NUM_SEQS + 1, "int32"), - max_seqlen_q: T.int32, - O: T.Tensor(O_SHAPE, DTYPE), - ): - with T.Kernel(NUM_SEQS, NUM_HEADS, threads=NUM_THREADS) as (bx, by): - Q_shared = T.alloc_shared([BLOCK_M, HEAD_DIM], DTYPE) - K_shared = T.alloc_shared([BLOCK_N, HEAD_DIM], DTYPE) - V_shared = T.alloc_shared([BLOCK_N, HEAD_DIM], DTYPE) - O_shared = T.alloc_shared([BLOCK_M, HEAD_DIM], DTYPE) - K_Cache_shared = T.alloc_shared([PAGE_BLOCK_SIZE, HEAD_DIM], DTYPE) - V_Cache_shared = T.alloc_shared([PAGE_BLOCK_SIZE, HEAD_DIM], DTYPE) - - acc_score_kv = T.alloc_fragment([BLOCK_M, BLOCK_N], ACCUM_DTYPE) - acc_score_kv_cast = T.alloc_fragment([BLOCK_M, BLOCK_N], DTYPE) - acc_score_kvcache = T.alloc_fragment([BLOCK_M, PAGE_BLOCK_SIZE], ACCUM_DTYPE) - acc_score_kvcache_cast = T.alloc_fragment([BLOCK_M, PAGE_BLOCK_SIZE], DTYPE) - - acc_output = T.alloc_fragment([BLOCK_M, HEAD_DIM], ACCUM_DTYPE) - scores_max = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) - scores_max_prev = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) - scores_scale = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) - scores_sum = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) - log_sum = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) - - T.annotate_layout({ - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) - - seq_idx = bx - head_idx = by - kv_head_idx = head_idx // NUM_GROUPS - - q_start_idx = cu_seqlens_q[seq_idx] - kv_start_idx = cu_seqlens_k[seq_idx] - q_end_idx = cu_seqlens_q[seq_idx + 1] - kv_end_idx = cu_seqlens_k[seq_idx + 1] - - cur_q_seqlen = q_end_idx - q_start_idx - cur_kv_seqlen = kv_end_idx - kv_start_idx - - cur_context_len = context_lens[seq_idx] - - T.copy(Q[q_start_idx : q_start_idx + BLOCK_M, head_idx, :], Q_shared) - - T.fill(acc_output, 0) - T.fill(acc_score_kv, 0) - T.fill(acc_score_kvcache, 0) - T.fill(log_sum, 0) - T.fill(scores_max, -T.infinity(ACCUM_DTYPE)) - - # ========================== - # Stage 1: KV Cache Attention (Context) - # ========================== - for page_block_idx_local in T.Pipelined(MAX_SEQ_NUM_BLOCKS, num_stages=NUM_STAGES): - page_block_idx_global = block_tables[seq_idx, page_block_idx_local] - if page_block_idx_global >= 0: - T.copy(K_Cache[page_block_idx_global, :, kv_head_idx, :], K_Cache_shared) - - for i, j in T.Parallel(BLOCK_M, PAGE_BLOCK_SIZE): - acc_score_kvcache[i, j] = T.if_then_else( - (i >= cur_q_seqlen or - page_block_idx_local * PAGE_BLOCK_SIZE + j >= cur_context_len), -1e9, 0 - ) - - # Compute attention scores - T.gemm(Q_shared, K_Cache_shared, acc_score_kvcache, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - # Compute online softmax - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(ACCUM_DTYPE)) - T.reduce_max(acc_score_kvcache, scores_max, dim=1, clear=False) - for i in T.Parallel(BLOCK_M): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - - for i in T.Parallel(BLOCK_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * SCALE - scores_max[i] * SCALE) - - for i, j in T.Parallel(BLOCK_M, PAGE_BLOCK_SIZE): - acc_score_kvcache[i, j] = T.exp2(acc_score_kvcache[i, j] * SCALE - scores_max[i] * SCALE) - - T.reduce_sum(acc_score_kvcache, scores_sum, dim=1) - for i in T.Parallel(BLOCK_M): - log_sum[i] = log_sum[i] * scores_scale[i] + scores_sum[i] - - T.copy(acc_score_kvcache, acc_score_kvcache_cast) - - # Scale previous output accumulator - for i, j in T.Parallel(BLOCK_M, HEAD_DIM): - acc_output[i, j] *= scores_scale[i] - - # Accumulate current V_cache contribution - T.copy(V_Cache[page_block_idx_global, :, kv_head_idx, :], V_Cache_shared) - T.gemm(acc_score_kvcache_cast, V_Cache_shared, acc_output, policy=T.GemmWarpPolicy.FullRow) - - if page_block_idx_local == MAX_SEQ_NUM_BLOCKS - 1: - # ========================== - # Stage 2: Fresh KV Attention (Self-Attn) - # ========================== - T.copy(K[kv_start_idx : kv_start_idx + BLOCK_N, kv_head_idx, :], K_shared) - - for i, j in T.Parallel(BLOCK_M, BLOCK_N): - acc_score_kv[i, j] = T.if_then_else(i >= cur_q_seqlen or j >= cur_kv_seqlen, -1e9, 0) - - T.gemm(Q_shared, K_shared, acc_score_kv, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(ACCUM_DTYPE)) - T.reduce_max(acc_score_kv, scores_max, dim=1, clear=False) - for i in T.Parallel(BLOCK_M): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - - for i in T.Parallel(BLOCK_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * SCALE - scores_max[i] * SCALE) - - for i, j in T.Parallel(BLOCK_M, BLOCK_N): - acc_score_kv[i, j] = T.exp2(acc_score_kv[i, j] * SCALE - scores_max[i] * SCALE) - - T.reduce_sum(acc_score_kv, scores_sum, dim=1) - for i in T.Parallel(BLOCK_M): - log_sum[i] = log_sum[i] * scores_scale[i] + scores_sum[i] - - T.copy(acc_score_kv, acc_score_kv_cast) - - # Scale previous output - for i, j in T.Parallel(BLOCK_M, HEAD_DIM): - acc_output[i, j] *= scores_scale[i] - - T.copy(V[kv_start_idx : kv_start_idx + BLOCK_N, kv_head_idx, :], V_shared) - - # Accumulate current V contribution - T.gemm(acc_score_kv_cast, V_shared, acc_output, policy=T.GemmWarpPolicy.FullRow) - - # Finalize - for i, j in T.Parallel(BLOCK_M, HEAD_DIM): - acc_output[i, j] /= log_sum[i] - - T.copy(acc_output, O_shared) - for i, d_idx in T.Parallel(BLOCK_M, HEAD_DIM): - if i < cur_q_seqlen: - O[i + q_start_idx, head_idx, d_idx] = O_shared[i, d_idx] - - return kernel - - -@tilelang.jit( - out_idx=[-1], - pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,}, -) -def dllm_flash_attn_decode_kernel_legacy( - NUM_SEQS: int, - NUM_GROUPS: int, - NUM_PAGE_BLOCKS: int, - Q_LEN: int, - KV_LEN: int, - NUM_HEADS: int, - HEAD_DIM: int, - IS_BLOCK_ATTN: bool, - DIFFUSION_BLOCK_SIZE: int, - MAX_SEQ_NUM_BLOCKS: int, - PAGE_BLOCK_SIZE: int = 32, - BLOCK_M: int = 64, - BLOCK_N: int = 64, - NUM_STAGES: int = 1, - NUM_THREADS: int = 128, -): - SCALE = (1.0 / HEAD_DIM)**0.5 * 1.44269504 # log2(e) - NUM_KV_HEADS = NUM_HEADS // NUM_GROUPS - Q_SHAPE = [Q_LEN, NUM_HEADS, HEAD_DIM] - KV_SHAPE = [KV_LEN, NUM_KV_HEADS, HEAD_DIM] - O_SHAPE = [Q_LEN, NUM_HEADS, HEAD_DIM] - K_CACHE_SHAPE = [NUM_PAGE_BLOCKS, PAGE_BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM] - V_CACHE_SHAPE = [NUM_PAGE_BLOCKS, PAGE_BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM] - BLOCK_TABLE_SHAPE = [NUM_SEQS, MAX_SEQ_NUM_BLOCKS] - DTYPE = "bfloat16" - ACCUM_DTYPE = "float" - - @T.prim_func - def kernel( - Q: T.Tensor(Q_SHAPE, DTYPE), - K: T.Tensor(KV_SHAPE, DTYPE), - V: T.Tensor(KV_SHAPE, DTYPE), - K_Cache: T.Tensor(K_CACHE_SHAPE, DTYPE), - V_Cache: T.Tensor(V_CACHE_SHAPE, DTYPE), - block_tables: T.Tensor(BLOCK_TABLE_SHAPE, "int32"), - context_lens: T.Tensor(NUM_SEQS, "int32"), - cu_seqlens_q: T.Tensor(NUM_SEQS + 1, "int32"), - cu_seqlens_k: T.Tensor(NUM_SEQS + 1, "int32"), - max_seqlen_q: T.int32, - O: T.Tensor(O_SHAPE, DTYPE), - ): - with T.Kernel(NUM_SEQS, NUM_HEADS, threads=NUM_THREADS) as (bx, by): - Q_shared = T.alloc_shared([BLOCK_M, HEAD_DIM], DTYPE) - K_shared = T.alloc_shared([BLOCK_N, HEAD_DIM], DTYPE) - V_shared = T.alloc_shared([BLOCK_N, HEAD_DIM], DTYPE) - O_shared = T.alloc_shared([BLOCK_M, HEAD_DIM], DTYPE) - K_Cache_shared = T.alloc_shared([PAGE_BLOCK_SIZE, HEAD_DIM], DTYPE) - V_Cache_shared = T.alloc_shared([PAGE_BLOCK_SIZE, HEAD_DIM], DTYPE) - - acc_score_kv = T.alloc_fragment([BLOCK_M, BLOCK_N], ACCUM_DTYPE) - acc_score_kv_cast = T.alloc_fragment([BLOCK_M, BLOCK_N], DTYPE) - acc_score_kvcache = T.alloc_fragment([BLOCK_M, PAGE_BLOCK_SIZE], ACCUM_DTYPE) - acc_score_kvcache_cast = T.alloc_fragment([BLOCK_M, PAGE_BLOCK_SIZE], DTYPE) - - acc_output = T.alloc_fragment([BLOCK_M, HEAD_DIM], ACCUM_DTYPE) - scores_max = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) - scores_max_prev = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) - scores_scale = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) - scores_sum = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) - log_sum = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) - block_table = T.alloc_fragment([MAX_SEQ_NUM_BLOCKS], "int32") - - T.annotate_layout({ - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - O_shared: tilelang.layout.make_swizzled_layout(O_shared), - }) - - seq_idx = bx - head_idx = by - kv_head_idx = head_idx // NUM_GROUPS - - q_start_idx = cu_seqlens_q[seq_idx] - kv_start_idx = cu_seqlens_k[seq_idx] - q_end_idx = cu_seqlens_q[seq_idx + 1] - kv_end_idx = cu_seqlens_k[seq_idx + 1] - - cur_q_seqlen = q_end_idx - q_start_idx - cur_kv_seqlen = kv_end_idx - kv_start_idx - - cur_context_len = context_lens[seq_idx] - - T.copy(block_tables[seq_idx, :], block_table) - T.copy(Q[q_start_idx : q_start_idx + BLOCK_M, head_idx, :], Q_shared) - - T.fill(acc_output, 0) - T.fill(acc_score_kv, 0) - T.fill(acc_score_kvcache, 0) - T.fill(log_sum, 0) - T.fill(scores_max, -T.infinity(ACCUM_DTYPE)) - - # ========================== - # Stage 1: KV Cache Attention (Context) - # ========================== - for page_block_idx_local in T.Pipelined(MAX_SEQ_NUM_BLOCKS, num_stages=NUM_STAGES): - page_block_idx_global = block_table[page_block_idx_local] - if page_block_idx_global >= 0: - T.copy(K_Cache[page_block_idx_global, :, kv_head_idx, :], K_Cache_shared) - - for i, j in T.Parallel(BLOCK_M, PAGE_BLOCK_SIZE): - acc_score_kvcache[i, j] = T.if_then_else( - (i >= cur_q_seqlen or - page_block_idx_local * PAGE_BLOCK_SIZE + j >= cur_context_len), -1e9, 0 - ) - - # Compute attention scores - T.gemm(Q_shared, K_Cache_shared, acc_score_kvcache, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - # Compute online softmax - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(ACCUM_DTYPE)) - T.reduce_max(acc_score_kvcache, scores_max, dim=1, clear=False) - for i in T.Parallel(BLOCK_M): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - - for i in T.Parallel(BLOCK_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * SCALE - scores_max[i] * SCALE) - - for i, j in T.Parallel(BLOCK_M, PAGE_BLOCK_SIZE): - acc_score_kvcache[i, j] = T.exp2(acc_score_kvcache[i, j] * SCALE - scores_max[i] * SCALE) - - T.reduce_sum(acc_score_kvcache, scores_sum, dim=1) - for i in T.Parallel(BLOCK_M): - log_sum[i] = log_sum[i] * scores_scale[i] + scores_sum[i] - - T.copy(acc_score_kvcache, acc_score_kvcache_cast) - - # Scale previous output accumulator - for i, j in T.Parallel(BLOCK_M, HEAD_DIM): - acc_output[i, j] *= scores_scale[i] - - # Accumulate current V_cache contribution - T.copy(V_Cache[page_block_idx_global, :, kv_head_idx, :], V_Cache_shared) - T.gemm(acc_score_kvcache_cast, V_Cache_shared, acc_output, policy=T.GemmWarpPolicy.FullRow) - - if page_block_idx_local == MAX_SEQ_NUM_BLOCKS - 1: - # ========================== - # Stage 2: Fresh KV Attention (Self-Attn) - # ========================== - T.copy(K[kv_start_idx : kv_start_idx + BLOCK_N, kv_head_idx, :], K_shared) - - for i, j in T.Parallel(BLOCK_M, BLOCK_N): - acc_score_kv[i, j] = T.if_then_else(i >= cur_q_seqlen or j >= cur_kv_seqlen, -1e9, 0) - - T.gemm(Q_shared, K_shared, acc_score_kv, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(ACCUM_DTYPE)) - T.reduce_max(acc_score_kv, scores_max, dim=1, clear=False) - for i in T.Parallel(BLOCK_M): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - - for i in T.Parallel(BLOCK_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * SCALE - scores_max[i] * SCALE) - - for i, j in T.Parallel(BLOCK_M, BLOCK_N): - acc_score_kv[i, j] = T.exp2(acc_score_kv[i, j] * SCALE - scores_max[i] * SCALE) - - T.reduce_sum(acc_score_kv, scores_sum, dim=1) - for i in T.Parallel(BLOCK_M): - log_sum[i] = log_sum[i] * scores_scale[i] + scores_sum[i] - - T.copy(acc_score_kv, acc_score_kv_cast) - - # Scale previous output - for i, j in T.Parallel(BLOCK_M, HEAD_DIM): - acc_output[i, j] *= scores_scale[i] - - T.copy(V[kv_start_idx : kv_start_idx + BLOCK_N, kv_head_idx, :], V_shared) - - # Accumulate current V contribution - T.gemm(acc_score_kv_cast, V_shared, acc_output, policy=T.GemmWarpPolicy.FullRow) - - # Finalize - for i, j in T.Parallel(BLOCK_M, HEAD_DIM): - acc_output[i, j] /= log_sum[i] - - T.copy(acc_output, O_shared) - for i, d_idx in T.Parallel(BLOCK_M, HEAD_DIM): - if i < cur_q_seqlen: - O[i + q_start_idx, head_idx, d_idx] = O_shared[i, d_idx] - - return kernel - - -def dllm_flash_attn_prefill( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale: float, - attn_metadata: AttnMetaDataBase -) -> torch.Tensor: - if attn_metadata.attn_type == "full_attention": - return flash_attn_varlen_func( - q, k, v, - attn_metadata.cu_seqlens_q, attn_metadata.cu_seqlens_k, - attn_metadata.max_seqlen_q, attn_metadata.max_seqlen_k, - softmax_scale=scale, block_table=None - ) - elif attn_metadata.attn_type == "block_attention": - if is_warming_up(): - global kernel_config - with set_autotune_inputs([ - q, k, v, - attn_metadata.cu_seqlens_q, - attn_metadata.cu_seqlens_k, - attn_metadata.max_seqlen_q, - ]): - prefill_kernel = dllm_flash_attn_prefill_kernel( - attn_metadata.num_seqs, - q.shape[1] // k.shape[1], - q.shape[0], - k.shape[0], - q.shape[1], - q.shape[2], - attn_metadata.attn_type == "block_attention", - attn_metadata.diffusion_block_size - ) - kernel_config = prefill_kernel.config - return prefill_kernel( - q, k, v, - attn_metadata.cu_seqlens_q, - attn_metadata.cu_seqlens_k, - attn_metadata.max_seqlen_q, - ) - else: - prefill_kernel = dllm_flash_attn_prefill_kernel( - attn_metadata.num_seqs, - q.shape[1] // k.shape[1], - q.shape[0], - k.shape[0], - q.shape[1], - q.shape[2], - attn_metadata.attn_type == "block_attention", - attn_metadata.diffusion_block_size, - **kernel_config - ) - return prefill_kernel( - q, k, v, - attn_metadata.cu_seqlens_q, - attn_metadata.cu_seqlens_k, - attn_metadata.max_seqlen_q, - ) - - -def dllm_flash_attn_decode( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - scale: float, - attn_metadata: AttnMetaDataBase -) -> torch.Tensor: - if attn_metadata.decode_mode == "static": - decode_kernel = dllm_flash_attn_decode_kernel( - attn_metadata.num_seqs, - q.shape[1] // k.shape[1], - k_cache.shape[0], - q.shape[0], - k.shape[0], - q.shape[1], - q.shape[2], - attn_metadata.attn_type == "block_attention", - attn_metadata.diffusion_block_size, - attn_metadata.block_tables.shape[1], - attn_metadata.page_block_size, - **kernel_config - ) - - return decode_kernel( - q, k, v, k_cache, v_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.cu_seqlens_q, - attn_metadata.cu_seqlens_k, - attn_metadata.max_seqlen_q, - ) - elif attn_metadata.decode_mode == "varlen": - k_comb, v_comb = load_kvcache(k_cache, v_cache, attn_metadata, k, v) - return flash_attn_varlen_func(q, k_comb, v_comb, - attn_metadata.cu_seqlens_q, attn_metadata.cu_seqlens_k, - attn_metadata.max_seqlen_q, attn_metadata.max_seqlen_k, - softmax_scale=scale, block_table=None) \ No newline at end of file diff --git a/diffulex_kernel/python/dllm_flash_attn_kernels.py b/diffulex_kernel/python/dllm_flash_attn_kernels.py new file mode 100644 index 0000000..1535ab0 --- /dev/null +++ b/diffulex_kernel/python/dllm_flash_attn_kernels.py @@ -0,0 +1,193 @@ +""" +Diffulex Flash-Attn kernel wrappers. + +Goals: +- Decode path should NOT require TileLang at import time. +- Prefill behavior remains unchanged (TileLang for block attention / flash-attn varlen otherwise), + but TileLang is imported lazily only when prefill is called. +""" + +from __future__ import annotations + +import os + +import torch +from flash_attn import flash_attn_varlen_func + +from diffulex.attention.metadata import AttnMetaDataBase +from diffulex_kernel.python.kv_cache_kernels import load_kvcache +from diffulex_kernel.python.paged_attn_decode_triton import paged_attn_decode_unified_triton + + +def dllm_flash_attn_prefill( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + attn_metadata: AttnMetaDataBase, +) -> torch.Tensor: + """ + Prefill attention wrapper. + + TileLang is imported lazily so decode-only usage does not depend on TileLang. + """ + from diffulex_kernel.python.dllm_flash_attn_prefill_tilelang import ( + dllm_flash_attn_prefill_tilelang, + ) + + return dllm_flash_attn_prefill_tilelang(q, k, v, scale, attn_metadata) + + +def _decode_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + scale: float, + attn_metadata: AttnMetaDataBase, +) -> torch.Tensor: + """ + Varlen decode path: + - gather/dequant KV cache with Triton `load_kvcache` + - run `flash_attn_varlen_func` + """ + do_profile = os.getenv("DIFFULEX_PROFILE_KVCACHE", "0") == "1" + if do_profile and q.is_cuda: + e0, e1, e2 = ( + torch.cuda.Event(enable_timing=True), + torch.cuda.Event(enable_timing=True), + torch.cuda.Event(enable_timing=True), + ) + e0.record() + k_comb, v_comb = load_kvcache(k_cache, v_cache, attn_metadata, k, v) + e1.record() + out = flash_attn_varlen_func( + q, + k_comb, + v_comb, + attn_metadata.cu_seqlens_q, + attn_metadata.cu_seqlens_k, + attn_metadata.max_seqlen_q, + attn_metadata.max_seqlen_k, + softmax_scale=scale, + block_table=None, + ) + e2.record() + e2.synchronize() + print( + f"[DIFFULEX_PROFILE_KVCACHE] decode(varlen) " + f"load_kvcache={e0.elapsed_time(e1):.3f}ms flash_attn={e1.elapsed_time(e2):.3f}ms" + ) + return out + + k_comb, v_comb = load_kvcache(k_cache, v_cache, attn_metadata, k, v) + return flash_attn_varlen_func( + q, + k_comb, + v_comb, + attn_metadata.cu_seqlens_q, + attn_metadata.cu_seqlens_k, + attn_metadata.max_seqlen_q, + attn_metadata.max_seqlen_k, + softmax_scale=scale, + block_table=None, + ) + + +def _decode_static_unified_triton_bf16( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + scale: float, + attn_metadata: AttnMetaDataBase, +) -> torch.Tensor: + return paged_attn_decode_unified_triton( + q, + k, + v, + k_cache, + v_cache, + attn_metadata, + softmax_scale=scale, + fp8_cache=False, + ) + + +def _decode_static_unified_triton_fp8_cache( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + scale: float, + attn_metadata: AttnMetaDataBase, +) -> torch.Tensor: + if attn_metadata.k_scale is None or attn_metadata.v_scale is None: + raise ValueError("FP8 KV decode requires k_scale and v_scale in metadata") + + # KV cache is stored as uint8 for FP8, but Triton expects float8 view dtype. + from diffulex.utils.quantization.context import get_kv_cache_strategy + + strategy = get_kv_cache_strategy() + if strategy is None or getattr(strategy, "kv_cache_format", "bf16") != "fp8": + raise ValueError(f"Expected kv_cache_format='fp8', got strategy={type(strategy)}") + + k_cache_fp8 = strategy.view_kv_cache_for_kernels(k_cache) + v_cache_fp8 = strategy.view_kv_cache_for_kernels(v_cache) + + return paged_attn_decode_unified_triton( + q, + k, + v, + k_cache_fp8, + v_cache_fp8, + attn_metadata, + softmax_scale=scale, + fp8_cache=True, + ) + + +def dllm_flash_attn_decode( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + scale: float, + attn_metadata: AttnMetaDataBase, +) -> torch.Tensor: + """ + Decode attention wrapper: + - static: Triton paged-attention over (paged) KV cache + current-step KV + - varlen: load_kvcache (Triton gather/dequant) + flash-attn varlen + """ + from diffulex.utils.quantization.context import get_kv_cache_strategy + + kv_strategy = get_kv_cache_strategy() + kv_fmt = getattr(kv_strategy, "kv_cache_format", "bf16") if kv_strategy is not None else "bf16" + + decode_mode = getattr(attn_metadata, "decode_mode", "varlen") + if decode_mode == "static": + # Only unified layout is supported in static paged-attention for now. + if getattr(attn_metadata, "kv_cache_layout", "unified") != "unified": + return _decode_varlen(q, k, v, k_cache, v_cache, scale, attn_metadata) + + if kv_fmt == "bf16": + return _decode_static_unified_triton_bf16(q, k, v, k_cache, v_cache, scale, attn_metadata) + if kv_fmt == "fp8": + return _decode_static_unified_triton_fp8_cache(q, k, v, k_cache, v_cache, scale, attn_metadata) + raise ValueError(f"Unsupported kv_cache_format={kv_fmt!r} for static decode") + + if decode_mode == "varlen": + return _decode_varlen(q, k, v, k_cache, v_cache, scale, attn_metadata) + + raise ValueError(f"Unsupported decode mode: {decode_mode!r}") + + +__all__ = [ + "dllm_flash_attn_prefill", + "dllm_flash_attn_decode", +] diff --git a/diffulex_kernel/python/dllm_flash_attn_prefill_tilelang.py b/diffulex_kernel/python/dllm_flash_attn_prefill_tilelang.py new file mode 100644 index 0000000..17dfaf9 --- /dev/null +++ b/diffulex_kernel/python/dllm_flash_attn_prefill_tilelang.py @@ -0,0 +1,250 @@ +import torch +import tilelang +import tilelang.language as T + +from flash_attn import flash_attn_varlen_func +from tilelang.autotuner import set_autotune_inputs + +from diffulex_kernel.python.auto_tuner import build_configs +from diffulex.attention.metadata import AttnMetaDataBase, is_warming_up + + +kernel_config = None + + +@tilelang.autotune(configs=build_configs()) +@tilelang.jit( + # NOTE: Disable TMA and warp specialized for now to avoid compile error on Hopper + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def dllm_flash_attn_prefill_kernel( + NUM_SEQS: int, + NUM_GROUPS: int, + Q_LEN: int, + KV_LEN: int, + NUM_HEADS: int, + HEAD_DIM: int, + IS_BLOCK_ATTN: bool, + DIFFUSION_BLOCK_SIZE: int, + BLOCK_M: int = 64, + BLOCK_N: int = 64, + NUM_STAGES: int = 1, + NUM_THREADS: int = 128, +): + SCALE = (1.0 / HEAD_DIM) ** 0.5 * 1.44269504 # log2(e) + NUM_KV_HEADS = NUM_HEADS // NUM_GROUPS + Q_SHAPE = [Q_LEN, NUM_HEADS, HEAD_DIM] + KV_SHAPE = [KV_LEN, NUM_KV_HEADS, HEAD_DIM] + O_SHAPE = [Q_LEN, NUM_HEADS, HEAD_DIM] + DTYPE = "bfloat16" + ACCUM_DTYPE = "float" + + @T.prim_func + def kernel( + Q: T.Tensor(Q_SHAPE, DTYPE), + K: T.Tensor(KV_SHAPE, DTYPE), + V: T.Tensor(KV_SHAPE, DTYPE), + cu_seqlens_q: T.Tensor(NUM_SEQS + 1, "int32"), + cu_seqlens_k: T.Tensor(NUM_SEQS + 1, "int32"), + max_seqlen_q: T.int32, + O: T.Tensor(O_SHAPE, DTYPE), + ): + with T.Kernel(T.ceildiv(max_seqlen_q, BLOCK_M), NUM_HEADS, NUM_SEQS, threads=NUM_THREADS) as (bx, by, bz): + Q_shared = T.alloc_shared([BLOCK_M, HEAD_DIM], DTYPE) + K_shared = T.alloc_shared([BLOCK_N, HEAD_DIM], DTYPE) + V_shared = T.alloc_shared([BLOCK_N, HEAD_DIM], DTYPE) + O_shared = T.alloc_shared([BLOCK_M, HEAD_DIM], DTYPE) + + acc_score = T.alloc_fragment([BLOCK_M, BLOCK_N], ACCUM_DTYPE) + acc_score_cast = T.alloc_fragment([BLOCK_M, BLOCK_N], DTYPE) + acc_output = T.alloc_fragment([BLOCK_M, HEAD_DIM], ACCUM_DTYPE) + scores_max = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) + scores_max_prev = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) + scores_scale = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) + scores_sum = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) + log_sum = T.alloc_fragment([BLOCK_M], ACCUM_DTYPE) + + T.annotate_layout( + { + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + } + ) + + q_block_idx = bx + seq_idx = bz + head_idx = by + kv_head_idx = head_idx // NUM_GROUPS + + q_start_idx = cu_seqlens_q[seq_idx] + kv_start_idx = cu_seqlens_k[seq_idx] + q_end_idx = cu_seqlens_q[seq_idx + 1] + kv_end_idx = cu_seqlens_k[seq_idx + 1] + + cur_q_seqlen = q_end_idx - q_start_idx + cur_kv_seqlen = kv_end_idx - kv_start_idx + + T.copy( + Q[q_start_idx + q_block_idx * BLOCK_M : q_start_idx + (q_block_idx + 1) * BLOCK_M, head_idx, :], + Q_shared, + ) + + T.fill(acc_output, 0) + T.fill(acc_score, 0) + T.fill(log_sum, 0) + T.fill(scores_max, -T.infinity(ACCUM_DTYPE)) + + loop_range = ( + T.min( + T.ceildiv(cur_q_seqlen + (q_block_idx + 1) * BLOCK_M, BLOCK_N), + T.ceildiv(cur_kv_seqlen, BLOCK_N), + ) + if IS_BLOCK_ATTN + else T.ceildiv(cur_kv_seqlen, BLOCK_N) + ) + for kv_block_idx in T.Pipelined(loop_range, num_stages=NUM_STAGES): + T.copy( + K[kv_start_idx + kv_block_idx * BLOCK_N : kv_start_idx + (kv_block_idx + 1) * BLOCK_N, kv_head_idx, :], + K_shared, + ) + + if IS_BLOCK_ATTN: + for i, j in T.Parallel(BLOCK_M, BLOCK_N): + num_diffusion_blocks = (q_block_idx * BLOCK_M + i) // DIFFUSION_BLOCK_SIZE + 1 + acc_score[i, j] = T.if_then_else( + (num_diffusion_blocks * DIFFUSION_BLOCK_SIZE <= kv_block_idx * BLOCK_N + j) + or (q_block_idx * BLOCK_M + i >= cur_q_seqlen or kv_block_idx * BLOCK_N + j >= cur_kv_seqlen), + -1e9, + 0, + ) + else: + for i, j in T.Parallel(BLOCK_M, BLOCK_N): + acc_score[i, j] = T.if_then_else( + (q_block_idx * BLOCK_M + i >= cur_q_seqlen or kv_block_idx * BLOCK_N + j >= cur_kv_seqlen), + -1e9, + 0, + ) + + T.gemm(Q_shared, K_shared, acc_score, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(ACCUM_DTYPE)) + T.reduce_max(acc_score, scores_max, dim=1, clear=False) + for i in T.Parallel(BLOCK_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + for i in T.parallel(BLOCK_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * SCALE - scores_max[i] * SCALE) + + for i, j in T.Parallel(BLOCK_M, BLOCK_N): + acc_score[i, j] = T.exp2(acc_score[i, j] * SCALE - scores_max[i] * SCALE) + + T.reduce_sum(acc_score, scores_sum, dim=1) + for i in T.Parallel(BLOCK_M): + log_sum[i] = log_sum[i] * scores_scale[i] + scores_sum[i] + + T.copy(acc_score, acc_score_cast) + for i, j in T.Parallel(BLOCK_M, HEAD_DIM): + acc_output[i, j] *= scores_scale[i] + + T.copy( + V[kv_start_idx + kv_block_idx * BLOCK_N : kv_start_idx + (kv_block_idx + 1) * BLOCK_N, kv_head_idx, :], + V_shared, + ) + T.gemm(acc_score_cast, V_shared, acc_output, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(BLOCK_M, HEAD_DIM): + acc_output[i, j] /= log_sum[i] + + T.copy(acc_output, O_shared) + for i, d_idx in T.Parallel(BLOCK_M, HEAD_DIM): + if i + q_block_idx * BLOCK_M < cur_q_seqlen: + O[i + q_start_idx + q_block_idx * BLOCK_M, head_idx, d_idx] = O_shared[i, d_idx] + + return kernel + + +def dllm_flash_attn_prefill_tilelang( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + attn_metadata: AttnMetaDataBase, +) -> torch.Tensor: + """ + TileLang-based prefill implementation (existing behavior). + Kept in a separate module so importing decode kernels doesn't require TileLang. + """ + global kernel_config + if attn_metadata.attn_type == "full_attention": + return flash_attn_varlen_func( + q, + k, + v, + attn_metadata.cu_seqlens_q, + attn_metadata.cu_seqlens_k, + attn_metadata.max_seqlen_q, + attn_metadata.max_seqlen_k, + softmax_scale=scale, + block_table=None, + ) + if attn_metadata.attn_type != "block_attention": + raise ValueError(f"Unsupported attn_type={attn_metadata.attn_type!r} for prefill") + + if is_warming_up(): + with set_autotune_inputs( + [ + q, + k, + v, + attn_metadata.cu_seqlens_q, + attn_metadata.cu_seqlens_k, + attn_metadata.max_seqlen_q, + ] + ): + prefill_kernel = dllm_flash_attn_prefill_kernel( + attn_metadata.num_seqs, + q.shape[1] // k.shape[1], + q.shape[0], + k.shape[0], + q.shape[1], + q.shape[2], + attn_metadata.attn_type == "block_attention", + attn_metadata.diffusion_block_size, + ) + kernel_config = prefill_kernel.config + return prefill_kernel( + q, + k, + v, + attn_metadata.cu_seqlens_q, + attn_metadata.cu_seqlens_k, + attn_metadata.max_seqlen_q, + ) + + config_kwargs = kernel_config if kernel_config is not None else {} + prefill_kernel = dllm_flash_attn_prefill_kernel( + attn_metadata.num_seqs, + q.shape[1] // k.shape[1], + q.shape[0], + k.shape[0], + q.shape[1], + q.shape[2], + attn_metadata.attn_type == "block_attention", + attn_metadata.diffusion_block_size, + **config_kwargs, + ) + return prefill_kernel( + q, + k, + v, + attn_metadata.cu_seqlens_q, + attn_metadata.cu_seqlens_k, + attn_metadata.max_seqlen_q, + ) + diff --git a/diffulex_kernel/python/kv_cache_kernels.py b/diffulex_kernel/python/kv_cache_kernels.py index b235f83..8010042 100755 --- a/diffulex_kernel/python/kv_cache_kernels.py +++ b/diffulex_kernel/python/kv_cache_kernels.py @@ -4,12 +4,13 @@ import triton.language as tl from typing import Tuple +import os from diffulex.attention.metadata import AttnMetaDataBase @triton.jit -def dllm_store_kvcache_kernel_unified( +def dllm_store_kvcache_kernel_unified_bf16( key_ptr, key_stride, value_ptr, @@ -19,6 +20,7 @@ def dllm_store_kvcache_kernel_unified( slot_mapping_ptr, D: tl.constexpr ): + """BF16 unified layout store kernel - no quantization, direct storage.""" token_idx = tl.program_id(0) slot = tl.load(slot_mapping_ptr + token_idx) if slot < 0: @@ -33,7 +35,7 @@ def dllm_store_kvcache_kernel_unified( @triton.jit -def dllm_store_kvcache_kernel_distinct( +def dllm_store_kvcache_kernel_distinct_bf16( k_ptr, v_ptr, k_cache_ptr, v_cache_ptr, slot_mapping_ptr, k_stride, v_stride, k_cache_stride_nblks, k_cache_stride_h, k_cache_stride_dx, k_cache_stride_blk_sz, k_cache_stride_x, @@ -41,6 +43,7 @@ def dllm_store_kvcache_kernel_distinct( nheads, hdim, blk_sz, x: tl.constexpr, D: tl.constexpr ): + """BF16 distinct layout store kernel - no quantization, direct storage.""" # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: D2F @@ -83,10 +86,97 @@ def dllm_store_kvcache_kernel_distinct( tl.store(v_cache_ptr + v_cache_offs, v) -def store_kvcache_distinct_layout(key: torch.Tensor, value: torch.Tensor, +@triton.jit +def dllm_store_kvcache_kernel_distinct_fp8( + k_quantized_ptr, v_quantized_ptr, k_cache_ptr, v_cache_ptr, slot_mapping_ptr, + k_quantized_stride, v_quantized_stride, + k_cache_stride_nblks, k_cache_stride_h, k_cache_stride_dx, k_cache_stride_blk_sz, k_cache_stride_x, + v_cache_stride_nblks, v_cache_stride_h, v_cache_stride_d, v_cache_stride_blk_sz, + nheads, hdim, blk_sz, + x: tl.constexpr, D: tl.constexpr +): + """FP8 distinct layout store kernel - stores already quantized uint8 key/value to cache.""" + # SPDX-License-Identifier: Apache-2.0 + # SPDX-FileCopyrightText: D2F + + # Organization: SJTU DENG Lab + # Author: Drew Jin (JIN. Yijie, @drewjin) + # Date: 2025-12-29 + # Email: drewjin0827@gmail.com + # All rights reserved. + + token_idx = tl.program_id(0) + slot_idx = tl.load(slot_mapping_ptr + token_idx) + if slot_idx < 0: + return + + blk_idx = slot_idx // blk_sz + off_blk = slot_idx % blk_sz + + offs_d = tl.arange(0, D) + offs_k = token_idx * k_quantized_stride + offs_d + offs_v = token_idx * v_quantized_stride + offs_d + k_uint8 = tl.load(k_quantized_ptr + offs_k) + v_uint8 = tl.load(v_quantized_ptr + offs_v) + + h_ids = offs_d // hdim + h_offs = offs_d % hdim + x_ids = h_offs // x + x_offs = h_offs % x + + k_cache_offs = (blk_idx * k_cache_stride_nblks + h_ids * k_cache_stride_h + + x_ids * k_cache_stride_dx + off_blk * k_cache_stride_blk_sz + + x_offs * k_cache_stride_x) + v_cache_offs = (blk_idx * v_cache_stride_nblks + h_ids * v_cache_stride_h + + h_offs * v_cache_stride_d + off_blk * v_cache_stride_blk_sz) + + tl.store(k_cache_ptr + k_cache_offs, k_uint8) + tl.store(v_cache_ptr + v_cache_offs, v_uint8) + + +def _store_kvcache_distinct_fp8(key: torch.Tensor, value: torch.Tensor, + k_cache: torch.Tensor, v_cache: torch.Tensor, + slot_mapping: torch.Tensor, + k_scale: torch.Tensor, v_scale: torch.Tensor, + *, strategy) -> None: + """Helper function for FP8 distinct layout store. + + Quantizes BF16 key/value to FP8 (uint8 storage) using strategy, then stores to cache. + """ + # k_cache: [num_blks, h, hdim // x, blk_sz, x] + # v_cache: [num_blks, h, hdim, blk_sz] + NBlks, NHeads, HDim_x, Blk_sz, x = k_cache.shape + HDim = HDim_x * x + N, num_kv_heads, head_dim = key.shape + D = num_kv_heads * head_dim + + assert HDim == head_dim and NHeads == num_kv_heads + assert N == slot_mapping.numel() + + # Vectorized quantization: [N, H, D] -> [N, H, D] uint8, then flatten to [N, H*D] + k_q, v_q = strategy.quantize_kv_for_store(key, value, k_scale=k_scale, v_scale=v_scale) + key_quantized = k_q.reshape(N, D).contiguous() + value_quantized = v_q.reshape(N, D).contiguous() + + assert key_quantized.dtype == torch.uint8, f"Expected uint8, got {key_quantized.dtype}" + assert value_quantized.dtype == torch.uint8, f"Expected uint8, got {value_quantized.dtype}" + + GRID = (N, ) + dllm_store_kvcache_kernel_distinct_fp8[GRID]( + key_quantized, value_quantized, + k_cache, v_cache, + slot_mapping, + key_quantized.stride(0), value_quantized.stride(0), + *k_cache.stride(), *v_cache.stride(), + NHeads, HDim, Blk_sz, + x, D + ) + + +def _store_kvcache_distinct_bf16(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, - slot_mapping: torch.Tensor, attn_metadata: AttnMetaDataBase) -> None: - # TODO: implement diffusion lm kv cache store + slot_mapping: torch.Tensor) -> None: + """Helper function for BF16 distinct layout store.""" # k_cache: [num_blks, h, hdim // x, blk_sz, x] # v_cache: [num_blks, h, hdim, blk_sz] NBlks, NHeads, HDim_x, Blk_sz, x = k_cache.shape @@ -96,7 +186,7 @@ def store_kvcache_distinct_layout(key: torch.Tensor, value: torch.Tensor, assert N == slot_mapping.numel() GRID = (N, ) - dllm_store_kvcache_kernel_distinct[GRID]( + dllm_store_kvcache_kernel_distinct_bf16[GRID]( key, value, k_cache, v_cache, slot_mapping, @@ -107,9 +197,79 @@ def store_kvcache_distinct_layout(key: torch.Tensor, value: torch.Tensor, ) -def store_kvcache_unified_layout(key: torch.Tensor, value: torch.Tensor, +@triton.jit +def dllm_store_kvcache_kernel_unified_fp8( + key_quantized_ptr, + key_quantized_stride, + value_quantized_ptr, + value_quantized_stride, + k_cache_ptr, + v_cache_ptr, + slot_mapping_ptr, + D: tl.constexpr +): + """FP8 unified layout store kernel - stores already quantized uint8 key/value to cache. + + For unified layout cache shape [num_blocks, block_size, num_kv_heads, head_dim], + we assume stride(1) == D (where D = num_kv_heads * head_dim), so offset is slot * D. + This matches the BF16 kernel's behavior. + """ + token_idx = tl.program_id(0) + slot = tl.load(slot_mapping_ptr + token_idx) + if slot < 0: + return + key_offsets = token_idx * key_quantized_stride + tl.arange(0, D) + value_offsets = token_idx * value_quantized_stride + tl.arange(0, D) + key_uint8 = tl.load(key_quantized_ptr + key_offsets) + value_uint8 = tl.load(value_quantized_ptr + value_offsets) + # For unified layout with stride(1) == D, offset is slot * D + cache_offsets = slot * D + tl.arange(0, D) + tl.store(k_cache_ptr + cache_offsets, key_uint8) + tl.store(v_cache_ptr + cache_offsets, value_uint8) + + +def _store_kvcache_unified_fp8(key: torch.Tensor, value: torch.Tensor, + k_cache: torch.Tensor, v_cache: torch.Tensor, + slot_mapping: torch.Tensor, + k_scale: torch.Tensor, v_scale: torch.Tensor, + *, strategy) -> None: + """Helper function for FP8 unified layout store. + + Quantizes BF16 key/value to FP8 (uint8 storage) using strategy, then stores to cache. + """ + N, num_kv_heads, head_dim = key.shape + D = num_kv_heads * head_dim + + # Vectorized quantization: [N, H, D] -> [N, H, D] uint8, then flatten to [N, H*D] + k_q, v_q = strategy.quantize_kv_for_store(key, value, k_scale=k_scale, v_scale=v_scale) + key_quantized = k_q.reshape(N, D).contiguous() + value_quantized = v_q.reshape(N, D).contiguous() + + assert key_quantized.dtype == torch.uint8, f"Expected uint8, got {key_quantized.dtype}" + assert value_quantized.dtype == torch.uint8, f"Expected uint8, got {value_quantized.dtype}" + assert N == slot_mapping.numel(), f"`N`: {N}, `slot_mapping.numel()`: {slot_mapping.numel()}" + + # For unified layout, cache shape is [num_blocks, block_size, num_kv_heads, head_dim] + # BF16 kernel uses cache directly (no view) and assumes stride(1) == D + # For FP8, we should do the same to match BF16 behavior + assert k_cache.stride(1) == D and v_cache.stride(1) == D, \ + f"Expected stride(1) == D ({D}), got k_cache.stride(1)={k_cache.stride(1)}, v_cache.stride(1)={v_cache.stride(1)}" + + # Use cache directly, matching BF16 kernel behavior + # Kernel uses slot * D as offset, which works with stride(1) == D + # Pass cache directly to kernel, matching BF16 kernel behavior + # The kernel expects cache to have stride(1) == D, which we've already verified + dllm_store_kvcache_kernel_unified_fp8[(N,)]( + key_quantized, key_quantized.stride(0), + value_quantized, value_quantized.stride(0), + k_cache, v_cache, slot_mapping, D + ) + + +def _store_kvcache_unified_bf16(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, - slot_mapping: torch.Tensor, attn_metadata: AttnMetaDataBase) -> None: + slot_mapping: torch.Tensor) -> None: + """Helper function for BF16 unified layout store.""" N, num_heads, head_dim = key.shape D = num_heads * head_dim assert key.stride(-1) == 1 and value.stride(-1) == 1 @@ -117,7 +277,7 @@ def store_kvcache_unified_layout(key: torch.Tensor, value: torch.Tensor, assert k_cache.stride(1) == D and v_cache.stride(1) == D assert N == slot_mapping.numel(), f"`N`: {N}, `slot_mapping.numel()`: {slot_mapping.numel()}" - dllm_store_kvcache_kernel_unified[(N,)]( + dllm_store_kvcache_kernel_unified_bf16[(N,)]( key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D @@ -125,7 +285,7 @@ def store_kvcache_unified_layout(key: torch.Tensor, value: torch.Tensor, @triton.jit -def load_kvcache_kernel(k_cache_ptr, v_cache_ptr, +def load_kvcache_kernel_bf16(k_cache_ptr, v_cache_ptr, k_new_ptr, v_new_ptr, block_table_ptr, k_out_ptr, v_out_ptr, @@ -227,12 +387,416 @@ def load_kvcache_kernel(k_cache_ptr, v_cache_ptr, tl.store(v_out_ptr + offs_cur_kv_new_to_out, v_new) -def load_kvcache(k_cache: torch.Tensor, v_cache: torch.Tensor, +@triton.jit +def load_kvcache_kernel_bf16_distinct( + k_cache_ptr, + v_cache_ptr, + k_new_ptr, + v_new_ptr, + block_table_ptr, + k_out_ptr, + v_out_ptr, + seqlens_ptr, + ctxlens_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + # distinct cache strides + k_cache_stride_nblks, + k_cache_stride_h, + k_cache_stride_dx, + k_cache_stride_blk_sz, + k_cache_stride_x, + v_cache_stride_nblks, + v_cache_stride_h, + v_cache_stride_d, + v_cache_stride_blk_sz, + # new / out / block_table strides + kv_new_stride_s, + kv_new_stride_h, + kv_new_stride_d, + block_table_stride_nseqs, + block_table_stride_maxblks, + kv_out_stride_s, + kv_out_stride_h, + kv_out_stride_d, + ctxlens_stride, + seqlens_stride, + cu_seqlens_q_stride, + cu_seqlens_k_stride, + LAST_BLK_ID: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + DIFFUSION_BLOCK_SIZE: tl.constexpr, + KV_LOAD_UNROLL_FACTOR: tl.constexpr, + X: tl.constexpr, +): + """ + Distinct layout BF16 load kernel. + + Layouts: + - k_cache: [NBlks, Hkv, HEAD_DIM//X, PAGE_SIZE, X] + - v_cache: [NBlks, Hkv, HEAD_DIM, PAGE_SIZE] + """ + seq_idx = tl.program_id(0) + local_blk_idx = tl.program_id(1) + kv_head_idx = tl.program_id(2) + + off_local_blk = seq_idx * block_table_stride_nseqs + local_blk_idx * block_table_stride_maxblks + global_blk_idx = tl.load(block_table_ptr + off_local_blk) + + if global_blk_idx != -1: + off_ctxlen = seq_idx * ctxlens_stride + global_ctxlen = tl.load(ctxlens_ptr + off_ctxlen) + cur_window_sz = (local_blk_idx + 1) * PAGE_SIZE + prev_window_sz = local_blk_idx * PAGE_SIZE + local_ctxlen = tl.where(global_ctxlen > cur_window_sz, PAGE_SIZE, global_ctxlen % PAGE_SIZE) + if global_ctxlen > prev_window_sz: + offs_kv_cache_seq = tl.arange(0, PAGE_SIZE) + offs_kv_cache_hdim = tl.arange(0, HEAD_DIM) + + x_ids = offs_kv_cache_hdim // X + x_offs = offs_kv_cache_hdim % X + + offs_k = ( + global_blk_idx * k_cache_stride_nblks + + kv_head_idx * k_cache_stride_h + + x_ids[:, None] * k_cache_stride_dx + + offs_kv_cache_seq[None, :] * k_cache_stride_blk_sz + + x_offs[:, None] * k_cache_stride_x + ) + offs_v = ( + global_blk_idx * v_cache_stride_nblks + + kv_head_idx * v_cache_stride_h + + offs_kv_cache_hdim[:, None] * v_cache_stride_d + + offs_kv_cache_seq[None, :] * v_cache_stride_blk_sz + ) + + kv_cache_mask = offs_kv_cache_seq[None, :] < local_ctxlen + k_cache = tl.load(k_cache_ptr + offs_k, mask=kv_cache_mask, other=0.0) + v_cache = tl.load(v_cache_ptr + offs_v, mask=kv_cache_mask, other=0.0) + + off_cu_seqlens_k = seq_idx * cu_seqlens_k_stride + kv_out_start_idx = tl.load(cu_seqlens_k_ptr + off_cu_seqlens_k) + cur_kv_cache_to_out_start_idx = kv_out_start_idx + prev_window_sz + offs_kv_cache_to_out = ( + (cur_kv_cache_to_out_start_idx + offs_kv_cache_seq[None, :]) * kv_out_stride_s + + kv_head_idx * kv_out_stride_h + + offs_kv_cache_hdim[:, None] * kv_out_stride_d + ) + tl.store(k_out_ptr + offs_kv_cache_to_out, k_cache, mask=kv_cache_mask) + tl.store(v_out_ptr + offs_kv_cache_to_out, v_cache, mask=kv_cache_mask) + + if local_blk_idx == LAST_BLK_ID: + off_cu_seqlens_q = seq_idx * cu_seqlens_q_stride + off_seqlens = seq_idx * seqlens_stride + kv_new_start_idx = tl.load(cu_seqlens_q_ptr + off_cu_seqlens_q) + active_seqlen = tl.load(seqlens_ptr + off_seqlens) + + offs_kv_new_seq = tl.arange(0, DIFFUSION_BLOCK_SIZE) + offs_kv_new_hdim = tl.arange(0, HEAD_DIM) + + for diff_blk_idx in tl.range(active_seqlen // DIFFUSION_BLOCK_SIZE, loop_unroll_factor=KV_LOAD_UNROLL_FACTOR): + off_diff_blk = diff_blk_idx * DIFFUSION_BLOCK_SIZE + cur_kv_new_start_idx = kv_new_start_idx + off_diff_blk + offs_cur_kv_new_seq = ( + (cur_kv_new_start_idx + offs_kv_new_seq[None, :]) * kv_new_stride_s + + kv_head_idx * kv_new_stride_h + + offs_kv_new_hdim[:, None] * kv_new_stride_d + ) + k_new = tl.load(k_new_ptr + offs_cur_kv_new_seq) + v_new = tl.load(v_new_ptr + offs_cur_kv_new_seq) + + off_ctxlen = seq_idx * ctxlens_stride + off_cu_seqlens_k = seq_idx * cu_seqlens_k_stride + global_ctxlen = tl.load(ctxlens_ptr + off_ctxlen) + kv_out_start_idx = tl.load(cu_seqlens_k_ptr + off_cu_seqlens_k) + cur_kv_new_to_out_start_idx = global_ctxlen + kv_out_start_idx + off_diff_blk + offs_cur_kv_new_to_out = ( + (cur_kv_new_to_out_start_idx + offs_kv_new_seq[None, :]) * kv_out_stride_s + + kv_head_idx * kv_out_stride_h + + offs_kv_new_hdim[:, None] * kv_out_stride_d + ) + tl.store(k_out_ptr + offs_cur_kv_new_to_out, k_new) + tl.store(v_out_ptr + offs_cur_kv_new_to_out, v_new) + + +@triton.jit +def load_kvcache_kernel_fp8_distinct( + k_cache_ptr, + v_cache_ptr, + k_scale_ptr, + v_scale_ptr, + k_new_ptr, + v_new_ptr, + block_table_ptr, + k_out_ptr, + v_out_ptr, + seqlens_ptr, + ctxlens_ptr, + cu_seqlens_q_ptr, + cu_seqlens_k_ptr, + # distinct cache strides + k_cache_stride_nblks, + k_cache_stride_h, + k_cache_stride_dx, + k_cache_stride_blk_sz, + k_cache_stride_x, + v_cache_stride_nblks, + v_cache_stride_h, + v_cache_stride_d, + v_cache_stride_blk_sz, + # new / out / block_table strides + kv_new_stride_s, + kv_new_stride_h, + kv_new_stride_d, + block_table_stride_nseqs, + block_table_stride_maxblks, + kv_out_stride_s, + kv_out_stride_h, + kv_out_stride_d, + ctxlens_stride, + seqlens_stride, + cu_seqlens_q_stride, + cu_seqlens_k_stride, + LAST_BLK_ID: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + DIFFUSION_BLOCK_SIZE: tl.constexpr, + KV_LOAD_UNROLL_FACTOR: tl.constexpr, + X: tl.constexpr, +): + """ + Distinct layout FP8 load kernel: + - Gather paged KV cache blocks from distinct K/V layouts. + - Dequantize FP8 -> BF16 and apply per-head scale inside kernel. + + Layouts: + - k_cache: [NBlks, Hkv, HEAD_DIM//X, PAGE_SIZE, X] (float8 view) + - v_cache: [NBlks, Hkv, HEAD_DIM, PAGE_SIZE] (float8 view) + """ + seq_idx = tl.program_id(0) + local_blk_idx = tl.program_id(1) + kv_head_idx = tl.program_id(2) + + off_local_blk = seq_idx * block_table_stride_nseqs + local_blk_idx * block_table_stride_maxblks + global_blk_idx = tl.load(block_table_ptr + off_local_blk) + + k_scale = tl.load(k_scale_ptr + kv_head_idx).to(tl.float32) + v_scale = tl.load(v_scale_ptr + kv_head_idx).to(tl.float32) + + if global_blk_idx != -1: + off_ctxlen = seq_idx * ctxlens_stride + global_ctxlen = tl.load(ctxlens_ptr + off_ctxlen) + cur_window_sz = (local_blk_idx + 1) * PAGE_SIZE + prev_window_sz = local_blk_idx * PAGE_SIZE + local_ctxlen = tl.where(global_ctxlen > cur_window_sz, PAGE_SIZE, global_ctxlen % PAGE_SIZE) + if global_ctxlen > prev_window_sz: + offs_kv_cache_seq = tl.arange(0, PAGE_SIZE) + offs_kv_cache_hdim = tl.arange(0, HEAD_DIM) + + x_ids = offs_kv_cache_hdim // X + x_offs = offs_kv_cache_hdim % X + + offs_k = ( + global_blk_idx * k_cache_stride_nblks + + kv_head_idx * k_cache_stride_h + + x_ids[:, None] * k_cache_stride_dx + + offs_kv_cache_seq[None, :] * k_cache_stride_blk_sz + + x_offs[:, None] * k_cache_stride_x + ) + offs_v = ( + global_blk_idx * v_cache_stride_nblks + + kv_head_idx * v_cache_stride_h + + offs_kv_cache_hdim[:, None] * v_cache_stride_d + + offs_kv_cache_seq[None, :] * v_cache_stride_blk_sz + ) + + kv_cache_mask = offs_kv_cache_seq[None, :] < local_ctxlen + k_cache = tl.load(k_cache_ptr + offs_k, mask=kv_cache_mask, other=0.0).to(tl.float32) * k_scale + v_cache = tl.load(v_cache_ptr + offs_v, mask=kv_cache_mask, other=0.0).to(tl.float32) * v_scale + k_cache_bf16 = k_cache.to(tl.bfloat16) + v_cache_bf16 = v_cache.to(tl.bfloat16) + + off_cu_seqlens_k = seq_idx * cu_seqlens_k_stride + kv_out_start_idx = tl.load(cu_seqlens_k_ptr + off_cu_seqlens_k) + cur_kv_cache_to_out_start_idx = kv_out_start_idx + prev_window_sz + offs_kv_cache_to_out = ( + (cur_kv_cache_to_out_start_idx + offs_kv_cache_seq[None, :]) * kv_out_stride_s + + kv_head_idx * kv_out_stride_h + + offs_kv_cache_hdim[:, None] * kv_out_stride_d + ) + tl.store(k_out_ptr + offs_kv_cache_to_out, k_cache_bf16, mask=kv_cache_mask) + tl.store(v_out_ptr + offs_kv_cache_to_out, v_cache_bf16, mask=kv_cache_mask) + + if local_blk_idx == LAST_BLK_ID: + off_cu_seqlens_q = seq_idx * cu_seqlens_q_stride + off_seqlens = seq_idx * seqlens_stride + kv_new_start_idx = tl.load(cu_seqlens_q_ptr + off_cu_seqlens_q) + active_seqlen = tl.load(seqlens_ptr + off_seqlens) + + offs_kv_new_seq = tl.arange(0, DIFFUSION_BLOCK_SIZE) + offs_kv_new_hdim = tl.arange(0, HEAD_DIM) + + for diff_blk_idx in tl.range(active_seqlen // DIFFUSION_BLOCK_SIZE, loop_unroll_factor=KV_LOAD_UNROLL_FACTOR): + off_diff_blk = diff_blk_idx * DIFFUSION_BLOCK_SIZE + cur_kv_new_start_idx = kv_new_start_idx + off_diff_blk + offs_cur_kv_new_seq = ( + (cur_kv_new_start_idx + offs_kv_new_seq[None, :]) * kv_new_stride_s + + kv_head_idx * kv_new_stride_h + + offs_kv_new_hdim[:, None] * kv_new_stride_d + ) + k_new = tl.load(k_new_ptr + offs_cur_kv_new_seq) + v_new = tl.load(v_new_ptr + offs_cur_kv_new_seq) + + off_ctxlen = seq_idx * ctxlens_stride + off_cu_seqlens_k = seq_idx * cu_seqlens_k_stride + global_ctxlen = tl.load(ctxlens_ptr + off_ctxlen) + kv_out_start_idx = tl.load(cu_seqlens_k_ptr + off_cu_seqlens_k) + cur_kv_new_to_out_start_idx = global_ctxlen + kv_out_start_idx + off_diff_blk + offs_cur_kv_new_to_out = ( + (cur_kv_new_to_out_start_idx + offs_kv_new_seq[None, :]) * kv_out_stride_s + + kv_head_idx * kv_out_stride_h + + offs_kv_new_hdim[:, None] * kv_out_stride_d + ) + tl.store(k_out_ptr + offs_cur_kv_new_to_out, k_new) + tl.store(v_out_ptr + offs_cur_kv_new_to_out, v_new) + +@triton.jit +def load_kvcache_kernel_fp8_unified( + k_cache_ptr, v_cache_ptr, + k_scale_ptr, v_scale_ptr, + k_new_ptr, v_new_ptr, + block_table_ptr, + k_out_ptr, v_out_ptr, + seqlens_ptr, ctxlens_ptr, + cu_seqlens_q_ptr, cu_seqlens_k_ptr, + kv_cache_stride_nblks, kv_cache_stride_blk, kv_cache_stride_h, kv_cache_stride_d, + kv_new_stride_s, kv_new_stride_h, kv_new_stride_d, + block_table_stride_nseqs, block_table_stride_maxblks, + kv_out_stride_s, kv_out_stride_h, kv_out_stride_d, + ctxlens_stride, seqlens_stride, + cu_seqlens_q_stride, cu_seqlens_k_stride, + LAST_BLK_ID: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + DIFFUSION_BLOCK_SIZE: tl.constexpr, + KV_LOAD_UNROLL_FACTOR: tl.constexpr, +): + """ + Unified layout FP8 load kernel: + - Gather paged KV cache blocks using block_tables/context_lens (same as BF16 kernel) + - Dequantize FP8 -> BF16 and apply per-head scale inside kernel + - Also appends active KV (k_new/v_new) once at LAST_BLK_ID + """ + seq_idx = tl.program_id(0) + local_blk_idx = tl.program_id(1) + kv_head_idx = tl.program_id(2) + + off_local_blk = seq_idx * block_table_stride_nseqs + local_blk_idx * block_table_stride_maxblks + global_blk_idx = tl.load(block_table_ptr + off_local_blk) + + # Per-head scales (float32) + k_scale = tl.load(k_scale_ptr + kv_head_idx).to(tl.float32) + v_scale = tl.load(v_scale_ptr + kv_head_idx).to(tl.float32) + + if global_blk_idx != -1: + off_ctxlen = seq_idx * ctxlens_stride + global_ctxlen = tl.load(ctxlens_ptr + off_ctxlen) + cur_window_sz = (local_blk_idx + 1) * PAGE_SIZE + prev_window_sz = local_blk_idx * PAGE_SIZE + local_ctxlen = tl.where(global_ctxlen > cur_window_sz, PAGE_SIZE, global_ctxlen % PAGE_SIZE) + if global_ctxlen > prev_window_sz: + offs_kv_cache_seq = tl.arange(0, PAGE_SIZE) + offs_kv_cache_hdim = tl.arange(0, HEAD_DIM) + offs_kv_cache = ( + global_blk_idx[None, :] * kv_cache_stride_nblks + + offs_kv_cache_seq[None, :] * kv_cache_stride_blk + + kv_head_idx * kv_cache_stride_h + + offs_kv_cache_hdim[:, None] * kv_cache_stride_d + ) + kv_cache_mask = offs_kv_cache_seq[None, :] < local_ctxlen + + # Load FP8 -> fp32, apply scale, store BF16 + k_cache = tl.load(k_cache_ptr + offs_kv_cache, mask=kv_cache_mask, other=0.0).to(tl.float32) * k_scale + v_cache = tl.load(v_cache_ptr + offs_kv_cache, mask=kv_cache_mask, other=0.0).to(tl.float32) * v_scale + k_cache_bf16 = k_cache.to(tl.bfloat16) + v_cache_bf16 = v_cache.to(tl.bfloat16) + + off_cu_seqlens_k = seq_idx * cu_seqlens_k_stride + kv_out_start_idx = tl.load(cu_seqlens_k_ptr + off_cu_seqlens_k) + cur_kv_cache_to_out_start_idx = kv_out_start_idx + prev_window_sz + offs_kv_cache_to_out = ( + (cur_kv_cache_to_out_start_idx + offs_kv_cache_seq[None, :]) * kv_out_stride_s + + kv_head_idx * kv_out_stride_h + + offs_kv_cache_hdim[:, None] * kv_out_stride_d + ) + tl.store(k_out_ptr + offs_kv_cache_to_out, k_cache_bf16, mask=kv_cache_mask) + tl.store(v_out_ptr + offs_kv_cache_to_out, v_cache_bf16, mask=kv_cache_mask) + + # Load and store active KV only once when first meet + if local_blk_idx == LAST_BLK_ID: + off_cu_seqlens_q = seq_idx * cu_seqlens_q_stride + off_seqlens = seq_idx * seqlens_stride + kv_new_start_idx = tl.load(cu_seqlens_q_ptr + off_cu_seqlens_q) + active_seqlen = tl.load(seqlens_ptr + off_seqlens) + + offs_kv_new_seq = tl.arange(0, DIFFUSION_BLOCK_SIZE) + offs_kv_new_hdim = tl.arange(0, HEAD_DIM) + + for diff_blk_idx in tl.range(active_seqlen // DIFFUSION_BLOCK_SIZE, loop_unroll_factor=KV_LOAD_UNROLL_FACTOR): + off_diff_blk = diff_blk_idx * DIFFUSION_BLOCK_SIZE + cur_kv_new_start_idx = kv_new_start_idx + off_diff_blk + offs_cur_kv_new_seq = ( + (cur_kv_new_start_idx + offs_kv_new_seq[None, :]) * kv_new_stride_s + + kv_head_idx * kv_new_stride_h + + offs_kv_new_hdim[:, None] * kv_new_stride_d + ) + k_new = tl.load(k_new_ptr + offs_cur_kv_new_seq) + v_new = tl.load(v_new_ptr + offs_cur_kv_new_seq) + + off_ctxlen = seq_idx * ctxlens_stride + off_cu_seqlens_k = seq_idx * cu_seqlens_k_stride + global_ctxlen = tl.load(ctxlens_ptr + off_ctxlen) + kv_out_start_idx = tl.load(cu_seqlens_k_ptr + off_cu_seqlens_k) + cur_kv_new_to_out_start_idx = global_ctxlen + kv_out_start_idx + off_diff_blk + offs_cur_kv_new_to_out = ( + (cur_kv_new_to_out_start_idx + offs_kv_new_seq[None, :]) * kv_out_stride_s + + kv_head_idx * kv_out_stride_h + + offs_kv_new_hdim[:, None] * kv_out_stride_d + ) + tl.store(k_out_ptr + offs_cur_kv_new_to_out, k_new) + tl.store(v_out_ptr + offs_cur_kv_new_to_out, v_new) + + +def _load_kvcache_bf16(k_cache: torch.Tensor, v_cache: torch.Tensor, attn_metadata: AttnMetaDataBase, k_new: torch.Tensor, v_new: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - assert k_cache.shape == v_cache.shape + """Helper function for BF16 load. + + Supports both unified and distinct layouts: + - Unified: k_cache.shape == v_cache.shape == [num_blocks, page_size, num_kv_heads, head_dim] + - Distinct: k_cache.shape = [num_blks, h, hdim // x, blk_sz, x], v_cache.shape = [num_blks, h, hdim, blk_sz] + """ assert k_new.shape == v_new.shape - N_BLOCKS, PAGE_SIZE, H_KV, HEAD_DIM = k_cache.shape + + # Determine layout from cache shape + is_unified = k_cache.shape == v_cache.shape and len(k_cache.shape) == 4 + + if is_unified: + # Unified layout: [num_blocks, page_size, num_kv_heads, head_dim] + N_BLOCKS, PAGE_SIZE, H_KV, HEAD_DIM = k_cache.shape + else: + # Distinct layout: k_cache [num_blks, h, hdim // x, blk_sz, x], v_cache [num_blks, h, hdim, blk_sz] + # For load kernel, we need PAGE_SIZE and HEAD_DIM + # PAGE_SIZE is typically the block size (blk_sz) + # HEAD_DIM is the head dimension + N_BLOCKS = k_cache.shape[0] + H_KV = k_cache.shape[1] + PAGE_SIZE = k_cache.shape[3] # blk_sz + # For distinct layout, HEAD_DIM is the total head dimension + # k_cache: [num_blks, h, hdim // x, blk_sz, x] -> HEAD_DIM = (hdim // x) * x + # v_cache: [num_blks, h, hdim, blk_sz] -> HEAD_DIM = hdim + HEAD_DIM = v_cache.shape[2] # hdim NUM_SEQS, MAX_SEQ_BLOCKS = attn_metadata.block_tables.shape ctxlens = attn_metadata.context_lens @@ -254,26 +818,319 @@ def load_kvcache(k_cache: torch.Tensor, v_cache: torch.Tensor, v_output = torch.empty_like(k_output) GRID = (NUM_SEQS, MAX_SEQ_BLOCKS, H_KV) - load_kvcache_kernel[GRID]( - k_cache, v_cache, - k_new, v_new, - attn_metadata.block_tables, - k_output, v_output, - seqlens, ctxlens, - cu_seqlens_q, cu_seqlens_k, - *k_cache.stride(), - *k_new.stride(), - *attn_metadata.block_tables.stride(), - *k_output.stride(), - ctxlens.stride(0), - seqlens.stride(0), - cu_seqlens_q.stride(0), - cu_seqlens_k.stride(0), - LAST_BLK_ID=attn_metadata.block_tables.shape[-1] - 1, - HEAD_DIM=HEAD_DIM, - PAGE_SIZE=PAGE_SIZE, - DIFFUSION_BLOCK_SIZE=DIFFUSION_BLOCK_SIZE, - KV_LOAD_UNROLL_FACTOR=2 - ) + + if is_unified: + # Unified cache: [NBlks, BlkSz, Hkv, Hdim] + kv_cache_stride_nblks, kv_cache_stride_blk, kv_cache_stride_h, kv_cache_stride_d = k_cache.stride() + load_kvcache_kernel_bf16[GRID]( + k_cache, v_cache, + k_new, v_new, + attn_metadata.block_tables, + k_output, v_output, + seqlens, ctxlens, + cu_seqlens_q, cu_seqlens_k, + kv_cache_stride_nblks, kv_cache_stride_blk, kv_cache_stride_h, kv_cache_stride_d, + *k_new.stride(), + *attn_metadata.block_tables.stride(), + *k_output.stride(), + ctxlens.stride(0), + seqlens.stride(0), + cu_seqlens_q.stride(0), + cu_seqlens_k.stride(0), + LAST_BLK_ID=attn_metadata.block_tables.shape[-1] - 1, + HEAD_DIM=HEAD_DIM, + PAGE_SIZE=PAGE_SIZE, + DIFFUSION_BLOCK_SIZE=DIFFUSION_BLOCK_SIZE, + KV_LOAD_UNROLL_FACTOR=2, + ) + else: + # Distinct cache needs a dedicated gather kernel due to K split layout. + x = int(k_cache.shape[-1]) + load_kvcache_kernel_bf16_distinct[GRID]( + k_cache, v_cache, + k_new, v_new, + attn_metadata.block_tables, + k_output, v_output, + seqlens, ctxlens, + cu_seqlens_q, cu_seqlens_k, + *k_cache.stride(), + *v_cache.stride(), + *k_new.stride(), + *attn_metadata.block_tables.stride(), + *k_output.stride(), + ctxlens.stride(0), + seqlens.stride(0), + cu_seqlens_q.stride(0), + cu_seqlens_k.stride(0), + LAST_BLK_ID=attn_metadata.block_tables.shape[-1] - 1, + HEAD_DIM=HEAD_DIM, + PAGE_SIZE=PAGE_SIZE, + DIFFUSION_BLOCK_SIZE=DIFFUSION_BLOCK_SIZE, + KV_LOAD_UNROLL_FACTOR=2, + X=x, + ) + + return k_output, v_output + + +def store_kvcache_unified_layout(key: torch.Tensor, value: torch.Tensor, + k_cache: torch.Tensor, v_cache: torch.Tensor, + slot_mapping: torch.Tensor, attn_metadata: AttnMetaDataBase) -> None: + """ + Store KV cache (unified layout). + Dynamically selects the appropriate kernel based on quantization strategy from context. + """ + # `slot_mapping` is expected to have one entry per token in `key/value` (dimension 0). + # In some flows (e.g. prefix-cache / partial-prefill), metadata may carry a longer + # mapping for the full sequence while `key/value` only contain the suffix tokens + # actually computed this step. In that case, align by taking the tail. + N = int(key.shape[0]) + if int(slot_mapping.numel()) != N: + if int(slot_mapping.numel()) > N: + slot_mapping = slot_mapping[-N:] + else: + raise AssertionError( + f"slot_mapping is shorter than key/value tokens: " + f"N={N}, slot_mapping.numel()={int(slot_mapping.numel())}" + ) + + from diffulex.utils.quantization.context import get_kv_cache_strategy + strategy = get_kv_cache_strategy() + if strategy is None: + _store_kvcache_unified_bf16(key, value, k_cache, v_cache, slot_mapping) + return + + fmt = getattr(strategy, "kv_cache_format", "bf16") + if fmt == "bf16": + _store_kvcache_unified_bf16(key, value, k_cache, v_cache, slot_mapping) + return + if fmt == "fp8": + if attn_metadata.k_scale is None or attn_metadata.v_scale is None: + raise ValueError("FP8 quantization requires k_scale and v_scale in metadata") + _store_kvcache_unified_fp8( + key, value, k_cache, v_cache, slot_mapping, + attn_metadata.k_scale, attn_metadata.v_scale, + strategy=strategy, + ) + return + raise ValueError(f"Unsupported kv_cache_format={fmt!r} for unified layout (strategy={type(strategy)})") + + +def store_kvcache_distinct_layout(key: torch.Tensor, value: torch.Tensor, + k_cache: torch.Tensor, v_cache: torch.Tensor, + slot_mapping: torch.Tensor, attn_metadata: AttnMetaDataBase) -> None: + """ + Store KV cache (distinct layout). + Dynamically selects the appropriate kernel based on quantization strategy from context. + """ + from diffulex.utils.quantization.context import get_kv_cache_strategy + strategy = get_kv_cache_strategy() + if strategy is None: + _store_kvcache_distinct_bf16(key, value, k_cache, v_cache, slot_mapping) + return + + fmt = getattr(strategy, "kv_cache_format", "bf16") + if fmt == "bf16": + _store_kvcache_distinct_bf16(key, value, k_cache, v_cache, slot_mapping) + return + if fmt == "fp8": + if attn_metadata.k_scale is None or attn_metadata.v_scale is None: + raise ValueError("FP8 quantization requires k_scale and v_scale in metadata") + _store_kvcache_distinct_fp8( + key, value, k_cache, v_cache, slot_mapping, + attn_metadata.k_scale, attn_metadata.v_scale, + strategy=strategy, + ) + return + raise ValueError(f"Unsupported kv_cache_format={fmt!r} for distinct layout (strategy={type(strategy)})") + + +def _load_kvcache_fp8(k_cache: torch.Tensor, v_cache: torch.Tensor, + attn_metadata: AttnMetaDataBase, + k_new: torch.Tensor, v_new: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Helper function for FP8 load. + + Unified layout uses a Triton fused kernel to gather+dequantize+apply-scale on-the-fly. + Distinct layout also uses a fused kernel (no Python full-cache dequant fallback). + + Supports both unified and distinct layouts: + - Unified: [num_blocks, page_size, num_kv_heads, head_dim] + - Distinct: k_cache [num_blks, h, hdim // x, blk_sz, x], v_cache [num_blks, h, hdim, blk_sz] + """ + from diffulex.utils.quantization.context import get_kv_cache_strategy + strategy = get_kv_cache_strategy() + if strategy is None or getattr(strategy, "kv_cache_format", "bf16") != "fp8": + raise ValueError(f"Expected kv_cache_format='fp8', got strategy={type(strategy)}") + + # Get scales from metadata + if attn_metadata.k_scale is None or attn_metadata.v_scale is None: + raise ValueError("FP8 dequantization requires k_scale and v_scale in metadata") + + k_scale = attn_metadata.k_scale # [num_kv_heads] + v_scale = attn_metadata.v_scale # [num_kv_heads] + + # Determine layout from cache shape + # Unified: k_cache.shape == v_cache.shape == [num_blocks, page_size, num_kv_heads, head_dim] + # Distinct: k_cache.shape = [num_blks, h, hdim // x, blk_sz, x], v_cache.shape = [num_blks, h, hdim, blk_sz] + is_unified = k_cache.shape == v_cache.shape and len(k_cache.shape) == 4 + + if is_unified: + # Unified layout: [num_blocks, page_size, num_kv_heads, head_dim] + N_BLOCKS, PAGE_SIZE, H_KV, HEAD_DIM = k_cache.shape + + # Ensure Triton sees float8 element type (storage is uint8 view) + k_cache_fp8 = strategy.view_kv_cache_for_kernels(k_cache) + v_cache_fp8 = strategy.view_kv_cache_for_kernels(v_cache) + + NUM_SEQS, MAX_SEQ_BLOCKS = attn_metadata.block_tables.shape + ctxlens = attn_metadata.context_lens + seqlens = attn_metadata.seq_lens_ts + assert sum(seqlens) == k_new.shape[0] + DIFFUSION_BLOCK_SIZE = attn_metadata.seqs[0].diffusion_block_size + MAX_DIFFUSION_BLOCK_SIZE = max(seqlens) + assert MAX_DIFFUSION_BLOCK_SIZE % DIFFUSION_BLOCK_SIZE == 0 + + total_lens = ctxlens + seqlens + cu_seqlens_q = attn_metadata.cu_seqlens_q + cu_seqlens_k = attn_metadata.cu_seqlens_k + assert sum(total_lens) == cu_seqlens_k[-1] + assert cu_seqlens_q.shape == cu_seqlens_k.shape + assert cu_seqlens_q.shape[0] == NUM_SEQS + 1 + + kv_output_shape = (sum(total_lens).item(), H_KV, HEAD_DIM) + k_output = torch.empty(kv_output_shape, device=k_cache.device, dtype=torch.bfloat16) + v_output = torch.empty_like(k_output) + + # Strides for unified cache: [stride(0), stride(1), stride(2), stride(3)] + kv_cache_stride_nblks, kv_cache_stride_blk, kv_cache_stride_h, kv_cache_stride_d = k_cache_fp8.stride() + + GRID = (NUM_SEQS, MAX_SEQ_BLOCKS, H_KV) + load_kvcache_kernel_fp8_unified[GRID]( + k_cache_fp8, v_cache_fp8, + k_scale, v_scale, + k_new, v_new, + attn_metadata.block_tables, + k_output, v_output, + seqlens, ctxlens, + cu_seqlens_q, cu_seqlens_k, + kv_cache_stride_nblks, kv_cache_stride_blk, kv_cache_stride_h, kv_cache_stride_d, + *k_new.stride(), + *attn_metadata.block_tables.stride(), + *k_output.stride(), + ctxlens.stride(0), + seqlens.stride(0), + cu_seqlens_q.stride(0), + cu_seqlens_k.stride(0), + LAST_BLK_ID=attn_metadata.block_tables.shape[-1] - 1, + HEAD_DIM=HEAD_DIM, + PAGE_SIZE=PAGE_SIZE, + DIFFUSION_BLOCK_SIZE=DIFFUSION_BLOCK_SIZE, + KV_LOAD_UNROLL_FACTOR=2, + ) + + # Optional correctness check: compare with the old Python dequant+BF16-gather reference + if os.getenv("DIFFULEX_DEBUG_FP8_LOAD_REF", "0") == "1": + # Avoid huge overhead accidentally + try: + total_tokens = int(sum(total_lens).item()) + except Exception: + total_tokens = -1 + if 0 <= total_tokens <= 4096: + # Reference dequantization (slow): full cache dequant in Python + k_cache_fp32 = k_cache_fp8.float() + v_cache_fp32 = v_cache_fp8.float() + k_scale_broadcast = k_scale.view(1, 1, -1, 1) + v_scale_broadcast = v_scale.view(1, 1, -1, 1) + k_cache_bf16_ref = (k_cache_fp32 * k_scale_broadcast).to(torch.bfloat16) + v_cache_bf16_ref = (v_cache_fp32 * v_scale_broadcast).to(torch.bfloat16) + k_ref, v_ref = _load_kvcache_bf16(k_cache_bf16_ref, v_cache_bf16_ref, attn_metadata, k_new, v_new) + max_diff_k = (k_ref - k_output).abs().max().item() + max_diff_v = (v_ref - v_output).abs().max().item() + print(f"[DIFFULEX_DEBUG_FP8_LOAD_REF] max_abs_diff k={max_diff_k:.6g} v={max_diff_v:.6g} (total_tokens={total_tokens})") + # Be strict: any mismatch likely indicates indexing/mask/scale bug. + if max_diff_k > 0 or max_diff_v > 0: + raise RuntimeError( + f"FP8 fused load mismatch: max_abs_diff k={max_diff_k} v={max_diff_v}. " + "Set DIFFULEX_DEBUG_FP8_LOAD_REF=0 to disable." + ) + + return k_output, v_output + else: + # Distinct layout: fused gather + dequant + scale in kernel. + k_cache_fp8 = strategy.view_kv_cache_for_kernels(k_cache) + v_cache_fp8 = strategy.view_kv_cache_for_kernels(v_cache) + + NUM_SEQS, MAX_SEQ_BLOCKS = attn_metadata.block_tables.shape + ctxlens = attn_metadata.context_lens + seqlens = attn_metadata.seq_lens_ts + assert sum(seqlens) == k_new.shape[0] + DIFFUSION_BLOCK_SIZE = attn_metadata.seqs[0].diffusion_block_size + MAX_DIFFUSION_BLOCK_SIZE = max(seqlens) + assert MAX_DIFFUSION_BLOCK_SIZE % DIFFUSION_BLOCK_SIZE == 0 + + total_lens = ctxlens + seqlens + cu_seqlens_q = attn_metadata.cu_seqlens_q + cu_seqlens_k = attn_metadata.cu_seqlens_k + assert sum(total_lens) == cu_seqlens_k[-1] + assert cu_seqlens_q.shape == cu_seqlens_k.shape + assert cu_seqlens_q.shape[0] == NUM_SEQS + 1 + + # Distinct cache shapes: + # k_cache: [NBlks, Hkv, HEAD_DIM//x, PAGE_SIZE, x] + # v_cache: [NBlks, Hkv, HEAD_DIM, PAGE_SIZE] + PAGE_SIZE = int(k_cache.shape[3]) + HEAD_DIM = int(v_cache.shape[2]) + H_KV = int(v_cache.shape[1]) + x = int(k_cache.shape[-1]) + + kv_output_shape = (sum(total_lens).item(), H_KV, HEAD_DIM) + k_output = torch.empty(kv_output_shape, device=k_cache.device, dtype=torch.bfloat16) + v_output = torch.empty_like(k_output) + + GRID = (NUM_SEQS, MAX_SEQ_BLOCKS, H_KV) + load_kvcache_kernel_fp8_distinct[GRID]( + k_cache_fp8, v_cache_fp8, + k_scale, v_scale, + k_new, v_new, + attn_metadata.block_tables, + k_output, v_output, + seqlens, ctxlens, + cu_seqlens_q, cu_seqlens_k, + *k_cache_fp8.stride(), + *v_cache_fp8.stride(), + *k_new.stride(), + *attn_metadata.block_tables.stride(), + *k_output.stride(), + ctxlens.stride(0), + seqlens.stride(0), + cu_seqlens_q.stride(0), + cu_seqlens_k.stride(0), + LAST_BLK_ID=attn_metadata.block_tables.shape[-1] - 1, + HEAD_DIM=HEAD_DIM, + PAGE_SIZE=PAGE_SIZE, + DIFFUSION_BLOCK_SIZE=DIFFUSION_BLOCK_SIZE, + KV_LOAD_UNROLL_FACTOR=2, + X=x, + ) + + return k_output, v_output + + +def load_kvcache(k_cache: torch.Tensor, v_cache: torch.Tensor, + attn_metadata: AttnMetaDataBase, + k_new: torch.Tensor, v_new: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Load KV cache. + Dynamically selects the appropriate kernel based on quantization strategy from context. + """ + from diffulex.utils.quantization.context import get_kv_cache_strategy + strategy = get_kv_cache_strategy() + if strategy is None: + return _load_kvcache_bf16(k_cache, v_cache, attn_metadata, k_new, v_new) - return k_output, v_output \ No newline at end of file + fmt = getattr(strategy, "kv_cache_format", "bf16") + if fmt == "bf16": + return _load_kvcache_bf16(k_cache, v_cache, attn_metadata, k_new, v_new) + if fmt == "fp8": + return _load_kvcache_fp8(k_cache, v_cache, attn_metadata, k_new, v_new) + raise ValueError(f"Unsupported kv_cache_format={fmt!r} for load (strategy={type(strategy)})") \ No newline at end of file diff --git a/diffulex_kernel/python/linear_kernels.py b/diffulex_kernel/python/linear_kernels.py new file mode 100644 index 0000000..259f7b9 --- /dev/null +++ b/diffulex_kernel/python/linear_kernels.py @@ -0,0 +1,2362 @@ +""" +W8A16, W4A16, W8A8, W4A8, FP8 W8A16, and FP8 W8A8 Linear GEMM kernels using TileLang. + +- W8A16: int8 weight × bf16 activation matrix multiplication with per-channel dequantization. +- W4A16: int4 weight (packed in int8) × bf16 activation matrix multiplication with per-channel dequantization. +- W8A8: int8 activation × int8 weight matrix multiplication, output int32 accumulator. +- W4A8: int8 activation × int4 weight (packed in int8) matrix multiplication, output int32 accumulator. +- FP8 W8A16: FP8 weight (uint8 storage) × bf16 activation matrix multiplication with per-channel dequantization. +- FP8 W8A8: FP8 weight (uint8 storage) × FP8 activation (uint8 storage) matrix multiplication with fused scaling. +""" + +from __future__ import annotations + +import tilelang +import tilelang.language as T +from tvm import tir + +from diffulex_kernel.python.auto_tuner import build_linear_configs + +@tilelang.autotune(configs=build_linear_configs()) +@tilelang.jit(out_idx=[3]) +def w8a16_gemm( + M: int, + N: int, + K: int, + block_M: int = 64, + block_N: int = 64, + block_K: int = 128, + num_stages: int = 2, + threads: int = 128, +): + """W8A16 GEMM kernel: bf16 activation × int8 weight (per-channel dequantized). + + Args: + M: Number of rows in activation matrix A + N: Number of output channels (rows in weight matrix B) + K: Inner dimension (columns in A, rows in B) + block_M: Block size for M dimension + block_N: Block size for N dimension + block_K: Block size for K dimension + num_stages: Number of pipeline stages + threads: Number of threads per block + + Returns: + Compiled TileLang kernel function with signature: + kernel(A: bf16[M, K], B: int8[N, K], Scales: bf16[N], C: bf16[M, N]) -> None + """ + # Fast path: only generate the simple copy-based kernel when all dims are perfectly tiled. + # Otherwise, generate a masked (tail-safe) kernel to avoid falling back for non-multiple sizes. + aligned = (M % block_M == 0) and (N % block_N == 0) and (K % block_K == 0) + + @T.prim_func + def main( + A: T.Tensor((M, K), T.bfloat16), # activation, shape (M, K) + B: T.Tensor((N, K), T.int8), # quantized weight, shape (N, K) + Scales: T.Tensor((N,), T.bfloat16), # per-channel scales, shape (N,) + C: T.Tensor((M, N), T.bfloat16), # output, shape (M, N) + ): + """W8A16 GEMM kernel implementation. + + Computes C = (A @ q^T) * Scales where q is the int8 quantized weight and Scales is per-output-channel. + This is mathematically equivalent to dequantizing weights inside the K loop, but avoids doing the + multiply-by-scale for every (N, K) element in every K tile. + + This implementation follows the W4A8 pattern with fragments for proper pipelining. + """ + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + zero_i8 = tir.const(0, T.int8) + zero_bf16 = tir.const(0, T.bfloat16) + + # Allocate shared memory buffers + A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) + B_shared = T.alloc_shared((block_N, block_K), T.int8) + + # Allocate fragments (matching W4A8 pattern for proper pipelining) + B_local = T.alloc_fragment((block_N, block_K), T.int8) + B_bf16_local = T.alloc_fragment((block_N, block_K), T.bfloat16) + B_bf16_prev_local = T.alloc_fragment((block_N, block_K), T.bfloat16) + + # Allocate fragment for accumulation (use float32 for precision) + C_local = T.alloc_fragment((block_M, block_N), T.float32) + C_scaled = T.alloc_fragment((block_M, block_N), T.bfloat16) + + # Optional: Add swizzled layout for B_shared (can improve performance) + # T.annotate_layout({B_shared: tilelang.layout.make_swizzled_layout(B_shared)}) + + # Clear accumulation buffer + T.clear(C_local) + + # Pipeline over K dimension + # Using the same pattern as W4A8: T.Pipelined(K // block_K, num_stages=num_stages) + # The key: we copy B_shared -> B_local, dequantize to B_dequantize_local, + # then copy to B_dequantize_prev_local before GEMM, matching W4A8 exactly + # Note: num_stages must match the number of pipeline operations TileLang detects + # For our case: copy A, copy B, copy B->local, dequantize, copy dequant->prev, gemm + # This creates multiple pipeline stages, so we need to ensure num_stages is appropriate + if aligned: + num_k_blocks = K // block_K + for k in T.Pipelined(num_k_blocks, num_stages=num_stages): + # Load A and B tiles to shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + + # Copy B_shared to local fragment (required for proper pipelining) + T.copy(B_shared, B_local) + + # Cast int8 -> bf16 (no scale here; apply scale once at output). + for i, j in T.Parallel(block_N, block_K): + B_bf16_local[i, j] = B_local[i, j].astype(T.float32).astype(T.bfloat16) + + # Copy to prev_local (required for pipeline synchronization) + T.copy(B_bf16_local, B_bf16_prev_local) + + # GEMM: C = A @ B_dequant^T + T.gemm(A_shared, B_bf16_prev_local, C_local, transpose_B=True) + else: + # Tail-safe kernel: mask-load A/B, mask-load scales (avoid OOB), store C with mask. + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=num_stages): + # Masked load A -> A_shared + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k * block_K + j + A_shared[i, j] = T.if_then_else( + (m < M) & (kk < K), + A[m, kk], + zero_bf16, + ) + + # Masked load B -> B_shared + for i, j in T.Parallel(block_N, block_K): + n = bx * block_N + i + kk = k * block_K + j + B_shared[i, j] = T.if_then_else( + (n < N) & (kk < K), + B[n, kk], + zero_i8, + ) + + # Copy B_shared to local fragment (required for proper pipelining) + T.copy(B_shared, B_local) + + # Cast int8 -> bf16 (no scale here; apply scale once at output). + for i, j in T.Parallel(block_N, block_K): + B_bf16_local[i, j] = B_local[i, j].astype(T.float32).astype(T.bfloat16) + + # Copy to prev_local (required for pipeline synchronization) + T.copy(B_bf16_local, B_bf16_prev_local) + + # GEMM (padded with zeros for out-of-range A/B) + T.gemm(A_shared, B_bf16_prev_local, C_local, transpose_B=True) + + # Apply per-channel scale at output: + # C[m, n] = (A @ q^T)[m, n] * Scales[n] + if aligned: + for i, j in T.Parallel(block_M, block_N): + scale_f32 = Scales[bx * block_N + j].astype(T.float32) + C_scaled[i, j] = (C_local[i, j] * scale_f32).astype(T.bfloat16) + T.copy( + C_scaled, + C[ + by * block_M : (by + 1) * block_M, + bx * block_N : (bx + 1) * block_N, + ], + ) + else: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + scale_bf16 = T.if_then_else(n < N, Scales[n], zero_bf16) + scale_f32 = scale_bf16.astype(T.float32) + C_scaled[i, j] = (C_local[i, j] * scale_f32).astype(T.bfloat16) + if (m < M) & (n < N): + C[m, n] = C_scaled[i, j] + + return main + + +@tilelang.autotune(configs=build_linear_configs()) +@tilelang.jit(out_idx=[4]) +def w8a16_gemm_bias( + M: int, + N: int, + K: int, + block_M: int = 64, + block_N: int = 64, + block_K: int = 128, + num_stages: int = 2, + threads: int = 128, +): + """W8A16 GEMM kernel with fused bias: bf16 activation × int8 weight -> bf16 output, then add bias. + + Signature: + kernel(A: bf16[M,K], B: int8[N,K], Scales: bf16[N], Bias: bf16[N], C: bf16[M,N]) -> None + """ + aligned = (M % block_M == 0) and (N % block_N == 0) and (K % block_K == 0) + + @T.prim_func + def main( + A: T.Tensor((M, K), T.bfloat16), + B: T.Tensor((N, K), T.int8), + Scales: T.Tensor((N,), T.bfloat16), + # NOTE: keep Bias as fp16 to avoid adapter issues observed with 1D bf16 inputs. + Bias: T.Tensor((N,), T.float16), + C: T.Tensor((M, N), T.bfloat16), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + zero_i8 = tir.const(0, T.int8) + zero_bf16 = tir.const(0, T.bfloat16) + + A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) + B_shared = T.alloc_shared((block_N, block_K), T.int8) + + B_local = T.alloc_fragment((block_N, block_K), T.int8) + B_bf16_local = T.alloc_fragment((block_N, block_K), T.bfloat16) + B_bf16_prev_local = T.alloc_fragment((block_N, block_K), T.bfloat16) + + C_local = T.alloc_fragment((block_M, block_N), T.float32) + C_out = T.alloc_fragment((block_M, block_N), T.bfloat16) + + T.clear(C_local) + + if aligned: + num_k_blocks = K // block_K + for k in T.Pipelined(num_k_blocks, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_bf16_local[i, j] = B_local[i, j].astype(T.float32).astype(T.bfloat16) + T.copy(B_bf16_local, B_bf16_prev_local) + + T.gemm(A_shared, B_bf16_prev_local, C_local, transpose_B=True) + else: + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=num_stages): + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k * block_K + j + A_shared[i, j] = T.if_then_else( + (m < M) & (kk < K), + A[m, kk], + zero_bf16, + ) + for i, j in T.Parallel(block_N, block_K): + n = bx * block_N + i + kk = k * block_K + j + B_shared[i, j] = T.if_then_else( + (n < N) & (kk < K), + B[n, kk], + zero_i8, + ) + + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_bf16_local[i, j] = B_local[i, j].astype(T.float32).astype(T.bfloat16) + T.copy(B_bf16_local, B_bf16_prev_local) + + T.gemm(A_shared, B_bf16_prev_local, C_local, transpose_B=True) + + # Apply per-channel scale and bias at output: + # C[m,n] = (A@q^T)[m,n] * Scales[n] + Bias[n] + if aligned: + for i, j in T.Parallel(block_M, block_N): + n = bx * block_N + j + scale_f32 = Scales[n].astype(T.float32) + bias_f32 = Bias[n].astype(T.float32) + C_out[i, j] = (C_local[i, j] * scale_f32 + bias_f32).astype(T.bfloat16) + T.copy( + C_out, + C[ + by * block_M : (by + 1) * block_M, + bx * block_N : (bx + 1) * block_N, + ], + ) + else: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + scale_bf16 = T.if_then_else(n < N, Scales[n], zero_bf16) + bias_f16 = T.if_then_else(n < N, Bias[n], tir.const(0, T.float16)) + scale_f32 = scale_bf16.astype(T.float32) + bias_f32 = bias_f16.astype(T.float32) + val = (C_local[i, j] * scale_f32 + bias_f32).astype(T.bfloat16) + if (m < M) & (n < N): + C[m, n] = val + + return main + + +@tilelang.autotune(configs=build_linear_configs()) +@tilelang.jit(out_idx=[3]) +def w4a16_gemm( + M: int, + N: int, + K: int, + block_M: int = 64, + block_N: int = 64, + block_K: int = 128, + num_stages: int = 2, + threads: int = 128, +): + """W4A16 GEMM kernel: bf16 activation × int4 weight (packed in int8, per-channel dequantized). + + Args: + M: Number of rows in activation matrix A + N: Number of output channels (rows in weight matrix B) + K: Inner dimension (columns in A, rows in B) + block_M: Block size for M dimension + block_N: Block size for N dimension + block_K: Block size for K dimension + num_stages: Number of pipeline stages + threads: Number of threads per block + + Returns: + Compiled TileLang kernel function with signature: + kernel(A: bf16[M, K], B_packed: int8[N, (K+1)//2], Scales: bf16[N], C: bf16[M, N]) -> None + + Note: + B_packed is int4 weights packed into int8 format. Each int8 byte contains 2 int4 values: + - Lower 4 bits: first int4 value (in range [0, 15], representing [-8, 7]) + - Upper 4 bits: second int4 value (in range [0, 15], representing [-8, 7]) + """ + # Fast path: only generate the simple copy-based kernel when all dims are perfectly tiled. + # Otherwise, generate a masked (tail-safe) kernel to avoid falling back for non-multiple sizes. + aligned = (M % block_M == 0) and (N % block_N == 0) and (K % block_K == 0) + + # Packed size: (K + 1) // 2 + packed_K = (K + 1) // 2 + + @T.prim_func + def main( + A: T.Tensor((M, K), T.bfloat16), # activation, shape (M, K) + B_packed: T.Tensor((N, packed_K), T.int8), # packed int4 weight, shape (N, (K+1)//2) + Scales: T.Tensor((N,), T.bfloat16), # per-channel scales, shape (N,) + C: T.Tensor((M, N), T.bfloat16), # output, shape (M, N) + ): + """W4A16 GEMM kernel implementation. + + Computes C = A @ B_dequant^T where: + - B_packed[i, j] contains 2 int4 values (packed in int8) + - Each int4 value is unpacked to q in [-8, 7] + - Per-channel dequantization is applied as: (A @ q^T) * Scales[n] (Scales is per-output-channel) + + This implementation avoids per-element dequantization inside the K loop by + factoring the per-channel scale to an output-side column scaling step, which + substantially reduces work vs. dequantizing every weight element. + """ + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + zero_i8 = tir.const(0, T.int8) + zero_bf16 = tir.const(0, T.bfloat16) + + # Constants for int4 unpacking + int4_offset = tir.const(8, T.int8) # Offset to convert [0, 15] to [-8, 7] + mask_lower = tir.const(0x0F, T.int8) # Mask for lower 4 bits + mask_upper_shift = tir.const(4, T.int8) # Shift for upper 4 bits + + # Allocate shared memory buffers + A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) + B_packed_shared = T.alloc_shared((block_N, (block_K + 1) // 2), T.int8) + + # Allocate fragments (matching W8A16 pattern for proper pipelining) + B_packed_local = T.alloc_fragment((block_N, (block_K + 1) // 2), T.int8) + B_unpacked_local = T.alloc_fragment((block_N, block_K), T.int8) # Unpacked int4 (as int8) + B_bf16_local = T.alloc_fragment((block_N, block_K), T.bfloat16) + B_bf16_prev_local = T.alloc_fragment((block_N, block_K), T.bfloat16) + + # Allocate fragment for accumulation (use float32 for precision) + C_local = T.alloc_fragment((block_M, block_N), T.float32) + C_scaled = T.alloc_fragment((block_M, block_N), T.bfloat16) + + # Clear accumulation buffer + T.clear(C_local) + + # Pipeline over K dimension + if aligned: + num_k_blocks = K // block_K + for k in T.Pipelined(num_k_blocks, num_stages=num_stages): + # Load A tile to shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + + # Load B_packed tile to shared memory + packed_k_start = (k * block_K) // 2 # Packed index for K dimension + T.copy(B_packed[bx * block_N, packed_k_start], B_packed_shared) + + # Copy B_packed_shared to local fragment + T.copy(B_packed_shared, B_packed_local) + + # Unpack int4 from packed int8 (TileLang-friendly indexing): + # B_unpacked_local is indexed by (i, j) directly to avoid indices-mismatch issues. + for i, j in T.Parallel(block_N, block_K): + j_packed = j // 2 + packed_byte = B_packed_local[i, j_packed] + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset + upper_int4 = upper_uint - int4_offset + is_lower = (j % 2) == 0 + B_unpacked_local[i, j] = T.if_then_else(is_lower, lower_int4, upper_int4) + + # Cast int4 (stored as int8) -> bf16 once per element (no scale here). + for i, j in T.Parallel(block_N, block_K): + B_bf16_local[i, j] = B_unpacked_local[i, j].astype(T.float32).astype(T.bfloat16) + + # Copy to prev_local (required for pipeline synchronization) + T.copy(B_bf16_local, B_bf16_prev_local) + + # GEMM: C = A @ B_dequant^T + # Here B is q (int4) cast to bf16; scale is applied once after K-accumulation. + T.gemm(A_shared, B_bf16_prev_local, C_local, transpose_B=True) + else: + # Tail-safe kernel: mask-load A/B_packed, unpack, dequantize, store C with mask + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=num_stages): + # Masked load A -> A_shared + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k * block_K + j + A_shared[i, j] = T.if_then_else( + (m < M) & (kk < K), + A[m, kk], + zero_bf16, + ) + + # Masked load B_packed -> B_packed_shared + packed_k_start = (k * block_K) // 2 + packed_k_size = (block_K + 1) // 2 + for i, j_packed in T.Parallel(block_N, packed_k_size): + n = bx * block_N + i + packed_idx = packed_k_start + j_packed + B_packed_shared[i, j_packed] = T.if_then_else( + (n < N) & (packed_idx < packed_K), + B_packed[n, packed_idx], + zero_i8, + ) + + # Copy B_packed_shared to local fragment + T.copy(B_packed_shared, B_packed_local) + + # Unpack int4 from int8 with boundary checks + for i, j in T.Parallel(block_N, block_K): + kk = k * block_K + j + # Convert to local packed index within this block + j_packed = j // 2 + packed_byte = B_packed_local[i, j_packed] + + # Extract both lower and upper 4 bits + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset # Convert [0, 15] to [-8, 7] + upper_int4 = upper_uint - int4_offset # Convert [0, 15] to [-8, 7] + + # Select the appropriate value based on whether j is even (lower) or odd (upper) + is_lower = (j % 2) == 0 + int4_val = T.if_then_else(is_lower, lower_int4, upper_int4) + + # Mask out-of-bound values to zero + in_bounds = (kk < K) & (j < block_K) + B_unpacked_local[i, j] = T.if_then_else(in_bounds, int4_val, zero_i8) + + # Cast int4 -> bf16 (no scale here). + for i, j in T.Parallel(block_N, block_K): + B_bf16_local[i, j] = B_unpacked_local[i, j].astype(T.float32).astype(T.bfloat16) + + # Copy to prev_local (required for pipeline synchronization) + T.copy(B_bf16_local, B_bf16_prev_local) + + # GEMM (padded with zeros for out-of-range A/B) + T.gemm(A_shared, B_bf16_prev_local, C_local, transpose_B=True) + + # Apply per-channel scale at output (equivalent to weight-side dequantization): + # C[m, n] = (A @ q^T)[m, n] * Scales[n] + if aligned: + for i, j in T.Parallel(block_M, block_N): + scale_f32 = Scales[bx * block_N + j].astype(T.float32) + C_scaled[i, j] = (C_local[i, j] * scale_f32).astype(T.bfloat16) + T.copy( + C_scaled, + C[ + by * block_M : (by + 1) * block_M, + bx * block_N : (bx + 1) * block_N, + ], + ) + else: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + scale_bf16 = T.if_then_else(n < N, Scales[n], zero_bf16) + scale_f32 = scale_bf16.astype(T.float32) + C_scaled[i, j] = (C_local[i, j] * scale_f32).astype(T.bfloat16) + if (m < M) & (n < N): + C[m, n] = C_scaled[i, j] + + return main + + +@tilelang.jit(out_idx=[2]) +def w8a8_gemm( + M: int, + N: int, + K: int, + block_M: int = 64, + block_N: int = 64, + block_K: int = 128, + num_stages: int = 2, + threads: int = 128, +): + """W8A8 GEMM kernel: int8 activation × int8 weight matrix multiplication. + + Args: + M: Number of rows in activation matrix A + N: Number of output channels (rows in weight matrix B) + K: Inner dimension (columns in A, rows in B) + block_M: Block size for M dimension + block_N: Block size for N dimension + block_K: Block size for K dimension + num_stages: Number of pipeline stages + threads: Number of threads per block + + Returns: + Compiled TileLang kernel function with signature: + kernel(A: int8[M, K], B: int8[N, K], C: int32[M, N]) -> None + + Note: + - Input A is int8 quantized activation [M, K] + - Input B is int8 quantized weight [N, K] (GEMM uses transpose_B=True internally) + - Output C is int32 accumulator [M, N] + - Scales (activation scales and weight scales) are applied externally after this kernel + """ + # Fast path: only generate the simple copy-based kernel when all dims are perfectly tiled. + # Otherwise, generate a masked (tail-safe) kernel to avoid falling back for non-multiple sizes. + aligned = (M % block_M == 0) and (N % block_N == 0) and (K % block_K == 0) + + @T.prim_func + def main( + A: T.Tensor((M, K), T.int8), # quantized activation, shape (M, K) + B: T.Tensor((N, K), T.int8), # quantized weight, shape (N, K) + C: T.Tensor((M, N), T.int32), # output accumulator, shape (M, N) + ): + """W8A8 GEMM kernel implementation. + + Computes C = A @ B where all inputs are int8 and output is int32. + This avoids overflow during accumulation by using int32 intermediate results. + """ + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + zero_i8 = tir.const(0, T.int8) + zero_i32 = tir.const(0, T.int32) + + # Allocate shared memory buffers + A_shared = T.alloc_shared((block_M, block_K), T.int8) + B_shared = T.alloc_shared((block_N, block_K), T.int8) + + # Allocate fragments for pipelining + A_local = T.alloc_fragment((block_M, block_K), T.int8) + B_local = T.alloc_fragment((block_N, block_K), T.int8) + A_local_prev = T.alloc_fragment((block_M, block_K), T.int8) + B_local_prev = T.alloc_fragment((block_N, block_K), T.int8) + + # Allocate fragment for accumulation (use int32 for precision) + C_local = T.alloc_fragment((block_M, block_N), T.int32) + + # Clear accumulation buffer + T.clear(C_local) + + # Pipeline over K dimension + if aligned: + num_k_blocks = K // block_K + for k in T.Pipelined(num_k_blocks, num_stages=num_stages): + # Load A and B tiles to shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + # B is stored as [N, K]; GEMM uses transpose_B=True. + T.copy(B[bx * block_N, k * block_K], B_shared) + + # Copy to local fragments (required for proper pipelining) + T.copy(A_shared, A_local) + T.copy(B_shared, B_local) + + # Copy to prev_local (required for pipeline synchronization) + T.copy(A_local, A_local_prev) + T.copy(B_local, B_local_prev) + + # GEMM: C = A @ B^T (int8 x int8 -> int32 accumulation). + # Important: use int8 operands; TileLang lowers to the appropriate int8 GEMM path. + T.gemm(A_local_prev, B_local_prev, C_local, transpose_B=True) + else: + # Tail-safe kernel: mask-load A/B, store C with mask + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=num_stages): + # Masked load A -> A_shared + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k * block_K + j + A_shared[i, j] = T.if_then_else( + (m < M) & (kk < K), + A[m, kk], + zero_i8, + ) + + # Masked load B -> B_shared + for i, j in T.Parallel(block_N, block_K): + n = bx * block_N + i + kk = k * block_K + j + B_shared[i, j] = T.if_then_else( + (kk < K) & (n < N), + B[n, kk], + zero_i8, + ) + + # Copy to local fragments + T.copy(A_shared, A_local) + T.copy(B_shared, B_local) + + # Copy to prev_local (required for pipeline synchronization) + T.copy(A_local, A_local_prev) + T.copy(B_local, B_local_prev) + + # GEMM (padded with zeros for out-of-range A/B) + T.gemm(A_local_prev, B_local_prev, C_local, transpose_B=True) + + # Store result to output + if aligned: + T.copy( + C_local, + C[ + by * block_M : (by + 1) * block_M, + bx * block_N : (bx + 1) * block_N, + ], + ) + else: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + if (m < M) & (n < N): + C[m, n] = C_local[i, j] + + return main + + +@tilelang.jit(out_idx=[1, 2]) +def w8a8_act_quant( + M: int, + K: int, + block_M: int = 64, + block_K: int = 256, + threads: int = 128, +): + """Fused per-row symmetric int8 activation quantization (BF16 -> INT8 + per-row scales). + + This kernel replaces the Python aten chain: + abs -> amax(reduce) -> div -> round -> clamp -> to(int8) + + For each row m: + absmax = max(abs(x[m, :])) + scale[m] = max(absmax, eps) / 127 + x_q[m, k] = clamp(round(x[m, k] / scale[m]), -127, 127).astype(int8) + + Returns: + kernel(A: bf16[M, K], A_q: int8[M, K], Scales: float32[M]) -> None + With out_idx=[1,2], the Python wrapper returns (A_q, Scales). + """ + + @T.prim_func + def main( + A: T.Tensor((M, K), T.bfloat16), + A_q: T.Tensor((M, K), T.int8), + Scales: T.Tensor((M,), T.float32), + ): + with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx,): + zero_f32 = tir.const(0.0, T.float32) + eps_f32 = tir.const(1e-8, T.float32) + inv127 = tir.const(1.0 / 127.0, T.float32) + neg127 = tir.const(-127.0, T.float32) + pos127 = tir.const(127.0, T.float32) + + # Tile buffers for abs/max reduction and scale broadcasting. + abs_tile = T.alloc_fragment((block_M, block_K), T.float32) + tile_max = T.alloc_fragment((block_M,), T.float32) + row_max = T.alloc_fragment((block_M,), T.float32) + scales_local = T.alloc_fragment((block_M,), T.float32) + + # Initialize running max to 0 (absmax is >=0). + T.fill(row_max, zero_f32) + + # Pass 1: compute per-row absmax. + for k0 in range(T.ceildiv(K, block_K)): + for i, j in T.Parallel(block_M, block_K): + m = bx * block_M + i + kk = k0 * block_K + j + v = T.if_then_else( + (m < M) & (kk < K), + A[m, kk].astype(T.float32), + zero_f32, + ) + # abs(v) without relying on optional intrinsics + abs_tile[i, j] = T.if_then_else(v < zero_f32, -v, v) + + T.fill(tile_max, zero_f32) + T.reduce_max(abs_tile, tile_max, dim=1, clear=True) + + for i in T.Parallel(block_M): + row_max[i] = T.max(row_max[i], tile_max[i]) + + # Compute scales once and optionally store to global output. + for i in T.Parallel(block_M): + m = bx * block_M + i + s = T.max(row_max[i], eps_f32) * inv127 + scales_local[i] = s + if m < M: + Scales[m] = s + + # Pass 2: quantize using the computed per-row scales. + for k0 in range(T.ceildiv(K, block_K)): + for i, j in T.Parallel(block_M, block_K): + m = bx * block_M + i + kk = k0 * block_K + j + if (m < M) & (kk < K): + s = scales_local[i] + x = A[m, kk].astype(T.float32) / s + q = T.min(T.max(T.round(x), neg127), pos127) + A_q[m, kk] = q.astype(T.int8) + + return main + + +@tilelang.jit(out_idx=[4]) +def w8a8_scaled_gemm( + M: int, + N: int, + K: int, + block_M: int = 64, + block_N: int = 64, + block_K: int = 128, + num_stages: int = 2, + threads: int = 128, +): + """W8A8 GEMM kernel with fused scaling: int8 activation × int8 weight -> bf16 output. + + This kernel computes: + C[m, n] = (sum_k A_i8[m, k] * B_i8[k, n]) * x_scale[m] * w_scale[n] + + Args: + M, N, K: GEMM sizes + x_scales: float32[M] per-row scales for activation quantization + w_scales: bf16[N] per-output-channel scales for weight quantization + + Returns: + kernel(A: int8[M,K], B: int8[K,N], x_scales: float32[M], w_scales: bf16[N], C: bf16[M,N]) -> None + """ + aligned = (M % block_M == 0) and (N % block_N == 0) and (K % block_K == 0) + + @T.prim_func + def main( + A: T.Tensor((M, K), T.int8), + B: T.Tensor((N, K), T.int8), + XScales: T.Tensor((M,), T.float32), + WScales: T.Tensor((N,), T.float16), + C: T.Tensor((M, N), T.bfloat16), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + zero_i8 = tir.const(0, T.int8) + zero_i32 = tir.const(0, T.int32) + zero_f32 = tir.const(0.0, T.float32) + zero_bf16 = tir.const(0, T.bfloat16) + zero_f16 = tir.const(0, T.float16) + + A_shared = T.alloc_shared((block_M, block_K), T.int8) + B_shared = T.alloc_shared((block_N, block_K), T.int8) + + A_local = T.alloc_fragment((block_M, block_K), T.int8) + B_local = T.alloc_fragment((block_N, block_K), T.int8) + A_local_prev = T.alloc_fragment((block_M, block_K), T.int8) + B_local_prev = T.alloc_fragment((block_N, block_K), T.int8) + + C_local = T.alloc_fragment((block_M, block_N), T.int32) + C_out = T.alloc_fragment((block_M, block_N), T.bfloat16) + + T.clear(C_local) + + if aligned: + num_k_blocks = K // block_K + for k in T.Pipelined(num_k_blocks, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + # B is stored as [N, K]; GEMM uses transpose_B=True. + T.copy(B[bx * block_N, k * block_K], B_shared) + + T.copy(A_shared, A_local) + T.copy(B_shared, B_local) + + T.copy(A_local, A_local_prev) + T.copy(B_local, B_local_prev) + + # int8 x int8 -> int32 accumulation + T.gemm(A_local_prev, B_local_prev, C_local, transpose_B=True) + else: + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=num_stages): + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k * block_K + j + A_shared[i, j] = T.if_then_else((m < M) & (kk < K), A[m, kk], zero_i8) + + for i, j in T.Parallel(block_N, block_K): + n = bx * block_N + i + kk = k * block_K + j + B_shared[i, j] = T.if_then_else((kk < K) & (n < N), B[n, kk], zero_i8) + + T.copy(A_shared, A_local) + T.copy(B_shared, B_local) + + T.copy(A_local, A_local_prev) + T.copy(B_local, B_local_prev) + + T.gemm(A_local_prev, B_local_prev, C_local, transpose_B=True) + + # Fused scaling + store + if aligned: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + x_s = XScales[m] # float32 + w_s = WScales[n].astype(T.float32) + C_out[i, j] = (C_local[i, j].astype(T.float32) * x_s * w_s).astype(T.bfloat16) + T.copy( + C_out, + C[ + by * block_M : (by + 1) * block_M, + bx * block_N : (bx + 1) * block_N, + ], + ) + else: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + x_s = T.if_then_else(m < M, XScales[m], zero_f32) + w_s_f16 = T.if_then_else(n < N, WScales[n], zero_f16) + w_s = w_s_f16.astype(T.float32) + val = (C_local[i, j].astype(T.float32) * x_s * w_s).astype(T.bfloat16) + if (m < M) & (n < N): + C[m, n] = val + + return main + + +@tilelang.autotune(configs=build_linear_configs()) +@tilelang.jit(out_idx=[3]) +def w8a8_fused_act_gemm( + M: int, + N: int, + K: int, + block_M: int = 64, + block_N: int = 64, + block_K: int = 128, + num_stages: int = 3, + threads: int = 128, +): + """W8A8 GEMM with fused activation quantization: bf16 activation -> int8 GEMM -> bf16 output. + + This kernel computes per-row scales internally (absmax / 127), quantizes A on the fly, + then runs int8 GEMM against B (int8) and applies per-row/per-channel scaling. + + Optimizations: + - Removed unnecessary fragment copies (A_local, A_local_prev, B_local, B_local_prev) + - Direct GEMM from shared memory (A_shared, B_shared -> C_local) + - Added swizzled layout for shared memory to reduce bank conflicts + - Increased num_stages to 3 for better latency hiding + """ + aligned = (M % block_M == 0) and (N % block_N == 0) and (K % block_K == 0) + + @T.prim_func + def main( + A: T.Tensor((M, K), T.bfloat16), + B: T.Tensor((N, K), T.int8), + WScales: T.Tensor((N,), T.float16), + C: T.Tensor((M, N), T.bfloat16), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + zero_i8 = tir.const(0, T.int8) + zero_i32 = tir.const(0, T.int32) + zero_f32 = tir.const(0.0, T.float32) + zero_bf16 = tir.const(0, T.bfloat16) + zero_f16 = tir.const(0, T.float16) + eps_f32 = tir.const(1e-8, T.float32) + inv127 = tir.const(1.0 / 127.0, T.float32) + neg127 = tir.const(-127.0, T.float32) + pos127 = tir.const(127.0, T.float32) + + A_shared = T.alloc_shared((block_M, block_K), T.int8) + B_shared = T.alloc_shared((block_N, block_K), T.int8) + + C_local = T.alloc_fragment((block_M, block_N), T.int32) + C_out = T.alloc_fragment((block_M, block_N), T.bfloat16) + + row_max = T.alloc_reducer((block_M,), T.float32, op="max") + scales_smem = T.alloc_shared((block_M,), T.float32) + + # Add swizzled layout for shared memory to reduce bank conflicts + T.annotate_layout({ + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + }) + + T.clear(C_local) + # absmax is non-negative; 0 is a safe initializer for max-reduction. + T.fill(row_max, zero_f32) + + # Pass 1: compute per-row absmax. + if aligned: + num_k_blocks = K // block_K + for k0 in range(num_k_blocks): + for i, j in T.Parallel(block_M, block_K): + v = A[by * block_M + i, k0 * block_K + j].astype(T.float32) + av = T.if_then_else(v < zero_f32, -v, v) + row_max[i] = T.max(row_max[i], av) + else: + for k0 in range(T.ceildiv(K, block_K)): + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k0 * block_K + j + v = T.if_then_else((m < M) & (kk < K), A[m, kk].astype(T.float32), zero_f32) + av = T.if_then_else(v < zero_f32, -v, v) + row_max[i] = T.max(row_max[i], av) + + # Materialize reducer results. + T.finalize_reducer(row_max) + + # Compute per-row scales. + for i in T.Parallel(block_M): + scales_smem[i] = T.max(row_max[i], eps_f32) * inv127 + + # Pass 2: quantize A on the fly and GEMM. + # Optimization: removed A_local, A_local_prev, B_local, B_local_prev + # Direct GEMM from shared memory saves 4 fragment copies per iteration! + if aligned: + num_k_blocks = K // block_K + for k in T.Pipelined(num_k_blocks, num_stages=num_stages): + # Quantize A directly into A_shared + for i, j in T.Parallel(block_M, block_K): + s = scales_smem[i] + x = A[by * block_M + i, k * block_K + j].astype(T.float32) / s + q = T.min(T.max(T.round(x), neg127), pos127) + A_shared[i, j] = q.astype(T.int8) + + # Load B directly into B_shared + # B is stored as [N, K]; GEMM uses transpose_B=True. + T.copy(B[bx * block_N, k * block_K], B_shared) + + # Direct GEMM from shared memory - no fragment copies! + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + else: + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=num_stages): + # Quantize A directly into A_shared with bounds checking + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k * block_K + j + if (m < M) & (kk < K): + s = scales_smem[i] + x = A[m, kk].astype(T.float32) / s + q = T.min(T.max(T.round(x), neg127), pos127) + A_shared[i, j] = q.astype(T.int8) + else: + A_shared[i, j] = zero_i8 + + # Load B directly into B_shared with bounds checking + for i, j in T.Parallel(block_N, block_K): + n = bx * block_N + i + kk = k * block_K + j + B_shared[i, j] = T.if_then_else((kk < K) & (n < N), B[n, kk], zero_i8) + + # Direct GEMM from shared memory - no fragment copies! + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + # Fused scaling + store + if aligned: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + x_s = scales_smem[i] + w_s = WScales[n].astype(T.float32) + C_out[i, j] = (C_local[i, j].astype(T.float32) * x_s * w_s).astype(T.bfloat16) + T.copy( + C_out, + C[ + by * block_M : (by + 1) * block_M, + bx * block_N : (bx + 1) * block_N, + ], + ) + else: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + x_s = T.if_then_else(m < M, scales_smem[i], zero_f32) + w_s_f16 = T.if_then_else(n < N, WScales[n], zero_f16) + w_s = w_s_f16.astype(T.float32) + val = (C_local[i, j].astype(T.float32) * x_s * w_s).astype(T.bfloat16) + if (m < M) & (n < N): + C[m, n] = val + + return main + + +@tilelang.jit(out_idx=[2]) +def w4a8_gemm( + M: int, + N: int, + K: int, + block_M: int = 64, + block_N: int = 64, + block_K: int = 128, + num_stages: int = 2, + threads: int = 128, +): + """W4A8 GEMM kernel: int8 activation × int4 weight (packed in int8) matrix multiplication. + + Args: + M: Number of rows in activation matrix A + N: Number of output channels (rows in weight matrix B) + K: Inner dimension (columns in A, rows in B) + block_M: Block size for M dimension + block_N: Block size for N dimension + block_K: Block size for K dimension + num_stages: Number of pipeline stages + threads: Number of threads per block + + Returns: + Compiled TileLang kernel function with signature: + kernel(A: int8[M, K], B_packed: int8[N, (K+1)//2], C: int32[M, N]) -> None + + Note: + - Input A is int8 quantized activation [M, K] + - Input B_packed is int4 weights packed into int8 format [N, (K+1)//2] + - Output C is int32 accumulator [M, N] + - Scales (activation scales and weight scales) are applied externally after this kernel + - B_packed is int4 weights packed into int8 format. Each int8 byte contains 2 int4 values: + - Lower 4 bits: first int4 value (in range [0, 15], representing [-8, 7]) + - Upper 4 bits: second int4 value (in range [0, 15], representing [-8, 7]) + """ + # Fast path: only generate the simple copy-based kernel when all dims are perfectly tiled. + # Otherwise, generate a masked (tail-safe) kernel to avoid falling back for non-multiple sizes. + aligned = (M % block_M == 0) and (N % block_N == 0) and (K % block_K == 0) + + # Packed size: (K + 1) // 2 + packed_K = (K + 1) // 2 + + @T.prim_func + def main( + A: T.Tensor((M, K), T.int8), # quantized activation, shape (M, K) + B_packed: T.Tensor((N, packed_K), T.int8), # packed int4 weight, shape (N, (K+1)//2) + C: T.Tensor((M, N), T.int32), # output accumulator, shape (M, N) + ): + """W4A8 GEMM kernel implementation. + + Computes C = A @ B_unpacked^T where: + - B_packed[i, j] contains 2 int4 values (packed in int8) + - Each int4 value is unpacked to q in [-8, 7] + - All operations use int8/int32 to avoid overflow during accumulation + """ + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + zero_i8 = tir.const(0, T.int8) + zero_i32 = tir.const(0, T.int32) + + # Constants for int4 unpacking + int4_offset = tir.const(8, T.int8) # Offset to convert [0, 15] to [-8, 7] + mask_lower = tir.const(0x0F, T.int8) # Mask for lower 4 bits + mask_upper_shift = tir.const(4, T.int8) # Shift for upper 4 bits + + # Allocate shared memory buffers + A_shared = T.alloc_shared((block_M, block_K), T.int8) + B_packed_shared = T.alloc_shared((block_N, (block_K + 1) // 2), T.int8) + + # Allocate fragments for pipelining + A_local = T.alloc_fragment((block_M, block_K), T.int8) + B_packed_local = T.alloc_fragment((block_N, (block_K + 1) // 2), T.int8) + B_unpacked_local = T.alloc_fragment((block_N, block_K), T.int8) # Unpacked int4 (as int8) + A_local_prev = T.alloc_fragment((block_M, block_K), T.int8) + B_unpacked_local_prev = T.alloc_fragment((block_N, block_K), T.int8) + + # Allocate fragment for accumulation (use int32 for precision) + C_local = T.alloc_fragment((block_M, block_N), T.int32) + + # Clear accumulation buffer + T.clear(C_local) + + # Pipeline over K dimension + if aligned: + num_k_blocks = K // block_K + for k in T.Pipelined(num_k_blocks, num_stages=num_stages): + # Load A tile to shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + + # Load B_packed tile to shared memory + packed_k_start = (k * block_K) // 2 # Packed index for K dimension + T.copy(B_packed[bx * block_N, packed_k_start], B_packed_shared) + + # Copy to local fragments + T.copy(A_shared, A_local) + T.copy(B_packed_shared, B_packed_local) + + # Unpack int4 from packed int8 + for i, j in T.Parallel(block_N, block_K): + j_packed = j // 2 + packed_byte = B_packed_local[i, j_packed] + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset + upper_int4 = upper_uint - int4_offset + is_lower = (j % 2) == 0 + B_unpacked_local[i, j] = T.if_then_else(is_lower, lower_int4, upper_int4) + + # Copy to prev_local (required for pipeline synchronization) + T.copy(A_local, A_local_prev) + T.copy(B_unpacked_local, B_unpacked_local_prev) + + # GEMM: C = A @ B_unpacked^T (int8 x int8 -> int32 accumulation). + # Use int8 operands; TileLang lowers to the proper int8 GEMM path. + T.gemm(A_local_prev, B_unpacked_local_prev, C_local, transpose_B=True) + else: + # Tail-safe kernel: mask-load A/B_packed, unpack, store C with mask + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=num_stages): + # Masked load A -> A_shared + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k * block_K + j + A_shared[i, j] = T.if_then_else( + (m < M) & (kk < K), + A[m, kk], + zero_i8, + ) + + # Masked load B_packed -> B_packed_shared + packed_k_start = (k * block_K) // 2 + packed_k_size = (block_K + 1) // 2 + for i, j_packed in T.Parallel(block_N, packed_k_size): + n = bx * block_N + i + packed_idx = packed_k_start + j_packed + B_packed_shared[i, j_packed] = T.if_then_else( + (n < N) & (packed_idx < packed_K), + B_packed[n, packed_idx], + zero_i8, + ) + + # Copy to local fragments + T.copy(A_shared, A_local) + T.copy(B_packed_shared, B_packed_local) + + # Unpack int4 from int8 with boundary checks + for i, j in T.Parallel(block_N, block_K): + kk = k * block_K + j + j_packed = j // 2 + packed_byte = B_packed_local[i, j_packed] + + # Extract both lower and upper 4 bits + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset + upper_int4 = upper_uint - int4_offset + + # Select the appropriate value based on whether j is even (lower) or odd (upper) + is_lower = (j % 2) == 0 + int4_val = T.if_then_else(is_lower, lower_int4, upper_int4) + + # Mask out-of-bound values to zero + in_bounds = (kk < K) & (j < block_K) + B_unpacked_local[i, j] = T.if_then_else(in_bounds, int4_val, zero_i8) + + # Copy to prev_local (required for pipeline synchronization) + T.copy(A_local, A_local_prev) + T.copy(B_unpacked_local, B_unpacked_local_prev) + + # GEMM (padded with zeros for out-of-range A/B) + T.gemm(A_local_prev, B_unpacked_local_prev, C_local, transpose_B=True) + + # Store result to output + if aligned: + T.copy( + C_local, + C[ + by * block_M : (by + 1) * block_M, + bx * block_N : (bx + 1) * block_N, + ], + ) + else: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + if (m < M) & (n < N): + C[m, n] = C_local[i, j] + + return main + + +@tilelang.jit(out_idx=[4]) +def w4a8_scaled_gemm( + M: int, + N: int, + K: int, + block_M: int = 64, + block_N: int = 64, + block_K: int = 128, + num_stages: int = 2, + threads: int = 128, +): + """W4A8 GEMM kernel with fused scaling: int8 activation × packed int4 weight -> bf16 output. + + Computes: + C[m, n] = (sum_k A_i8[m,k] * q_i4[n,k]) * x_scale[m] * w_scale[n] + + Where q_i4 is unpacked from B_packed on the fly into int8 in [-8, 7]. + """ + aligned = (M % block_M == 0) and (N % block_N == 0) and (K % block_K == 0) + packed_K = (K + 1) // 2 + + @T.prim_func + def main( + A: T.Tensor((M, K), T.int8), + B_packed: T.Tensor((N, packed_K), T.int8), + XScales: T.Tensor((M,), T.float32), + WScales: T.Tensor((N,), T.float16), + C: T.Tensor((M, N), T.bfloat16), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + zero_i8 = tir.const(0, T.int8) + zero_i32 = tir.const(0, T.int32) + zero_f32 = tir.const(0.0, T.float32) + zero_bf16 = tir.const(0, T.bfloat16) + zero_f16 = tir.const(0, T.float16) + + int4_offset = tir.const(8, T.int8) + mask_lower = tir.const(0x0F, T.int8) + mask_upper_shift = tir.const(4, T.int8) + + A_shared = T.alloc_shared((block_M, block_K), T.int8) + B_packed_shared = T.alloc_shared((block_N, (block_K + 1) // 2), T.int8) + + A_local = T.alloc_fragment((block_M, block_K), T.int8) + B_packed_local = T.alloc_fragment((block_N, (block_K + 1) // 2), T.int8) + B_unpacked_local = T.alloc_fragment((block_N, block_K), T.int8) + A_local_prev = T.alloc_fragment((block_M, block_K), T.int8) + B_unpacked_local_prev = T.alloc_fragment((block_N, block_K), T.int8) + + C_local = T.alloc_fragment((block_M, block_N), T.int32) + C_out = T.alloc_fragment((block_M, block_N), T.bfloat16) + + T.clear(C_local) + + if aligned: + num_k_blocks = K // block_K + for k in T.Pipelined(num_k_blocks, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + + packed_k_start = (k * block_K) // 2 + T.copy(B_packed[bx * block_N, packed_k_start], B_packed_shared) + + T.copy(A_shared, A_local) + T.copy(B_packed_shared, B_packed_local) + + for i, j in T.Parallel(block_N, block_K): + j_packed = j // 2 + packed_byte = B_packed_local[i, j_packed] + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset + upper_int4 = upper_uint - int4_offset + is_lower = (j % 2) == 0 + B_unpacked_local[i, j] = T.if_then_else(is_lower, lower_int4, upper_int4) + + T.copy(A_local, A_local_prev) + T.copy(B_unpacked_local, B_unpacked_local_prev) + + T.gemm(A_local_prev, B_unpacked_local_prev, C_local, transpose_B=True) + else: + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=num_stages): + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k * block_K + j + A_shared[i, j] = T.if_then_else((m < M) & (kk < K), A[m, kk], zero_i8) + + packed_k_start = (k * block_K) // 2 + packed_k_size = (block_K + 1) // 2 + for i, j_packed in T.Parallel(block_N, packed_k_size): + n = bx * block_N + i + packed_idx = packed_k_start + j_packed + B_packed_shared[i, j_packed] = T.if_then_else( + (n < N) & (packed_idx < packed_K), + B_packed[n, packed_idx], + zero_i8, + ) + + T.copy(A_shared, A_local) + T.copy(B_packed_shared, B_packed_local) + + for i, j in T.Parallel(block_N, block_K): + kk = k * block_K + j + j_packed = j // 2 + packed_byte = B_packed_local[i, j_packed] + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset + upper_int4 = upper_uint - int4_offset + is_lower = (j % 2) == 0 + int4_val = T.if_then_else(is_lower, lower_int4, upper_int4) + in_bounds = (kk < K) & (j < block_K) + B_unpacked_local[i, j] = T.if_then_else(in_bounds, int4_val, zero_i8) + + T.copy(A_local, A_local_prev) + T.copy(B_unpacked_local, B_unpacked_local_prev) + + T.gemm(A_local_prev, B_unpacked_local_prev, C_local, transpose_B=True) + + # Fused scaling + store + if aligned: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + x_s = XScales[m] + w_s = WScales[n].astype(T.float32) + C_out[i, j] = (C_local[i, j].astype(T.float32) * x_s * w_s).astype(T.bfloat16) + T.copy( + C_out, + C[ + by * block_M : (by + 1) * block_M, + bx * block_N : (bx + 1) * block_N, + ], + ) + else: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + x_s = T.if_then_else(m < M, XScales[m], zero_f32) + w_s_f16 = T.if_then_else(n < N, WScales[n], zero_f16) + w_s = w_s_f16.astype(T.float32) + val = (C_local[i, j].astype(T.float32) * x_s * w_s).astype(T.bfloat16) + if (m < M) & (n < N): + C[m, n] = val + + return main + + +@tilelang.autotune(configs=build_linear_configs()) +@tilelang.jit(out_idx=[3]) +def w4a8_fused_act_gemm( + M: int, + N: int, + K: int, + block_M: int = 64, + block_N: int = 64, + block_K: int = 128, + num_stages: int = 3, + threads: int = 128, +): + """W4A8 GEMM with fused activation quantization: bf16 activation -> int8 GEMM -> bf16 output. + + This kernel computes per-row scales internally (absmax / 127), quantizes A on the fly, + unpacks packed int4 weights, then applies fused scaling. + + Optimizations: + - Reduced fragment copies: unpack B directly in shared memory + - Added swizzled layout for shared memory + - Increased num_stages to 3 for better latency hiding + """ + aligned = (M % block_M == 0) and (N % block_N == 0) and (K % block_K == 0) + packed_K = (K + 1) // 2 + + @T.prim_func + def main( + A: T.Tensor((M, K), T.bfloat16), + B_packed: T.Tensor((N, packed_K), T.int8), + WScales: T.Tensor((N,), T.float16), + C: T.Tensor((M, N), T.bfloat16), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + zero_i8 = tir.const(0, T.int8) + zero_i32 = tir.const(0, T.int32) + zero_f32 = tir.const(0.0, T.float32) + zero_bf16 = tir.const(0, T.bfloat16) + zero_f16 = tir.const(0, T.float16) + eps_f32 = tir.const(1e-8, T.float32) + inv127 = tir.const(1.0 / 127.0, T.float32) + neg127 = tir.const(-127.0, T.float32) + pos127 = tir.const(127.0, T.float32) + + int4_offset = tir.const(8, T.int8) + mask_lower = tir.const(0x0F, T.int8) + mask_upper_shift = tir.const(4, T.int8) + + A_shared = T.alloc_shared((block_M, block_K), T.int8) + B_packed_shared = T.alloc_shared((block_N, (block_K + 1) // 2), T.int8) + B_unpacked_shared = T.alloc_shared((block_N, block_K), T.int8) + + C_local = T.alloc_fragment((block_M, block_N), T.int32) + C_out = T.alloc_fragment((block_M, block_N), T.bfloat16) + + row_max = T.alloc_reducer((block_M,), T.float32, op="max") + scales_smem = T.alloc_shared((block_M,), T.float32) + + # Add swizzled layout for shared memory + T.annotate_layout({ + A_shared: tilelang.layout.make_swizzled_layout(A_shared), + B_unpacked_shared: tilelang.layout.make_swizzled_layout(B_unpacked_shared), + }) + + T.clear(C_local) + # absmax is non-negative; 0 is a safe initializer for max-reduction. + T.fill(row_max, zero_f32) + + # Pass 1: compute per-row absmax. + if aligned: + num_k_blocks = K // block_K + for k0 in range(num_k_blocks): + for i, j in T.Parallel(block_M, block_K): + v = A[by * block_M + i, k0 * block_K + j].astype(T.float32) + av = T.if_then_else(v < zero_f32, -v, v) + row_max[i] = T.max(row_max[i], av) + else: + for k0 in range(T.ceildiv(K, block_K)): + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k0 * block_K + j + v = T.if_then_else((m < M) & (kk < K), A[m, kk].astype(T.float32), zero_f32) + av = T.if_then_else(v < zero_f32, -v, v) + row_max[i] = T.max(row_max[i], av) + + # Materialize reducer results. + T.finalize_reducer(row_max) + + # Compute per-row scales. + for i in T.Parallel(block_M): + scales_smem[i] = T.max(row_max[i], eps_f32) * inv127 + + # Pass 2: quantize A, unpack B, GEMM. + # Optimization: unpack B directly in shared memory, avoid fragment copies + if aligned: + num_k_blocks = K // block_K + for k in T.Pipelined(num_k_blocks, num_stages=num_stages): + # Quantize A directly into A_shared + for i, j in T.Parallel(block_M, block_K): + s = scales_smem[i] + x = A[by * block_M + i, k * block_K + j].astype(T.float32) / s + q = T.min(T.max(T.round(x), neg127), pos127) + A_shared[i, j] = q.astype(T.int8) + + # Load B_packed into shared memory + packed_k_start = (k * block_K) // 2 + T.copy(B_packed[bx * block_N, packed_k_start], B_packed_shared) + + # Unpack B directly in shared memory + for i, j in T.Parallel(block_N, block_K): + j_packed = j // 2 + packed_byte = B_packed_shared[i, j_packed] + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset + upper_int4 = upper_uint - int4_offset + # NOTE: Avoid introducing a let-bound var (e.g., `is_lower`) inside a fused/vectorized + # Parallel loop. Some TileLang/TVM lower passes may attempt to re-bind the same Var + # with different loop symbols and fail with: + # "Trying to update var 'is_lower' with a different value" + B_unpacked_shared[i, j] = T.if_then_else((j % 2) == 0, lower_int4, upper_int4) + + # Direct GEMM from shared memory - no fragment copies! + T.gemm(A_shared, B_unpacked_shared, C_local, transpose_B=True) + else: + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=num_stages): + # Quantize A directly into A_shared with bounds checking + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k * block_K + j + if (m < M) & (kk < K): + s = scales_smem[i] + x = A[m, kk].astype(T.float32) / s + q = T.min(T.max(T.round(x), neg127), pos127) + A_shared[i, j] = q.astype(T.int8) + else: + A_shared[i, j] = zero_i8 + + # Load B_packed into shared memory with bounds checking + packed_k_start = (k * block_K) // 2 + packed_k_size = (block_K + 1) // 2 + for i, j_packed in T.Parallel(block_N, packed_k_size): + n = bx * block_N + i + packed_idx = packed_k_start + j_packed + B_packed_shared[i, j_packed] = T.if_then_else( + (n < N) & (packed_idx < packed_K), + B_packed[n, packed_idx], + zero_i8, + ) + + # Unpack B directly in shared memory with bounds checking + for i, j in T.Parallel(block_N, block_K): + kk = k * block_K + j + j_packed = j // 2 + packed_byte = B_packed_shared[i, j_packed] + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset + upper_int4 = upper_uint - int4_offset + int4_val = T.if_then_else((j % 2) == 0, lower_int4, upper_int4) + in_bounds = (kk < K) & (j < block_K) + B_unpacked_shared[i, j] = T.if_then_else(in_bounds, int4_val, zero_i8) + + # Direct GEMM from shared memory - no fragment copies! + T.gemm(A_shared, B_unpacked_shared, C_local, transpose_B=True) + + # Fused scaling + store + if aligned: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + x_s = scales_smem[i] + w_s = WScales[n].astype(T.float32) + C_out[i, j] = (C_local[i, j].astype(T.float32) * x_s * w_s).astype(T.bfloat16) + T.copy( + C_out, + C[ + by * block_M : (by + 1) * block_M, + bx * block_N : (bx + 1) * block_N, + ], + ) + else: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + x_s = T.if_then_else(m < M, scales_smem[i], zero_f32) + w_s_f16 = T.if_then_else(n < N, WScales[n], zero_f16) + w_s = w_s_f16.astype(T.float32) + val = (C_local[i, j].astype(T.float32) * x_s * w_s).astype(T.bfloat16) + if (m < M) & (n < N): + C[m, n] = val + + return main + + +@tilelang.autotune(configs=build_linear_configs()) +@tilelang.jit(out_idx=[3]) +def fp8_e4m3_w8a16_gemm( + M: int, + N: int, + K: int, + block_M: int = 64, + block_N: int = 64, + block_K: int = 128, + num_stages: int = 2, + threads: int = 128, +): + """FP8 E4M3 W8A16 GEMM kernel: bf16 activation × FP8 E4M3 weight (uint8 storage, per-channel dequantized).""" + aligned = (M % block_M == 0) and (N % block_N == 0) and (K % block_K == 0) + + @T.prim_func + def main( + A: T.Tensor((M, K), T.bfloat16), + # IMPORTANT: pass fp8 tensors from PyTorch by using `uint8_tensor.view(torch_fp8_dtype)`. + # Do NOT pass raw uint8 storage here, otherwise we would need reinterpret logic and lose performance. + B: T.Tensor((N, K), T.float8_e4m3fn), + Scales: T.Tensor((N,), T.float32), + C: T.Tensor((M, N), T.bfloat16), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + zero_bf16 = tir.const(0, T.bfloat16) + zero_f32 = tir.const(0.0, T.float32) + zero_fp8 = tir.const(0, T.float8_e4m3fn) + + A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) + B_shared = T.alloc_shared((block_N, block_K), T.float8_e4m3fn) + + # Follow the same pipeline pattern as int8 `w8a16_gemm`: + # B_shared -> B_local -> (cast) B_bf16_local -> B_bf16_prev_local -> GEMM + B_local = T.alloc_fragment((block_N, block_K), T.float8_e4m3fn) + B_bf16_local = T.alloc_fragment((block_N, block_K), T.bfloat16) + B_bf16_prev_local = T.alloc_fragment((block_N, block_K), T.bfloat16) + + C_local = T.alloc_fragment((block_M, block_N), T.float32) + C_scaled = T.alloc_fragment((block_M, block_N), T.bfloat16) + + T.clear(C_local) + + if aligned: + num_k_blocks = K // block_K + for k in T.Pipelined(num_k_blocks, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.copy(B_shared, B_local) + + # Cast fp8 -> fp32 -> bf16 (avoid fp16/half path, which can trigger cutlass bf16 ambiguity). + for i, j in T.Parallel(block_N, block_K): + B_bf16_local[i, j] = B_local[i, j].astype(T.float32).astype(T.bfloat16) + + T.copy(B_bf16_local, B_bf16_prev_local) + T.gemm(A_shared, B_bf16_prev_local, C_local, transpose_B=True) + else: + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=num_stages): + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k * block_K + j + A_shared[i, j] = T.if_then_else((m < M) & (kk < K), A[m, kk], zero_bf16) + + for i, j in T.Parallel(block_N, block_K): + n = bx * block_N + i + kk = k * block_K + j + B_shared[i, j] = T.if_then_else((n < N) & (kk < K), B[n, kk], zero_fp8) + + T.copy(B_shared, B_local) + + for i, j in T.Parallel(block_N, block_K): + B_bf16_local[i, j] = B_local[i, j].astype(T.float32).astype(T.bfloat16) + + T.copy(B_bf16_local, B_bf16_prev_local) + T.gemm(A_shared, B_bf16_prev_local, C_local, transpose_B=True) + + # Apply per-channel scale at output: C[m, n] = (A @ q_fp8^T)[m, n] * Scales[n] + if aligned: + for i, j in T.Parallel(block_M, block_N): + scale_f32 = Scales[bx * block_N + j] + C_scaled[i, j] = (C_local[i, j] * scale_f32).astype(T.bfloat16) + T.copy(C_scaled, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) + else: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + scale_f32 = T.if_then_else(n < N, Scales[n], zero_f32) + val = (C_local[i, j] * scale_f32).astype(T.bfloat16) + if (m < M) & (n < N): + C[m, n] = val + + return main + + +@tilelang.autotune(configs=build_linear_configs()) +@tilelang.jit(out_idx=[3]) +def fp8_e5m2_w8a16_gemm( + M: int, + N: int, + K: int, + block_M: int = 64, + block_N: int = 64, + block_K: int = 128, + num_stages: int = 2, + threads: int = 128, +): + """FP8 E5M2 W8A16 GEMM kernel: bf16 activation × FP8 E5M2 weight (uint8 storage, per-channel dequantized).""" + aligned = (M % block_M == 0) and (N % block_N == 0) and (K % block_K == 0) + + @T.prim_func + def main( + A: T.Tensor((M, K), T.bfloat16), + B: T.Tensor((N, K), T.float8_e5m2), + Scales: T.Tensor((N,), T.float32), + C: T.Tensor((M, N), T.bfloat16), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + zero_bf16 = tir.const(0, T.bfloat16) + zero_f32 = tir.const(0.0, T.float32) + zero_fp8 = tir.const(0, T.float8_e5m2) + + A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) + B_shared = T.alloc_shared((block_N, block_K), T.float8_e5m2) + + B_local = T.alloc_fragment((block_N, block_K), T.float8_e5m2) + B_bf16_local = T.alloc_fragment((block_N, block_K), T.bfloat16) + B_bf16_prev_local = T.alloc_fragment((block_N, block_K), T.bfloat16) + + C_local = T.alloc_fragment((block_M, block_N), T.float32) + C_scaled = T.alloc_fragment((block_M, block_N), T.bfloat16) + + T.clear(C_local) + + if aligned: + num_k_blocks = K // block_K + for k in T.Pipelined(num_k_blocks, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.copy(B_shared, B_local) + + for i, j in T.Parallel(block_N, block_K): + B_bf16_local[i, j] = B_local[i, j].astype(T.float32).astype(T.bfloat16) + + T.copy(B_bf16_local, B_bf16_prev_local) + T.gemm(A_shared, B_bf16_prev_local, C_local, transpose_B=True) + else: + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=num_stages): + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k * block_K + j + A_shared[i, j] = T.if_then_else((m < M) & (kk < K), A[m, kk], zero_bf16) + + for i, j in T.Parallel(block_N, block_K): + n = bx * block_N + i + kk = k * block_K + j + B_shared[i, j] = T.if_then_else((n < N) & (kk < K), B[n, kk], zero_fp8) + + T.copy(B_shared, B_local) + + for i, j in T.Parallel(block_N, block_K): + B_bf16_local[i, j] = B_local[i, j].astype(T.float32).astype(T.bfloat16) + + T.copy(B_bf16_local, B_bf16_prev_local) + T.gemm(A_shared, B_bf16_prev_local, C_local, transpose_B=True) + + if aligned: + for i, j in T.Parallel(block_M, block_N): + scale_f32 = Scales[bx * block_N + j] + C_scaled[i, j] = (C_local[i, j] * scale_f32).astype(T.bfloat16) + T.copy(C_scaled, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) + else: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + scale_f32 = T.if_then_else(n < N, Scales[n], zero_f32) + val = (C_local[i, j] * scale_f32).astype(T.bfloat16) + if (m < M) & (n < N): + C[m, n] = val + + return main + + +@tilelang.autotune(configs=build_linear_configs()) +@tilelang.jit(out_idx=[4]) +def fp8_e4m3_w8a8_gemm( + M: int, + N: int, + K: int, + block_M: int = 64, + block_N: int = 64, + block_K: int = 128, + num_stages: int = 2, + threads: int = 128, +): + """FP8 E4M3 W8A8 GEMM kernel: FP8 E4M3 activation × FP8 E4M3 weight with fused scaling.""" + aligned = (M % block_M == 0) and (N % block_N == 0) and (K % block_K == 0) + + @T.prim_func + def main( + A: T.Tensor((M, K), T.float8_e4m3fn), + B: T.Tensor((N, K), T.float8_e4m3fn), + XScales: T.Tensor((M,), T.float32), + WScales: T.Tensor((N,), T.float16), + C: T.Tensor((M, N), T.bfloat16), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + zero_f32 = tir.const(0.0, T.float32) + zero_f16 = tir.const(0, T.float16) + zero_fp8 = tir.const(0, T.float8_e4m3fn) + + A_shared = T.alloc_shared((block_M, block_K), T.float8_e4m3fn) + B_shared = T.alloc_shared((block_N, block_K), T.float8_e4m3fn) + + C_local = T.alloc_fragment((block_M, block_N), T.float32) + C_out = T.alloc_fragment((block_M, block_N), T.bfloat16) + + T.clear(C_local) + + if aligned: + num_k_blocks = K // block_K + for k in T.Pipelined(num_k_blocks, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + else: + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=num_stages): + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k * block_K + j + A_shared[i, j] = T.if_then_else((m < M) & (kk < K), A[m, kk], zero_fp8) + + for i, j in T.Parallel(block_N, block_K): + n = bx * block_N + i + kk = k * block_K + j + B_shared[i, j] = T.if_then_else((n < N) & (kk < K), B[n, kk], zero_fp8) + + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + # Fused scaling + store: C = (A@B^T) * x_scale[m] * w_scale[n] + if aligned: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + x_s = XScales[m] + w_s = WScales[n].astype(T.float32) + C_out[i, j] = (C_local[i, j] * x_s * w_s).astype(T.bfloat16) + T.copy(C_out, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) + else: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + x_s = T.if_then_else(m < M, XScales[m], zero_f32) + w_s_f16 = T.if_then_else(n < N, WScales[n], zero_f16) + w_s = w_s_f16.astype(T.float32) + val = (C_local[i, j] * x_s * w_s).astype(T.bfloat16) + if (m < M) & (n < N): + C[m, n] = val + + return main + + +@tilelang.autotune(configs=build_linear_configs()) +@tilelang.jit(out_idx=[4]) +def fp8_e5m2_w8a8_gemm( + M: int, + N: int, + K: int, + block_M: int = 64, + block_N: int = 64, + block_K: int = 128, + num_stages: int = 2, + threads: int = 128, +): + """FP8 E5M2 W8A8 GEMM kernel: FP8 E5M2 activation × FP8 E5M2 weight with fused scaling.""" + aligned = (M % block_M == 0) and (N % block_N == 0) and (K % block_K == 0) + + @T.prim_func + def main( + A: T.Tensor((M, K), T.float8_e5m2), + B: T.Tensor((N, K), T.float8_e5m2), + XScales: T.Tensor((M,), T.float32), + WScales: T.Tensor((N,), T.float16), + C: T.Tensor((M, N), T.bfloat16), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + zero_f32 = tir.const(0.0, T.float32) + zero_f16 = tir.const(0, T.float16) + zero_fp8 = tir.const(0, T.float8_e5m2) + + A_shared = T.alloc_shared((block_M, block_K), T.float8_e5m2) + B_shared = T.alloc_shared((block_N, block_K), T.float8_e5m2) + + C_local = T.alloc_fragment((block_M, block_N), T.float32) + C_out = T.alloc_fragment((block_M, block_N), T.bfloat16) + + T.clear(C_local) + + if aligned: + num_k_blocks = K // block_K + for k in T.Pipelined(num_k_blocks, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + else: + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=num_stages): + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k * block_K + j + A_shared[i, j] = T.if_then_else((m < M) & (kk < K), A[m, kk], zero_fp8) + + for i, j in T.Parallel(block_N, block_K): + n = bx * block_N + i + kk = k * block_K + j + B_shared[i, j] = T.if_then_else((n < N) & (kk < K), B[n, kk], zero_fp8) + + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + if aligned: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + x_s = XScales[m] + w_s = WScales[n].astype(T.float32) + C_out[i, j] = (C_local[i, j] * x_s * w_s).astype(T.bfloat16) + T.copy(C_out, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) + else: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + x_s = T.if_then_else(m < M, XScales[m], zero_f32) + w_s_f16 = T.if_then_else(n < N, WScales[n], zero_f16) + w_s = w_s_f16.astype(T.float32) + val = (C_local[i, j] * x_s * w_s).astype(T.bfloat16) + if (m < M) & (n < N): + C[m, n] = val + + return main + + +@tilelang.autotune(configs=build_linear_configs()) +@tilelang.jit(out_idx=[5]) +def gptq_w4a16_gemm( + M: int, + N: int, + K: int, + num_groups: int, + group_size: int, + block_M: int = 64, + block_N: int = 64, + block_K: int = 128, + num_stages: int = 2, + threads: int = 128, +): + """GPTQ W4A16 GEMM kernel: bf16 activation × GPTQ int4 weight (packed in int8, groupwise dequantized). + + Args: + M: Number of rows in activation matrix A + N: Number of output channels (rows in weight matrix B) + K: Inner dimension (columns in A, rows in B) + num_groups: Number of quantization groups + group_size: Size of each group + block_M: Block size for M dimension + block_N: Block size for N dimension + block_K: Block size for K dimension + num_stages: Number of pipeline stages + threads: Number of threads per block + + Returns: + Compiled TileLang kernel function with signature: + kernel(A: bf16[M, K], QWeight: int8[N, (K+1)//2], QZeros: int8[num_groups, (K+1)//2], + Scales: float32[num_groups, K], GIdx: int32[N], C: bf16[M, N]) -> None + """ + aligned = (M % block_M == 0) and (N % block_N == 0) and (K % block_K == 0) + packed_K = (K + 1) // 2 + + @T.prim_func + def main( + A: T.Tensor((M, K), T.bfloat16), + QWeight: T.Tensor((N, packed_K), T.int8), + QZeros: T.Tensor((num_groups, packed_K), T.int8), + Scales: T.Tensor((num_groups, K), T.float32), + GIdx: T.Tensor((N,), T.int32), + C: T.Tensor((M, N), T.bfloat16), + ): + """GPTQ W4A16 GEMM kernel implementation with groupwise dequantization.""" + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + zero_i8 = tir.const(0, T.int8) + zero_bf16 = tir.const(0, T.bfloat16) + zero_f32 = tir.const(0, T.float32) + + # Constants for int4 unpacking + int4_offset = tir.const(8, T.int8) + mask_lower = tir.const(0x0F, T.int8) + mask_upper_shift = tir.const(4, T.int8) + + # Allocate shared memory buffers + A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) + QWeight_shared = T.alloc_shared((block_N, (block_K + 1) // 2), T.int8) + QZeros_shared = T.alloc_shared((num_groups, (block_K + 1) // 2), T.int8) + + # Allocate fragments + QWeight_local = T.alloc_fragment((block_N, (block_K + 1) // 2), T.int8) + QZeros_local = T.alloc_fragment((num_groups, (block_K + 1) // 2), T.int8) + W_unpacked_local = T.alloc_fragment((block_N, block_K), T.int8) + Z_unpacked_local = T.alloc_fragment((num_groups, block_K), T.int8) + W_dequant_local = T.alloc_fragment((block_N, block_K), T.bfloat16) + W_dequant_prev_local = T.alloc_fragment((block_N, block_K), T.bfloat16) + + # Allocate fragment for accumulation + C_local = T.alloc_fragment((block_M, block_N), T.float32) + + # Clear accumulation buffer + T.clear(C_local) + + if aligned: + num_k_blocks = K // block_K + for k in T.Pipelined(num_k_blocks, num_stages=num_stages): + # Load A tile + T.copy(A[by * block_M, k * block_K], A_shared) + + # Load QWeight and QZeros tiles + packed_k_start = (k * block_K) // 2 + T.copy(QWeight[bx * block_N, packed_k_start], QWeight_shared) + T.copy(QZeros[0:num_groups, packed_k_start], QZeros_shared) + + # Copy to local fragments + T.copy(QWeight_shared, QWeight_local) + T.copy(QZeros_shared, QZeros_local) + + # Unpack QWeight int4 -> int8 + for i, j in T.Parallel(block_N, block_K): + j_packed = j // 2 + packed_byte = QWeight_local[i, j_packed] + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset + upper_int4 = upper_uint - int4_offset + is_lower = (j % 2) == 0 + W_unpacked_local[i, j] = T.if_then_else(is_lower, lower_int4, upper_int4) + + # Unpack QZeros int4 -> int8 + for g, j in T.Parallel(num_groups, block_K): + j_packed = j // 2 + packed_byte = QZeros_local[g, j_packed] + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset + upper_int4 = upper_uint - int4_offset + is_lower = (j % 2) == 0 + Z_unpacked_local[g, j] = T.if_then_else(is_lower, lower_int4, upper_int4) + + # Dequantize weights: weight = quantized_int4 * scale + zero + # where zero = zero_quantized_int4 * scale + for i, j in T.Parallel(block_N, block_K): + n = bx * block_N + i + kk = k * block_K + j + # Get group_id from GIdx, clamp to [0, num_groups-1] + group_id = GIdx[n] + group_id = T.if_then_else(group_id < 0, 0, group_id) + group_id = T.if_then_else(group_id >= num_groups, num_groups - 1, group_id) + + # Get scale and zero_quantized + scale = Scales[group_id, kk] + zero_quantized = Z_unpacked_local[group_id, j].astype(T.float32) + weight_quantized = W_unpacked_local[i, j].astype(T.float32) + + # Dequantize: weight = weight_quantized * scale + zero_quantized * scale + zero = zero_quantized * scale + weight_dequant = weight_quantized * scale + zero + W_dequant_local[i, j] = weight_dequant.astype(T.bfloat16) + + # Copy to prev_local for pipeline synchronization + T.copy(W_dequant_local, W_dequant_prev_local) + + # GEMM: C = A @ W_dequant^T + T.gemm(A_shared, W_dequant_prev_local, C_local, transpose_B=True) + else: + # Tail-safe kernel + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=num_stages): + # Masked load A + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k * block_K + j + A_shared[i, j] = T.if_then_else((m < M) & (kk < K), A[m, kk], zero_bf16) + + # Masked load QWeight + packed_k_start = (k * block_K) // 2 + packed_k_size = (block_K + 1) // 2 + for i, j_packed in T.Parallel(block_N, packed_k_size): + n = bx * block_N + i + packed_idx = packed_k_start + j_packed + QWeight_shared[i, j_packed] = T.if_then_else( + (n < N) & (packed_idx < packed_K), + QWeight[n, packed_idx], + zero_i8, + ) + + # Masked load QZeros + for g, j_packed in T.Parallel(num_groups, packed_k_size): + packed_idx = packed_k_start + j_packed + QZeros_shared[g, j_packed] = T.if_then_else( + (g < num_groups) & (packed_idx < packed_K), + QZeros[g, packed_idx], + zero_i8, + ) + + # Copy to local fragments + T.copy(QWeight_shared, QWeight_local) + T.copy(QZeros_shared, QZeros_local) + + # Unpack QWeight with boundary checks + for i, j in T.Parallel(block_N, block_K): + kk = k * block_K + j + j_packed = j // 2 + packed_byte = QWeight_local[i, j_packed] + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset + upper_int4 = upper_uint - int4_offset + is_lower = (j % 2) == 0 + int4_val = T.if_then_else(is_lower, lower_int4, upper_int4) + in_bounds = (kk < K) & (j < block_K) + W_unpacked_local[i, j] = T.if_then_else(in_bounds, int4_val, zero_i8) + + # Unpack QZeros with boundary checks + for g, j in T.Parallel(num_groups, block_K): + kk = k * block_K + j + j_packed = j // 2 + packed_byte = QZeros_local[g, j_packed] + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset + upper_int4 = upper_uint - int4_offset + is_lower = (j % 2) == 0 + int4_val = T.if_then_else(is_lower, lower_int4, upper_int4) + in_bounds = (kk < K) & (j < block_K) & (g < num_groups) + Z_unpacked_local[g, j] = T.if_then_else(in_bounds, int4_val, zero_i8) + + # Dequantize weights with boundary checks + for i, j in T.Parallel(block_N, block_K): + n = bx * block_N + i + kk = k * block_K + j + in_bounds = (n < N) & (kk < K) + n = bx * block_N + i + kk = k * block_K + j + in_bounds = (n < N) & (kk < K) + in_bounds = (n < N) & (kk < K) + + # Get group_id from GIdx, clamp to [0, num_groups-1] + group_id = GIdx[n] + group_id = T.if_then_else(group_id < 0, 0, group_id) + group_id = T.if_then_else(group_id >= num_groups, num_groups - 1, group_id) + + # Get scale and zero_quantized (use safe values when out of bounds) + scale = T.if_then_else(in_bounds, Scales[group_id, kk], zero_f32) + zero_quantized = Z_unpacked_local[group_id, j].astype(T.float32) + weight_quantized = W_unpacked_local[i, j].astype(T.float32) + + # Dequantize + zero = zero_quantized * scale + weight_dequant = weight_quantized * scale + zero + W_dequant_local[i, j] = T.if_then_else( + in_bounds, + weight_dequant.astype(T.bfloat16), + zero_bf16 + ) + + # Copy to prev_local + T.copy(W_dequant_local, W_dequant_prev_local) + + # GEMM + T.gemm(A_shared, W_dequant_prev_local, C_local, transpose_B=True) + + # Store output + if aligned: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + C[m, n] = C_local[i, j].astype(T.bfloat16) + else: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + if (m < M) & (n < N): + C[m, n] = C_local[i, j].astype(T.bfloat16) + + return main + + +@tilelang.autotune(configs=build_linear_configs()) +@tilelang.jit(out_idx=[4]) +def awq_w4a16_gemm( + M: int, + N: int, + K: int, + num_groups: int, + group_size: int, + block_M: int = 64, + block_N: int = 64, + block_K: int = 128, + num_stages: int = 2, + threads: int = 128, +): + """AWQ W4A16 GEMM kernel: bf16 activation × AWQ int4 weight (packed in int8, groupwise dequantized). + + Args: + M: Number of rows in activation matrix A + N: Number of output channels (rows in weight matrix B) + K: Inner dimension (columns in A, rows in B) + num_groups: Number of quantization groups + group_size: Size of each group + block_M: Block size for M dimension + block_N: Block size for N dimension + block_K: Block size for K dimension + num_stages: Number of pipeline stages + threads: Number of threads per block + + Returns: + Compiled TileLang kernel function with signature: + kernel(A: bf16[M, K], QWeight: int8[N, (K+1)//2], QZeros: int8[num_groups, (K+1)//2], + Scales: float32[num_groups, K], C: bf16[M, N]) -> None + """ + aligned = (M % block_M == 0) and (N % block_N == 0) and (K % block_K == 0) + packed_K = (K + 1) // 2 + + @T.prim_func + def main( + A: T.Tensor((M, K), T.bfloat16), + QWeight: T.Tensor((N, packed_K), T.int8), + QZeros: T.Tensor((num_groups, packed_K), T.int8), + Scales: T.Tensor((num_groups, K), T.float32), + C: T.Tensor((M, N), T.bfloat16), + ): + """AWQ W4A16 GEMM kernel implementation with groupwise dequantization (sequential grouping).""" + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + zero_i8 = tir.const(0, T.int8) + zero_bf16 = tir.const(0, T.bfloat16) + zero_f32 = tir.const(0, T.float32) + + # Constants for int4 unpacking + int4_offset = tir.const(8, T.int8) + mask_lower = tir.const(0x0F, T.int8) + mask_upper_shift = tir.const(4, T.int8) + + # Allocate shared memory buffers + A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) + QWeight_shared = T.alloc_shared((block_N, (block_K + 1) // 2), T.int8) + QZeros_shared = T.alloc_shared((num_groups, (block_K + 1) // 2), T.int8) + + # Allocate fragments + QWeight_local = T.alloc_fragment((block_N, (block_K + 1) // 2), T.int8) + QZeros_local = T.alloc_fragment((num_groups, (block_K + 1) // 2), T.int8) + W_unpacked_local = T.alloc_fragment((block_N, block_K), T.int8) + Z_unpacked_local = T.alloc_fragment((num_groups, block_K), T.int8) + W_dequant_local = T.alloc_fragment((block_N, block_K), T.bfloat16) + W_dequant_prev_local = T.alloc_fragment((block_N, block_K), T.bfloat16) + + # Allocate fragment for accumulation + C_local = T.alloc_fragment((block_M, block_N), T.float32) + + # Clear accumulation buffer + T.clear(C_local) + + if aligned: + num_k_blocks = K // block_K + for k in T.Pipelined(num_k_blocks, num_stages=num_stages): + # Load A tile + T.copy(A[by * block_M, k * block_K], A_shared) + + # Load QWeight and QZeros tiles + packed_k_start = (k * block_K) // 2 + T.copy(QWeight[bx * block_N, packed_k_start], QWeight_shared) + T.copy(QZeros[0:num_groups, packed_k_start], QZeros_shared) + + # Copy to local fragments + T.copy(QWeight_shared, QWeight_local) + T.copy(QZeros_shared, QZeros_local) + + # Unpack QWeight int4 -> int8 + for i, j in T.Parallel(block_N, block_K): + j_packed = j // 2 + packed_byte = QWeight_local[i, j_packed] + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset + upper_int4 = upper_uint - int4_offset + is_lower = (j % 2) == 0 + W_unpacked_local[i, j] = T.if_then_else(is_lower, lower_int4, upper_int4) + + # Unpack QZeros int4 -> int8 + for g, j in T.Parallel(num_groups, block_K): + j_packed = j // 2 + packed_byte = QZeros_local[g, j_packed] + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset + upper_int4 = upper_uint - int4_offset + is_lower = (j % 2) == 0 + Z_unpacked_local[g, j] = T.if_then_else(is_lower, lower_int4, upper_int4) + + # Dequantize weights: weight = quantized_int4 * scale + zero + # where zero = zero_quantized_int4 * scale + # AWQ uses sequential grouping: group_id = n // group_size + for i, j in T.Parallel(block_N, block_K): + n = bx * block_N + i + kk = k * block_K + j + # Compute group_id using sequential grouping + group_id = n // group_size + # Clamp to [0, num_groups-1] + group_id = T.if_then_else(group_id < 0, 0, group_id) + group_id = T.if_then_else(group_id >= num_groups, num_groups - 1, group_id) + + # Get scale and zero_quantized + scale = Scales[group_id, kk] + zero_quantized = Z_unpacked_local[group_id, j].astype(T.float32) + weight_quantized = W_unpacked_local[i, j].astype(T.float32) + + # Dequantize: weight = weight_quantized * scale + zero_quantized * scale + zero = zero_quantized * scale + weight_dequant = weight_quantized * scale + zero + W_dequant_local[i, j] = weight_dequant.astype(T.bfloat16) + + # Copy to prev_local for pipeline synchronization + T.copy(W_dequant_local, W_dequant_prev_local) + + # GEMM: C = A @ W_dequant^T + T.gemm(A_shared, W_dequant_prev_local, C_local, transpose_B=True) + else: + # Tail-safe kernel + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=num_stages): + # Masked load A + for i, j in T.Parallel(block_M, block_K): + m = by * block_M + i + kk = k * block_K + j + A_shared[i, j] = T.if_then_else((m < M) & (kk < K), A[m, kk], zero_bf16) + + # Masked load QWeight + packed_k_start = (k * block_K) // 2 + packed_k_size = (block_K + 1) // 2 + for i, j_packed in T.Parallel(block_N, packed_k_size): + n = bx * block_N + i + packed_idx = packed_k_start + j_packed + QWeight_shared[i, j_packed] = T.if_then_else( + (n < N) & (packed_idx < packed_K), + QWeight[n, packed_idx], + zero_i8, + ) + + # Masked load QZeros + for g, j_packed in T.Parallel(num_groups, packed_k_size): + packed_idx = packed_k_start + j_packed + QZeros_shared[g, j_packed] = T.if_then_else( + (g < num_groups) & (packed_idx < packed_K), + QZeros[g, packed_idx], + zero_i8, + ) + + # Copy to local fragments + T.copy(QWeight_shared, QWeight_local) + T.copy(QZeros_shared, QZeros_local) + + # Unpack QWeight with boundary checks + for i, j in T.Parallel(block_N, block_K): + kk = k * block_K + j + j_packed = j // 2 + packed_byte = QWeight_local[i, j_packed] + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset + upper_int4 = upper_uint - int4_offset + is_lower = (j % 2) == 0 + int4_val = T.if_then_else(is_lower, lower_int4, upper_int4) + in_bounds = (kk < K) & (j < block_K) + W_unpacked_local[i, j] = T.if_then_else(in_bounds, int4_val, zero_i8) + + # Unpack QZeros with boundary checks + for g, j in T.Parallel(num_groups, block_K): + kk = k * block_K + j + j_packed = j // 2 + packed_byte = QZeros_local[g, j_packed] + lower_uint = (packed_byte & mask_lower).astype(T.int8) + upper_uint = ((packed_byte >> mask_upper_shift) & mask_lower).astype(T.int8) + lower_int4 = lower_uint - int4_offset + upper_int4 = upper_uint - int4_offset + is_lower = (j % 2) == 0 + int4_val = T.if_then_else(is_lower, lower_int4, upper_int4) + in_bounds = (kk < K) & (j < block_K) & (g < num_groups) + Z_unpacked_local[g, j] = T.if_then_else(in_bounds, int4_val, zero_i8) + + # Dequantize weights with boundary checks + # AWQ uses sequential grouping: group_id = n // group_size + for i, j in T.Parallel(block_N, block_K): + n = bx * block_N + i + kk = k * block_K + j + in_bounds = (n < N) & (kk < K) + # Compute group_id using sequential grouping + group_id = n // group_size + # Clamp to [0, num_groups-1] + group_id = T.if_then_else(group_id < 0, 0, group_id) + group_id = T.if_then_else(group_id >= num_groups, num_groups - 1, group_id) + + # Get scale and zero_quantized + scale = T.if_then_else(in_bounds, Scales[group_id, kk], zero_f32) + zero_quantized = Z_unpacked_local[group_id, j].astype(T.float32) + weight_quantized = W_unpacked_local[i, j].astype(T.float32) + + # Dequantize + zero = zero_quantized * scale + weight_dequant = weight_quantized * scale + zero + W_dequant_local[i, j] = T.if_then_else( + in_bounds, + weight_dequant.astype(T.bfloat16), + zero_bf16 + ) + + # Copy to prev_local + T.copy(W_dequant_local, W_dequant_prev_local) + + # GEMM + T.gemm(A_shared, W_dequant_prev_local, C_local, transpose_B=True) + + # Store output + if aligned: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + C[m, n] = C_local[i, j].astype(T.bfloat16) + else: + for i, j in T.Parallel(block_M, block_N): + m = by * block_M + i + n = bx * block_N + j + if (m < M) & (n < N): + C[m, n] = C_local[i, j].astype(T.bfloat16) + + return main diff --git a/diffulex_kernel/python/paged_attn_decode_triton.py b/diffulex_kernel/python/paged_attn_decode_triton.py new file mode 100644 index 0000000..1fabf19 --- /dev/null +++ b/diffulex_kernel/python/paged_attn_decode_triton.py @@ -0,0 +1,661 @@ +import torch +import triton +import triton.language as tl + +import os + +from diffulex.attention.metadata import AttnMetaDataBase + + +@triton.jit +def _paged_decode_attn_unified_bf16_cache_kernel( + q_ptr, + k_ptr, + v_ptr, + k_cache_ptr, + v_cache_ptr, + block_tables_ptr, + context_lens_ptr, + cu_seqlens_q_ptr, + o_ptr, + softmax_scale, # fp32 scalar + # q/k/v/o strides + q_stride_s, + q_stride_h, + q_stride_d, + kv_stride_s, + kv_stride_h, + kv_stride_d, + o_stride_s, + o_stride_h, + o_stride_d, + # cache strides: [nblks, page, kvh, d] + k_cache_stride_nblks, + k_cache_stride_page, + k_cache_stride_h, + k_cache_stride_d, + v_cache_stride_nblks, + v_cache_stride_page, + v_cache_stride_h, + v_cache_stride_d, + # block_tables strides + block_tables_stride_s, + block_tables_stride_b, + # misc + NUM_GROUPS: tl.constexpr, + HEAD_DIM: tl.constexpr, + HEAD_DIM_PADDED: tl.constexpr, + PAGE_SIZE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_seq = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + kv_head = pid_head // NUM_GROUPS + + q_start = tl.load(cu_seqlens_q_ptr + pid_seq).to(tl.int32) + q_end = tl.load(cu_seqlens_q_ptr + pid_seq + 1).to(tl.int32) + q_len = q_end - q_start + new_len = q_len # decode path: current-step KV length matches query length + context_len = tl.load(context_lens_ptr + pid_seq).to(tl.int32) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_DIM_PADDED) + mask_m = offs_m < q_len + mask_d = offs_d < HEAD_DIM + + q_offs = (q_start + offs_m[:, None]) * q_stride_s + pid_head * q_stride_h + offs_d[None, :] * q_stride_d + q = tl.load(q_ptr + q_offs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.bfloat16) + + m = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM_PADDED], dtype=tl.float32) + + # Cache stage: iterate only needed blocks (dynamic loop, like vLLM kernels). + offs_n_cache = tl.arange(0, BLOCK_N) + tok_off_cache = offs_n_cache + mask_n_cache = offs_n_cache < PAGE_SIZE + + num_cache_blocks = (context_len + PAGE_SIZE - 1) // PAGE_SIZE + for blk in range(0, num_cache_blocks): + page = tl.load(block_tables_ptr + pid_seq * block_tables_stride_s + blk * block_tables_stride_b).to(tl.int32) + tok_base = blk * PAGE_SIZE + tok_idx = tok_base + tok_off_cache + valid_tok = (page >= 0) & (tok_idx < context_len) & mask_n_cache + + k_offs = ( + page * k_cache_stride_nblks + + tok_off_cache[:, None] * k_cache_stride_page + + kv_head * k_cache_stride_h + + offs_d[None, :] * k_cache_stride_d + ) + k_blk = tl.load( + k_cache_ptr + k_offs, + mask=valid_tok[:, None] & mask_d[None, :], + other=0.0, + ).to(tl.bfloat16) + + scores = tl.dot(q, tl.trans(k_blk)).to(tl.float32) * softmax_scale + scores = tl.where(mask_m[:, None] & valid_tok[None, :], scores, float("-inf")) + + m_new = tl.maximum(m, tl.max(scores, axis=1)) + p = tl.exp(scores - m_new[:, None]) + l_new = l * tl.exp(m - m_new) + tl.sum(p, axis=1) + alpha = tl.exp(m - m_new) + acc *= alpha[:, None] + + v_offs = ( + page * v_cache_stride_nblks + + tok_off_cache[:, None] * v_cache_stride_page + + kv_head * v_cache_stride_h + + offs_d[None, :] * v_cache_stride_d + ) + v_blk = tl.load( + v_cache_ptr + v_offs, + mask=valid_tok[:, None] & mask_d[None, :], + other=0.0, + ).to(tl.bfloat16) + + acc += tl.dot(p.to(tl.bfloat16), v_blk).to(tl.float32) + m = m_new + l = l_new + + # New KV stage (dynamic tiles) + kv_start = q_start + for start_n in range(0, new_len, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + valid_tok = offs_n < new_len + + k_offs = (kv_start + offs_n[None, :]) * kv_stride_s + kv_head * kv_stride_h + offs_d[:, None] * kv_stride_d + k_blk = tl.load( + k_ptr + k_offs, + mask=valid_tok[None, :] & mask_d[:, None], + other=0.0, + ).to(tl.bfloat16) + + scores = tl.dot(q, k_blk).to(tl.float32) * softmax_scale + scores = tl.where(mask_m[:, None] & valid_tok[None, :], scores, float("-inf")) + + m_new = tl.maximum(m, tl.max(scores, axis=1)) + p = tl.exp(scores - m_new[:, None]) + l_new = l * tl.exp(m - m_new) + tl.sum(p, axis=1) + alpha = tl.exp(m - m_new) + acc *= alpha[:, None] + + v_offs = (kv_start + offs_n[:, None]) * kv_stride_s + kv_head * kv_stride_h + offs_d[None, :] * kv_stride_d + v_blk = tl.load( + v_ptr + v_offs, + mask=valid_tok[:, None] & mask_d[None, :], + other=0.0, + ).to(tl.bfloat16) + + acc += tl.dot(p.to(tl.bfloat16), v_blk).to(tl.float32) + m = m_new + l = l_new + + out = acc / l[:, None] + o_offs = (q_start + offs_m[:, None]) * o_stride_s + pid_head * o_stride_h + offs_d[None, :] * o_stride_d + tl.store(o_ptr + o_offs, out.to(tl.bfloat16), mask=mask_m[:, None] & mask_d[None, :]) + + +@triton.jit +def _paged_decode_attn_unified_fp8_cache_kernel_legacy( + q_ptr, + k_ptr, + v_ptr, + k_cache_ptr, + v_cache_ptr, + k_scale_ptr, + v_scale_ptr, + block_tables_ptr, + context_lens_ptr, + cu_seqlens_q_ptr, + o_ptr, + softmax_scale, # fp32 scalar + # q/k/v/o strides + q_stride_s, + q_stride_h, + q_stride_d, + kv_stride_s, + kv_stride_h, + kv_stride_d, + o_stride_s, + o_stride_h, + o_stride_d, + # cache strides: [nblks, page, kvh, d] + k_cache_stride_nblks, + k_cache_stride_page, + k_cache_stride_h, + k_cache_stride_d, + v_cache_stride_nblks, + v_cache_stride_page, + v_cache_stride_h, + v_cache_stride_d, + # block_tables strides + block_tables_stride_s, + block_tables_stride_b, + # misc + NUM_GROUPS: tl.constexpr, + HEAD_DIM: tl.constexpr, + HEAD_DIM_PADDED: tl.constexpr, + PAGE_SIZE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_seq = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + kv_head = pid_head // NUM_GROUPS + k_scale = tl.load(k_scale_ptr + kv_head).to(tl.float32) + v_scale = tl.load(v_scale_ptr + kv_head).to(tl.float32) + + q_start = tl.load(cu_seqlens_q_ptr + pid_seq).to(tl.int32) + q_end = tl.load(cu_seqlens_q_ptr + pid_seq + 1).to(tl.int32) + q_len = q_end - q_start + new_len = q_len + context_len = tl.load(context_lens_ptr + pid_seq).to(tl.int32) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_DIM_PADDED) + mask_m = offs_m < q_len + mask_d = offs_d < HEAD_DIM + + q_offs = (q_start + offs_m[:, None]) * q_stride_s + pid_head * q_stride_h + offs_d[None, :] * q_stride_d + q = tl.load(q_ptr + q_offs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.bfloat16) + + m = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM_PADDED], dtype=tl.float32) + + offs_n_cache = tl.arange(0, BLOCK_N) + tok_off_cache = offs_n_cache + mask_n_cache = offs_n_cache < PAGE_SIZE + + num_cache_blocks = (context_len + PAGE_SIZE - 1) // PAGE_SIZE + for blk in range(0, num_cache_blocks): + page = tl.load(block_tables_ptr + pid_seq * block_tables_stride_s + blk * block_tables_stride_b).to(tl.int32) + tok_base = blk * PAGE_SIZE + tok_idx = tok_base + tok_off_cache + valid_tok = (page >= 0) & (tok_idx < context_len) & mask_n_cache + + k_offs = ( + page * k_cache_stride_nblks + + tok_off_cache[:, None] * k_cache_stride_page + + kv_head * k_cache_stride_h + + offs_d[None, :] * k_cache_stride_d + ) + # fp8 cache values: dot(Q, K_fp8) * k_scale == dot(Q, (K_fp8*k_scale)) + k_blk = tl.load( + k_cache_ptr + k_offs, + mask=valid_tok[:, None] & mask_d[None, :], + other=0.0, + ).to(tl.bfloat16) + + scores = tl.dot(q, tl.trans(k_blk)).to(tl.float32) * (softmax_scale * k_scale) + scores = tl.where(mask_m[:, None] & valid_tok[None, :], scores, float("-inf")) + + m_new = tl.maximum(m, tl.max(scores, axis=1)) + p = tl.exp(scores - m_new[:, None]) + l_new = l * tl.exp(m - m_new) + tl.sum(p, axis=1) + alpha = tl.exp(m - m_new) + acc *= alpha[:, None] + + v_offs = ( + page * v_cache_stride_nblks + + tok_off_cache[:, None] * v_cache_stride_page + + kv_head * v_cache_stride_h + + offs_d[None, :] * v_cache_stride_d + ) + v_blk = tl.load( + v_cache_ptr + v_offs, + mask=valid_tok[:, None] & mask_d[None, :], + other=0.0, + ).to(tl.bfloat16) + + # Apply v_scale on weights (cheaper than scaling V elementwise). + acc += tl.dot((p * v_scale).to(tl.bfloat16), v_blk).to(tl.float32) + m = m_new + l = l_new + + kv_start = q_start + for start_n in range(0, new_len, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + valid_tok = offs_n < new_len + + k_offs = (kv_start + offs_n[None, :]) * kv_stride_s + kv_head * kv_stride_h + offs_d[:, None] * kv_stride_d + k_blk = tl.load( + k_ptr + k_offs, + mask=valid_tok[None, :] & mask_d[:, None], + other=0.0, + ).to(tl.bfloat16) + + scores = tl.dot(q, k_blk).to(tl.float32) * softmax_scale + scores = tl.where(mask_m[:, None] & valid_tok[None, :], scores, float("-inf")) + + m_new = tl.maximum(m, tl.max(scores, axis=1)) + p = tl.exp(scores - m_new[:, None]) + l_new = l * tl.exp(m - m_new) + tl.sum(p, axis=1) + alpha = tl.exp(m - m_new) + acc *= alpha[:, None] + + v_offs = (kv_start + offs_n[:, None]) * kv_stride_s + kv_head * kv_stride_h + offs_d[None, :] * kv_stride_d + v_blk = tl.load( + v_ptr + v_offs, + mask=valid_tok[:, None] & mask_d[None, :], + other=0.0, + ).to(tl.bfloat16) + + acc += tl.dot(p.to(tl.bfloat16), v_blk).to(tl.float32) + m = m_new + l = l_new + + out = acc / l[:, None] + o_offs = (q_start + offs_m[:, None]) * o_stride_s + pid_head * o_stride_h + offs_d[None, :] * o_stride_d + tl.store(o_ptr + o_offs, out.to(tl.bfloat16), mask=mask_m[:, None] & mask_d[None, :]) + + +@triton.jit +def _paged_decode_attn_unified_fp8_cache_fused_dot_kernel( + q_ptr, + k_ptr, + v_ptr, + k_cache_ptr, + v_cache_ptr, + k_scale_ptr, + v_scale_ptr, + block_tables_ptr, + context_lens_ptr, + cu_seqlens_q_ptr, + o_ptr, + softmax_scale, # fp32 scalar + # q/k/v/o strides + q_stride_s, + q_stride_h, + q_stride_d, + kv_stride_s, + kv_stride_h, + kv_stride_d, + o_stride_s, + o_stride_h, + o_stride_d, + # cache strides: [nblks, page, kvh, d] + k_cache_stride_nblks, + k_cache_stride_page, + k_cache_stride_h, + k_cache_stride_d, + v_cache_stride_nblks, + v_cache_stride_page, + v_cache_stride_h, + v_cache_stride_d, + # block_tables strides + block_tables_stride_s, + block_tables_stride_b, + # misc + KV_FORMAT: tl.constexpr, + NUM_GROUPS: tl.constexpr, + HEAD_DIM: tl.constexpr, + HEAD_DIM_PADDED: tl.constexpr, + PAGE_SIZE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + FP8-cache decode kernel with *fused* fp8 math: + - Keep KV cache tiles in float8 (via fp8 view tensor) + - Use tl.dot_scaled(..., rhs_format="e4m3/e5m2") to consume fp8 without explicit dequant tensors + - Apply per-head scalar scales (k_scale/v_scale) without elementwise dequantization + """ + pid_seq = tl.program_id(0) + pid_head = tl.program_id(1) + pid_m = tl.program_id(2) + + kv_head = pid_head // NUM_GROUPS + k_scale = tl.load(k_scale_ptr + kv_head).to(tl.float32) + v_scale = tl.load(v_scale_ptr + kv_head).to(tl.float32) + + q_start = tl.load(cu_seqlens_q_ptr + pid_seq).to(tl.int32) + q_end = tl.load(cu_seqlens_q_ptr + pid_seq + 1).to(tl.int32) + q_len = q_end - q_start + new_len = q_len + context_len = tl.load(context_lens_ptr + pid_seq).to(tl.int32) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_DIM_PADDED) + mask_m = offs_m < q_len + mask_d = offs_d < HEAD_DIM + + # Load Q (bf16). Note: triton 3.5 `tl.dot` does not support mixed bf16/fp16 x fp8. + # We use `tl.dot_scaled` (microscaling) to accept fp8 operands. + q_offs = (q_start + offs_m[:, None]) * q_stride_s + pid_head * q_stride_h + offs_d[None, :] * q_stride_d + q = tl.load(q_ptr + q_offs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.bfloat16) + + m = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM_PADDED], dtype=tl.float32) + + offs_n_cache = tl.arange(0, BLOCK_N) + tok_off_cache = offs_n_cache + mask_n_cache = offs_n_cache < PAGE_SIZE + + num_cache_blocks = (context_len + PAGE_SIZE - 1) // PAGE_SIZE + for blk in range(0, num_cache_blocks): + page = tl.load(block_tables_ptr + pid_seq * block_tables_stride_s + blk * block_tables_stride_b).to(tl.int32) + tok_base = blk * PAGE_SIZE + tok_idx = tok_base + tok_off_cache + valid_tok = (page >= 0) & (tok_idx < context_len) & mask_n_cache + + # K cache: keep fp8 element type; load as [K, N] to match dot_scaled rhs layout. + k_offs = ( + page * k_cache_stride_nblks + + tok_off_cache[None, :] * k_cache_stride_page + + kv_head * k_cache_stride_h + + offs_d[:, None] * k_cache_stride_d + ) + k_blk = tl.load( + k_cache_ptr + k_offs, + mask=mask_d[:, None] & valid_tok[None, :], + other=0.0, + ) + + # scores = QK^T * softmax_scale, with scalar k_scale applied after dot: + # dot(Q, K_true) == dot(Q, K_fp8) * k_scale (per-head scalar scale). + scores = tl.dot_scaled( + q, + None, + "bf16", + k_blk, + None, + KV_FORMAT, + ) * (softmax_scale * k_scale) + scores = tl.where(mask_m[:, None] & valid_tok[None, :], scores, float("-inf")) + + m_new = tl.maximum(m, tl.max(scores, axis=1)) + p = tl.exp(scores - m_new[:, None]) + l_new = l * tl.exp(m - m_new) + tl.sum(p, axis=1) + alpha = tl.exp(m - m_new) + acc *= alpha[:, None] + + # V cache: keep fp8 element type for tl.dot. + v_offs = ( + page * v_cache_stride_nblks + + tok_off_cache[:, None] * v_cache_stride_page + + kv_head * v_cache_stride_h + + offs_d[None, :] * v_cache_stride_d + ) + v_blk = tl.load( + v_cache_ptr + v_offs, + mask=valid_tok[:, None] & mask_d[None, :], + other=0.0, + ) + + # acc += P @ V_true == (P @ V_fp8) * v_scale + acc += tl.dot_scaled( + p.to(tl.float16), + None, + "fp16", + v_blk, + None, + KV_FORMAT, + ) * v_scale + m = m_new + l = l_new + + # New KV stage (bf16 tensors, unchanged) + kv_start = q_start + for start_n in range(0, new_len, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + valid_tok = offs_n < new_len + + k_offs = (kv_start + offs_n[None, :]) * kv_stride_s + kv_head * kv_stride_h + offs_d[:, None] * kv_stride_d + k_blk = tl.load( + k_ptr + k_offs, + mask=valid_tok[None, :] & mask_d[:, None], + other=0.0, + ).to(tl.bfloat16) + + scores = tl.dot(q, k_blk, out_dtype=tl.float32) * softmax_scale + scores = tl.where(mask_m[:, None] & valid_tok[None, :], scores, float("-inf")) + + m_new = tl.maximum(m, tl.max(scores, axis=1)) + p = tl.exp(scores - m_new[:, None]) + l_new = l * tl.exp(m - m_new) + tl.sum(p, axis=1) + alpha = tl.exp(m - m_new) + acc *= alpha[:, None] + + v_offs = (kv_start + offs_n[:, None]) * kv_stride_s + kv_head * kv_stride_h + offs_d[None, :] * kv_stride_d + v_blk = tl.load( + v_ptr + v_offs, + mask=valid_tok[:, None] & mask_d[None, :], + other=0.0, + ).to(tl.bfloat16) + + acc += tl.dot(p.to(tl.bfloat16), v_blk, out_dtype=tl.float32) + m = m_new + l = l_new + + out = acc / l[:, None] + o_offs = (q_start + offs_m[:, None]) * o_stride_s + pid_head * o_stride_h + offs_d[None, :] * o_stride_d + tl.store(o_ptr + o_offs, out.to(tl.bfloat16), mask=mask_m[:, None] & mask_d[None, :]) + + +def paged_attn_decode_unified_triton( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: AttnMetaDataBase, + *, + softmax_scale: float, + fp8_cache: bool, +) -> torch.Tensor: + """ + Triton paged-attention decode for unified KV cache layout. + + q: [total_q, num_heads, head_dim] (bf16) + k/v: [total_q, num_kv_heads, head_dim] (bf16), aligned with cu_seqlens_q + k_cache/v_cache: + - bf16: [num_page_blocks, page_size, num_kv_heads, head_dim] + - fp8 : same shape but dtype must be float8 view for triton (strategy.view_kv_cache_for_kernels) + """ + assert q.is_cuda and k.is_cuda and v.is_cuda and k_cache.is_cuda and v_cache.is_cuda + assert q.dtype == torch.bfloat16 and k.dtype == torch.bfloat16 and v.dtype == torch.bfloat16 + assert attn_metadata.block_tables is not None and attn_metadata.context_lens is not None and attn_metadata.cu_seqlens_q is not None + assert attn_metadata.kv_cache_layout == "unified", f"only unified layout supported, got {attn_metadata.kv_cache_layout}" + + # Be robust to different metadata implementations (dataclass vs SimpleNamespace in tests). + num_seqs = int(attn_metadata.cu_seqlens_q.numel() - 1) + num_heads = q.shape[1] + head_dim = q.shape[2] + num_kv_heads = k.shape[1] + assert num_heads % num_kv_heads == 0 + num_groups = num_heads // num_kv_heads + + page_size = int(attn_metadata.page_block_size) + + # Heuristics: BLOCK_M = 64 (supports diffusion_block_size=32/64), BLOCK_N = page_size/new-tile + BLOCK_M = 64 + BLOCK_N = 32 if page_size <= 32 else 64 + # Cache stage requires BLOCK_N == PAGE_SIZE to simplify; enforce. + if BLOCK_N != page_size: + BLOCK_N = page_size + + head_dim_padded = 1 << (head_dim - 1).bit_length() + + o = torch.empty_like(q) + grid = (num_seqs, num_heads, triton.cdiv(int(attn_metadata.max_seqlen_q), BLOCK_M)) + + if fp8_cache: + if attn_metadata.k_scale is None or attn_metadata.v_scale is None: + raise ValueError("fp8_cache=True requires attn_metadata.k_scale/v_scale") + # Default to fused fp8-dot kernel; fallback to legacy on compile/runtime failures. + # Set DIFFULEX_PAGED_DECODE_FP8_FUSED_DOT=0 to force legacy. + # Set DIFFULEX_PAGED_DECODE_FP8_FUSED_DOT_STRICT=1 to raise instead of fallback. + use_fused_dot = os.getenv("DIFFULEX_PAGED_DECODE_FP8_FUSED_DOT", "1") != "0" + strict_fused = os.getenv("DIFFULEX_PAGED_DECODE_FP8_FUSED_DOT_STRICT", "0") == "1" + if use_fused_dot: + # `tl.dot_scaled` needs the fp8 format string to interpret raw bytes correctly. + # Derive from the fp8 view dtype (torch.float8_*). + dt = str(k_cache.dtype) + if "e4m3" in dt: + kv_format = "e4m3" + elif "e5m2" in dt: + kv_format = "e5m2" + else: + raise ValueError(f"Unsupported fp8 k_cache dtype for fused-dot: {k_cache.dtype}") + try: + _paged_decode_attn_unified_fp8_cache_fused_dot_kernel[grid]( + q, k, v, + k_cache, v_cache, + attn_metadata.k_scale, attn_metadata.v_scale, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.cu_seqlens_q, + o, + softmax_scale, + *q.stride(), *k.stride(), *o.stride(), + *k_cache.stride(), *v_cache.stride(), + *attn_metadata.block_tables.stride(), + KV_FORMAT=kv_format, + NUM_GROUPS=num_groups, + HEAD_DIM=head_dim, + HEAD_DIM_PADDED=head_dim_padded, + PAGE_SIZE=page_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=4, + num_stages=2, + ) + except Exception: + if strict_fused: + raise + _paged_decode_attn_unified_fp8_cache_kernel_legacy[grid]( + q, k, v, + k_cache, v_cache, + attn_metadata.k_scale, attn_metadata.v_scale, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.cu_seqlens_q, + o, + softmax_scale, + *q.stride(), *k.stride(), *o.stride(), + *k_cache.stride(), *v_cache.stride(), + *attn_metadata.block_tables.stride(), + NUM_GROUPS=num_groups, + HEAD_DIM=head_dim, + HEAD_DIM_PADDED=head_dim_padded, + PAGE_SIZE=page_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=4, + num_stages=2, + ) + else: + _paged_decode_attn_unified_fp8_cache_kernel_legacy[grid]( + q, k, v, + k_cache, v_cache, + attn_metadata.k_scale, attn_metadata.v_scale, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.cu_seqlens_q, + o, + softmax_scale, + *q.stride(), *k.stride(), *o.stride(), + *k_cache.stride(), *v_cache.stride(), + *attn_metadata.block_tables.stride(), + NUM_GROUPS=num_groups, + HEAD_DIM=head_dim, + HEAD_DIM_PADDED=head_dim_padded, + PAGE_SIZE=page_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=4, + num_stages=2, + ) + else: + _paged_decode_attn_unified_bf16_cache_kernel[grid]( + q, k, v, + k_cache, v_cache, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.cu_seqlens_q, + o, + softmax_scale, + *q.stride(), *k.stride(), *o.stride(), + *k_cache.stride(), *v_cache.stride(), + *attn_metadata.block_tables.stride(), + NUM_GROUPS=num_groups, + HEAD_DIM=head_dim, + HEAD_DIM_PADDED=head_dim_padded, + PAGE_SIZE=page_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=4, + num_stages=2, + ) + + return o + diff --git a/diffulex_legacy/config.py b/diffulex_legacy/config.py index ecc1408..a5b1dd6 100755 --- a/diffulex_legacy/config.py +++ b/diffulex_legacy/config.py @@ -42,6 +42,7 @@ class Config: num_kvcache_blocks: int = -1 k_cache_hdim_split_factor_x: int = 8 kv_cache_layout: str = "unified" # "unified" or "distinct" + kv_cache_dtype: str = "bf16" # "bf16", "fp16", "fp32", "fp8", "fp8_e4m3", "fp8_e5m2" def __post_init__(self): assert os.path.isdir(self.model) diff --git a/diffulex_legacy/engine/dp_engine.py b/diffulex_legacy/engine/dp_engine.py index 70f8e82..b9da2b7 100755 --- a/diffulex_legacy/engine/dp_engine.py +++ b/diffulex_legacy/engine/dp_engine.py @@ -115,12 +115,10 @@ def __init__(self, model, **kwargs): need_gpus = self.dp_size * cfg.tensor_parallel_size assert len(vis) >= need_gpus, f"Require {need_gpus} GPUs (dp={self.dp_size}, tp={cfg.tensor_parallel_size}), visible {len(vis)}" - # Optional overrides: kwargs['device_ids'] or env D2F_DEVICE_MAP + # Optional overrides: kwargs['device_ids'] override = None if 'device_ids' in kwargs and kwargs['device_ids']: override = list(kwargs['device_ids']) - elif os.environ.get('D2F_DEVICE_MAP'): - override = [int(x) for x in os.environ['D2F_DEVICE_MAP'].split(',') if x.strip() != ''] if override is not None: assert len(override) >= need_gpus, f"device_ids length {len(override)} < required {need_gpus}" # All override devices must be in visible list diff --git a/diffulex_legacy/engine/model_runner.py b/diffulex_legacy/engine/model_runner.py index e7fcd0e..b4fc1d5 100755 --- a/diffulex_legacy/engine/model_runner.py +++ b/diffulex_legacy/engine/model_runner.py @@ -22,6 +22,26 @@ get_context_diffusion_lm, reset_context_diffusion_lm ) +from diffulex.utils.kv_cache_dtype import parse_kv_cache_dtype + + +def _get_kv_cache_storage_info(kv_cache_dtype: str) -> tuple[torch.dtype, int]: + """ + Returns (storage_dtype, itemsize) for KV cache allocation. + - For FP8: returns (torch.uint8, 1) because FP8 uses uint8 storage + - For other dtypes: returns (torch.bfloat16/fp16/fp32, itemsize) + """ + spec = parse_kv_cache_dtype(kv_cache_dtype) + if spec.is_fp8: + return torch.uint8, 1 + elif spec.enum.value == 0: # BF16 + return torch.bfloat16, 2 + elif spec.enum.value == 1: # FP16 + return torch.float16, 2 + elif spec.enum.value == 2: # FP32 + return torch.float32, 4 + else: + raise ValueError(f"Unsupported kv_cache_dtype: {kv_cache_dtype}") class ModelRunnerBase(ABC): @@ -186,6 +206,7 @@ def warmup_model(self): def allocate_kv_cache(self): config = self.config hf_config = config.hf_config + storage_dtype, itemsize = _get_kv_cache_storage_info(config.kv_cache_dtype) free, total = torch.cuda.mem_get_info() used = total - free peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"] @@ -200,14 +221,14 @@ def allocate_kv_cache(self): raise AttributeError(f"Cannot determine head_dim from config: {type(hf_config)}") block_bytes = (2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * - head_dim * hf_config.torch_dtype.itemsize) + head_dim * itemsize) config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes assert config.num_kvcache_blocks > 0 # [kv_separated, layer_id, block_id, block_size(segmented seq_len), head, head_dim] self.kv_cache = torch.zeros( 2, hf_config.num_hidden_layers, config.num_kvcache_blocks, - self.block_size, num_kv_heads, head_dim) + self.block_size, num_kv_heads, head_dim, dtype=storage_dtype) layer_id = 0 for module in self.model.modules(): if hasattr(module, "k_cache") and hasattr(module, "v_cache"): @@ -250,7 +271,7 @@ def prepare_prefill(self, seqs: List[SequenceForCausalLM]): cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) - set_context_causal_lm(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables) + set_context_causal_lm(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables, kv_cache_dtype=self.config.kv_cache_dtype) return input_ids, positions def prepare_decode(self, seqs: List[SequenceForCausalLM]): @@ -271,7 +292,7 @@ def prepare_decode(self, seqs: List[SequenceForCausalLM]): slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) block_tables = self.prepare_block_tables(seqs) - set_context_causal_lm(False, cu_seqlens_k=cu_seqlens_k, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables) + set_context_causal_lm(False, cu_seqlens_k=cu_seqlens_k, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables, kv_cache_dtype=self.config.kv_cache_dtype) return input_ids, positions @torch.inference_mode() @@ -336,7 +357,7 @@ def capture_cudagraph(self): for bs in reversed(self.graph_bs): graph = torch.cuda.CUDAGraph() - set_context_causal_lm(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs]) + set_context_causal_lm(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs], kv_cache_dtype=self.config.kv_cache_dtype) outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup with torch.cuda.graph(graph, self.graph_pool): outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture @@ -381,6 +402,7 @@ def warmup_model(self): def allocate_kv_cache(self): config = self.config hf_config = config.hf_config + storage_dtype, itemsize = _get_kv_cache_storage_info(config.kv_cache_dtype) free, total = torch.cuda.mem_get_info() used = total - free peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"] @@ -394,8 +416,7 @@ def allocate_kv_cache(self): else: raise AttributeError(f"Cannot determine head_dim from config: {type(hf_config)}") - dtype = hf_config.torch_dtype if hasattr(hf_config, "torch_dtype") and hf_config.torch_dtype else torch.bfloat16 - block_bytes = (2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * dtype.itemsize) + block_bytes = (2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * itemsize) get_num_kvcache_blocks = lambda gpu_memory_utilization: int(total * gpu_memory_utilization - # noqa: E731 used - peak + current) // block_bytes try: @@ -421,11 +442,11 @@ def allocate_kv_cache(self): self.k_cache = torch.zeros( hf_config.num_hidden_layers, config.num_kvcache_blocks, - num_kv_heads, head_dim // x, self.block_size, x + num_kv_heads, head_dim // x, self.block_size, x, dtype=storage_dtype ) self.v_cache = torch.zeros( hf_config.num_hidden_layers, config.num_kvcache_blocks, - num_kv_heads, head_dim, self.block_size + num_kv_heads, head_dim, self.block_size, dtype=storage_dtype ) layer_id = 0 for module in self.model.modules(): @@ -437,7 +458,7 @@ def allocate_kv_cache(self): # [kv_separated, layer_id, block_id, block_size(segmented seq_len), head, head_dim] self.kv_cache = torch.zeros( 2, hf_config.num_hidden_layers, config.num_kvcache_blocks, - self.block_size, num_kv_heads, head_dim) + self.block_size, num_kv_heads, head_dim, dtype=storage_dtype) layer_id = 0 for module in self.model.modules(): if hasattr(module, "k_cache") and hasattr(module, "v_cache"): diff --git a/diffulex_legacy/layers/attention/attention_v4.py b/diffulex_legacy/layers/attention/attention_v4.py index 88e624f..4f1d031 100755 --- a/diffulex_legacy/layers/attention/attention_v4.py +++ b/diffulex_legacy/layers/attention/attention_v4.py @@ -17,6 +17,13 @@ from diffulex_legacy.utils.context import ContextForDiffusionLM, get_context_causal_lm, get_context_diffusion_lm +def _get_kv_cache_dtype(context: ContextForDiffusionLM, model_type: str) -> str: + if model_type == 'diffusion_lm': + return context.seqs[0].config.kv_cache_dtype + else: # causal_lm + return getattr(context, 'kv_cache_dtype', 'bf16') # fallback for backward compatibility + + class Attention(nn.Module): def __init__( self, @@ -47,6 +54,11 @@ def __init__( partial(flex_attention, kernel_options=kernel_options, enable_gqa=True, return_lse=False, training=False), dynamic=True) self._block_mask_cache = {} + # FP8 scale management: maintain running max per head + self.k_max_abs: torch.Tensor | None = None # [num_kv_heads] + self.v_max_abs: torch.Tensor | None = None # [num_kv_heads] + self.q_max_abs: torch.Tensor | None = None # [num_heads] + self.kv_cache_dtype_cache: str | None = None @lru_cache(maxsize=32) def dllm_block_mask(self, block_mask: torch.Tensor, @@ -81,6 +93,108 @@ def _mask_mod(batch, head, token_q, token_kv): ) return self._block_mask_cache[cache_key] + def _update_and_compute_fp8_scales( + self, + k: torch.Tensor, + v: torch.Tensor, + kv_cache_dtype: str, + device: torch.device + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """ + Update running max and compute FP8 scales. + Returns (k_scale, v_scale) or (None, None) if not FP8. + """ + from diffulex.utils.kv_cache_dtype import parse_kv_cache_dtype + spec = parse_kv_cache_dtype(kv_cache_dtype) + if not spec.is_fp8: + return None, None + + # Reset running max if dtype changed + if self.kv_cache_dtype_cache != kv_cache_dtype: + self.k_max_abs = None + self.v_max_abs = None + self.q_max_abs = None + self.kv_cache_dtype_cache = kv_cache_dtype + + # Compute current batch absmax: [num_kv_heads] + k_absmax = k.to(torch.float32).abs().amax(dim=(0, 2)) # [num_kv_heads] + v_absmax = v.to(torch.float32).abs().amax(dim=(0, 2)) # [num_kv_heads] + + # Update running max + if self.k_max_abs is None: + self.k_max_abs = k_absmax.clone().detach() + self.v_max_abs = v_absmax.clone().detach() + else: + self.k_max_abs = torch.maximum(self.k_max_abs, k_absmax) + self.v_max_abs = torch.maximum(self.v_max_abs, v_absmax) + + # Compute scale from running max + eps = 1e-8 + fp8_max = spec.fp8_max + k_scale = (self.k_max_abs / fp8_max).clamp_min(eps) + v_scale = (self.v_max_abs / fp8_max).clamp_min(eps) + + return k_scale, v_scale + + def _update_and_compute_q_fp8_scale( + self, + q: torch.Tensor, + kv_cache_dtype: str, + device: torch.device + ) -> torch.Tensor | None: + """ + Update running max for Q and compute FP8 scale. + Returns q_scale or None if not FP8. + """ + from diffulex.utils.kv_cache_dtype import parse_kv_cache_dtype + spec = parse_kv_cache_dtype(kv_cache_dtype) + if not spec.is_fp8: + return None + + # Reset running max if dtype changed + if self.kv_cache_dtype_cache != kv_cache_dtype: + self.q_max_abs = None + self.kv_cache_dtype_cache = kv_cache_dtype + + # Compute current batch absmax: [num_heads] + q_absmax = q.to(torch.float32).abs().amax(dim=(0, 2)) # [num_heads] + + # Update running max + if self.q_max_abs is None: + self.q_max_abs = q_absmax.clone().detach() + else: + self.q_max_abs = torch.maximum(self.q_max_abs, q_absmax) + + # Compute scale from running max + eps = 1e-8 + fp8_max = spec.fp8_max + q_scale = (self.q_max_abs / fp8_max).clamp_min(eps) + + return q_scale + + def _get_fp8_scales_from_max(self, kv_cache_dtype: str) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """Convert running max to scales. Returns (None, None) if not FP8 or max not initialized.""" + from diffulex.utils.kv_cache_dtype import parse_kv_cache_dtype + spec = parse_kv_cache_dtype(kv_cache_dtype) + if not spec.is_fp8 or self.k_max_abs is None or self.v_max_abs is None: + return None, None + eps = 1e-8 + fp8_max = spec.fp8_max + k_scale = (self.k_max_abs / fp8_max).clamp_min(eps) + v_scale = (self.v_max_abs / fp8_max).clamp_min(eps) + return k_scale, v_scale + + def _get_q_fp8_scale_from_max(self, kv_cache_dtype: str) -> torch.Tensor | None: + """Convert running max to Q scale. Returns None if not FP8 or max not initialized.""" + from diffulex.utils.kv_cache_dtype import parse_kv_cache_dtype + spec = parse_kv_cache_dtype(kv_cache_dtype) + if not spec.is_fp8 or self.q_max_abs is None: + return None + eps = 1e-8 + fp8_max = spec.fp8_max + q_scale = (self.q_max_abs / fp8_max).clamp_min(eps) + return q_scale + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: List[torch.Tensor] | None = None) -> torch.Tensor: # Reshape @@ -95,8 +209,16 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, # Fast Store KV cache if k_cache.numel() and v_cache.numel(): if not (self.model_type == 'diffusion_lm' and not context.need_kv_cache_store): + kv_cache_dtype = _get_kv_cache_dtype(context, self.model_type) + k_scale, v_scale = self._update_and_compute_fp8_scales(k, v, kv_cache_dtype, k.device) store_kvcache = store_kvcache_unified_layout if is_unified_layout else store_kvcache_distinct_layout - store_kvcache(k, v, k_cache, v_cache, context.slot_mapping, self.model_type, context) + store_kvcache( + k, v, k_cache, v_cache, context.slot_mapping, self.model_type, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale, + context=context + ) # CHECK_STORING(k_cache, v_cache, k, v, context) transpose_fn = lambda x: rearrange(x, 's h d -> 1 h s d').contiguous() @@ -128,16 +250,80 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, config = context.seqs[0].config diffusion_block_size = config.diffusion_block_size if is_unified_layout: - q_t = transpose_fn(q) - k_comb, v_comb = load_kvcache(self.k_cache, self.v_cache, context, k, v) - # k_comb, v_comb = CHECK_LOADING(k_comb, v_comb, k, v, k_cache, v_cache, context)`` - k_t, v_t = transpose_fn(k_comb), transpose_fn(v_comb) + kv_cache_dtype = _get_kv_cache_dtype(context, self.model_type) + from diffulex.utils.kv_cache_dtype import parse_kv_cache_dtype + spec = parse_kv_cache_dtype(kv_cache_dtype) + + # Check if using FP8 KV cache - if so, use FP8 attention kernel + if spec.is_fp8: + from diffulex_legacy.layers.attention.ops.triton_flash_attention import triton_flash_attention + + # Get K, V scales + k_scale, v_scale = self._get_fp8_scales_from_max(kv_cache_dtype) + if k_scale is None and v_scale is None: + k_scale, v_scale = self._update_and_compute_fp8_scales(k, v, kv_cache_dtype, k.device) + + # Get Q scale + q_scale = self._get_q_fp8_scale_from_max(kv_cache_dtype) + if q_scale is None: + q_scale = self._update_and_compute_q_fp8_scale(q, kv_cache_dtype, q.device) + + # Load K, V in FP8 format (no dequantization) + fp8_dtype = spec.fp8_view_dtype + k_comb, v_comb = load_kvcache( + self.k_cache, self.v_cache, context, k, v, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale, + out_dtype=fp8_dtype # Key: output FP8 + ) + + # Prepare output tensor (FP16/BF16, kernel will handle dequantization) + # Output shape must match q.shape (triton_flash_attention requirement) + o = torch.empty_like(q) + + # Compute output scale for FP8 output (we use None to output in FP16/BF16) + fp8_out_scale = None # None means output in FP16/BF16 + + # Prepare fp8_scales tuple: (q_scale, k_scale, v_scale, p_scale) + # p_scale is typically 1.0 for softmax scale + p_scale = torch.ones(1, device=q.device, dtype=torch.float32) + fp8_scales = (q_scale, k_scale, v_scale, p_scale) + + # Call triton_flash_attention with FP8 support + o = triton_flash_attention( + q, k_comb, v_comb, o, + context.cu_seqlens_q, context.cu_seqlens_k, + context.max_seqlen_q, context.max_seqlen_k, + causal=False, # diffusion_lm is not causal + softmax_scale=self.scale, + bias=None, + fp8_scales=fp8_scales, + fp8_out_scale=fp8_out_scale, + block_table=None, + ) + + # Reshape output: [total_tokens, num_heads, head_dim] -> [1, num_heads, total_tokens, head_dim] + o = rearrange(o, 's h d -> 1 h s d').contiguous() + else: + # Original path for non-FP8 (BF16/FP16/FP32) + q_t = transpose_fn(q) + k_scale, v_scale = self._get_fp8_scales_from_max(kv_cache_dtype) + if k_scale is None and v_scale is None: + k_scale, v_scale = self._update_and_compute_fp8_scales(k, v, kv_cache_dtype, k.device) + k_comb, v_comb = load_kvcache( + self.k_cache, self.v_cache, context, k, v, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale + ) + k_t, v_t = transpose_fn(k_comb), transpose_fn(v_comb) - B, H, Sq, _ = q_t.shape - _, _, Skv, _ = k_t.shape - block_mask = self.dllm_block_mask(context.block_mask, B, H, Sq, Skv, str(q.device)) + B, H, Sq, _ = q_t.shape + _, _, Skv, _ = k_t.shape + block_mask = self.dllm_block_mask(context.block_mask, B, H, Sq, Skv, str(q.device)) - o = self.attention(q_t, k_t, v_t, block_mask=block_mask) + o = self.attention(q_t, k_t, v_t, block_mask=block_mask) else: # FIXME: Kernel not ok... o = torch.empty_like(q).to(q.device).to(q.dtype) diff --git a/diffulex_legacy/layers/attention/attention_v5.py b/diffulex_legacy/layers/attention/attention_v5.py index 4ac0727..fc33c6e 100644 --- a/diffulex_legacy/layers/attention/attention_v5.py +++ b/diffulex_legacy/layers/attention/attention_v5.py @@ -18,6 +18,13 @@ from diffulex_legacy.utils.context import ContextForDiffusionLM, get_context_causal_lm, get_context_diffusion_lm +def _get_kv_cache_dtype(context: ContextForDiffusionLM, model_type: str) -> str: + if model_type == 'diffusion_lm': + return context.seqs[0].config.kv_cache_dtype + else: # causal_lm + return getattr(context, 'kv_cache_dtype', 'bf16') # fallback for backward compatibility + + class Attention(nn.Module): def __init__( self, @@ -48,6 +55,11 @@ def __init__( partial(flex_attention, kernel_options=kernel_options, enable_gqa=True, return_lse=False, training=False), dynamic=True) self._block_mask_cache = {} + # FP8 scale management: maintain running max per head + self.k_max_abs: torch.Tensor | None = None # [num_kv_heads] + self.v_max_abs: torch.Tensor | None = None # [num_kv_heads] + self.q_max_abs: torch.Tensor | None = None # [num_heads] + self.kv_cache_dtype_cache: str | None = None @lru_cache(maxsize=32) def dllm_block_mask(self, block_mask: torch.Tensor, @@ -82,6 +94,108 @@ def _mask_mod(batch, head, token_q, token_kv): ) return self._block_mask_cache[cache_key] + def _update_and_compute_fp8_scales( + self, + k: torch.Tensor, + v: torch.Tensor, + kv_cache_dtype: str, + device: torch.device + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """ + Update running max and compute FP8 scales. + Returns (k_scale, v_scale) or (None, None) if not FP8. + """ + from diffulex.utils.kv_cache_dtype import parse_kv_cache_dtype + spec = parse_kv_cache_dtype(kv_cache_dtype) + if not spec.is_fp8: + return None, None + + # Reset running max if dtype changed + if self.kv_cache_dtype_cache != kv_cache_dtype: + self.k_max_abs = None + self.v_max_abs = None + self.q_max_abs = None + self.kv_cache_dtype_cache = kv_cache_dtype + + # Compute current batch absmax: [num_kv_heads] + k_absmax = k.to(torch.float32).abs().amax(dim=(0, 2)) # [num_kv_heads] + v_absmax = v.to(torch.float32).abs().amax(dim=(0, 2)) # [num_kv_heads] + + # Update running max + if self.k_max_abs is None: + self.k_max_abs = k_absmax.clone().detach() + self.v_max_abs = v_absmax.clone().detach() + else: + self.k_max_abs = torch.maximum(self.k_max_abs, k_absmax) + self.v_max_abs = torch.maximum(self.v_max_abs, v_absmax) + + # Compute scale from running max + eps = 1e-8 + fp8_max = spec.fp8_max + k_scale = (self.k_max_abs / fp8_max).clamp_min(eps) + v_scale = (self.v_max_abs / fp8_max).clamp_min(eps) + + return k_scale, v_scale + + def _update_and_compute_q_fp8_scale( + self, + q: torch.Tensor, + kv_cache_dtype: str, + device: torch.device + ) -> torch.Tensor | None: + """ + Update running max for Q and compute FP8 scale. + Returns q_scale or None if not FP8. + """ + from diffulex.utils.kv_cache_dtype import parse_kv_cache_dtype + spec = parse_kv_cache_dtype(kv_cache_dtype) + if not spec.is_fp8: + return None + + # Reset running max if dtype changed + if self.kv_cache_dtype_cache != kv_cache_dtype: + self.q_max_abs = None + self.kv_cache_dtype_cache = kv_cache_dtype + + # Compute current batch absmax: [num_heads] + q_absmax = q.to(torch.float32).abs().amax(dim=(0, 2)) # [num_heads] + + # Update running max + if self.q_max_abs is None: + self.q_max_abs = q_absmax.clone().detach() + else: + self.q_max_abs = torch.maximum(self.q_max_abs, q_absmax) + + # Compute scale from running max + eps = 1e-8 + fp8_max = spec.fp8_max + q_scale = (self.q_max_abs / fp8_max).clamp_min(eps) + + return q_scale + + def _get_fp8_scales_from_max(self, kv_cache_dtype: str) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """Convert running max to scales. Returns (None, None) if not FP8 or max not initialized.""" + from diffulex.utils.kv_cache_dtype import parse_kv_cache_dtype + spec = parse_kv_cache_dtype(kv_cache_dtype) + if not spec.is_fp8 or self.k_max_abs is None or self.v_max_abs is None: + return None, None + eps = 1e-8 + fp8_max = spec.fp8_max + k_scale = (self.k_max_abs / fp8_max).clamp_min(eps) + v_scale = (self.v_max_abs / fp8_max).clamp_min(eps) + return k_scale, v_scale + + def _get_q_fp8_scale_from_max(self, kv_cache_dtype: str) -> torch.Tensor | None: + """Convert running max to Q scale. Returns None if not FP8 or max not initialized.""" + from diffulex.utils.kv_cache_dtype import parse_kv_cache_dtype + spec = parse_kv_cache_dtype(kv_cache_dtype) + if not spec.is_fp8 or self.q_max_abs is None: + return None + eps = 1e-8 + fp8_max = spec.fp8_max + q_scale = (self.q_max_abs / fp8_max).clamp_min(eps) + return q_scale + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: List[torch.Tensor] | None = None) -> torch.Tensor: # Reshape @@ -96,8 +210,16 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, # Fast Store KV cache if k_cache.numel() and v_cache.numel(): if not (self.model_type == 'diffusion_lm' and not context.need_kv_cache_store): + kv_cache_dtype = _get_kv_cache_dtype(context, self.model_type) + k_scale, v_scale = self._update_and_compute_fp8_scales(k, v, kv_cache_dtype, k.device) store_kvcache = store_kvcache_unified_layout if is_unified_layout else store_kvcache_distinct_layout - store_kvcache(k, v, k_cache, v_cache, context.slot_mapping, self.model_type, context) + store_kvcache( + k, v, k_cache, v_cache, context.slot_mapping, self.model_type, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale, + context=context + ) # CHECK_STORING(k_cache, v_cache, k, v, context) transpose_fn = lambda x: rearrange(x, 's h d -> 1 h s d').contiguous() @@ -129,11 +251,76 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, config = context.seqs[0].config diffusion_block_size = config.diffusion_block_size if is_unified_layout: - k_comb, v_comb = load_kvcache(self.k_cache, self.v_cache, context, k, v) - o = flash_attn_varlen_func(q, k_comb, v_comb, - context.cu_seqlens_q, context.cu_seqlens_k, - context.max_seqlen_q, context.max_seqlen_k, - softmax_scale=self.scale, block_table=None) + kv_cache_dtype = _get_kv_cache_dtype(context, self.model_type) + from diffulex.utils.kv_cache_dtype import parse_kv_cache_dtype + spec = parse_kv_cache_dtype(kv_cache_dtype) + + # Check if using FP8 KV cache - if so, use FP8 attention kernel + if spec.is_fp8: + from diffulex_legacy.layers.attention.ops.triton_flash_attention import triton_flash_attention + + # Get K, V scales + k_scale, v_scale = self._get_fp8_scales_from_max(kv_cache_dtype) + if k_scale is None and v_scale is None: + k_scale, v_scale = self._update_and_compute_fp8_scales(k, v, kv_cache_dtype, k.device) + + # Get Q scale + q_scale = self._get_q_fp8_scale_from_max(kv_cache_dtype) + if q_scale is None: + q_scale = self._update_and_compute_q_fp8_scale(q, kv_cache_dtype, q.device) + + # Load K, V in FP8 format (no dequantization) + fp8_dtype = spec.fp8_view_dtype + k_comb, v_comb = load_kvcache( + self.k_cache, self.v_cache, context, k, v, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale, + out_dtype=fp8_dtype # Key: output FP8 + ) + + # Prepare output tensor (FP16/BF16, kernel will handle dequantization) + # Output shape must match q.shape (triton_flash_attention requirement) + o = torch.empty_like(q) + + # Compute output scale for FP8 output (we use None to output in FP16/BF16) + fp8_out_scale = None # None means output in FP16/BF16 + + # Prepare fp8_scales tuple: (q_scale, k_scale, v_scale, p_scale) + # p_scale is typically 1.0 for softmax scale + p_scale = torch.ones(1, device=q.device, dtype=torch.float32) + fp8_scales = (q_scale, k_scale, v_scale, p_scale) + + # Call triton_flash_attention with FP8 support + o = triton_flash_attention( + q, k_comb, v_comb, o, + context.cu_seqlens_q, context.cu_seqlens_k, + context.max_seqlen_q, context.max_seqlen_k, + causal=False, # diffusion_lm is not causal + softmax_scale=self.scale, + bias=None, + fp8_scales=fp8_scales, + fp8_out_scale=fp8_out_scale, + block_table=None, + ) + + # attention_v5.py output format is already [total_tokens, num_heads, head_dim] + # So no reshape needed here + else: + # Original path for non-FP8 (BF16/FP16/FP32) + k_scale, v_scale = self._get_fp8_scales_from_max(kv_cache_dtype) + if k_scale is None and v_scale is None: + k_scale, v_scale = self._update_and_compute_fp8_scales(k, v, kv_cache_dtype, k.device) + k_comb, v_comb = load_kvcache( + self.k_cache, self.v_cache, context, k, v, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale + ) + o = flash_attn_varlen_func(q, k_comb, v_comb, + context.cu_seqlens_q, context.cu_seqlens_k, + context.max_seqlen_q, context.max_seqlen_k, + softmax_scale=self.scale, block_table=None) else: # FIXME: Kernel not ok... o = torch.empty_like(q).to(q.device).to(q.dtype) diff --git a/diffulex_legacy/layers/attention/ops/kv_cache_kernels.py b/diffulex_legacy/layers/attention/ops/kv_cache_kernels.py index fcd6c22..4726379 100755 --- a/diffulex_legacy/layers/attention/ops/kv_cache_kernels.py +++ b/diffulex_legacy/layers/attention/ops/kv_cache_kernels.py @@ -8,6 +8,12 @@ from diffulex_legacy.utils.context import ContextForDiffusionLM from diffulex_legacy.engine.sequence import SequenceForDiffusionLM +from diffulex.utils.kv_cache_dtype import ( + KvCacheDType, + ensure_scale_tensor, + parse_kv_cache_dtype, + view_fp8_cache, +) @triton.jit def store_kvcache_kernel_causal_lm( @@ -18,15 +24,35 @@ def store_kvcache_kernel_causal_lm( k_cache_ptr, v_cache_ptr, slot_mapping_ptr, - D: tl.constexpr + k_scale_ptr, + v_scale_ptr, + HEAD_DIM: tl.constexpr, + D: tl.constexpr, + KV_CACHE_DTYPE: tl.constexpr, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr ): idx = tl.program_id(0) - key_offsets = idx * key_stride + tl.arange(0, D) - value_offsets = idx * value_stride + tl.arange(0, D) - key = tl.load(key_ptr + key_offsets) - value = tl.load(value_ptr + value_offsets) + offs_d = tl.arange(0, D) + key_offsets = idx * key_stride + offs_d + value_offsets = idx * value_stride + offs_d + key = tl.load(key_ptr + key_offsets).to(tl.float32) + value = tl.load(value_ptr + value_offsets).to(tl.float32) slot = tl.load(slot_mapping_ptr + idx) - cache_offsets = slot * D + tl.arange(0, D) + cache_offsets = slot * D + offs_d + + # Triton kernels cannot reference Python globals (e.g., Enum). Use constexpr ints. + if KV_CACHE_DTYPE == 3 or KV_CACHE_DTYPE == 4: + head_id = offs_d // HEAD_DIM + k_scale = tl.load(k_scale_ptr + head_id).to(tl.float32) + v_scale = tl.load(v_scale_ptr + head_id).to(tl.float32) + k_scale = tl.maximum(k_scale, 1e-8) + v_scale = tl.maximum(v_scale, 1e-8) + key = key / k_scale + value = value / v_scale + key = tl.maximum(tl.minimum(key, FP8_MAX), FP8_MIN) + value = tl.maximum(tl.minimum(value, FP8_MAX), FP8_MIN) + tl.store(k_cache_ptr + cache_offsets, key) tl.store(v_cache_ptr + cache_offsets, value) @@ -40,17 +66,36 @@ def store_kvcache_kernel_diffusion_lm( k_cache_ptr, v_cache_ptr, slot_mapping_ptr, - D: tl.constexpr + k_scale_ptr, + v_scale_ptr, + HEAD_DIM: tl.constexpr, + D: tl.constexpr, + KV_CACHE_DTYPE: tl.constexpr, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr ): token_idx = tl.program_id(0) slot = tl.load(slot_mapping_ptr + token_idx) if slot < 0: return - key_offsets = token_idx * key_stride + tl.arange(0, D) - value_offsets = token_idx * value_stride + tl.arange(0, D) - key = tl.load(key_ptr + key_offsets) - value = tl.load(value_ptr + value_offsets) - cache_offsets = slot * D + tl.arange(0, D) + offs_d = tl.arange(0, D) + key_offsets = token_idx * key_stride + offs_d + value_offsets = token_idx * value_stride + offs_d + key = tl.load(key_ptr + key_offsets).to(tl.float32) + value = tl.load(value_ptr + value_offsets).to(tl.float32) + cache_offsets = slot * D + offs_d + + if KV_CACHE_DTYPE == 3 or KV_CACHE_DTYPE == 4: + head_id = offs_d // HEAD_DIM + k_scale = tl.load(k_scale_ptr + head_id).to(tl.float32) + v_scale = tl.load(v_scale_ptr + head_id).to(tl.float32) + k_scale = tl.maximum(k_scale, 1e-8) + v_scale = tl.maximum(v_scale, 1e-8) + key = key / k_scale + value = value / v_scale + key = tl.maximum(tl.minimum(key, FP8_MAX), FP8_MIN) + value = tl.maximum(tl.minimum(value, FP8_MAX), FP8_MIN) + tl.store(k_cache_ptr + cache_offsets, key) tl.store(v_cache_ptr + cache_offsets, value) @@ -61,8 +106,13 @@ def store_kvcache_kernel_diffusion_lm_distinct( k_stride, v_stride, k_cache_stride_nblks, k_cache_stride_h, k_cache_stride_dx, k_cache_stride_blk_sz, k_cache_stride_x, v_cache_stride_nblks, v_cache_stride_h, v_cache_stride_d, v_cache_stride_blk_sz, + k_scale_ptr, v_scale_ptr, nheads, hdim, blk_sz, - x: tl.constexpr, D: tl.constexpr + x: tl.constexpr, + D: tl.constexpr, + KV_CACHE_DTYPE: tl.constexpr, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr ): # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: D2F @@ -88,8 +138,8 @@ def store_kvcache_kernel_diffusion_lm_distinct( offs_d = tl.arange(0, D) offs_k = token_idx * k_stride + offs_d offs_v = token_idx * v_stride + offs_d - k = tl.load(k_ptr + offs_k) - v = tl.load(v_ptr + offs_v) + k = tl.load(k_ptr + offs_k).to(tl.float32) + v = tl.load(v_ptr + offs_v).to(tl.float32) h_ids = offs_d // hdim h_offs = offs_d % hdim @@ -102,6 +152,16 @@ def store_kvcache_kernel_diffusion_lm_distinct( v_cache_offs = (blk_idx * v_cache_stride_nblks + h_ids * v_cache_stride_h + h_offs * v_cache_stride_d + off_blk * v_cache_stride_blk_sz) + if KV_CACHE_DTYPE == 3 or KV_CACHE_DTYPE == 4: + k_scale = tl.load(k_scale_ptr + h_ids).to(tl.float32) + v_scale = tl.load(v_scale_ptr + h_ids).to(tl.float32) + k_scale = tl.maximum(k_scale, 1e-8) + v_scale = tl.maximum(v_scale, 1e-8) + k = k / k_scale + v = v / v_scale + k = tl.maximum(tl.minimum(k, FP8_MAX), FP8_MIN) + v = tl.maximum(tl.minimum(v, FP8_MAX), FP8_MIN) + tl.store(k_cache_ptr + k_cache_offs, k) tl.store(v_cache_ptr + v_cache_offs, v) @@ -109,7 +169,11 @@ def store_kvcache_kernel_diffusion_lm_distinct( def store_kvcache_distinct_layout(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor, model_type: str = 'causal_lm', + kv_cache_dtype: str = "bf16", + k_scale=None, + v_scale=None, context: ContextForDiffusionLM = None) -> None: + spec = parse_kv_cache_dtype(kv_cache_dtype) if model_type == 'causal_lm': # k_cache: [num_blks, blk_sz, h, hdim] @@ -120,10 +184,20 @@ def store_kvcache_distinct_layout(key: torch.Tensor, value: torch.Tensor, assert key.stride(1) == head_dim and value.stride(1) == head_dim assert k_cache.stride(1) == D and v_cache.stride(1) == D assert N == slot_mapping.numel() + k_cache_view = view_fp8_cache(k_cache, kv_cache_dtype) + v_cache_view = view_fp8_cache(v_cache, kv_cache_dtype) + k_scale_ts = ensure_scale_tensor(k_scale, num_kv_heads=num_heads, device=key.device) + v_scale_ts = ensure_scale_tensor(v_scale, num_kv_heads=num_heads, device=key.device) store_kvcache_kernel_causal_lm[(N,)]( key, key.stride(0), value, value.stride(0), - k_cache, v_cache, slot_mapping, D + k_cache_view, v_cache_view, slot_mapping, + k_scale_ts, v_scale_ts, + head_dim, + D, + int(spec.enum), + float(spec.fp8_min or 0.0), + float(spec.fp8_max or 0.0), ) else: # TODO: implement diffusion lm kv cache store @@ -134,41 +208,70 @@ def store_kvcache_distinct_layout(key: torch.Tensor, value: torch.Tensor, N = key.shape[0] assert HDim == key.shape[-1] and NHeads == key.shape[1] assert N == slot_mapping.numel() + k_cache_view = view_fp8_cache(k_cache, kv_cache_dtype) + v_cache_view = view_fp8_cache(v_cache, kv_cache_dtype) + k_scale_ts = ensure_scale_tensor(k_scale, num_kv_heads=NHeads, device=key.device) + v_scale_ts = ensure_scale_tensor(v_scale, num_kv_heads=NHeads, device=key.device) GRID = (N, ) store_kvcache_kernel_diffusion_lm_distinct[GRID]( key, value, - k_cache, v_cache, + k_cache_view, v_cache_view, slot_mapping, key.stride(0), value.stride(0), *k_cache.stride(), *v_cache.stride(), + k_scale_ts, v_scale_ts, NHeads, HDim, Blk_sz, - x, HDim * NHeads + x, + HDim * NHeads, + int(spec.enum), + float(spec.fp8_min or 0.0), + float(spec.fp8_max or 0.0), ) def store_kvcache_unified_layout(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor, model_type: str = 'causal_lm', + kv_cache_dtype: str = "bf16", + k_scale=None, + v_scale=None, context: ContextForDiffusionLM = None) -> None: + spec = parse_kv_cache_dtype(kv_cache_dtype) N, num_heads, head_dim = key.shape D = num_heads * head_dim assert key.stride(-1) == 1 and value.stride(-1) == 1 assert key.stride(1) == head_dim and value.stride(1) == head_dim assert k_cache.stride(1) == D and v_cache.stride(1) == D assert N == slot_mapping.numel(), f"`N`: {N}, `slot_mapping.numel()`: {slot_mapping.numel()}" + k_cache_view = view_fp8_cache(k_cache, kv_cache_dtype) + v_cache_view = view_fp8_cache(v_cache, kv_cache_dtype) + k_scale_ts = ensure_scale_tensor(k_scale, num_kv_heads=num_heads, device=key.device) + v_scale_ts = ensure_scale_tensor(v_scale, num_kv_heads=num_heads, device=key.device) if model_type == 'causal_lm': store_kvcache_kernel_causal_lm[(N,)]( key, key.stride(0), value, value.stride(0), - k_cache, v_cache, slot_mapping, D + k_cache_view, v_cache_view, slot_mapping, + k_scale_ts, v_scale_ts, + head_dim, + D, + int(spec.enum), + float(spec.fp8_min or 0.0), + float(spec.fp8_max or 0.0), ) elif model_type == 'diffusion_lm': store_kvcache_kernel_diffusion_lm[(N,)]( key, key.stride(0), value, value.stride(0), - k_cache, v_cache, slot_mapping, D + k_cache_view, v_cache_view, slot_mapping, + k_scale_ts, v_scale_ts, + head_dim, + D, + int(spec.enum), + float(spec.fp8_min or 0.0), + float(spec.fp8_max or 0.0), ) @@ -179,6 +282,7 @@ def load_kvcache_kernel_kv(k_cache_ptr, v_cache_ptr, k_out_ptr, v_out_ptr, seqlens_ptr, ctxlens_ptr, cu_seqlens_q_ptr, cu_seqlens_k_ptr, + k_scale_ptr, v_scale_ptr, kv_cache_stride_nblks, kv_cache_stride_blk, kv_cache_stride_h, kv_cache_stride_d, kv_new_stride_s, kv_new_stride_h, kv_new_stride_d, block_table_stride_nseqs, block_table_stride_maxblks, @@ -189,7 +293,11 @@ def load_kvcache_kernel_kv(k_cache_ptr, v_cache_ptr, HEAD_DIM: tl.constexpr, PAGE_SIZE: tl.constexpr, DIFFUSION_BLOCK_SIZE: tl.constexpr, - KV_LOAD_UNROLL_FACTOR: tl.constexpr): + KV_LOAD_UNROLL_FACTOR: tl.constexpr, + KV_CACHE_DTYPE: tl.constexpr, + OUT_DTYPE: tl.constexpr, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr): # BUG FIX # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: D2F @@ -226,6 +334,25 @@ def load_kvcache_kernel_kv(k_cache_ptr, v_cache_ptr, kv_cache_mask = offs_kv_cache_seq[None, :] < local_ctxlen k_cache = tl.load(k_cache_ptr + offs_kv_cache, mask=kv_cache_mask, other=0.0) v_cache = tl.load(v_cache_ptr + offs_kv_cache, mask=kv_cache_mask, other=0.0) + + # If output is FP8, don't dequantize (keep FP8) + # If output is not FP8, dequantize to float32 first + if OUT_DTYPE == 3 or OUT_DTYPE == 4: + # FP8 output: store directly without dequantization + k_cache_out = k_cache + v_cache_out = v_cache + else: + # Non-FP8 output: dequantize if needed + if KV_CACHE_DTYPE == 3 or KV_CACHE_DTYPE == 4: + k_cache = k_cache.to(tl.float32) + v_cache = v_cache.to(tl.float32) + k_scale = tl.load(k_scale_ptr + kv_head_idx).to(tl.float32) + v_scale = tl.load(v_scale_ptr + kv_head_idx).to(tl.float32) + k_cache_out = k_cache * k_scale + v_cache_out = v_cache * v_scale + else: + k_cache_out = k_cache.to(tl.float32) + v_cache_out = v_cache.to(tl.float32) # Store KV cache into output KV tensors off_cu_seqlens_k = seq_idx * cu_seqlens_k_stride @@ -236,8 +363,8 @@ def load_kvcache_kernel_kv(k_cache_ptr, v_cache_ptr, kv_head_idx * kv_out_stride_h + # Hkv: HeadId offs_kv_cache_hdim[:, None] * kv_out_stride_d # Hdim: HeadDim Elems ) - tl.store(k_out_ptr + offs_kv_cache_to_out, k_cache, mask=kv_cache_mask) - tl.store(v_out_ptr + offs_kv_cache_to_out, v_cache, mask=kv_cache_mask) + tl.store(k_out_ptr + offs_kv_cache_to_out, k_cache_out, mask=kv_cache_mask) + tl.store(v_out_ptr + offs_kv_cache_to_out, v_cache_out, mask=kv_cache_mask) # Load and store active KV only once when first meet if local_blk_idx == LAST_BLK_ID: @@ -271,16 +398,47 @@ def load_kvcache_kernel_kv(k_cache_ptr, v_cache_ptr, kv_head_idx * kv_out_stride_h + # Hkv: HeadId offs_kv_new_hdim[:, None] * kv_out_stride_d # Hdim: HeadDim Elems ) - tl.store(k_out_ptr + offs_cur_kv_new_to_out, k_new) - tl.store(v_out_ptr + offs_cur_kv_new_to_out, v_new) + # IMPORTANT: + # - When OUT_DTYPE is FP8, the output K/V are consumed by FP8 attention kernels + # which assume K/V are in the *quantized domain* (value / scale). + # - Cached K/V are already stored in quantized FP8 domain. + # - But k_new/v_new are BF16/FP16 values. If we store them directly into FP8 output, + # they would be cast without dividing by scale, causing a scale mismatch and large errors. + if OUT_DTYPE == 3 or OUT_DTYPE == 4: + if KV_CACHE_DTYPE == 3 or KV_CACHE_DTYPE == 4: + k_s = tl.load(k_scale_ptr + kv_head_idx).to(tl.float32) + v_s = tl.load(v_scale_ptr + kv_head_idx).to(tl.float32) + k_s = tl.maximum(k_s, 1e-8) + v_s = tl.maximum(v_s, 1e-8) + k_new_q = (k_new.to(tl.float32) / k_s) + v_new_q = (v_new.to(tl.float32) / v_s) + else: + # If cache isn't FP8, treat scale as 1.0 for quantized output. + k_new_q = k_new.to(tl.float32) + v_new_q = v_new.to(tl.float32) + # Clamp to FP8 representable range before storing + k_new_q = tl.maximum(tl.minimum(k_new_q, FP8_MAX), FP8_MIN) + v_new_q = tl.maximum(tl.minimum(v_new_q, FP8_MAX), FP8_MIN) + tl.store(k_out_ptr + offs_cur_kv_new_to_out, k_new_q) + tl.store(v_out_ptr + offs_cur_kv_new_to_out, v_new_q) + else: + tl.store(k_out_ptr + offs_cur_kv_new_to_out, k_new) + tl.store(v_out_ptr + offs_cur_kv_new_to_out, v_new) def load_kvcache(k_cache: torch.Tensor, v_cache: torch.Tensor, context: ContextForDiffusionLM, - k_new: torch.Tensor, v_new: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + k_new: torch.Tensor, v_new: torch.Tensor, + kv_cache_dtype: str = "bf16", + k_scale=None, + v_scale=None, + out_dtype: torch.dtype | None = None) -> Tuple[torch.Tensor, torch.Tensor]: + spec = parse_kv_cache_dtype(kv_cache_dtype) assert k_cache.shape == v_cache.shape assert k_new.shape == v_new.shape - N_BLOCKS, PAGE_SIZE, H_KV, HEAD_DIM = k_cache.shape + k_cache_view = view_fp8_cache(k_cache, kv_cache_dtype) + v_cache_view = view_fp8_cache(v_cache, kv_cache_dtype) + N_BLOCKS, PAGE_SIZE, H_KV, HEAD_DIM = k_cache_view.shape NUM_SEQS, MAX_SEQ_BLOCKS = context.block_tables.shape ctxlens = context.context_lens @@ -298,17 +456,43 @@ def load_kvcache(k_cache: torch.Tensor, v_cache: torch.Tensor, assert cu_seqlens_q.shape[0] == NUM_SEQS + 1 kv_output_shape = (sum(total_lens).item(), H_KV, HEAD_DIM) - k_output = torch.empty(kv_output_shape, device=k_cache.device, dtype=k_cache.dtype) + out_dtype = k_new.dtype if out_dtype is None else out_dtype + + # Determine OUT_DTYPE for kernel (constexpr int) + from diffulex.utils.kv_cache_dtype import KvCacheDType + if out_dtype == torch.bfloat16: + out_dtype_enum = int(KvCacheDType.BF16) # 0 + elif out_dtype == torch.float16: + out_dtype_enum = int(KvCacheDType.FP16) # 1 + elif out_dtype == torch.float32: + out_dtype_enum = int(KvCacheDType.FP32) # 2 + elif spec.is_fp8 and out_dtype == spec.fp8_view_dtype: + out_dtype_enum = int(spec.enum) # 3 or 4 + else: + # Default: use k_new.dtype + if k_new.dtype == torch.bfloat16: + out_dtype_enum = int(KvCacheDType.BF16) + elif k_new.dtype == torch.float16: + out_dtype_enum = int(KvCacheDType.FP16) + elif k_new.dtype == torch.float32: + out_dtype_enum = int(KvCacheDType.FP32) + else: + raise ValueError(f"Unsupported out_dtype: {out_dtype}") + + k_output = torch.empty(kv_output_shape, device=k_cache.device, dtype=out_dtype) v_output = torch.empty_like(k_output) + k_scale_ts = ensure_scale_tensor(k_scale, num_kv_heads=H_KV, device=k_cache.device) + v_scale_ts = ensure_scale_tensor(v_scale, num_kv_heads=H_KV, device=k_cache.device) GRID = (NUM_SEQS, MAX_SEQ_BLOCKS, H_KV) load_kvcache_kernel_kv[GRID]( - k_cache, v_cache, + k_cache_view, v_cache_view, k_new, v_new, context.block_tables, k_output, v_output, seqlens, ctxlens, cu_seqlens_q, cu_seqlens_k, + k_scale_ts, v_scale_ts, *k_cache.stride(), *k_new.stride(), *context.block_tables.stride(), @@ -321,7 +505,11 @@ def load_kvcache(k_cache: torch.Tensor, v_cache: torch.Tensor, HEAD_DIM=HEAD_DIM, PAGE_SIZE=PAGE_SIZE, DIFFUSION_BLOCK_SIZE=DIFFUSION_BLOCK_SIZE, - KV_LOAD_UNROLL_FACTOR=2 + KV_LOAD_UNROLL_FACTOR=2, + KV_CACHE_DTYPE=int(spec.enum), + OUT_DTYPE=out_dtype_enum, + FP8_MIN=float(spec.fp8_min) if spec.is_fp8 and spec.fp8_min is not None else 0.0, + FP8_MAX=float(spec.fp8_max) if spec.is_fp8 and spec.fp8_max is not None else 0.0, ) return k_output, v_output @@ -330,6 +518,9 @@ def load_kvcache(k_cache: torch.Tensor, v_cache: torch.Tensor, def CHECK_STORING(k_cache: torch.Tensor, v_cache: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context: ContextForDiffusionLM) -> None: + # FP8 cache uses uint8 storage; exact bitwise match is not expected here. + if k_cache.dtype == torch.uint8 or v_cache.dtype == torch.uint8: + return k_list, v_list = [torch.split(tensor, context.seq_lens, dim=0) for tensor in (k, v)] for seq_idx, seq in enumerate(context.seqs): cached_num_tokens = seq.cached_num_tokens @@ -366,6 +557,8 @@ def CHECK_LOADING(k_comb: torch.Tensor, v_comb: torch.Tensor, k_new: torch.Tensor, v_new: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, context: ContextForDiffusionLM) -> Tuple[torch.Tensor, torch.Tensor]: + if k_cache.dtype == torch.uint8 or v_cache.dtype == torch.uint8: + return k_comb, v_comb try: k_list, v_list = [torch.split(tensor, context.seq_lens, dim=0) for tensor in (k_new, v_new)] cat_k_list = [] diff --git a/diffulex_legacy/layers/attention/ops/triton_flash_attention.py b/diffulex_legacy/layers/attention/ops/triton_flash_attention.py index 37dd535..1a3febb 100755 --- a/diffulex_legacy/layers/attention/ops/triton_flash_attention.py +++ b/diffulex_legacy/layers/attention/ops/triton_flash_attention.py @@ -229,12 +229,13 @@ def _attn_fwd_inner( def get_cdna_autotune_configs(): + # For Nvidia GPU (CUDA), waves_per_eu is not supported and should not be included return [ triton.Config( { 'BLOCK_M': 256, 'BLOCK_N': 64, - 'waves_per_eu': 2, + # 'waves_per_eu': 2, 'PRE_LOAD_V': False }, num_stages=1, @@ -243,7 +244,7 @@ def get_cdna_autotune_configs(): { 'BLOCK_M': 128, 'BLOCK_N': 128, - 'waves_per_eu': 2, + # 'waves_per_eu': 2, 'PRE_LOAD_V': False }, num_stages=1, @@ -252,7 +253,7 @@ def get_cdna_autotune_configs(): { 'BLOCK_M': 256, 'BLOCK_N': 128, - 'waves_per_eu': 2, + # 'waves_per_eu': 2, 'PRE_LOAD_V': False }, num_stages=1, @@ -261,7 +262,7 @@ def get_cdna_autotune_configs(): { 'BLOCK_M': 128, 'BLOCK_N': 64, - 'waves_per_eu': 1, + # 'waves_per_eu': 1, 'PRE_LOAD_V': False }, num_stages=1, @@ -270,7 +271,7 @@ def get_cdna_autotune_configs(): { 'BLOCK_M': 128, 'BLOCK_N': 64, - 'waves_per_eu': 3, + # 'waves_per_eu': 3, 'PRE_LOAD_V': True }, num_stages=1, @@ -279,7 +280,7 @@ def get_cdna_autotune_configs(): { 'BLOCK_M': 128, 'BLOCK_N': 64, - 'waves_per_eu': 3, + # 'waves_per_eu': 3, 'PRE_LOAD_V': False }, num_stages=1, @@ -288,7 +289,7 @@ def get_cdna_autotune_configs(): { 'BLOCK_M': 64, 'BLOCK_N': 64, - 'waves_per_eu': 4, + # 'waves_per_eu': 4, 'PRE_LOAD_V': False }, num_stages=1, @@ -297,7 +298,7 @@ def get_cdna_autotune_configs(): { 'BLOCK_M': 32, 'BLOCK_N': 32, - 'waves_per_eu': 4, + # 'waves_per_eu': 4, 'PRE_LOAD_V': False }, num_stages=1, @@ -536,6 +537,7 @@ def attn_fwd( return # If MQA / GQA, set the K and V head offsets appropriately. + # Compute GROUP_SIZE and off_h_k early so they can be used in both FP8 and non-FP8 paths GROUP_SIZE: tl.constexpr = HQ // HK off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q @@ -620,10 +622,18 @@ def attn_fwd( q = load_fn(Q_block_ptr, True, padded_head, "zero") if not USE_FP8: q = (q * qk_scale).to(Q_block_ptr.type.element_ty) - acc_scale = 1.0 + acc_scale: tl.float32 = 1.0 else: - qk_scale *= q_scale * k_scale - acc_scale = p_scale * v_scale + # Load per-head scale values from scale tensors + # q_scale, k_scale, v_scale are per-head tensors, so we load the value for current head + # Note: off_h_k is already computed above (line 542) + q_scale_val = tl.load(q_scale + off_h_q) + k_scale_val = tl.load(k_scale + off_h_k) + v_scale_val = tl.load(v_scale + off_h_k) + # p_scale is a scalar tensor (torch.ones(1, ...)), so we load from index 0 + p_scale_val = tl.load(p_scale) + qk_scale *= q_scale_val * k_scale_val + acc_scale = p_scale_val * v_scale_val # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -855,6 +865,10 @@ def forward( def check_and_convert(t, scale): if t.dtype != float8: descale = 1.0 / scale + # Reshape scale to broadcast correctly: [num_heads] -> [1, num_heads, 1] + # This allows broadcasting with t of shape [seq_len, num_heads, head_dim] + if descale.dim() == 1 and t.dim() == 3: + descale = descale.view(1, -1, 1) ts = (t * descale).clamp(min=float8_info.min, max=float8_info.max) return ts.to(float8) @@ -930,7 +944,17 @@ def check_and_convert(t, scale): else: bias_strides = (0, 0, 0, 0) - p_descale = 1.0 / p_scale + # IMPORTANT: `p_descale` is used as a scalar inside Triton kernel (e.g. `p *= p_descale`). + # If we pass a 1-element tensor here (common when fp8_scales provides `p_scale` as Tensor), + # Triton will treat it as pointer and fail to compile. + if isinstance(p_scale, torch.Tensor): + if p_scale.numel() != 1: + raise ValueError( + f"Expected p_scale to be a scalar tensor (numel==1), got shape={tuple(p_scale.shape)}" + ) + p_descale = float(1.0 / float(p_scale.item())) + else: + p_descale = float(1.0 / float(p_scale)) o_descale = 1.0 / fp8_out_scale.item( ) if fp8_out_scale is not None else 1.0 @@ -1004,7 +1028,8 @@ def triton_flash_attention( fp8_out_scale=None, block_table=None, ): - _attention.apply( + # _attention.apply returns (o, encoded_softmax). Most callsites expect only `o`. + o, _ = _attention.apply( q, k, v, @@ -1019,4 +1044,5 @@ def triton_flash_attention( fp8_scales, fp8_out_scale, block_table, - ) \ No newline at end of file + ) + return o \ No newline at end of file diff --git a/diffulex_legacy/utils/context.py b/diffulex_legacy/utils/context.py index 8986276..0258430 100755 --- a/diffulex_legacy/utils/context.py +++ b/diffulex_legacy/utils/context.py @@ -21,6 +21,7 @@ class ContextBase: class ContextForCausalLM(ContextBase): kv_cache_layout: str = "unified" # Only "unified" is supported for Causal LM currently need_kv_cache_store: bool = True + kv_cache_dtype: str = "bf16" # "bf16", "fp16", "fp32", "fp8", "fp8_e4m3", "fp8_e5m2" _CONTEXT_FOR_CAUSAL_LM = ContextForCausalLM() @@ -31,14 +32,16 @@ def set_context_causal_lm( is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, - slot_mapping=None, context_lens=None, block_tables=None + slot_mapping=None, context_lens=None, block_tables=None, + kv_cache_dtype: str = "bf16" ) -> None: global _CONTEXT_FOR_CAUSAL_LM _CONTEXT_FOR_CAUSAL_LM = ContextForCausalLM( is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - slot_mapping, context_lens, block_tables + slot_mapping, context_lens, block_tables, + kv_cache_dtype=kv_cache_dtype ) def reset_context_causal_lm() -> None: diff --git a/diffulex_profiler/README.md b/diffulex_profiler/README.md new file mode 100644 index 0000000..3fa25a7 --- /dev/null +++ b/diffulex_profiler/README.md @@ -0,0 +1,327 @@ +# Diffulex Profiler + +A modular profiling framework for performance analysis of the Diffulex inference engine. This module provides comprehensive performance metrics collection, multiple profiling backends, and flexible result export capabilities. + +## Features + +- **Multiple Profiling Backends**: Support for simple timing, VizTracer, and PyTorch Profiler +- **Comprehensive Metrics**: Collect timing, throughput, GPU utilization, memory usage, and custom metrics +- **Flexible Export**: Export results in JSON, CSV, or human-readable summary formats +- **Easy Integration**: Simple context manager API for seamless integration with existing code +- **Modular Design**: Extensible architecture for adding custom backends and exporters + +## Installation + +The profiler is included as part of the Diffulex package. No additional installation is required beyond the standard Diffulex dependencies. + +Optional dependencies for advanced features: +- `viztracer`: For detailed function call tracing (already in dependencies) +- `pynvml`: For detailed GPU utilization metrics (optional) + +## Quick Start + +### Basic Usage + +```python +from diffulex_profiler import DiffulexProfiler, ProfilerConfig +from diffulex import Diffulex, SamplingParams + +# Initialize profiler +profiler = DiffulexProfiler( + config=ProfilerConfig( + enabled=True, + backend="simple", + output_dir="log/profiles" + ) +) + +# Initialize Diffulex engine +llm = Diffulex(model_path, model_name="dream", ...) + +# Profile inference +with profiler.profile("inference", metadata={"batch_size": 10}): + outputs = llm.generate(prompts, sampling_params) + total_tokens = sum(len(o['token_ids']) for o in outputs) + profiler.record_throughput(total_tokens) + +# Export results +profiler.export("log/profiles/inference_profile.json") +``` + +### Advanced Usage with Multiple Sections + +```python +profiler = DiffulexProfiler( + config=ProfilerConfig( + enabled=True, + backend="simple", + collect_gpu_metrics=True, + collect_memory_metrics=True, + export_formats=["json", "csv", "summary"] + ) +) + +# Profile different sections +with profiler.profile("model_loading"): + llm = Diffulex(model_path, ...) + +with profiler.profile("prefill", metadata={"num_prompts": len(prompts)}): + # Prefill phase + pass + +with profiler.profile("decode"): + outputs = llm.generate(prompts, sampling_params) + profiler.record_throughput(sum(len(o['token_ids']) for o in outputs)) + +# Get summary +summary = profiler.get_summary() +print(f"Total duration: {summary['total_duration_sec']:.2f}s") +print(f"Average throughput: {summary['avg_throughput_tokens_per_sec']:.2f} tok/s") + +# Export all results +profiler.export() +``` + +## Configuration + +### ProfilerConfig + +The `ProfilerConfig` class provides comprehensive configuration options: + +```python +@dataclass +class ProfilerConfig: + enabled: bool = True # Enable/disable profiling + backend: str = "simple" # Backend: "simple", "viztracer", "pytorch" + output_dir: str = "log/profiles" # Output directory for results + output_file: Optional[str] = None # Optional custom output filename + collect_gpu_metrics: bool = True # Collect GPU metrics + collect_memory_metrics: bool = True # Collect memory metrics + collect_timing: bool = True # Collect timing information + export_formats: List[str] = ["json", "summary"] # Export formats + viztracer_config: Optional[Dict] = None # VizTracer-specific config + pytorch_profiler_config: Optional[Dict] = None # PyTorch Profiler config +``` + +## Profiling Backends + +### Simple Timer Backend (Default) + +The simplest backend that only tracks execution time. No additional dependencies required. + +```python +profiler = DiffulexProfiler( + config=ProfilerConfig(backend="simple") +) +``` + +### VizTracer Backend + +For detailed function call tracing and visualization. Requires `viztracer` package. + +```python +profiler = DiffulexProfiler( + config=ProfilerConfig( + backend="viztracer", + viztracer_config={ + "output_file": "trace.json", + "file_info": True, + } + ) +) +``` + +### PyTorch Profiler Backend + +For GPU/CPU operation-level profiling. Built into PyTorch. + +```python +profiler = DiffulexProfiler( + config=ProfilerConfig( + backend="pytorch", + pytorch_profiler_config={ + "activities": [ProfilerActivity.CPU, ProfilerActivity.CUDA], + "record_shapes": True, + "profile_memory": True, + } + ) +) +``` + +## Metrics Collection + +The profiler automatically collects: + +- **Timing**: Start time, end time, duration +- **Throughput**: Tokens per second (when recorded via `record_throughput()`) +- **GPU Metrics**: Utilization, memory usage, device information +- **Memory Metrics**: System memory usage and deltas +- **Custom Metrics**: User-defined metrics via `record_metric()` + +### Recording Custom Metrics + +```python +with profiler.profile("custom_section"): + # Your code here + profiler.record_metric("num_sequences", 10) + profiler.record_metric("avg_length", 128.5) + profiler.record_throughput(total_tokens=1000) +``` + +## Export Formats + +### JSON Export + +Structured JSON format suitable for programmatic analysis: + +```python +profiler = DiffulexProfiler( + config=ProfilerConfig(export_formats=["json"]) +) +profiler.export("results.json") +``` + +### CSV Export + +Tabular format for spreadsheet analysis: + +```python +profiler = DiffulexProfiler( + config=ProfilerConfig(export_formats=["csv"]) +) +profiler.export("results.csv") +``` + +### Summary Export + +Human-readable text summary: + +```python +profiler = DiffulexProfiler( + config=ProfilerConfig(export_formats=["summary"]) +) +profiler.export("results.txt") +``` + +## Integration Examples + +### Integration with Diffulex Engine + +```python +from diffulex_profiler import DiffulexProfiler, ProfilerConfig +from diffulex import Diffulex, SamplingParams + +# Setup +profiler = DiffulexProfiler(ProfilerConfig(enabled=True)) +llm = Diffulex(model_path, model_name="dream", ...) +sampling_params = SamplingParams(temperature=0.0, max_tokens=256) + +# Profile generation +prompts = ["What is 2+2?", "Explain quantum computing"] +with profiler.profile("generate", metadata={"num_prompts": len(prompts)}): + outputs = llm.generate(prompts, sampling_params) + total_tokens = sum(len(o['token_ids']) for o in outputs) + profiler.record_throughput(total_tokens) + profiler.record_metric("num_outputs", len(outputs)) + profiler.record_metric("avg_diff_steps", + sum(o['n_diff_steps'] for o in outputs) / len(outputs)) + +# Export +profiler.export("generation_profile.json") +summary = profiler.get_summary() +print(f"Throughput: {summary['avg_throughput_tokens_per_sec']:.2f} tok/s") +``` + +### Batch Profiling + +```python +profiler = DiffulexProfiler(ProfilerConfig(enabled=True)) + +for batch_idx, batch in enumerate(batches): + with profiler.profile(f"batch_{batch_idx}", metadata={"batch_size": len(batch)}): + outputs = llm.generate(batch, sampling_params) + profiler.record_throughput(sum(len(o['token_ids']) for o in outputs)) + +profiler.export("batch_profiles.json") +``` + +## API Reference + +### DiffulexProfiler + +Main profiler class. + +#### Methods + +- `profile(name: str, metadata: Optional[Dict] = None)`: Context manager for profiling +- `start(name: str, metadata: Optional[Dict] = None)`: Start profiling a section +- `stop()`: Stop profiling current section +- `record_metric(name: str, value: Any)`: Record a custom metric +- `record_throughput(tokens: int, duration: Optional[float] = None)`: Record throughput +- `export(output_path: Optional[str] = None)`: Export results +- `get_summary() -> Dict[str, Any]`: Get summary statistics +- `clear()`: Clear all collected metrics + +### PerformanceMetrics + +Container for performance metrics. + +#### Attributes + +- `name`: Section name +- `duration`: Duration in seconds +- `total_tokens`: Total tokens processed +- `throughput_tokens_per_sec`: Throughput in tokens/second +- `gpu_utilization`: GPU utilization percentage +- `memory_delta_mb`: Memory usage delta in MB +- `custom_metrics`: Dictionary of custom metrics +- `metadata`: User-provided metadata + +## Best Practices + +1. **Use Context Managers**: Always use the `profile()` context manager for automatic cleanup +2. **Record Throughput**: Call `record_throughput()` after inference to get accurate throughput metrics +3. **Add Metadata**: Include relevant metadata (batch size, model config, etc.) for better analysis +4. **Choose Appropriate Backend**: Use "simple" for basic timing, "viztracer" for detailed tracing, "pytorch" for GPU profiling +5. **Export Regularly**: Export results periodically for long-running experiments +6. **Clear When Needed**: Use `clear()` to reset metrics between different profiling sessions + +## Troubleshooting + +### Profiler Not Collecting Metrics + +- Ensure `enabled=True` in `ProfilerConfig` +- Check that you're using the context manager correctly +- Verify that `start()` and `stop()` are called in pairs + +### GPU Metrics Not Available + +- Ensure CUDA is available: `torch.cuda.is_available()` +- Install `pynvml` for detailed GPU utilization: `pip install pynvml` + +### Backend Import Errors + +- Simple backend is always available +- VizTracer backend requires: `pip install viztracer` +- PyTorch Profiler is built into PyTorch + +## Contributing + +To add a new profiling backend: + +1. Create a new class inheriting from `ProfilerBackend` +2. Implement `start()` and `stop()` methods +3. Add it to `backends/__init__.py` +4. Update `DiffulexProfiler._init_backend()` to support it + +To add a new exporter: + +1. Create a new class inheriting from `ProfilerExporter` +2. Implement `export()` method +3. Add it to `exporters/__init__.py` +4. Update `DiffulexProfiler._init_exporters()` to support it + +## License + +Same as the main Diffulex project. + diff --git a/diffulex_profiler/__init__.py b/diffulex_profiler/__init__.py new file mode 100644 index 0000000..67c812a --- /dev/null +++ b/diffulex_profiler/__init__.py @@ -0,0 +1,41 @@ +""" +Diffulex Profiler - Modular profiling framework for performance analysis of Diffulex inference engine +""" + +from diffulex_profiler.profiler import DiffulexProfiler, ProfilerConfig +from diffulex_profiler.metrics import ( + PerformanceMetrics, + collect_gpu_metrics, + collect_cpu_metrics, + collect_memory_metrics, +) +from diffulex_profiler.backends import ( + ProfilerBackend, + SimpleTimerBackend, + VizTracerBackend, + PyTorchProfilerBackend, +) +from diffulex_profiler.exporters import ( + ProfilerExporter, + JSONExporter, + CSVExporter, + SummaryExporter, +) + +__all__ = [ + "DiffulexProfiler", + "ProfilerConfig", + "PerformanceMetrics", + "collect_gpu_metrics", + "collect_cpu_metrics", + "collect_memory_metrics", + "ProfilerBackend", + "SimpleTimerBackend", + "VizTracerBackend", + "PyTorchProfilerBackend", + "ProfilerExporter", + "JSONExporter", + "CSVExporter", + "SummaryExporter", +] + diff --git a/diffulex_profiler/backends/__init__.py b/diffulex_profiler/backends/__init__.py new file mode 100644 index 0000000..65bdb2c --- /dev/null +++ b/diffulex_profiler/backends/__init__.py @@ -0,0 +1,24 @@ +""" +Profiling backends for different profiling tools. +""" +from diffulex_profiler.backends.base import ProfilerBackend +from diffulex_profiler.backends.simple import SimpleTimerBackend + +__all__ = [ + "ProfilerBackend", + "SimpleTimerBackend", +] + +# Optional backends +try: + from diffulex_profiler.backends.viztracer import VizTracerBackend + __all__.append("VizTracerBackend") +except ImportError: + pass + +try: + from diffulex_profiler.backends.pytorch import PyTorchProfilerBackend + __all__.append("PyTorchProfilerBackend") +except ImportError: + pass + diff --git a/diffulex_profiler/backends/base.py b/diffulex_profiler/backends/base.py new file mode 100644 index 0000000..ed77513 --- /dev/null +++ b/diffulex_profiler/backends/base.py @@ -0,0 +1,30 @@ +""" +Base class for profiling backends. +""" +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any + + +class ProfilerBackend(ABC): + """Abstract base class for profiling backends.""" + + @abstractmethod + def start(self, name: str) -> None: + """Start profiling a section.""" + pass + + @abstractmethod + def stop(self) -> Optional[Dict[str, Any]]: + """Stop profiling and return collected data.""" + pass + + @abstractmethod + def __enter__(self): + """Context manager entry.""" + pass + + @abstractmethod + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + pass + diff --git a/diffulex_profiler/backends/pytorch.py b/diffulex_profiler/backends/pytorch.py new file mode 100644 index 0000000..1a4dc59 --- /dev/null +++ b/diffulex_profiler/backends/pytorch.py @@ -0,0 +1,147 @@ +""" +PyTorch Profiler backend. +""" +from typing import Optional, Dict, Any +from pathlib import Path + +try: + import torch + from torch.profiler import profile, record_function, ProfilerActivity + PYTORCH_PROFILER_AVAILABLE = True +except ImportError: + PYTORCH_PROFILER_AVAILABLE = False + profile = None + record_function = None + ProfilerActivity = None + +from diffulex_profiler.backends.base import ProfilerBackend +from diffulex.logger import get_logger + +logger = get_logger(__name__) + + +class PyTorchProfilerBackend(ProfilerBackend): + """PyTorch Profiler-based backend for GPU/CPU operation profiling.""" + + def __init__( + self, + output_dir: Optional[str] = None, + activities: Optional[list] = None, + *, + export_stacks: bool = True, + stacks_metric: str = "self_cuda_time_total", + export_table: bool = True, + table_sort_by: Optional[str] = None, + table_row_limit: int = 50, + **kwargs, + ): + if not PYTORCH_PROFILER_AVAILABLE: + raise ImportError("PyTorch Profiler is not available") + + self.output_dir = Path(output_dir) if output_dir else Path("log/profiles") + self.output_dir.mkdir(parents=True, exist_ok=True) + + if activities is None: + activities = [ProfilerActivity.CPU] + if torch.cuda.is_available(): + activities.append(ProfilerActivity.CUDA) + + self.activities = activities + self.export_stacks = export_stacks + self.stacks_metric = stacks_metric + self.export_table = export_table + self.table_sort_by = table_sort_by + self.table_row_limit = table_row_limit + self.config = kwargs + self.profiler: Optional[profile] = None + self.current_name: Optional[str] = None + + def start(self, name: str) -> None: + """Start PyTorch Profiler.""" + if self.profiler is not None: + logger.warning("PyTorch Profiler already started, stopping previous instance") + self.stop() + + self.current_name = name + # Remove explicitly set parameters from config to avoid conflicts + config_filtered = {k: v for k, v in self.config.items() + if k not in ('record_shapes', 'profile_memory', 'with_stack', 'activities')} + self.profiler = profile( + activities=self.activities, + record_shapes=True, + profile_memory=True, + with_stack=True, + **config_filtered + ) + self.profiler.__enter__() + + def stop(self) -> Optional[Dict[str, Any]]: + """Stop PyTorch Profiler and export artifacts (trace/stacks/table).""" + if self.profiler is None: + return None + + self.profiler.__exit__(None, None, None) + + trace_file = self.output_dir / f"pytorch_trace_{self.current_name}.json" + stacks_file = self.output_dir / f"pytorch_stacks_{self.current_name}.stacks" + table_file = self.output_dir / f"pytorch_top_{self.current_name}.txt" + try: + self.profiler.export_chrome_trace(str(trace_file)) + except Exception as e: + logger.warning(f"Failed to export PyTorch trace: {e}") + trace_file = None + + # Export stacks for flamegraph (Brendan Gregg format). + if self.export_stacks: + try: + metric = self.stacks_metric + # If user requested a CUDA metric but CUDA isn't available, fall back to CPU. + if (not torch.cuda.is_available()) and ("cuda" in metric): + metric = "self_cpu_time_total" + self.profiler.export_stacks(str(stacks_file), metric) + except Exception as e: + logger.warning(f"Failed to export PyTorch stacks: {e}") + stacks_file = None + + # Export top table for quick inspection. + if self.export_table: + try: + sort_by = self.table_sort_by + if not sort_by: + sort_by = "self_cuda_time_total" if torch.cuda.is_available() else "self_cpu_time_total" + top = self.profiler.key_averages().table(sort_by=sort_by, row_limit=int(self.table_row_limit)) + table_file.write_text(top, encoding="utf-8") + except Exception as e: + logger.warning(f"Failed to export PyTorch top table: {e}") + table_file = None + + result = { + "backend": "pytorch", + "trace_file": str(trace_file) if trace_file else None, + "stacks_file": str(stacks_file) if stacks_file else None, + "top_table_file": str(table_file) if table_file else None, + "name": self.current_name, + } + + try: + events = self.profiler.key_averages() + result["summary"] = { + "total_events": len(events), + "cpu_time_total_ms": sum(e.cpu_time_total_us for e in events) / 1000, + "cuda_time_total_ms": sum(e.cuda_time_total_us for e in events) / 1000 if torch.cuda.is_available() else 0, + } + except Exception as e: + logger.warning(f"Failed to get profiler summary: {e}") + + self.profiler = None + self.current_name = None + + return result + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.profiler is not None: + self.stop() + diff --git a/diffulex_profiler/backends/simple.py b/diffulex_profiler/backends/simple.py new file mode 100644 index 0000000..c4128f2 --- /dev/null +++ b/diffulex_profiler/backends/simple.py @@ -0,0 +1,44 @@ +""" +Simple timer-based profiling backend. +""" +import time +from typing import Optional, Dict, Any + +from diffulex_profiler.backends.base import ProfilerBackend + + +class SimpleTimerBackend(ProfilerBackend): + """Simple timer-based profiling backend that only tracks time.""" + + def __init__(self): + self.start_time: Optional[float] = None + self.current_name: Optional[str] = None + + def start(self, name: str) -> None: + """Start timing.""" + self.current_name = name + self.start_time = time.perf_counter() + + def stop(self) -> Optional[Dict[str, Any]]: + """Stop timing and return duration.""" + if self.start_time is None: + return None + + duration = time.perf_counter() - self.start_time + result = { + "duration_sec": duration, + "name": self.current_name, + } + + self.start_time = None + self.current_name = None + + return result + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.start_time is not None: + self.stop() + diff --git a/diffulex_profiler/backends/viztracer.py b/diffulex_profiler/backends/viztracer.py new file mode 100644 index 0000000..22cf38e --- /dev/null +++ b/diffulex_profiler/backends/viztracer.py @@ -0,0 +1,75 @@ +""" +VizTracer profiling backend. +""" +from typing import Optional, Dict, Any +from pathlib import Path + +try: + from viztracer import VizTracer + VIZTRACER_AVAILABLE = True +except ImportError: + VIZTRACER_AVAILABLE = False + VizTracer = None + +from diffulex_profiler.backends.base import ProfilerBackend +from diffulex.logger import get_logger + +logger = get_logger(__name__) + + +class VizTracerBackend(ProfilerBackend): + """VizTracer-based profiling backend for detailed function call tracing.""" + + def __init__(self, output_file: Optional[str] = None, output_dir: Optional[str] = None, **kwargs): + if not VIZTRACER_AVAILABLE: + raise ImportError("VizTracer is not installed. Install it with: pip install viztracer") + + self.output_file = output_file + self.output_dir = output_dir + self.tracer: Optional[VizTracer] = None + self.config = kwargs + + def start(self, name: str) -> None: + """Start VizTracer.""" + if self.tracer is not None: + logger.warning("VizTracer already started, stopping previous instance") + self.stop() + + if self.output_file: + output_file = self.output_file + else: + output_file = f"viztracer_{name}.json" + + # If output_dir is specified, prepend it to the output_file path + if self.output_dir: + output_file = str(Path(self.output_dir) / Path(output_file).name) + # Ensure output directory exists + Path(self.output_dir).mkdir(parents=True, exist_ok=True) + + logger.info(f"VizTracer output file: {output_file}") + self.tracer = VizTracer(output_file=output_file, **self.config) + self.tracer.start() + + def stop(self) -> Optional[Dict[str, Any]]: + """Stop VizTracer and return trace file path.""" + if self.tracer is None: + return None + + self.tracer.stop() + output_file = self.tracer.output_file + + result = { + "backend": "viztracer", + "output_file": str(output_file), + } + + self.tracer = None + return result + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.tracer is not None: + self.stop() + diff --git a/diffulex_profiler/example.py b/diffulex_profiler/example.py new file mode 100644 index 0000000..64e07f5 --- /dev/null +++ b/diffulex_profiler/example.py @@ -0,0 +1,141 @@ +""" +Example usage of Diffulex Profiler. + +This example demonstrates how to use the profiler to collect performance metrics +during Diffulex inference. +""" +from diffulex_profiler import DiffulexProfiler, ProfilerConfig +from diffulex import Diffulex, SamplingParams + + +def example_basic_usage(): + """Basic profiling example.""" + # Initialize profiler + profiler = DiffulexProfiler( + config=ProfilerConfig( + enabled=True, + backend="simple", + output_dir="log/profiles", + collect_gpu_metrics=True, + collect_memory_metrics=True, + ) + ) + + # Initialize Diffulex engine + model_path = "/path/to/your/model" + llm = Diffulex( + model_path, + model_name="dream", + tensor_parallel_size=1, + data_parallel_size=1, + gpu_memory_utilization=0.25, + max_model_len=2048, + decoding_strategy="d2f", + ) + + # Prepare prompts + prompts = ["What is 2+2?", "Explain quantum computing"] + sampling_params = SamplingParams(temperature=0.0, max_tokens=256) + + # Profile inference + with profiler.profile("inference", metadata={"num_prompts": len(prompts)}): + outputs = llm.generate(prompts, sampling_params) + total_tokens = sum(len(o['token_ids']) for o in outputs) + profiler.record_throughput(total_tokens) + profiler.record_metric("num_outputs", len(outputs)) + profiler.record_metric("avg_diff_steps", + sum(o['n_diff_steps'] for o in outputs) / len(outputs)) + + # Export results + profiler.export("inference_profile.json") + + # Get summary + summary = profiler.get_summary() + print(f"Total duration: {summary['total_duration_sec']:.2f}s") + print(f"Average throughput: {summary['avg_throughput_tokens_per_sec']:.2f} tok/s") + + +def example_multiple_sections(): + """Example with multiple profiling sections.""" + profiler = DiffulexProfiler( + config=ProfilerConfig( + enabled=True, + backend="simple", + export_formats=["json", "csv", "summary"] + ) + ) + + # Profile model loading + with profiler.profile("model_loading"): + model_path = "/path/to/your/model" + llm = Diffulex( + model_path, + model_name="dream", + tensor_parallel_size=1, + data_parallel_size=1, + gpu_memory_utilization=0.25, + max_model_len=2048, + decoding_strategy="d2f", + ) + + # Profile prefill + prompts = ["Prompt 1", "Prompt 2"] + with profiler.profile("prefill", metadata={"num_prompts": len(prompts)}): + # Prefill operations + pass + + # Profile decode + with profiler.profile("decode"): + outputs = llm.generate(prompts, SamplingParams()) + profiler.record_throughput(sum(len(o['token_ids']) for o in outputs)) + + # Export all results + profiler.export("multi_section_profile.json") + + +def example_viztracer_backend(): + """Example using VizTracer backend for detailed tracing.""" + profiler = DiffulexProfiler( + config=ProfilerConfig( + enabled=True, + backend="viztracer", + viztracer_config={ + "output_file": "trace.json", + "file_info": True, + } + ) + ) + + with profiler.profile("detailed_trace"): + # Your code here + pass + + profiler.export() + + +def example_pytorch_profiler(): + """Example using PyTorch Profiler for GPU/CPU profiling.""" + from torch.profiler import ProfilerActivity + + profiler = DiffulexProfiler( + config=ProfilerConfig( + enabled=True, + backend="pytorch", + pytorch_profiler_config={ + "activities": [ProfilerActivity.CPU, ProfilerActivity.CUDA], + "record_shapes": True, + "profile_memory": True, + } + ) + ) + + with profiler.profile("gpu_profiling"): + # Your code here + pass + + profiler.export() + + +if __name__ == "__main__": + example_basic_usage() + diff --git a/diffulex_profiler/exporters/__init__.py b/diffulex_profiler/exporters/__init__.py new file mode 100644 index 0000000..a0019f4 --- /dev/null +++ b/diffulex_profiler/exporters/__init__.py @@ -0,0 +1,19 @@ +""" +Exporters for profiling results. +""" +from diffulex_profiler.exporters.base import ProfilerExporter +from diffulex_profiler.exporters.json import JSONExporter +from diffulex_profiler.exporters.summary import SummaryExporter + +__all__ = [ + "ProfilerExporter", + "JSONExporter", + "SummaryExporter", +] + +try: + from diffulex_profiler.exporters.csv import CSVExporter + __all__.append("CSVExporter") +except ImportError: + pass + diff --git a/diffulex_profiler/exporters/base.py b/diffulex_profiler/exporters/base.py new file mode 100644 index 0000000..07badad --- /dev/null +++ b/diffulex_profiler/exporters/base.py @@ -0,0 +1,24 @@ +""" +Base class for profiler exporters. +""" +from abc import ABC, abstractmethod +from pathlib import Path +from typing import List + +from diffulex_profiler.metrics import PerformanceMetrics + + +class ProfilerExporter(ABC): + """Abstract base class for exporting profiling results.""" + + @abstractmethod + def export(self, metrics: List[PerformanceMetrics], output_path: Path) -> None: + """ + Export metrics to a file. + + Args: + metrics: List of performance metrics to export + output_path: Base path for output (exporter may add extension) + """ + pass + diff --git a/diffulex_profiler/exporters/csv.py b/diffulex_profiler/exporters/csv.py new file mode 100644 index 0000000..ee26767 --- /dev/null +++ b/diffulex_profiler/exporters/csv.py @@ -0,0 +1,46 @@ +""" +CSV exporter for profiling results. +""" +import csv +from pathlib import Path +from typing import List + +from diffulex_profiler.exporters.base import ProfilerExporter +from diffulex_profiler.metrics import PerformanceMetrics + + +class CSVExporter(ProfilerExporter): + """Export profiling results to CSV format.""" + + def export(self, metrics: List[PerformanceMetrics], output_path: Path) -> None: + """Export metrics to CSV file.""" + output_file = output_path.with_suffix(".csv") + + if not metrics: + return + + fieldnames = set(["name", "duration_sec", "total_tokens", "throughput_tokens_per_sec"]) + + for m in metrics: + fieldnames.update(m.custom_metrics.keys()) + if m.metadata: + fieldnames.update(f"metadata_{k}" for k in m.metadata.keys()) + + fieldnames = sorted(list(fieldnames)) + + with open(output_file, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + + for m in metrics: + row = { + "name": m.name, + "duration_sec": m.duration, + "total_tokens": m.total_tokens, + "throughput_tokens_per_sec": m.throughput_tokens_per_sec, + } + row.update(m.custom_metrics) + for k, v in m.metadata.items(): + row[f"metadata_{k}"] = v + writer.writerow(row) + diff --git a/diffulex_profiler/exporters/json.py b/diffulex_profiler/exporters/json.py new file mode 100644 index 0000000..19fc641 --- /dev/null +++ b/diffulex_profiler/exporters/json.py @@ -0,0 +1,43 @@ +""" +JSON exporter for profiling results. +""" +import json +from pathlib import Path +from typing import List + +from diffulex_profiler.exporters.base import ProfilerExporter +from diffulex_profiler.metrics import PerformanceMetrics + + +class JSONExporter(ProfilerExporter): + """Export profiling results to JSON format.""" + + def export(self, metrics: List[PerformanceMetrics], output_path: Path) -> None: + """Export metrics to JSON file.""" + output_file = output_path.with_suffix(".json") + + data = { + "metrics": [m.to_dict() for m in metrics], + "summary": self._compute_summary(metrics), + } + + with open(output_file, "w") as f: + json.dump(data, f, indent=2) + + def _compute_summary(self, metrics: List[PerformanceMetrics]) -> dict: + """Compute summary statistics.""" + if not metrics: + return {} + + total_duration = sum(m.duration for m in metrics if m.duration) + total_tokens = sum(m.total_tokens for m in metrics if m.total_tokens) + + return { + "total_sections": len(metrics), + "total_duration_sec": total_duration, + "total_tokens": total_tokens, + "avg_throughput_tokens_per_sec": ( + total_tokens / total_duration if total_duration > 0 else 0 + ), + } + diff --git a/diffulex_profiler/exporters/summary.py b/diffulex_profiler/exporters/summary.py new file mode 100644 index 0000000..4569402 --- /dev/null +++ b/diffulex_profiler/exporters/summary.py @@ -0,0 +1,75 @@ +""" +Summary exporter for profiling results (human-readable text output). +""" +from pathlib import Path +from typing import List + +from diffulex_profiler.exporters.base import ProfilerExporter +from diffulex_profiler.metrics import PerformanceMetrics +from diffulex.logger import get_logger + +logger = get_logger(__name__) + + +class SummaryExporter(ProfilerExporter): + """Export profiling results as a human-readable summary.""" + + def export(self, metrics: List[PerformanceMetrics], output_path: Path) -> None: + """Export metrics as a text summary.""" + output_file = output_path.with_suffix(".txt") + + summary_lines = [] + summary_lines.append("=" * 80) + summary_lines.append("Diffulex Profiling Summary") + summary_lines.append("=" * 80) + summary_lines.append("") + + total_duration = sum(m.duration for m in metrics if m.duration) + total_tokens = sum(m.total_tokens for m in metrics if m.total_tokens) + avg_throughput = ( + total_tokens / total_duration if total_duration > 0 and total_tokens > 0 else 0 + ) + + summary_lines.append(f"Total Sections: {len(metrics)}") + summary_lines.append(f"Total Duration: {total_duration:.2f} seconds") + summary_lines.append(f"Total Tokens: {total_tokens}") + summary_lines.append(f"Average Throughput: {avg_throughput:.2f} tokens/sec") + summary_lines.append("") + + summary_lines.append("-" * 80) + summary_lines.append("Section Details:") + summary_lines.append("-" * 80) + + for m in metrics: + summary_lines.append(f"\nSection: {m.name}") + summary_lines.append(f" Duration: {m.duration:.4f} seconds") + if m.total_tokens > 0: + summary_lines.append(f" Tokens: {m.total_tokens}") + summary_lines.append(f" Throughput: {m.throughput_tokens_per_sec:.2f} tokens/sec") + if m.gpu_utilization != 0: + summary_lines.append(f" GPU Utilization: {m.gpu_utilization:.2f}%") + if m.memory_delta_mb != 0: + summary_lines.append(f" Memory Delta: {m.memory_delta_mb:.2f} MB") + if m.custom_metrics: + summary_lines.append(f" Custom Metrics: {m.custom_metrics}") + if m.metadata: + summary_lines.append(f" Metadata: {m.metadata}") + if m.backend_data and m.backend_data.get("backend") == "viztracer": + output_file = m.backend_data.get("output_file", "N/A") + summary_lines.append(f" VizTracer Output: {output_file}") + if m.backend_data and m.backend_data.get("backend") == "pytorch": + trace_file = m.backend_data.get("trace_file", "N/A") + stacks_file = m.backend_data.get("stacks_file", "N/A") + top_table_file = m.backend_data.get("top_table_file", "N/A") + summary_lines.append(f" PyTorch Trace: {trace_file}") + summary_lines.append(f" PyTorch Stacks: {stacks_file}") + summary_lines.append(f" PyTorch Top Table: {top_table_file}") + + summary_lines.append("") + summary_lines.append("=" * 80) + + with open(output_file, "w") as f: + f.write("\n".join(summary_lines)) + + logger.info("\n".join(summary_lines)) + diff --git a/diffulex_profiler/metrics.py b/diffulex_profiler/metrics.py new file mode 100644 index 0000000..f3678ed --- /dev/null +++ b/diffulex_profiler/metrics.py @@ -0,0 +1,113 @@ +""" +Performance metrics collection and data structures. +""" +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import Dict, Any, Optional + +import torch + +try: + import psutil + PSUTIL_AVAILABLE = True +except ImportError: + PSUTIL_AVAILABLE = False + + +@dataclass +class PerformanceMetrics: + """Container for performance metrics collected during profiling.""" + name: str + metadata: Dict[str, Any] = field(default_factory=dict) + start_time: float = 0.0 + end_time: float = 0.0 + duration: float = 0.0 + total_tokens: int = 0 + throughput_tokens_per_sec: float = 0.0 + gpu_metrics_start: Optional[Dict[str, Any]] = None + gpu_metrics_end: Optional[Dict[str, Any]] = None + gpu_utilization: float = 0.0 + memory_metrics_start: Optional[Dict[str, Any]] = None + memory_metrics_end: Optional[Dict[str, Any]] = None + memory_delta_mb: float = 0.0 + custom_metrics: Dict[str, Any] = field(default_factory=dict) + backend_data: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert metrics to dictionary for serialization.""" + return { + "name": self.name, + "metadata": self.metadata, + "duration_sec": self.duration, + "total_tokens": self.total_tokens, + "throughput_tokens_per_sec": self.throughput_tokens_per_sec, + "gpu_utilization": self.gpu_utilization, + "memory_delta_mb": self.memory_delta_mb, + "custom_metrics": self.custom_metrics, + "backend_data": self.backend_data, + } + + +def collect_gpu_metrics() -> Dict[str, Any]: + """Collect current GPU metrics.""" + if not torch.cuda.is_available(): + return {} + + metrics = {} + try: + device = torch.cuda.current_device() + metrics["device"] = device + metrics["device_name"] = torch.cuda.get_device_name(device) + + memory_stats = torch.cuda.memory_stats(device) + metrics["allocated_mb"] = memory_stats.get("allocated_bytes.all.current", 0) / (1024 ** 2) + metrics["reserved_mb"] = memory_stats.get("reserved_bytes.all.current", 0) / (1024 ** 2) + metrics["peak_allocated_mb"] = memory_stats.get("allocated_bytes.all.peak", 0) / (1024 ** 2) + + try: + import pynvml + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + util = pynvml.nvmlDeviceGetUtilizationRates(handle) + metrics["utilization"] = util.gpu + metrics["memory_utilization"] = util.memory + except (ImportError, Exception): + pass + + except Exception: + pass + + return metrics + + +def collect_cpu_metrics() -> Dict[str, Any]: + """Collect current CPU metrics.""" + if not PSUTIL_AVAILABLE: + return {} + try: + return { + "cpu_percent": psutil.cpu_percent(interval=0.1), + "cpu_count": psutil.cpu_count(), + "load_avg": psutil.getloadavg() if hasattr(psutil, "getloadavg") else None, + } + except Exception: + return {} + + +def collect_memory_metrics() -> Dict[str, Any]: + """Collect current memory metrics.""" + if not PSUTIL_AVAILABLE: + return {} + try: + mem = psutil.virtual_memory() + return { + "total_mb": mem.total / (1024 ** 2), + "available_mb": mem.available / (1024 ** 2), + "used_mb": mem.used / (1024 ** 2), + "percent": mem.percent, + } + except Exception: + return {} + diff --git a/diffulex_profiler/profiler.py b/diffulex_profiler/profiler.py new file mode 100644 index 0000000..a165dcb --- /dev/null +++ b/diffulex_profiler/profiler.py @@ -0,0 +1,261 @@ +""" +Core profiler implementation for Diffulex. +""" +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import Any, Optional, Dict, List +from contextlib import contextmanager +from pathlib import Path + +import torch + +from diffulex_profiler.metrics import PerformanceMetrics, collect_gpu_metrics, collect_memory_metrics +from diffulex_profiler.backends import ProfilerBackend, SimpleTimerBackend +from diffulex_profiler.exporters import ProfilerExporter, JSONExporter, SummaryExporter +from diffulex.logger import get_logger + +logger = get_logger(__name__) + + +@dataclass +class ProfilerConfig: + """Configuration for the profiler.""" + enabled: bool = True + backend: str = "simple" # "simple", "viztracer", "pytorch" + output_dir: str = "log/profiles" + output_file: Optional[str] = None + collect_gpu_metrics: bool = True + collect_memory_metrics: bool = True + collect_timing: bool = True + export_formats: List[str] = field(default_factory=lambda: ["json", "summary"]) + viztracer_config: Optional[Dict[str, Any]] = None + pytorch_profiler_config: Optional[Dict[str, Any]] = None + + +class DiffulexProfiler: + """ + Main profiler class for collecting performance metrics during Diffulex inference. + + Example: + >>> profiler = DiffulexProfiler(config=ProfilerConfig(enabled=True)) + >>> with profiler.profile("inference"): + ... outputs = llm.generate(prompts, sampling_params) + >>> profiler.export("log/profiles/result.json") + """ + + def __init__(self, config: Optional[ProfilerConfig] = None): + self.config = config or ProfilerConfig() + self.metrics: List[PerformanceMetrics] = [] + self.current_metrics: Optional[PerformanceMetrics] = None + self.backend: Optional[ProfilerBackend] = None + self.exporters: List[ProfilerExporter] = [] + + if not self.config.enabled: + return + + self._init_backend() + self._init_exporters() + Path(self.config.output_dir).mkdir(parents=True, exist_ok=True) + + def _init_backend(self): + """Initialize the profiling backend.""" + if self.config.backend == "simple": + self.backend = SimpleTimerBackend() + elif self.config.backend == "viztracer": + try: + from diffulex_profiler.backends import VizTracerBackend + viztracer_config = self.config.viztracer_config or {} + # Pass output_dir to VizTracerBackend so it can save files in the correct location + if "output_dir" not in viztracer_config: + viztracer_config["output_dir"] = self.config.output_dir + self.backend = VizTracerBackend(**viztracer_config) + except ImportError: + logger.warning("VizTracer not available, falling back to simple timer") + self.backend = SimpleTimerBackend() + elif self.config.backend == "pytorch": + try: + from diffulex_profiler.backends import PyTorchProfilerBackend + pytorch_config = self.config.pytorch_profiler_config or {} + # Keep output dir consistent across backends. + if "output_dir" not in pytorch_config: + pytorch_config["output_dir"] = self.config.output_dir + self.backend = PyTorchProfilerBackend(**pytorch_config) + except ImportError: + logger.warning("PyTorch Profiler not available, falling back to simple timer") + self.backend = SimpleTimerBackend() + else: + logger.warning(f"Unknown backend '{self.config.backend}', using simple timer") + self.backend = SimpleTimerBackend() + + def _init_exporters(self): + """Initialize exporters based on config.""" + for fmt in self.config.export_formats: + if fmt == "json": + self.exporters.append(JSONExporter()) + elif fmt == "csv": + from diffulex_profiler.exporters import CSVExporter + self.exporters.append(CSVExporter()) + elif fmt == "summary": + self.exporters.append(SummaryExporter()) + else: + logger.warning(f"Unknown export format '{fmt}', skipping") + + @contextmanager + def profile(self, name: str, metadata: Optional[Dict[str, Any]] = None): + """ + Context manager for profiling a code block. + + Args: + name: Name of the profiling section + metadata: Optional metadata to attach to the metrics + + Example: + >>> with profiler.profile("model_forward", {"batch_size": 32}): + ... output = model(input_ids) + """ + if not self.config.enabled: + yield + return + + self.start(name, metadata) + try: + yield + finally: + self.stop() + + def start(self, name: str, metadata: Optional[Dict[str, Any]] = None): + """Start profiling a section.""" + if not self.config.enabled: + return + + self.current_metrics = PerformanceMetrics( + name=name, + metadata=metadata or {}, + ) + + if self.config.collect_timing: + self.current_metrics.start_time = time.perf_counter() + + if self.backend: + self.backend.start(name) + + if self.config.collect_gpu_metrics and torch.cuda.is_available(): + self.current_metrics.gpu_metrics_start = collect_gpu_metrics() + + if self.config.collect_memory_metrics: + self.current_metrics.memory_metrics_start = collect_memory_metrics() + + def stop(self): + """Stop profiling the current section.""" + if not self.config.enabled or self.current_metrics is None: + return + + if self.config.collect_timing: + self.current_metrics.end_time = time.perf_counter() + self.current_metrics.duration = ( + self.current_metrics.end_time - self.current_metrics.start_time + ) + + if self.backend: + backend_data = self.backend.stop() + if backend_data: + self.current_metrics.backend_data = backend_data + + if self.config.collect_gpu_metrics and torch.cuda.is_available(): + self.current_metrics.gpu_metrics_end = collect_gpu_metrics() + if self.current_metrics.gpu_metrics_start and self.current_metrics.gpu_metrics_end: + self.current_metrics.gpu_utilization = ( + self.current_metrics.gpu_metrics_end.get("utilization", 0) - + self.current_metrics.gpu_metrics_start.get("utilization", 0) + ) + + if self.config.collect_memory_metrics: + self.current_metrics.memory_metrics_end = collect_memory_metrics() + if (self.current_metrics.memory_metrics_start and + self.current_metrics.memory_metrics_end): + start_mem = self.current_metrics.memory_metrics_start.get("used_mb", 0) + end_mem = self.current_metrics.memory_metrics_end.get("used_mb", 0) + self.current_metrics.memory_delta_mb = end_mem - start_mem + + self.metrics.append(self.current_metrics) + self.current_metrics = None + + def record_metric(self, name: str, value: Any): + """Record a custom metric.""" + if not self.config.enabled or self.current_metrics is None: + return + self.current_metrics.custom_metrics[name] = value + + def record_throughput(self, tokens: int, duration: Optional[float] = None): + """Record throughput in tokens per second.""" + if not self.config.enabled or self.current_metrics is None: + return + if duration is None: + duration = self.current_metrics.duration + if duration and duration > 0: + self.current_metrics.throughput_tokens_per_sec = tokens / duration + self.current_metrics.total_tokens = tokens + + def export(self, output_path: Optional[str] = None): + """ + Export profiling results using configured exporters. + + Args: + output_path: Optional custom output path. If not provided, uses config output_file + or generates one based on timestamp. + """ + if not self.config.enabled or not self.metrics: + logger.warning("No metrics to export") + return + + if output_path is None: + if self.config.output_file: + output_path = str(Path(self.config.output_dir) / self.config.output_file) + else: + timestamp = time.strftime("%Y%m%d_%H%M%S") + output_path = str(Path(self.config.output_dir) / f"profile_{timestamp}") + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + for exporter in self.exporters: + try: + exporter.export(self.metrics, output_path) + except Exception as e: + logger.error(f"Failed to export using {exporter.__class__.__name__}: {e}") + + def get_summary(self) -> Dict[str, Any]: + """Get a summary of all collected metrics.""" + if not self.metrics: + return {} + + total_duration = sum(m.duration for m in self.metrics if m.duration) + total_tokens = sum(m.total_tokens for m in self.metrics if m.total_tokens) + avg_throughput = ( + total_tokens / total_duration + if total_duration > 0 and total_tokens > 0 + else 0 + ) + + return { + "total_sections": len(self.metrics), + "total_duration_sec": total_duration, + "total_tokens": total_tokens, + "avg_throughput_tokens_per_sec": avg_throughput, + "sections": [ + { + "name": m.name, + "duration_sec": m.duration, + "throughput_tokens_per_sec": m.throughput_tokens_per_sec, + "total_tokens": m.total_tokens, + } + for m in self.metrics + ], + } + + def clear(self): + """Clear all collected metrics.""" + self.metrics.clear() + self.current_metrics = None \ No newline at end of file diff --git a/docs/make.bat b/docs/make.bat index 2034948..51d3652 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -1,35 +1,35 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set SOURCEDIR=. -set BUILDDIR=_build - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.https://www.sphinx-doc.org/ - exit /b 1 -) - -if "%1" == "" goto help - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% - -:end -popd +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/examples/test_bf16_kernel_e2e.py b/examples/test_bf16_kernel_e2e.py new file mode 100755 index 0000000..f61f0a6 --- /dev/null +++ b/examples/test_bf16_kernel_e2e.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +"""简单的端到端测试脚本,验证 BF16 kernel 功能""" +import os +import time + +from transformers import AutoTokenizer +from diffulex import Diffulex, SamplingParams + + +def main(): + # 模型配置 + model = "/data1/ckpts/Dream-org/Dream-v0-Base-7B" + + print("=" * 60) + print("初始化 Diffulex 模型...") + print("=" * 60) + + llm = Diffulex( + model, + lora_path="/data1/ckpts/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora", + use_lora=True, + model_name="dream", + enforce_eager=True, + data_parallel_size=1, + tensor_parallel_size=1, + gpu_memory_utilization=0.25, + max_num_batched_tokens=2048, + max_num_seqs=10, + max_model_len=2048, + accept_threshold=0.95, + complete_threshold=0.9, + add_new_block_threshold=0.1, + kv_cache_layout="unified", + decoding_strategy="d2f" + ) + + print("✓ 模型初始化完成\n") + + # 加载 tokenizer + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + sampling_params = SamplingParams(temperature=0.0, max_tokens=128) + + # 定义几个测试 prompt + test_prompts = [ + "The capital of France is", + "1 + 1 equals", + "Python is a programming language that", + ] + + # 添加 BOS token + prompts = [tokenizer.bos_token + p for p in test_prompts] + + print("=" * 60) + print(f"运行生成测试 ({len(prompts)} 个 prompt)...") + print("=" * 60) + + start_time = time.time() + outputs = llm.generate(prompts, sampling_params) + end_time = time.time() + + print("\n" + "=" * 60) + print("生成结果:") + print("=" * 60) + + total_tokens = sum(len(o['token_ids']) for o in outputs) + total_time = end_time - start_time + avg_tps = total_tokens / total_time if total_time > 0 else 0 + avg_diff_steps = sum(o['n_diff_steps'] for o in outputs) / len(outputs) if outputs else 0 + + print(f"\n总计:") + print(f" - 生成输出数: {len(outputs)}") + print(f" - 总 token 数: {total_tokens}") + print(f" - 总时间: {total_time:.2f} 秒") + print(f" - 平均 TPS: {avg_tps:.2f} tok/s") + print(f" - 平均扩散步数: {avg_diff_steps:.2f}") + + print("\n" + "=" * 60) + print("详细输出:") + print("=" * 60) + + for idx, (prompt, output) in enumerate(zip(test_prompts, outputs)): + print(f"\n[Prompt {idx + 1}]") + print(f"输入: {prompt}") + print(f"输出: {output['text']}") + print(f"Token IDs 长度: {len(output['token_ids'])}") + print(f"扩散步数: {output['n_diff_steps']}") + print("-" * 60) + + print("\n✓ BF16 kernel 端到端测试完成!") + + +if __name__ == "__main__": + main() + diff --git a/examples/test_dllm_kv_cache_store.py b/examples/test_dllm_kv_cache_store.py index b2b4113..9e58351 100755 --- a/examples/test_dllm_kv_cache_store.py +++ b/examples/test_dllm_kv_cache_store.py @@ -3,7 +3,7 @@ from einops import rearrange -from diffulex_legacy.layers.attention.attention_v4 import store_kvcache_distinct_layout, store_kvcache_unified +from diffulex_legacy.layers.attention.ops import store_kvcache_distinct_layout, store_kvcache_unified_layout if __name__ == "__main__": @@ -34,7 +34,7 @@ # unified_layout s = time.time() - store_kvcache_unified(k, v, k_cache, v_cache, slot_mapping, model_type='diffusion_lm') + store_kvcache_unified_layout(k, v, k_cache, v_cache, slot_mapping, model_type='diffusion_lm') print(f"Unified layout KV cache stored in {time.time() - s:.4f} seconds.") start_idx = 0 diff --git a/examples/test_dream_dvllm_gsm8k.py b/examples/test_dream_diffulex_gsm8k.py similarity index 80% rename from examples/test_dream_dvllm_gsm8k.py rename to examples/test_dream_diffulex_gsm8k.py index 39fa506..5cb18aa 100755 --- a/examples/test_dream_dvllm_gsm8k.py +++ b/examples/test_dream_diffulex_gsm8k.py @@ -10,30 +10,6 @@ from transformers import AutoTokenizer from diffulex import Diffulex, SamplingParams -import diffulex.model.dream - -def summarize_profiling(csv_path: str) -> dict: - totals = {} - total_nums = {} - avgs = {} - with open(csv_path, 'r', newline='') as f: - reader = csv.dictReader(f) - for row in reader: - for k, v in row.items(): - try: - val = float(v) - except ValueError: - continue - if val != 0.0: - total_nums[k] = total_nums.get(k, 0) + 1 - totals[k] = totals.get(k, 0.0) + val - print(pd.DataFrame([totals]).T) - for k, v in totals.items(): - if k in total_nums and total_nums[k] > 0: - avgs[k] = v / total_nums[k] - else: - avgs[k] = 0.0 - print(pd.DataFrame([avgs]).T) FEW_SHOTS=""" @@ -49,7 +25,7 @@ def summarize_profiling(csv_path: str) -> dict: use_lora=True, model_name="dream", enforce_eager=True, - data_parallel_size=1, + data_parallel_size=8, tensor_parallel_size=1, gpu_memory_utilization=0.25, max_num_batched_tokens=2048, @@ -85,6 +61,7 @@ def summarize_profiling(csv_path: str) -> dict: f"Avg TPS: {sum(len(o['token_ids']) for o in outputs) / (e - s):.2f} tok/s.\n" f"AVG Number of Diffusion Steps: {sum(o['n_diff_steps'] for o in outputs) / len(outputs):.2f}\n", "=*=" * 30) - # for idx, o in enumerate(outputs): - # print("\n", "=*=" * 30) - # print(f"[Prompt {idx} Result] \n{prompts[idx] + "\n----------\n" + o['text']}\n") \ No newline at end of file + for idx, o in enumerate(outputs): + print("\n", "=*=" * 30) + resp = prompts[idx] + "\n----------\n" + o["text"] + print(f"[Prompt {idx} Result]\n{resp}\n") \ No newline at end of file diff --git a/examples/test_dream_dvllm_human_eval.py b/examples/test_dream_dvllm_human_eval.py index 2d95f00..9e72be6 100755 --- a/examples/test_dream_dvllm_human_eval.py +++ b/examples/test_dream_dvllm_human_eval.py @@ -84,4 +84,5 @@ def summarize_profiling(csv_path: str) -> dict: "=*=" * 30) for idx, o in enumerate(outputs): print("\n", "=*=" * 30) - print(f"[Prompt {idx} Result] \n{prompts[idx] + "\n----------\n" + o['text']}\n") \ No newline at end of file + resp = prompts[idx] + "\n----------\n" + o["text"] + print(f"[Prompt {idx} Result]\n{resp}\n") \ No newline at end of file diff --git a/examples/test_fastdllmv2_diffulex_gsm8k.py b/examples/test_fastdllmv2_diffulex_gsm8k.py index 110c090..2aeb222 100755 --- a/examples/test_fastdllmv2_diffulex_gsm8k.py +++ b/examples/test_fastdllmv2_diffulex_gsm8k.py @@ -39,13 +39,14 @@ def summarize_profiling(csv_path: str) -> dict: FEW_SHOTS = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" if __name__ == "__main__": + PROFILE = False model = "/data1/ckpts/Efficient-Large-Model/Fast_dLLM_v2_7B" local_data_path = "/data1/LargeData/gsm8k" LLM = Diffulex( model, use_lora=False, model_name="fast_dllm_v2", - enforce_eager=True, + enforce_eager=False, data_parallel_size=1, tensor_parallel_size=1, gpu_memory_utilization=0.25, @@ -59,21 +60,21 @@ def summarize_profiling(csv_path: str) -> dict: tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) sampling_params = SamplingParams(temperature=0.0, max_tokens=256) - # dataset = load_dataset("gsm8k", "main", split="test")["question"][:10] - dataset = load_dataset(local_data_path, "main", split="test", trust_remote_code=True)["question"][:10] + dataset = load_dataset("gsm8k", "main", split="test")["question"][:15] prompts = [ FEW_SHOTS + f"<|im_start|>user\nQuestion: {question}\nAnswer:<|im_end|>\n<|im_start|>assistant\n" for question in tqdm(dataset) ] - - output_file = "log/profiles/perf_dvllm_fastdllmv2_7B.json" - if os.path.exists(output_file): - os.remove(output_file) - # with VizTracer(output_file=output_file, file_info=True) as tracer: - # outputs = llm.generate(prompts[:5], sampling_params) - # time.sleep(60) s = time.time() - outputs = LLM.generate(prompts, sampling_params) + if PROFILE: + output_file = "log/profiles/perf_dvllm_dream_7B.json" + if os.path.exists(output_file): + os.remove(output_file) + + with VizTracer(output_file=output_file, file_info=True) as tracer: + outputs = LLM.generate(prompts, sampling_params) + else: + outputs = LLM.generate(prompts, sampling_params) e = time.time() print("=*=" * 30, "\nProfiling Results\n", @@ -86,4 +87,5 @@ def summarize_profiling(csv_path: str) -> dict: "=*=" * 30) for idx, o in enumerate(outputs): print("\n", "=*=" * 30) - print(f"[Prompt {idx} Result] \n{prompts[idx] + "\n----------\n" + o['text']}\n") \ No newline at end of file + resp = prompts[idx] + "\n----------\n" + o["text"] + print(f"[Prompt {idx} Result]\n{resp}\n") \ No newline at end of file diff --git a/examples/test_fp8_kernel_e2e.py b/examples/test_fp8_kernel_e2e.py new file mode 100755 index 0000000..481a3dd --- /dev/null +++ b/examples/test_fp8_kernel_e2e.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +"""端到端测试:使用FP8 kernel(在kernel内部进行FP8到BF16转换)""" +import os +import time + +from transformers import AutoTokenizer +from diffulex import Diffulex, SamplingParams + + +def main(): + # 模型配置 + model = "/data1/ckpts/Dream-org/Dream-v0-Base-7B" + + print("=" * 60) + print("初始化 Diffulex 模型 (FP8 KV Cache with FP8 Kernel)...") + print("=" * 60) + + llm = Diffulex( + model, + lora_path="/data1/ckpts/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora", + use_lora=True, + model_name="dream", + enforce_eager=True, + data_parallel_size=1, + tensor_parallel_size=1, + gpu_memory_utilization=0.25, + max_num_batched_tokens=2048, + max_num_seqs=10, + max_model_len=2048, + accept_threshold=0.95, + complete_threshold=0.9, + add_new_block_threshold=0.1, + kv_cache_layout="unified", # FP8 kernel只支持unified layout + kv_cache_dtype="fp8_e4m3", # 使用FP8 KV cache + decoding_strategy="d2f" + ) + + print("✓ 模型初始化完成 (FP8 KV Cache with FP8 Kernel)\n") + + # 加载 tokenizer + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + sampling_params = SamplingParams(temperature=0.0, max_tokens=128) + + # 定义几个测试 prompt + test_prompts = [ + "The capital of France is", + "1 + 1 equals", + "Python is a programming language that", + ] + + # 添加 BOS token + prompts = [tokenizer.bos_token + p for p in test_prompts] + + print("=" * 60) + print(f"运行生成测试 ({len(prompts)} 个 prompt)...") + print("使用FP8 KV cache,FP8 kernel在内部进行转换") + print("=" * 60) + + start_time = time.time() + outputs = llm.generate(prompts, sampling_params) + end_time = time.time() + + print("\n" + "=" * 60) + print("生成结果:") + print("=" * 60) + + total_tokens = sum(len(o['token_ids']) for o in outputs) + total_time = end_time - start_time + avg_tps = total_tokens / total_time if total_time > 0 else 0 + avg_diff_steps = sum(o['n_diff_steps'] for o in outputs) / len(outputs) if outputs else 0 + + print(f"\n总计:") + print(f" - 生成输出数: {len(outputs)}") + print(f" - 总 token 数: {total_tokens}") + print(f" - 总时间: {total_time:.2f} 秒") + print(f" - 平均 TPS: {avg_tps:.2f} tok/s") + print(f" - 平均扩散步数: {avg_diff_steps:.2f}") + + print("\n" + "=" * 60) + print("详细输出:") + print("=" * 60) + + for idx, (prompt, output) in enumerate(zip(test_prompts, outputs)): + print(f"\n[Prompt {idx + 1}]") + print(f"输入: {prompt}") + print(f"输出: {output['text']}") + print(f"Token IDs 长度: {len(output['token_ids'])}") + print(f"扩散步数: {output['n_diff_steps']}") + print("-" * 60) + + print("\n✓ FP8 Kernel 端到端测试完成!") + + +if __name__ == "__main__": + main() + diff --git a/examples/test_fp8_kv_cache_comprehensive.py b/examples/test_fp8_kv_cache_comprehensive.py new file mode 100755 index 0000000..be2fd5b --- /dev/null +++ b/examples/test_fp8_kv_cache_comprehensive.py @@ -0,0 +1,1309 @@ +#!/usr/bin/env python3 +""" +FP8 KV Cache 综合测试脚本 + +该脚本整合了所有 FP8 KV Cache 相关的测试,可以通过命令行参数选择运行哪些测试。 + +测试类别: +1. kernel - Kernel 层 FP8 支持测试(roundtrip tests) +2. integration - Attention layer FP8 集成测试 +3. pipeline - 完整 pipeline 测试(需要模型) +4. memory - 内存使用验证测试(需要模型) +5. speed - 速度对比测试(需要模型) +6. quality - 质量和速度对比测试(需要模型,较耗时) +7. attention_kernel - FP8 attention kernel 单元测试 +8. attention_e2e - FP8 attention kernel 端到端测试(需要模型) +9. attention_numerics - FP8 attention kernel 数值验证测试 +10. all - 运行所有测试(不包括 quality 和 root_cause,因为需要较长时间) + +用法示例: + # 运行所有测试(除了 quality) + python test_fp8_kv_cache_comprehensive.py --tests all + + # 运行特定测试 + python test_fp8_kv_cache_comprehensive.py --tests kernel integration + + # 运行 speed 和 quality 测试 + python test_fp8_kv_cache_comprehensive.py --tests speed quality +""" + +import os +import sys +import argparse +import traceback +import torch +import time +import gc +import numpy as np +from typing import List, Dict, Any, Tuple +from dataclasses import dataclass +from einops import rearrange + +# 添加项目路径 +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, PROJECT_ROOT) + +# 导入必要的模块 +from vllm.platforms import current_platform +from diffulex_legacy.layers.attention.ops import ( + store_kvcache_unified_layout, + store_kvcache_distinct_layout, + load_kvcache +) +from diffulex_legacy.layers.attention.attention_v4 import Attention +from diffulex_legacy.utils.context import ( + set_context_diffusion_lm, + get_context_diffusion_lm, + ContextForDiffusionLM, + ContextForCausalLM, + set_context_causal_lm, + get_context_causal_lm +) +from diffulex_legacy.config import Config +from diffulex_legacy import LLM, SamplingParams +from diffulex.utils.kv_cache_dtype import parse_kv_cache_dtype +from diffulex_legacy.layers.attention.ops.triton_flash_attention import triton_flash_attention +from transformers import AutoTokenizer + + +# ============================================================================ +# 测试辅助函数和类 +# ============================================================================ + +@dataclass +class _Seq: + diffusion_block_size: int = 32 + + +@dataclass +class _Ctx: + seq_lens_ts: torch.Tensor + context_lens: torch.Tensor + total_lens: torch.Tensor + block_tables: torch.Tensor + cu_seqlens_q: torch.Tensor + cu_seqlens_k: torch.Tensor + seq_lens: List[int] = None + seqs: List[_Seq] = None + + def __post_init__(self): + self.seq_lens = self.seq_lens_ts.tolist() + self.seqs = [_Seq()] + + +def _build_cu_seqlens(x: torch.Tensor) -> torch.Tensor: + return torch.tensor( + [0] + list(torch.cumsum(x, dim=0).cpu().numpy()), + dtype=torch.int32, + device="cuda", + ) + + +def get_gpu_memory_info(): + """Get current GPU memory usage in MB.""" + if torch.cuda.is_available(): + allocated = torch.cuda.memory_allocated() / 1024**2 # MB + reserved = torch.cuda.memory_reserved() / 1024**2 # MB + max_allocated = torch.cuda.max_memory_allocated() / 1024**2 # MB + return { + "allocated_mb": allocated, + "reserved_mb": reserved, + "max_allocated_mb": max_allocated, + } + return None + + +# ============================================================================ +# 测试函数 - Kernel 层测试 +# ============================================================================ + +def test_kv_cache_fp8_unified_roundtrip(): + """测试 FP8 unified layout KV cache 的存储和加载往返""" + torch.random.manual_seed(114514) + + num_seqs = 4 + blk_sz = 256 + H = 4 + head_dim = 128 + + seq_lens = torch.tensor([64, 32, 64, 32], dtype=torch.int32, device="cuda") + ctx_lens = torch.tensor([119, 110, 81, 114], dtype=torch.int32, device="cuda") + assert seq_lens.numel() == num_seqs and ctx_lens.numel() == num_seqs + total_lens = seq_lens + ctx_lens + + kv_shape = (int(total_lens.sum().item()), H, head_dim) + k_all = torch.randn(kv_shape, device="cuda", dtype=torch.bfloat16) + v_all = torch.randn_like(k_all) + + slot_mapping: list[int] = [] + start = 0 + for seq_idx in range(num_seqs): + ctx = int(ctx_lens[seq_idx].item()) + new = int(seq_lens[seq_idx].item()) + slot_mapping.extend(list(range(seq_idx * blk_sz, seq_idx * blk_sz + ctx))) + slot_mapping.extend([-1] * new) + start += ctx + new + slot_mapping_ts = torch.tensor(slot_mapping, dtype=torch.int64, device="cuda") + assert slot_mapping_ts.numel() == kv_shape[0] + + kv_cache_shape = (num_seqs, blk_sz, H, head_dim) + k_cache_u8 = torch.zeros(kv_cache_shape, device="cuda", dtype=torch.uint8) + v_cache_u8 = torch.zeros_like(k_cache_u8) + + fp8 = current_platform.fp8_dtype() + fp8_max = float(torch.finfo(fp8).max) + eps = 1e-6 + k_absmax = k_all.to(torch.float32).abs().amax(dim=(0, 2)) + v_absmax = v_all.to(torch.float32).abs().amax(dim=(0, 2)) + k_scale = (k_absmax / fp8_max).clamp_min(eps) + v_scale = (v_absmax / fp8_max).clamp_min(eps) + + store_kvcache_unified_layout( + k_all, v_all, k_cache_u8, v_cache_u8, slot_mapping_ts, + model_type="diffusion_lm", + kv_cache_dtype="fp8_e4m3", + k_scale=k_scale, + v_scale=v_scale, + ) + + k_cache_fp8 = k_cache_u8.view(fp8).to(torch.float32) * k_scale[None, None, :, None] + v_cache_fp8 = v_cache_u8.view(fp8).to(torch.float32) * v_scale[None, None, :, None] + start = 0 + for seq_idx in range(num_seqs): + ctx = int(ctx_lens[seq_idx].item()) + new = int(seq_lens[seq_idx].item()) + k_ctx_ref = k_all[start : start + ctx].to(torch.float32) + v_ctx_ref = v_all[start : start + ctx].to(torch.float32) + k_ctx_got = k_cache_fp8[seq_idx, :ctx] + v_ctx_got = v_cache_fp8[seq_idx, :ctx] + assert torch.allclose(k_ctx_got, k_ctx_ref, atol=1e-1, rtol=1e-1) + assert torch.allclose(v_ctx_got, v_ctx_ref, atol=1e-1, rtol=1e-1) + start += ctx + new + + k_new_list = [] + v_new_list = [] + start = 0 + for seq_idx in range(num_seqs): + ctx = int(ctx_lens[seq_idx].item()) + new = int(seq_lens[seq_idx].item()) + k_new_list.append(k_all[start + ctx : start + ctx + new]) + v_new_list.append(v_all[start + ctx : start + ctx + new]) + start += ctx + new + k_new = torch.cat(k_new_list, dim=0).contiguous() + v_new = torch.cat(v_new_list, dim=0).contiguous() + + block_tables = torch.arange(num_seqs, dtype=torch.int32, device="cuda").view(num_seqs, 1) + cu_seqlens_q = _build_cu_seqlens(seq_lens) + cu_seqlens_k = _build_cu_seqlens(total_lens) + ctx = _Ctx( + seq_lens_ts=seq_lens, + context_lens=ctx_lens, + total_lens=total_lens, + block_tables=block_tables, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + + k_out, v_out = load_kvcache( + k_cache_u8, v_cache_u8, ctx, k_new, v_new, + kv_cache_dtype="fp8_e4m3", + k_scale=k_scale, + v_scale=v_scale, + ) + + out_splits = torch.split(k_out, total_lens.tolist(), dim=0) + new_splits = torch.split(k_new, seq_lens.tolist(), dim=0) + start = 0 + for seq_idx in range(num_seqs): + ctx_len = int(ctx_lens[seq_idx].item()) + new_len = int(seq_lens[seq_idx].item()) + k_ref_ctx = k_all[start : start + ctx_len].to(k_out.dtype) + k_got_ctx = out_splits[seq_idx][:ctx_len] + assert torch.allclose(k_got_ctx, k_ref_ctx, atol=1e-1, rtol=1e-1) + assert torch.equal(out_splits[seq_idx][ctx_len : ctx_len + new_len], new_splits[seq_idx]) + start += ctx_len + new_len + + print("FP8 unified KV cache store/load roundtrip: OK") + + +def test_kv_cache_fp8_distinct_roundtrip(): + """测试 FP8 distinct layout KV cache 的存储""" + torch.random.manual_seed(114514) + + num_seqs = 4 + blk_sz = 256 + H = 4 + head_dim = 128 + x = 8 + + seq_lens = torch.tensor([64, 32, 64, 32], dtype=torch.int32, device="cuda") + ctx_lens = torch.tensor([119, 110, 81, 114], dtype=torch.int32, device="cuda") + total_lens = seq_lens + ctx_lens + + kv_shape = (int(total_lens.sum().item()), H, head_dim) + k_all = torch.randn(kv_shape, device="cuda", dtype=torch.bfloat16) + v_all = torch.randn_like(k_all) + + slot_mapping: list[int] = [] + start = 0 + for seq_idx in range(num_seqs): + ctx = int(ctx_lens[seq_idx].item()) + new = int(seq_lens[seq_idx].item()) + slot_mapping.extend(list(range(seq_idx * blk_sz, seq_idx * blk_sz + ctx))) + slot_mapping.extend([-1] * new) + start += ctx + new + slot_mapping_ts = torch.tensor(slot_mapping, dtype=torch.int64, device="cuda") + + k_cache_u8 = torch.zeros((num_seqs, H, head_dim // x, blk_sz, x), device="cuda", dtype=torch.uint8) + v_cache_u8 = torch.zeros((num_seqs, H, head_dim, blk_sz), device="cuda", dtype=torch.uint8) + + fp8 = current_platform.fp8_dtype() + fp8_max = float(torch.finfo(fp8).max) + eps = 1e-6 + k_absmax = k_all.to(torch.float32).abs().amax(dim=(0, 2)) + v_absmax = v_all.to(torch.float32).abs().amax(dim=(0, 2)) + k_scale = (k_absmax / fp8_max).clamp_min(eps) + v_scale = (v_absmax / fp8_max).clamp_min(eps) + + store_kvcache_distinct_layout( + k_all, v_all, k_cache_u8, v_cache_u8, slot_mapping_ts, + model_type="diffusion_lm", + kv_cache_dtype="fp8_e4m3", + k_scale=k_scale, + v_scale=v_scale, + ) + + k_cache_fp8 = k_cache_u8.view(fp8).to(torch.float32) + v_cache_fp8 = v_cache_u8.view(fp8).to(torch.float32) + k_cache_deq = k_cache_fp8 * k_scale[None, :, None, None, None] + v_cache_deq = v_cache_fp8 * v_scale[None, :, None, None] + k_cache_unified = rearrange(k_cache_deq, "b h n s x -> b s h (n x)").contiguous() + v_cache_unified = rearrange(v_cache_deq, "b h d s -> b s h d").contiguous() + + start = 0 + for seq_idx in range(num_seqs): + ctx = int(ctx_lens[seq_idx].item()) + new = int(seq_lens[seq_idx].item()) + k_ctx_ref = k_all[start : start + ctx].to(torch.float32) + v_ctx_ref = v_all[start : start + ctx].to(torch.float32) + assert torch.allclose(k_cache_unified[seq_idx, :ctx], k_ctx_ref, atol=1e-1, rtol=1e-1) + assert torch.allclose(v_cache_unified[seq_idx, :ctx], v_ctx_ref, atol=1e-1, rtol=1e-1) + start += ctx + new + + print("FP8 distinct KV cache store roundtrip (ctx portion): OK") + + +# ============================================================================ +# 测试函数 - Integration 测试 +# ============================================================================ + +def test_running_max_update(attn: Attention): + """Test running max update in FP8 scale computation.""" + num_heads = 8 + num_kv_heads = 4 + head_dim = 128 + seq_len = 64 + + device = 'cuda' + k1 = torch.randn(seq_len, num_kv_heads, head_dim, device=device) * 0.5 + v1 = torch.randn(seq_len, num_kv_heads, head_dim, device=device) * 0.5 + + kv_cache_dtype = "fp8_e4m3" + k_scale1, v_scale1 = attn._update_and_compute_fp8_scales(k1, v1, kv_cache_dtype, device) + + assert k_scale1 is not None and v_scale1 is not None + assert attn.k_max_abs is not None and attn.v_max_abs is not None + + k2 = torch.randn(seq_len, num_kv_heads, head_dim, device=device) * 1.5 + v2 = torch.randn(seq_len, num_kv_heads, head_dim, device=device) * 1.5 + + k_max_abs_before = attn.k_max_abs.clone() + v_max_abs_before = attn.v_max_abs.clone() + + k_scale2, v_scale2 = attn._update_and_compute_fp8_scales(k2, v2, kv_cache_dtype, device) + + assert torch.all(attn.k_max_abs >= k_max_abs_before) + assert torch.all(attn.v_max_abs >= v_max_abs_before) + + k_scale3, v_scale3 = attn._update_and_compute_fp8_scales(k1, v1, "bf16", device) + assert k_scale3 is None and v_scale3 is None + + k_scale4, v_scale4 = attn._update_and_compute_fp8_scales(k1, v1, "fp8_e5m2", device) + assert attn.kv_cache_dtype_cache == "fp8_e5m2" + + +def test_scale_computation(attn: Attention): + """Test scale computation from running max.""" + device = 'cuda' + seq_len = 64 + num_kv_heads = 4 + head_dim = 128 + k = torch.randn(seq_len, num_kv_heads, head_dim, device=device) + v = torch.randn(seq_len, num_kv_heads, head_dim, device=device) + + kv_cache_dtype = "fp8_e4m3" + k_scale, v_scale = attn._update_and_compute_fp8_scales(k, v, kv_cache_dtype, device) + + assert k_scale.shape == (num_kv_heads,) + assert v_scale.shape == (num_kv_heads,) + assert torch.all(k_scale > 0) + assert torch.all(v_scale > 0) + + k_scale2, v_scale2 = attn._get_fp8_scales_from_max(kv_cache_dtype) + assert k_scale2 is not None and v_scale2 is not None + assert torch.allclose(k_scale, k_scale2) + assert torch.allclose(v_scale, v_scale2) + + k_scale3, v_scale3 = attn._get_fp8_scales_from_max("bf16") + assert k_scale3 is None and v_scale3 is None + + +def test_context_kv_cache_dtype(): + """Test context kv_cache_dtype access.""" + ctx_causal = ContextForCausalLM() + assert ctx_causal.kv_cache_dtype == "bf16" + + set_context_causal_lm(True, kv_cache_dtype="fp8_e4m3") + ctx_causal2 = get_context_causal_lm() + assert ctx_causal2.kv_cache_dtype == "fp8_e4m3" + + from diffulex_legacy.layers.attention.attention_v4 import _get_kv_cache_dtype + + class MockConfig: + kv_cache_dtype = "fp8_e4m3" + + class MockSeq: + def __init__(self): + self.config = MockConfig() + + ctx_diff = ContextForDiffusionLM.__new__(ContextForDiffusionLM) + ctx_diff.seqs = [MockSeq()] + ctx_diff.seq_lens = None + ctx_diff.seq_lens_ts = None + ctx_diff.kv_cache_layout = "unified" + ctx_diff.need_kv_cache_store = True + ctx_diff.d2f_pp = False + ctx_diff.block_mask = None + ctx_diff.is_prefill = False + ctx_diff.cu_seqlens_q = None + ctx_diff.cu_seqlens_k = None + ctx_diff.max_seqlen_q = 0 + ctx_diff.max_seqlen_k = 0 + ctx_diff.slot_mapping = None + ctx_diff.context_lens = None + ctx_diff.block_tables = None + + dtype1 = _get_kv_cache_dtype(ctx_diff, "diffusion_lm") + assert dtype1 == "fp8_e4m3" + + dtype2 = _get_kv_cache_dtype(ctx_causal2, "causal_lm") + assert dtype2 == "fp8_e4m3" + + +# ============================================================================ +# 测试函数 - Pipeline 测试 +# ============================================================================ + +def test_fp8_kv_cache_pipeline(): + """Test FP8 KV cache in a complete inference pipeline.""" + model = "/data1/ckpts/Dream-org/Dream-v0-Base-7B" + + llm = LLM( + model, + lora_path="/data1/ckpts/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora", + use_lora=True, + model_name="dream", + model_type="diffusion_lm", + enforce_eager=True, + data_parallel_size=1, + tensor_parallel_size=1, + gpu_memory_utilization=0.25, + max_num_batched_tokens=2048, + max_num_seqs=20, + max_model_len=2048, + accept_threshold=0.95, + complete_threshold=0.9, + add_new_block_threshold=0.1, + kv_cache_layout="unified", + kv_cache_dtype="fp8_e4m3", + ) + + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + test_prompts = [ + tokenizer.bos_token + "Hello, how are you?", + tokenizer.bos_token + "The capital of France is", + tokenizer.bos_token + "Python is a programming language that", + ] + + sampling_params = SamplingParams(temperature=0.7, max_tokens=50) + outputs = llm.generate(test_prompts, sampling_params) + + for i, (prompt, output) in enumerate(zip(test_prompts, outputs)): + generated_text = output.get("text", "") + token_ids = output.get("token_ids", []) + + if not generated_text.strip(): + raise ValueError(f"Generated text is empty for prompt {i+1}") + if len(token_ids) == 0: + raise ValueError(f"No tokens generated for prompt {i+1}") + + +# ============================================================================ +# 测试函数 - Memory 测试 +# ============================================================================ + +def test_kv_cache_memory(kv_cache_dtype="bf16"): + """Test KV cache memory usage with specified dtype.""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + gc.collect() + + model = "/data1/ckpts/Dream-org/Dream-v0-Base-7B" + + llm = LLM( + model, + lora_path="/data1/ckpts/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora", + use_lora=True, + model_name="dream", + model_type="diffusion_lm", + enforce_eager=True, + data_parallel_size=1, + tensor_parallel_size=1, + gpu_memory_utilization=0.25, + max_num_batched_tokens=2048, + max_num_seqs=20, + max_model_len=2048, + accept_threshold=0.95, + complete_threshold=0.9, + add_new_block_threshold=0.1, + kv_cache_layout="unified", + kv_cache_dtype=kv_cache_dtype, + ) + + model_runner = llm.model_runner + if hasattr(model_runner, 'kv_cache') and model_runner.kv_cache is not None: + kv_cache = model_runner.kv_cache + kv_cache_size_mb = kv_cache.element_size() * kv_cache.numel() / 1024**2 + + config = model_runner.config + if hasattr(config, 'num_kvcache_blocks') and config.num_kvcache_blocks > 0: + hf_config = config.hf_config + num_layers = hf_config.num_hidden_layers + block_size = config.kvcache_block_size + num_blocks = config.num_kvcache_blocks + + if hasattr(hf_config, 'head_dim'): + head_dim = hf_config.head_dim + elif hasattr(hf_config, 'hidden_size') and hasattr(hf_config, 'num_attention_heads'): + head_dim = hf_config.hidden_size // hf_config.num_attention_heads + else: + head_dim = 128 + + num_kv_heads = getattr(hf_config, 'num_key_value_heads', getattr(hf_config, 'num_attention_heads', 32)) + + spec = parse_kv_cache_dtype(kv_cache_dtype) + itemsize = 1 if spec.is_fp8 else (2 if kv_cache_dtype in ['bf16', 'fp16'] else 4) + + elements_per_block = 2 * num_layers * block_size * num_kv_heads * head_dim + size_per_block_mb = elements_per_block * itemsize / 1024**2 + + print(f" num_blocks: {num_blocks}") + print(f" Size per block: {size_per_block_mb:.2f} MB") + print(f" Total size: {kv_cache_size_mb:.2f} MB") + + +# ============================================================================ +# 测试函数 - Speed 测试 +# ============================================================================ + +def test_kv_cache_speed(kv_cache_dtype="bf16", num_prompts=3): + """Test generation speed with specified KV cache dtype.""" + model = "/data1/ckpts/Dream-org/Dream-v0-Base-7B" + + llm = LLM( + model, + lora_path="/data1/ckpts/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora", + use_lora=True, + model_name="dream", + model_type="diffusion_lm", + enforce_eager=True, + data_parallel_size=1, + tensor_parallel_size=1, + gpu_memory_utilization=0.25, + max_num_batched_tokens=2048, + max_num_seqs=20, + max_model_len=2048, + accept_threshold=0.95, + complete_threshold=0.9, + add_new_block_threshold=0.1, + kv_cache_layout="unified", + kv_cache_dtype=kv_cache_dtype, + ) + + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + test_prompts = [ + tokenizer.bos_token + "Hello, how are you?", + tokenizer.bos_token + "The capital of France is", + tokenizer.bos_token + "Python is a programming language that", + ][:num_prompts] + + sampling_params = SamplingParams(temperature=0.7, max_tokens=50) + + start_gen = time.time() + outputs = llm.generate(test_prompts, sampling_params) + gen_time = time.time() - start_gen + + total_tokens = sum(len(o.get("token_ids", [])) for o in outputs) + throughput = total_tokens / gen_time + + print(f" - Generation time: {gen_time:.2f}s") + print(f" - Total tokens: {total_tokens}") + print(f" - Throughput: {throughput:.2f} tok/s") + + return { + "kv_cache_dtype": kv_cache_dtype, + "gen_time": gen_time, + "total_tokens": total_tokens, + "throughput": throughput, + } + + +# ============================================================================ +# 测试函数 - Quality 测试 +# ============================================================================ + +TEST_PROMPTS = [ + "The capital of France is", + "In a world where technology", + "The importance of education", + "Climate change is one of", + "Artificial intelligence has the potential", +] + +def run_inference(llm: LLM, prompts: List[str], sampling_params: SamplingParams, num_runs: int = 3) -> Dict: + """运行推理并收集性能和质量指标""" + results = { + 'total_time': 0.0, + 'total_tokens': 0, + 'outputs': [], + } + + for run in range(num_runs): + start_time = time.time() + outputs = llm.generate(prompts, sampling_params) + elapsed_time = time.time() - start_time + + total_tokens = sum(len(output.get("token_ids", [])) for output in outputs) + + results['total_time'] += elapsed_time + results['total_tokens'] += total_tokens + results['outputs'].append(outputs) + + results['avg_time'] = results['total_time'] / num_runs + results['avg_tokens'] = results['total_tokens'] / num_runs + results['avg_throughput'] = results['avg_tokens'] / results['avg_time'] + + return results + +def compare_outputs(bf16_outputs: List, fp8_outputs: List, prompts: List[str]) -> Dict: + """比较两种配置的输出""" + comparison = { + 'text_similarity': [], + 'texts_bf16': [], + 'texts_fp8': [], + } + + for bf16_out, fp8_out, prompt in zip(bf16_outputs, fp8_outputs, prompts): + bf16_text = bf16_out.get("text", "") + fp8_text = fp8_out.get("text", "") + + comparison['texts_bf16'].append(bf16_text) + comparison['texts_fp8'].append(fp8_text) + + if bf16_text and fp8_text: + min_len = min(len(bf16_text), len(fp8_text)) + if min_len > 0: + matches = sum(1 for a, b in zip(bf16_text[:min_len], fp8_text[:min_len]) if a == b) + similarity = matches / min_len + comparison['text_similarity'].append(similarity) + else: + comparison['text_similarity'].append(0.0) + else: + comparison['text_similarity'].append(0.0) + + comparison['avg_similarity'] = np.mean(comparison['text_similarity']) if comparison['text_similarity'] else 0.0 + + return comparison + + +# ============================================================================ +# 测试函数 - Attention Kernel 测试 +# ============================================================================ + +def test_q_scale_computation(attn: Attention): + """Test Q scale computation and running max update.""" + device = 'cuda' + seq_len = 64 + num_heads = 8 + head_dim = 128 + kv_cache_dtype = "fp8_e4m3" + + q1 = torch.randn(seq_len, num_heads, head_dim, device=device) * 0.5 + + q_scale1 = attn._update_and_compute_q_fp8_scale(q1, kv_cache_dtype, device) + + assert q_scale1 is not None + assert attn.q_max_abs is not None + assert q_scale1.shape == (num_heads,) + assert torch.all(q_scale1 > 0) + + q2 = torch.randn(seq_len, num_heads, head_dim, device=device) * 1.5 + q_max_abs_before = attn.q_max_abs.clone() + + q_scale2 = attn._update_and_compute_q_fp8_scale(q2, kv_cache_dtype, device) + + assert torch.all(attn.q_max_abs >= q_max_abs_before) + + q_scale3 = attn._get_q_fp8_scale_from_max(kv_cache_dtype) + assert q_scale3 is not None + assert torch.allclose(q_scale2, q_scale3) + + q_scale4 = attn._update_and_compute_q_fp8_scale(q1, "bf16", device) + assert q_scale4 is None + + q_scale5 = attn._update_and_compute_q_fp8_scale(q1, "fp8_e5m2", device) + assert attn.kv_cache_dtype_cache == "fp8_e5m2" + + +def test_q_kv_scale_consistency(attn: Attention): + """Test that Q, K, V scales are computed consistently.""" + device = 'cuda' + seq_len = 64 + num_heads = 8 + num_kv_heads = 4 + head_dim = 128 + kv_cache_dtype = "fp8_e4m3" + + scale_factor = 1.0 + q = torch.randn(seq_len, num_heads, head_dim, device=device) * scale_factor + k = torch.randn(seq_len, num_kv_heads, head_dim, device=device) * scale_factor + v = torch.randn(seq_len, num_kv_heads, head_dim, device=device) * scale_factor + + q_scale = attn._update_and_compute_q_fp8_scale(q, kv_cache_dtype, device) + k_scale, v_scale = attn._update_and_compute_fp8_scales(k, v, kv_cache_dtype, device) + + assert q_scale is not None + assert k_scale is not None and v_scale is not None + + assert q_scale.shape == (num_heads,) + assert k_scale.shape == (num_kv_heads,) + assert v_scale.shape == (num_kv_heads,) + + assert torch.all(q_scale > 0) + assert torch.all(k_scale > 0) + assert torch.all(v_scale > 0) + + q_scale_retrieved = attn._get_q_fp8_scale_from_max(kv_cache_dtype) + k_scale_retrieved, v_scale_retrieved = attn._get_fp8_scales_from_max(kv_cache_dtype) + + assert torch.allclose(q_scale, q_scale_retrieved) + assert torch.allclose(k_scale, k_scale_retrieved) + assert torch.allclose(v_scale, v_scale_retrieved) + + +def test_fp8_attention_kernel_integration(attn: Attention): + """Test FP8 attention kernel integration in decode path.""" + device = 'cuda' + seq_len = 32 + num_heads = 8 + num_kv_heads = 4 + head_dim = 128 + kv_cache_dtype = "fp8_e4m3" + + q = torch.randn(seq_len, num_heads, head_dim, device=device, dtype=torch.bfloat16) + k = torch.randn(seq_len, num_kv_heads, head_dim, device=device, dtype=torch.bfloat16) + v = torch.randn(seq_len, num_kv_heads, head_dim, device=device, dtype=torch.bfloat16) + + q_scale = attn._update_and_compute_q_fp8_scale(q, kv_cache_dtype, device) + assert q_scale is not None + assert q_scale.shape == (num_heads,) + + k_scale, v_scale = attn._update_and_compute_fp8_scales(k, v, kv_cache_dtype, device) + assert k_scale is not None and v_scale is not None + + q_scale_retrieved = attn._get_q_fp8_scale_from_max(kv_cache_dtype) + k_scale_retrieved, v_scale_retrieved = attn._get_fp8_scales_from_max(kv_cache_dtype) + + assert q_scale_retrieved is not None + assert k_scale_retrieved is not None and v_scale_retrieved is not None + + +def test_fp8_attention_pipeline(): + """Test FP8 attention kernel in full pipeline.""" + model = "/data1/ckpts/Dream-org/Dream-v0-Base-7B" + + llm = LLM( + model, + lora_path="/data1/ckpts/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora", + use_lora=True, + model_name="dream", + model_type="diffusion_lm", + enforce_eager=True, + data_parallel_size=1, + tensor_parallel_size=1, + gpu_memory_utilization=0.25, + max_num_batched_tokens=2048, + max_num_seqs=20, + max_model_len=2048, + accept_threshold=0.95, + complete_threshold=0.9, + add_new_block_threshold=0.1, + kv_cache_layout="unified", + kv_cache_dtype="fp8_e4m3", + ) + + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + test_prompts = [ + tokenizer.bos_token + "Hello, how are you?", + tokenizer.bos_token + "The capital of France is", + ] + + sampling_params = SamplingParams(temperature=0.7, max_tokens=30) + outputs = llm.generate(test_prompts, sampling_params) + + for i, (prompt, output) in enumerate(zip(test_prompts, outputs)): + generated_text = output.get("text", "") + token_ids = output.get("token_ids", []) + + if not generated_text.strip(): + raise ValueError(f"Generated text is empty for prompt {i+1}") + if len(token_ids) == 0: + raise ValueError(f"No tokens generated for prompt {i+1}") + + +@torch.no_grad() +def test_fp8_attention_kernel_numerics(): + """数值对齐测试:验证 triton_flash_attention 的 FP8 路径在 scale 约定下是否正确。""" + torch.manual_seed(0) + device = "cuda" + + def run_case(seqlen_q: int, seqlen_k: int, nheads_q: int, nheads_kv: int, head_dim: int): + assert head_dim in (32, 64, 128, 256) + assert nheads_q % nheads_kv == 0 + group = nheads_q // nheads_kv + + q = (torch.randn(seqlen_q, nheads_q, head_dim, device=device, dtype=torch.bfloat16) * 0.2).contiguous() + k = (torch.randn(seqlen_k, nheads_kv, head_dim, device=device, dtype=torch.bfloat16) * 0.2).contiguous() + v = (torch.randn(seqlen_k, nheads_kv, head_dim, device=device, dtype=torch.bfloat16) * 0.2).contiguous() + + spec = parse_kv_cache_dtype("fp8_e4m3") + fp8_dtype = spec.fp8_view_dtype + assert fp8_dtype is not None and spec.fp8_max is not None and spec.fp8_min is not None + + fp8_max = float(spec.fp8_max) + eps = 1e-8 + q_max = q.float().abs().amax(dim=(0, 2)) + k_max = k.float().abs().amax(dim=(0, 2)) + v_max = v.float().abs().amax(dim=(0, 2)) + q_scale = (q_max / fp8_max).clamp_min(eps).float() + k_scale = (k_max / fp8_max).clamp_min(eps).float() + v_scale = (v_max / fp8_max).clamp_min(eps).float() + p_scale = torch.ones(1, device=device, dtype=torch.float32) + + cu_seqlens_q = torch.tensor([0, seqlen_q], device=device, dtype=torch.int32) + cu_seqlens_k = torch.tensor([0, seqlen_k], device=device, dtype=torch.int32) + + o = torch.empty_like(q) + out = triton_flash_attention( + q, k, v, o, + cu_seqlens_q, cu_seqlens_k, + seqlen_q, seqlen_k, + causal=False, + softmax_scale=(head_dim ** -0.5), + bias=None, + fp8_scales=(q_scale, k_scale, v_scale, p_scale), + fp8_out_scale=None, + block_table=None, + ) + + def quantize_to_fp8(t_bf16: torch.Tensor, scale: torch.Tensor, heads: int) -> torch.Tensor: + descale = (1.0 / scale).view(1, heads, 1) + t_q = (t_bf16.float() * descale).clamp(min=float(spec.fp8_min), max=float(spec.fp8_max)) + return t_q.to(fp8_dtype).float() + + q_q = quantize_to_fp8(q, q_scale, nheads_q) + k_q = quantize_to_fp8(k, k_scale, nheads_kv) + v_q = quantize_to_fp8(v, v_scale, nheads_kv) + + sm_scale = head_dim ** -0.5 + kv_for_q = torch.arange(nheads_q, device=device) // group + k_q_mapped = k_q[:, kv_for_q, :] + v_q_mapped = v_q[:, kv_for_q, :] + k_scale_mapped = k_scale[kv_for_q] + v_scale_mapped = v_scale[kv_for_q] + + scores = torch.einsum("qhd,khd->hqk", q_q, k_q_mapped) + restore = (q_scale * k_scale_mapped) * sm_scale + scores = scores * restore.view(-1, 1, 1) + p = torch.softmax(scores, dim=-1) + out_ref = torch.einsum("hqk,khd->qhd", p, v_q_mapped) + out_ref = out_ref * v_scale_mapped.view(1, -1, 1) + + out_f = out.float() + diff = (out_f - out_ref).abs() + rel = diff / (out_ref.abs() + 1e-6) + print("=" * 80) + print(f"FP8 attention kernel numerics check (Q={seqlen_q}, K={seqlen_k}, Hq={nheads_q}, Hkv={nheads_kv}, D={head_dim})") + print(f"abs diff: mean={diff.mean().item():.6f} max={diff.max().item():.6f}") + print(f"rel diff: mean={rel.mean().item():.6f} max={rel.max().item():.6f}") + + run_case(seqlen_q=32, seqlen_k=32, nheads_q=4, nheads_kv=4, head_dim=64) + run_case(seqlen_q=32, seqlen_k=64, nheads_q=32, nheads_kv=4, head_dim=64) + + +# ============================================================================ +# 测试运行函数 +# ============================================================================ + +def run_kernel_tests() -> Dict: + """运行 Kernel 层 FP8 支持测试""" + print("\n" + "=" * 80) + print("测试类别 1: Kernel 层 FP8 支持测试") + print("=" * 80) + + results = { + 'unified_roundtrip': False, + 'distinct_roundtrip': False, + } + + print("\n[1.1] Unified Layout Roundtrip Test") + print("-" * 80) + try: + test_kv_cache_fp8_unified_roundtrip() + print("✅ Unified layout roundtrip test PASSED") + results['unified_roundtrip'] = True + except Exception as e: + print(f"❌ Unified layout roundtrip test FAILED: {e}") + traceback.print_exc() + + print("\n[1.2] Distinct Layout Roundtrip Test") + print("-" * 80) + try: + test_kv_cache_fp8_distinct_roundtrip() + print("✅ Distinct layout roundtrip test PASSED") + results['distinct_roundtrip'] = True + except Exception as e: + print(f"❌ Distinct layout roundtrip test FAILED: {e}") + traceback.print_exc() + + return results + +def run_integration_tests() -> Dict: + """运行 Attention layer FP8 集成测试""" + print("\n" + "=" * 80) + print("测试类别 2: Attention Layer FP8 集成测试") + print("=" * 80) + + results = {'integration': False} + + print("\n[2.1] Attention Layer FP8 Integration Test") + print("-" * 80) + try: + num_heads = 8 + num_kv_heads = 4 + head_dim = 128 + + attn = Attention( + num_heads=num_heads, + head_dim=head_dim, + scale=1.0 / (head_dim ** 0.5), + num_kv_heads=num_kv_heads, + model_type='diffusion_lm' + ) + + test_context_kv_cache_dtype() + test_running_max_update(attn) + test_scale_computation(attn) + + print("✅ Attention layer FP8 integration test PASSED") + results['integration'] = True + except Exception as e: + print(f"❌ Attention layer FP8 integration test FAILED: {e}") + traceback.print_exc() + + return results + +def run_pipeline_tests() -> Dict: + """运行完整 Pipeline 测试""" + print("\n" + "=" * 80) + print("测试类别 3: 完整 Pipeline 测试") + print("=" * 80) + + results = {'pipeline': False} + + print("\n[3.1] FP8 KV Cache Pipeline Test") + print("-" * 80) + print("注意:此测试需要模型 checkpoint,可能需要较长时间...") + try: + test_fp8_kv_cache_pipeline() + print("✅ FP8 KV cache pipeline test PASSED") + results['pipeline'] = True + except Exception as e: + print(f"❌ FP8 KV cache pipeline test FAILED: {e}") + traceback.print_exc() + + return results + +def run_memory_tests() -> Dict: + """运行内存使用验证测试""" + print("\n" + "=" * 80) + print("测试类别 4: 内存使用验证测试") + print("=" * 80) + + results = { + 'memory_bf16': False, + 'memory_fp8': False, + } + + print("\n[4.1] BF16 Memory Usage Test") + print("-" * 80) + try: + test_kv_cache_memory("bf16") + print("✅ BF16 memory usage test PASSED") + results['memory_bf16'] = True + except Exception as e: + print(f"❌ BF16 memory usage test FAILED: {e}") + traceback.print_exc() + + print("\n[4.2] FP8 Memory Usage Test") + print("-" * 80) + try: + test_kv_cache_memory("fp8_e4m3") + print("✅ FP8 memory usage test PASSED") + results['memory_fp8'] = True + except Exception as e: + print(f"❌ FP8 memory usage test FAILED: {e}") + traceback.print_exc() + + return results + +def run_speed_tests() -> Dict: + """运行速度对比测试""" + print("\n" + "=" * 80) + print("测试类别 5: 速度对比测试") + print("=" * 80) + + results = { + 'speed_bf16': False, + 'speed_fp8': False, + } + + print("\n[5.1] BF16 Speed Test") + print("-" * 80) + try: + test_kv_cache_speed("bf16", num_prompts=3) + print("✅ BF16 speed test PASSED") + results['speed_bf16'] = True + except Exception as e: + print(f"❌ BF16 speed test FAILED: {e}") + traceback.print_exc() + + print("\n[5.2] FP8 Speed Test") + print("-" * 80) + try: + test_kv_cache_speed("fp8_e4m3", num_prompts=3) + print("✅ FP8 speed test PASSED") + results['speed_fp8'] = True + except Exception as e: + print(f"❌ FP8 speed test FAILED: {e}") + traceback.print_exc() + + return results + +def run_quality_tests() -> Dict: + """运行质量和速度对比测试""" + print("\n" + "=" * 80) + print("测试类别 6: 质量和速度对比测试") + print("=" * 80) + print("注意:此测试需要较长时间(可能需要 10-20 分钟)...") + + results = {'quality': False} + + print("\n[6.1] FP8 vs BF16 Quality and Speed Comparison") + print("-" * 80) + try: + import torch.distributed as dist + + model_path = "/data1/ckpts/Dream-org/Dream-v0-Base-7B" + lora_path = "/data1/ckpts/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora" + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + test_prompts = [tokenizer.bos_token + prompt for prompt in TEST_PROMPTS] + + sampling_params = SamplingParams(temperature=0.7, max_tokens=50) + num_runs = 3 + + llm_bf16 = LLM( + model_path, + lora_path=lora_path, + use_lora=True, + model_name="dream", + model_type="diffusion_lm", + enforce_eager=True, + data_parallel_size=1, + tensor_parallel_size=1, + gpu_memory_utilization=0.4, + max_num_batched_tokens=2048, + max_num_seqs=20, + max_model_len=2048, + accept_threshold=0.95, + complete_threshold=0.9, + add_new_block_threshold=0.1, + kv_cache_layout="unified", + kv_cache_dtype="bf16", + ) + + bf16_results = run_inference(llm_bf16, test_prompts, sampling_params, num_runs) + print(f"\n[BF16 结果汇总]") + print(f" 平均吞吐量: {bf16_results['avg_throughput']:.2f} tok/s") + + del llm_bf16 + torch.cuda.empty_cache() + if dist.is_initialized(): + dist.destroy_process_group() + + llm_fp8 = LLM( + model_path, + lora_path=lora_path, + use_lora=True, + model_name="dream", + model_type="diffusion_lm", + enforce_eager=True, + data_parallel_size=1, + tensor_parallel_size=1, + gpu_memory_utilization=0.4, + max_num_batched_tokens=2048, + max_num_seqs=20, + max_model_len=2048, + accept_threshold=0.95, + complete_threshold=0.9, + add_new_block_threshold=0.1, + kv_cache_layout="unified", + kv_cache_dtype="fp8_e4m3", + ) + + fp8_results = run_inference(llm_fp8, test_prompts, sampling_params, num_runs) + print(f"\n[FP8 结果汇总]") + print(f" 平均吞吐量: {fp8_results['avg_throughput']:.2f} tok/s") + + speedup = fp8_results['avg_throughput'] / bf16_results['avg_throughput'] + print(f"\n 速度比: {speedup:.2f}x") + + bf16_outputs_last = bf16_results['outputs'][-1] + fp8_outputs_last = fp8_results['outputs'][-1] + + comparison = compare_outputs(bf16_outputs_last, fp8_outputs_last, test_prompts) + print(f"\n平均文本相似度: {comparison['avg_similarity']:.4f}") + + del llm_fp8 + torch.cuda.empty_cache() + if dist.is_initialized(): + dist.destroy_process_group() + + print("✅ Quality and speed comparison test PASSED") + results['quality'] = True + except Exception as e: + print(f"❌ Quality and speed comparison test FAILED: {e}") + traceback.print_exc() + + return results + +def run_attention_kernel_tests() -> Dict: + """运行 FP8 Attention Kernel 单元测试""" + print("\n" + "=" * 80) + print("测试类别 7: FP8 Attention Kernel 单元测试") + print("=" * 80) + + results = {'attention_kernel': False} + + print("\n[7.1] FP8 Attention Kernel Unit Test") + print("-" * 80) + try: + num_heads = 8 + num_kv_heads = 4 + head_dim = 128 + + attn = Attention( + num_heads=num_heads, + head_dim=head_dim, + scale=1.0 / (head_dim ** 0.5), + num_kv_heads=num_kv_heads, + model_type='diffusion_lm' + ) + + test_q_scale_computation(attn) + test_q_kv_scale_consistency(attn) + test_fp8_attention_kernel_integration(attn) + + print("✅ FP8 attention kernel unit test PASSED") + results['attention_kernel'] = True + except Exception as e: + print(f"❌ FP8 attention kernel unit test FAILED: {e}") + traceback.print_exc() + + return results + +def run_attention_e2e_tests() -> Dict: + """运行 FP8 Attention Kernel 端到端测试""" + print("\n" + "=" * 80) + print("测试类别 8: FP8 Attention Kernel 端到端测试") + print("=" * 80) + + results = {'attention_e2e': False} + + print("\n[8.1] FP8 Attention Kernel End-to-End Test") + print("-" * 80) + print("注意:此测试需要模型 checkpoint,可能需要较长时间...") + try: + test_fp8_attention_pipeline() + print("✅ FP8 attention kernel end-to-end test PASSED") + results['attention_e2e'] = True + except Exception as e: + print(f"❌ FP8 attention kernel end-to-end test FAILED: {e}") + traceback.print_exc() + + return results + +def run_attention_numerics_tests() -> Dict: + """运行 FP8 Attention Kernel 数值验证测试""" + print("\n" + "=" * 80) + print("测试类别 9: FP8 Attention Kernel 数值验证测试") + print("=" * 80) + + results = {'attention_numerics': False} + + print("\n[9.1] FP8 Attention Kernel Numerics Test") + print("-" * 80) + try: + test_fp8_attention_kernel_numerics() + print("✅ FP8 attention kernel numerics test PASSED") + results['attention_numerics'] = True + except Exception as e: + print(f"❌ FP8 attention kernel numerics test FAILED: {e}") + traceback.print_exc() + + return results + +def print_summary(all_results: Dict): + """打印测试结果摘要""" + print("\n" + "=" * 80) + print("测试结果摘要") + print("=" * 80) + + total_tests = 0 + passed_tests = 0 + + for category, results in all_results.items(): + print(f"\n{category.upper()}:") + for test_name, passed in results.items(): + total_tests += 1 + status = "✅ PASSED" if passed else "❌ FAILED" + print(f" {test_name}: {status}") + if passed: + passed_tests += 1 + + print(f"\n总计: {passed_tests}/{total_tests} 测试通过 ({passed_tests/total_tests*100:.1f}%)") + + if passed_tests == total_tests: + print("\n🎉 所有测试通过!") + return 0 + else: + print(f"\n⚠️ 有 {total_tests - passed_tests} 个测试失败") + return 1 + +def main(): + parser = argparse.ArgumentParser( + description='FP8 KV Cache 综合测试脚本', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +测试类别说明: + kernel - Kernel 层 FP8 支持测试(roundtrip tests) + integration - Attention layer FP8 集成测试 + pipeline - 完整 pipeline 测试(需要模型) + memory - 内存使用验证测试(需要模型) + speed - 速度对比测试(需要模型) + quality - 质量和速度对比测试(需要模型,较耗时) + attention_kernel - FP8 attention kernel 单元测试 + attention_e2e - FP8 attention kernel 端到端测试(需要模型) + attention_numerics - FP8 attention kernel 数值验证测试 + all - 运行所有测试(除了 quality,因为需要较长时间) + +示例: + # 运行所有测试(除了 quality) + python test_fp8_kv_cache_comprehensive.py --tests all + + # 运行特定测试 + python test_fp8_kv_cache_comprehensive.py --tests kernel integration + + # 运行 speed 和 quality 测试 + python test_fp8_kv_cache_comprehensive.py --tests speed quality + """ + ) + parser.add_argument( + '--tests', + nargs='+', + default=['all'], + choices=['kernel', 'integration', 'pipeline', 'memory', 'speed', + 'quality', 'attention_kernel', 'attention_e2e', 'attention_numerics', 'all'], + help='要运行的测试类别(默认: all)' + ) + + args = parser.parse_args() + + if 'all' in args.tests: + test_categories = ['kernel', 'integration', 'pipeline', 'memory', 'speed', + 'attention_kernel', 'attention_e2e', 'attention_numerics'] + else: + test_categories = args.tests + + print("=" * 80) + print("FP8 KV Cache 综合测试") + print("=" * 80) + print(f"测试类别: {', '.join(test_categories)}") + print(f"工作目录: {PROJECT_ROOT}") + + all_results = {} + + if 'kernel' in test_categories: + all_results['kernel'] = run_kernel_tests() + + if 'integration' in test_categories: + all_results['integration'] = run_integration_tests() + + if 'pipeline' in test_categories: + all_results['pipeline'] = run_pipeline_tests() + + if 'memory' in test_categories: + all_results['memory'] = run_memory_tests() + + if 'speed' in test_categories: + all_results['speed'] = run_speed_tests() + + if 'quality' in test_categories: + all_results['quality'] = run_quality_tests() + + if 'attention_kernel' in test_categories: + all_results['attention_kernel'] = run_attention_kernel_tests() + + if 'attention_e2e' in test_categories: + all_results['attention_e2e'] = run_attention_e2e_tests() + + if 'attention_numerics' in test_categories: + all_results['attention_numerics'] = run_attention_numerics_tests() + + exit_code = print_summary(all_results) + sys.exit(exit_code) + +if __name__ == '__main__': + main() diff --git a/examples/test_fp8_kv_cache_distinct.py b/examples/test_fp8_kv_cache_distinct.py new file mode 100755 index 0000000..9bd6096 --- /dev/null +++ b/examples/test_fp8_kv_cache_distinct.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +"""测试FP8 KV cache distinct layout:使用FP8存储,在Python层显式反量化为BF16,然后用原来的BF16算子""" +import os +import time + +from transformers import AutoTokenizer +from diffulex import Diffulex, SamplingParams + + +def main(): + # 模型配置 + model = "/data1/ckpts/Dream-org/Dream-v0-Base-7B" + + print("=" * 60) + print("初始化 Diffulex 模型 (FP8 KV Cache, Distinct Layout)...") + print("=" * 60) + + llm = Diffulex( + model, + lora_path="/data1/ckpts/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora", + use_lora=True, + model_name="dream", + enforce_eager=True, + data_parallel_size=1, + tensor_parallel_size=1, + gpu_memory_utilization=0.25, + max_num_batched_tokens=2048, + max_num_seqs=10, + max_model_len=2048, + accept_threshold=0.95, + complete_threshold=0.9, + add_new_block_threshold=0.1, + kv_cache_layout="distinct", # 使用 distinct layout + kv_cache_dtype="fp8_e4m3", # 使用FP8 KV cache + decoding_strategy="d2f" + ) + + print("✓ 模型初始化完成 (FP8 KV Cache, Distinct Layout)\n") + + # 加载 tokenizer + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + sampling_params = SamplingParams(temperature=0.0, max_tokens=128) + + # 定义几个测试 prompt + test_prompts = [ + "The capital of France is", + "1 + 1 equals", + "Python is a programming language that", + ] + + # 添加 BOS token + prompts = [tokenizer.bos_token + p for p in test_prompts] + + print("=" * 60) + print(f"运行生成测试 ({len(prompts)} 个 prompt)...") + print("使用FP8 KV cache (distinct layout),Python层反量化为BF16后使用BF16算子") + print("=" * 60) + + start_time = time.time() + outputs = llm.generate(prompts, sampling_params) + end_time = time.time() + + print("\n" + "=" * 60) + print("生成结果:") + print("=" * 60) + + total_tokens = sum(len(o['token_ids']) for o in outputs) + total_time = end_time - start_time + avg_tps = total_tokens / total_time if total_time > 0 else 0 + avg_diff_steps = sum(o['n_diff_steps'] for o in outputs) / len(outputs) if outputs else 0 + + print(f"\n总计:") + print(f" - 生成输出数: {len(outputs)}") + print(f" - 总 token 数: {total_tokens}") + print(f" - 总时间: {total_time:.2f} 秒") + print(f" - 平均 TPS: {avg_tps:.2f} tok/s") + print(f" - 平均扩散步数: {avg_diff_steps:.2f}") + + print("\n" + "=" * 60) + print("详细输出:") + print("=" * 60) + + for idx, (prompt, output) in enumerate(zip(test_prompts, outputs)): + print(f"\n[Prompt {idx + 1}]") + print(f"输入: {prompt}") + print(f"输出: {output['text']}") + print(f"Token IDs 长度: {len(output['token_ids'])}") + print(f"扩散步数: {output['n_diff_steps']}") + print("-" * 60) + + print("\n✓ FP8 KV Cache (Distinct Layout, Python层反量化) 端到端测试完成!") + + +if __name__ == "__main__": + main() + diff --git a/examples/test_fp8_kv_cache_python_dequant.py b/examples/test_fp8_kv_cache_python_dequant.py new file mode 100755 index 0000000..556d97a --- /dev/null +++ b/examples/test_fp8_kv_cache_python_dequant.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +"""测试FP8 KV cache:使用FP8存储,在Python层显式反量化为BF16,然后用原来的BF16算子""" +import os +import time + +from transformers import AutoTokenizer +from diffulex import Diffulex, SamplingParams + + +def main(): + # 模型配置 + model = "/data1/ckpts/Dream-org/Dream-v0-Base-7B" + + print("=" * 60) + print("初始化 Diffulex 模型 (FP8 KV Cache)...") + print("=" * 60) + + llm = Diffulex( + model, + lora_path="/data1/ckpts/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora", + use_lora=True, + model_name="dream", + enforce_eager=True, + data_parallel_size=1, + tensor_parallel_size=1, + gpu_memory_utilization=0.25, + max_num_batched_tokens=2048, + max_num_seqs=10, + max_model_len=2048, + accept_threshold=0.95, + complete_threshold=0.9, + add_new_block_threshold=0.1, + kv_cache_layout="unified", + kv_cache_dtype="fp8_e4m3", # 使用FP8 KV cache + decoding_strategy="d2f" + ) + + print("✓ 模型初始化完成 (FP8 KV Cache)\n") + + # 加载 tokenizer + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + sampling_params = SamplingParams(temperature=0.0, max_tokens=128) + + # 定义几个测试 prompt + test_prompts = [ + "The capital of France is", + "1 + 1 equals", + "Python is a programming language that", + ] + + # 添加 BOS token + prompts = [tokenizer.bos_token + p for p in test_prompts] + + print("=" * 60) + print(f"运行生成测试 ({len(prompts)} 个 prompt)...") + print("使用FP8 KV cache,Python层反量化为BF16后使用BF16算子") + print("=" * 60) + + start_time = time.time() + outputs = llm.generate(prompts, sampling_params) + end_time = time.time() + + print("\n" + "=" * 60) + print("生成结果:") + print("=" * 60) + + total_tokens = sum(len(o['token_ids']) for o in outputs) + total_time = end_time - start_time + avg_tps = total_tokens / total_time if total_time > 0 else 0 + avg_diff_steps = sum(o['n_diff_steps'] for o in outputs) / len(outputs) if outputs else 0 + + print(f"\n总计:") + print(f" - 生成输出数: {len(outputs)}") + print(f" - 总 token 数: {total_tokens}") + print(f" - 总时间: {total_time:.2f} 秒") + print(f" - 平均 TPS: {avg_tps:.2f} tok/s") + print(f" - 平均扩散步数: {avg_diff_steps:.2f}") + + print("\n" + "=" * 60) + print("详细输出:") + print("=" * 60) + + for idx, (prompt, output) in enumerate(zip(test_prompts, outputs)): + print(f"\n[Prompt {idx + 1}]") + print(f"输入: {prompt}") + print(f"输出: {output['text']}") + print(f"Token IDs 长度: {len(output['token_ids'])}") + print(f"扩散步数: {output['n_diff_steps']}") + print("-" * 60) + + print("\n✓ FP8 KV Cache (Python层反量化) 端到端测试完成!") + + +if __name__ == "__main__": + main() + diff --git a/examples/test_fp8_linear.py b/examples/test_fp8_linear.py new file mode 100644 index 0000000..bbafa1b --- /dev/null +++ b/examples/test_fp8_linear.py @@ -0,0 +1,174 @@ +""" +End-to-end test for FP8 Linear quantization. + +This script tests FP8 Linear strategies in a complete inference pipeline. +Note: This is a basic smoke test. For full model inference, see the main +test scripts in the examples directory. +""" + +import torch +import torch.nn.functional as F + +from diffulex.utils.quantization.registry import create_linear_strategy +from diffulex.utils.quantization.context import get_quantization_context + + +def test_fp8_w8a16_e2e(): + """End-to-end test for FP8 W8A16 strategy.""" + print("Testing FP8 W8A16 (e4m3) strategy...") + + # Create strategy + strategy = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="bf16") + ctx = get_quantization_context() + ctx.set_linear_strategy("attn", strategy) + + # Simulate a small attention projection + M, K, N = 32, 512, 256 # batch_size=32, hidden_size=512, num_heads*head_dim=256 + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda" if torch.cuda.is_available() else "cpu") + weight = torch.randn(N, K, dtype=torch.bfloat16, device=x.device) + bias = torch.randn(N, dtype=torch.bfloat16, device=x.device) + + # Reference output + ref_out = F.linear(x, weight, bias) + + # FP8 quantized output + fp8_out = strategy.linear_forward(x, weight, bias, quant_kind="attn") + + # Check output + assert fp8_out.shape == ref_out.shape + assert fp8_out.dtype == torch.bfloat16 + + # Compute error metrics + max_error = torch.abs(fp8_out - ref_out).max().item() + mean_error = torch.abs(fp8_out - ref_out).mean().item() + relative_error = (torch.abs(fp8_out - ref_out) / (ref_out.abs() + 1e-8)).mean().item() + + print(f" Max error: {max_error:.6f}") + print(f" Mean error: {mean_error:.6f}") + print(f" Mean relative error: {relative_error:.6f}") + print(f" Output range: [{fp8_out.min().item():.3f}, {fp8_out.max().item():.3f}]") + print(" ✓ FP8 W8A16 test passed") + + return { + "max_error": max_error, + "mean_error": mean_error, + "relative_error": relative_error, + } + + +def test_fp8_w8a8_e2e(): + """End-to-end test for FP8 W8A8 strategy.""" + print("Testing FP8 W8A8 (e4m3) strategy...") + + # Create strategy + strategy = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="fp8_e4m3") + ctx = get_quantization_context() + ctx.set_linear_strategy("attn", strategy) + + # Simulate a small attention projection + M, K, N = 32, 512, 256 + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda" if torch.cuda.is_available() else "cpu") + weight = torch.randn(N, K, dtype=torch.bfloat16, device=x.device) + bias = torch.randn(N, dtype=torch.bfloat16, device=x.device) + + # Reference output + ref_out = F.linear(x, weight, bias) + + # FP8 quantized output + fp8_out = strategy.linear_forward(x, weight, bias, quant_kind="attn") + + # Check output + assert fp8_out.shape == ref_out.shape + assert fp8_out.dtype == torch.bfloat16 + + # Compute error metrics + max_error = torch.abs(fp8_out - ref_out).max().item() + mean_error = torch.abs(fp8_out - ref_out).mean().item() + relative_error = (torch.abs(fp8_out - ref_out) / (ref_out.abs() + 1e-8)).mean().item() + + print(f" Max error: {max_error:.6f}") + print(f" Mean error: {mean_error:.6f}") + print(f" Mean relative error: {relative_error:.6f}") + print(f" Output range: [{fp8_out.min().item():.3f}, {fp8_out.max().item():.3f}]") + print(" ✓ FP8 W8A8 test passed") + + return { + "max_error": max_error, + "mean_error": mean_error, + "relative_error": relative_error, + } + + +def test_memory_usage(): + """Test memory usage comparison (basic check).""" + print("Testing memory usage...") + + if not torch.cuda.is_available(): + print(" Skipping memory test (CUDA not available)") + return + + device = torch.device("cuda") + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + # BF16 baseline + M, K, N = 32, 512, 256 + weight_bf16 = torch.randn(N, K, dtype=torch.bfloat16, device=device) + mem_bf16 = torch.cuda.memory_allocated() + + # FP8 quantized + strategy = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="bf16") + weight_fp8, scales = strategy.quantize_weight_for_kernel(weight_bf16, device=device) + mem_fp8 = torch.cuda.memory_allocated() + + # Memory reduction + weight_size_bf16 = weight_bf16.numel() * 2 # bf16 = 2 bytes + weight_size_fp8 = weight_fp8.numel() * 1 + scales.numel() * 4 # uint8 = 1 byte, float32 = 4 bytes + reduction = (1 - weight_size_fp8 / weight_size_bf16) * 100 + + print(f" BF16 weight size: {weight_size_bf16 / 1024:.2f} KB") + print(f" FP8 weight size: {weight_size_fp8 / 1024:.2f} KB") + print(f" Memory reduction: {reduction:.1f}%") + print(" ✓ Memory test passed") + + +def main(): + """Run all end-to-end tests.""" + print("=" * 60) + print("FP8 Linear Quantization End-to-End Tests") + print("=" * 60) + print() + + try: + # Test FP8 W8A16 + w8a16_metrics = test_fp8_w8a16_e2e() + print() + + # Test FP8 W8A8 + w8a8_metrics = test_fp8_w8a8_e2e() + print() + + # Test memory usage + test_memory_usage() + print() + + print("=" * 60) + print("All tests passed!") + print("=" * 60) + print() + print("Summary:") + print(f" FP8 W8A16 - Max error: {w8a16_metrics['max_error']:.6f}") + print(f" FP8 W8A8 - Max error: {w8a8_metrics['max_error']:.6f}") + + except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc() + return 1 + + return 0 + + +if __name__ == "__main__": + exit(main()) + diff --git a/examples/test_gptq_awq_loading.py b/examples/test_gptq_awq_loading.py new file mode 100644 index 0000000..3cb8eed --- /dev/null +++ b/examples/test_gptq_awq_loading.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python3 +"""测试 GPTQ/AWQ 离线量化权重加载功能 + +此脚本演示如何加载 GPTQ/AWQ 格式的量化 checkpoint 并验证权重是否正确加载。 + +使用方法: + # 测试 GPTQ checkpoint 加载 + python test_gptq_awq_loading.py --format gptq --model-path /path/to/gptq/checkpoint + + # 测试 AWQ checkpoint 加载 + python test_gptq_awq_loading.py --format awq --model-path /path/to/awq/checkpoint + + # 列出所有线性层及其量化状态 + python test_gptq_awq_loading.py --format gptq --model-path /path/to/checkpoint --list-layers +""" +import os +import sys +import argparse +from pathlib import Path + +# Make stdout/stderr line-buffered so progress logs are visible even when redirected/captured. +try: + sys.stdout.reconfigure(line_buffering=True) + sys.stderr.reconfigure(line_buffering=True) +except Exception: + pass + +# 确保从当前仓库导入 +_REPO_ROOT = Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from diffulex.config import Config +from diffulex.model import AutoModelForDiffusionLM +from diffulex.utils.quantization.context import get_linear_strategy + + +def list_quantized_layers(model, format_name: str): + """列出所有线性层及其量化状态.""" + print("\n" + "=" * 80) + print(f"线性层量化状态 ({format_name.upper()})") + print("=" * 80) + print(f"{'模块名称':<50} {'类型':<15} {'量化状态':<15}") + print("-" * 80) + + gptq_count = 0 + awq_count = 0 + other_count = 0 + no_quant_count = 0 + + for name, module in model.named_modules(): + if hasattr(module, "has_offline_quantized_weight"): + if module.has_offline_quantized_weight(): + format_val = int(module._offline_quant_format.item()) if module._offline_quant_format.numel() > 0 else 0 + if format_val == 1: + quant_status = "GPTQ (离线)" + gptq_count += 1 + elif format_val == 2: + quant_status = "AWQ (离线)" + awq_count += 1 + else: + quant_status = "未知" + other_count += 1 + module_type = module.__class__.__name__ + print(f"{name:<50} {module_type:<15} {quant_status:<15}") + elif hasattr(module, "has_quantized_weight") and module.has_quantized_weight(): + quant_status = "运行时量化" + module_type = module.__class__.__name__ + print(f"{name:<50} {module_type:<15} {quant_status:<15}") + other_count += 1 + elif hasattr(module, "weight") and module.weight is not None: + quant_status = "未量化" + module_type = module.__class__.__name__ + if "Linear" in module_type: + print(f"{name:<50} {module_type:<15} {quant_status:<15}") + no_quant_count += 1 + + print("-" * 80) + print(f"\n统计:") + print(f" - GPTQ 离线量化层: {gptq_count}") + print(f" - AWQ 离线量化层: {awq_count}") + print(f" - 运行时量化层: {other_count}") + print(f" - 未量化层: {no_quant_count}") + print(f" - 总计: {gptq_count + awq_count + other_count + no_quant_count}") + + +def test_model_forward(model, config, num_test_inputs: int = 2): + """测试模型前向传播.""" + print("\n" + "=" * 80) + print("测试模型前向传播") + print("=" * 80) + + # 获取模型的输入大小(从第一个线性层的 input_size 推断) + hidden_size = None + for name, module in model.named_modules(): + if hasattr(module, "input_size"): + hidden_size = module.input_size + break + + if hidden_size is None: + print("⚠ 无法确定模型的 hidden_size,跳过前向传播测试") + return + + print(f"使用 hidden_size={hidden_size}") + + try: + import torch + import torch.nn.functional as F + + # 创建测试输入 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + test_inputs = [] + for i in range(num_test_inputs): + x = torch.randn(1, hidden_size, dtype=torch.bfloat16, device=device) + test_inputs.append(x) + + print(f"\n运行 {len(test_inputs)} 个测试输入...") + for i, x in enumerate(test_inputs): + print(f"\n 测试输入 {i+1}/{len(test_inputs)}: shape={x.shape}, dtype={x.dtype}") + + # 测试第一个线性层的 forward + found_linear = False + for name, module in model.named_modules(): + if hasattr(module, "forward") and hasattr(module, "quant_kind"): + try: + output = module(x) + print(f" ✓ {name}: output shape={output.shape}, dtype={output.dtype}") + found_linear = True + break + except Exception as e: + print(f" ✗ {name}: 错误 - {e}") + import traceback + traceback.print_exc() + break + + if not found_linear: + print(f" ⚠ 未找到可测试的线性层") + + print("\n✓ 前向传播测试完成") + + except Exception as e: + print(f"\n✗ 前向传播测试失败: {e}") + import traceback + traceback.print_exc() + + +def main(): + parser = argparse.ArgumentParser( + description="测试 GPTQ/AWQ 离线量化权重加载功能", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +示例用法: + %(prog)s --format gptq --model-path /path/to/gptq/checkpoint + %(prog)s --format awq --model-path /path/to/awq/checkpoint + %(prog)s --format gptq --model-path /path/to/checkpoint --list-layers --test-forward + """ + ) + + parser.add_argument( + "--format", + type=str, + choices=["gptq", "awq"], + required=True, + help="量化格式: 'gptq' 或 'awq'" + ) + parser.add_argument( + "--model-path", + type=str, + required=True, + help="模型 checkpoint 路径(包含 .safetensors 文件)" + ) + parser.add_argument( + "--model-name", + type=str, + default="dream", + help="模型名称(默认: 'dream')" + ) + parser.add_argument( + "--list-layers", + action="store_true", + help="列出所有线性层及其量化状态" + ) + parser.add_argument( + "--test-forward", + action="store_true", + help="测试模型前向传播" + ) + parser.add_argument( + "--tensor-parallel-size", + type=int, + default=1, + help="Tensor parallel size (默认: 1,仅 TP=1 支持离线量化权重加载)" + ) + + args = parser.parse_args() + + # 验证模型路径 + model_path = Path(args.model_path) + if not model_path.exists(): + print(f"错误: 模型路径不存在: {model_path}") + sys.exit(1) + + safetensors_files = list(model_path.glob("*.safetensors")) + if not safetensors_files: + print(f"警告: 在 {model_path} 中未找到 .safetensors 文件") + + print("=" * 80) + print("GPTQ/AWQ 离线量化权重加载测试") + print("=" * 80) + print(f"量化格式: {args.format.upper()}") + print(f"模型路径: {model_path}") + print(f"模型名称: {args.model_name}") + print(f"Tensor Parallel Size: {args.tensor_parallel_size}") + print(f"找到 {len(safetensors_files)} 个 .safetensors 文件") + print("=" * 80) + + # 检查 safetensors 文件中是否包含 GPTQ/AWQ keys + if safetensors_files: + print("\n检查 checkpoint 中的量化 keys...") + gptq_keys = [] + awq_keys = [] + for file in safetensors_files: + from safetensors import safe_open + with safe_open(file, "pt", "cpu") as f: + for key in f.keys(): + if key.endswith(".qweight"): + gptq_keys.append(key) + awq_keys.append(key) + elif key.endswith(".qzeros"): + gptq_keys.append(key) + awq_keys.append(key) + elif key.endswith(".scales"): + gptq_keys.append(key) + awq_keys.append(key) + elif key.endswith(".g_idx"): + gptq_keys.append(key) + + print(f" 找到 {len(set(k.rsplit('.', 1)[0] for k in gptq_keys if k.endswith('.qweight')))} 个可能的量化层") + if gptq_keys and args.format == "gptq": + print(f" 找到 {len([k for k in gptq_keys if k.endswith('.g_idx')])} 个 g_idx keys (GPTQ)") + + # 创建配置 + try: + config = Config( + model=str(model_path), + model_name=args.model_name, + tensor_parallel_size=args.tensor_parallel_size, + data_parallel_size=1, + linear_attn_weight_dtype=args.format, + linear_mlp_weight_dtype=args.format, + linear_attn_act_dtype="bf16", + linear_mlp_act_dtype="bf16", + use_lora=False, + gpu_memory_utilization=0.3, + max_num_batched_tokens=1024, + max_num_seqs=4, + max_model_len=1024, + decoding_strategy="d2f", + enforce_eager=True, + ) + print("\n✓ 配置创建成功") + except Exception as e: + print(f"\n✗ 配置创建失败: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + + # 检查 TP 支持 + if args.tensor_parallel_size > 1: + print("\n⚠ 警告: Tensor Parallel > 1 目前不完全支持离线量化权重加载") + print(" 如果遇到问题,请使用 --tensor-parallel-size 1") + + # 加载模型 + print("\n加载模型...") + try: + model = AutoModelForDiffusionLM.from_config(config) + print("✓ 模型加载成功") + except Exception as e: + print(f"\n✗ 模型加载失败: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + + # 列出量化层 + if args.list_layers: + list_quantized_layers(model, args.format) + + # 测试前向传播 + if args.test_forward: + test_model_forward(model, config) + + print("\n" + "=" * 80) + print("测试完成") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/test_llada_dvllm_human_eval.py b/examples/test_llada_dvllm_human_eval.py index 5f52c72..a2c9704 100755 --- a/examples/test_llada_dvllm_human_eval.py +++ b/examples/test_llada_dvllm_human_eval.py @@ -83,4 +83,5 @@ def summarize_profiling(csv_path: str) -> dict: "=*=" * 30) for idx, o in enumerate(outputs): print("\n", "=*=" * 30) - print(f"[Prompt {idx} Result] \n{prompts[idx] + "\n----------\n" + o['text']}\n") \ No newline at end of file + resp = prompts[idx] + "\n----------\n" + o["text"] + print(f"[Prompt {idx} Result]\n{resp}\n") \ No newline at end of file diff --git a/examples/test_quantization_generation.py b/examples/test_quantization_generation.py new file mode 100755 index 0000000..7ffd26f --- /dev/null +++ b/examples/test_quantization_generation.py @@ -0,0 +1,851 @@ +#!/usr/bin/env python3 +"""统一的量化策略文本生成测试脚本 + +支持测试以下量化策略组合: +- BF16 + BF16 KV +- BF16 + FP8 KV +- W8A16 + BF16 KV +- W8A16 + FP8 KV +- W4A16 + BF16 KV +- W4A16 + FP8 KV +- W8A8 + BF16 KV +- W8A8 + FP8 KV +- W4A8 + BF16 KV +- W4A8 + FP8 KV +- FP8 W8A16 (e4m3) + BF16 KV +- FP8 W8A16 (e4m3) + FP8 KV +- FP8 W8A16 (e5m2) + BF16 KV +- FP8 W8A16 (e5m2) + FP8 KV +- FP8 W8A8 (e4m3) + BF16 KV +- FP8 W8A8 (e4m3) + FP8 KV +- FP8 W8A8 (e5m2) + BF16 KV +- FP8 W8A8 (e5m2) + FP8 KV +- GPTQ W4A16 (离线量化) + BF16 KV +- GPTQ W4A16 (离线量化) + FP8 KV +- AWQ W4A16 (离线量化) + BF16 KV +- AWQ W4A16 (离线量化) + FP8 KV + +使用方法: + # 运行所有策略 + python test_quantization_generation.py --all + + # 只运行 BF16 相关策略 + python test_quantization_generation.py --bf16 + + # 只运行 W8A16 相关策略 + python test_quantization_generation.py --w8a16 + + # 只运行 W4A16 相关策略 + python test_quantization_generation.py --w4a16 + + # 只运行 W8A8 相关策略 + python test_quantization_generation.py --w8a8 + + # 只运行 W4A8 相关策略 + python test_quantization_generation.py --w4a8 + + # 只运行 FP8 W8A16 相关策略 + python test_quantization_generation.py --fp8_w8a16 + + # 只运行 FP8 W8A8 相关策略 + python test_quantization_generation.py --fp8_w8a8 + + # 只运行 GPTQ 相关策略(需要先运行量化脚本生成离线权重) + python test_quantization_generation.py --gptq + + # 只运行 AWQ 相关策略(需要先运行量化脚本生成离线权重) + python test_quantization_generation.py --awq + + # 自定义选择(用逗号分隔) + python test_quantization_generation.py --strategies bf16_bf16kv,w8a16_bf16kv + + # 只测试某个策略 + python test_quantization_generation.py --strategies w4a16_fp8kv + + # 使用量化后的模型路径(如果先运行了量化脚本) + python test_quantization_generation.py --gptq --model-path /path/to/quantized/model +""" +import os +import sys +import time +import argparse +import gc +import json +import subprocess +from pathlib import Path +from typing import Dict, Optional, List, Tuple + +# Make stdout/stderr line-buffered so progress logs are visible even when redirected/captured. +try: + sys.stdout.reconfigure(line_buffering=True) + sys.stderr.reconfigure(line_buffering=True) +except Exception: + pass + +# +# NOTE: +# 这个脚本不应假设本机 CUDA 安装路径或默认 GPU 号。 +# 如需指定 CUDA/设备,请在运行前自行设置: +# - CUDA_HOME / CUDA_PATH / PATH / LD_LIBRARY_PATH +# - CUDA_VISIBLE_DEVICES +# 或者在你自己的 wrapper 脚本里处理。 + +# 确保从当前仓库导入 +_REPO_ROOT = Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from transformers import AutoTokenizer +from diffulex import Diffulex, SamplingParams + + +# 支持的策略配置 +STRATEGY_CONFIGS = { + 'bf16_bf16kv': { + 'name': 'BF16 + BF16 KV', + 'linear_attn_weight_dtype': 'bf16', + 'linear_mlp_weight_dtype': 'bf16', + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'bf16', + 'kv_cache_dtype': 'bf16', + }, + 'bf16_fp8kv': { + 'name': 'BF16 + FP8 KV', + 'linear_attn_weight_dtype': 'bf16', + 'linear_mlp_weight_dtype': 'bf16', + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'bf16', + 'kv_cache_dtype': 'fp8', + }, + 'w8a16_bf16kv': { + 'name': 'W8A16 + BF16 KV', + 'linear_attn_weight_dtype': 'int8', + 'linear_mlp_weight_dtype': 'int8', + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'bf16', + 'kv_cache_dtype': 'bf16', + }, + 'w8a16_fp8kv': { + 'name': 'W8A16 + FP8 KV', + 'linear_attn_weight_dtype': 'int8', + 'linear_mlp_weight_dtype': 'int8', + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'bf16', + 'kv_cache_dtype': 'fp8', + }, + 'w4a16_bf16kv': { + 'name': 'W4A16 + BF16 KV', + 'linear_attn_weight_dtype': 'int4', + 'linear_mlp_weight_dtype': 'int4', + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'bf16', + 'kv_cache_dtype': 'bf16', + }, + 'w4a16_fp8kv': { + 'name': 'W4A16 + FP8 KV', + 'linear_attn_weight_dtype': 'int4', + 'linear_mlp_weight_dtype': 'int4', + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'bf16', + 'kv_cache_dtype': 'fp8', + }, + 'w8a8_bf16kv': { + 'name': 'W8A8 + BF16 KV', + 'linear_attn_weight_dtype': 'int8', + 'linear_mlp_weight_dtype': 'int8', + 'linear_attn_act_dtype': 'int8', + 'linear_mlp_act_dtype': 'int8', + 'kv_cache_dtype': 'bf16', + }, + 'w8a8_fp8kv': { + 'name': 'W8A8 + FP8 KV', + 'linear_attn_weight_dtype': 'int8', + 'linear_mlp_weight_dtype': 'int8', + 'linear_attn_act_dtype': 'int8', + 'linear_mlp_act_dtype': 'int8', + 'kv_cache_dtype': 'fp8', + }, + 'w4a8_bf16kv': { + 'name': 'W4A8(MLP A8) + W4A16(Attn A16) + BF16 KV', + 'linear_attn_weight_dtype': 'int4', + 'linear_mlp_weight_dtype': 'int4', + # Pure W4A8 (int4 weight + int8 act) tends to severely hurt generation quality without calibration. + # Minimal quality-first tweak: keep attention activation at bf16 (W4A16), while keeping MLP at int8 act (W4A8). + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'int8', + 'kv_cache_dtype': 'bf16', + }, + 'w4a8_fp8kv': { + 'name': 'W4A8(MLP A8) + W4A16(Attn A16) + FP8 KV', + 'linear_attn_weight_dtype': 'int4', + 'linear_mlp_weight_dtype': 'int4', + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'int8', + 'kv_cache_dtype': 'fp8', + }, + # FP8 W8A16 strategies + 'fp8_w8a16_e4m3_bf16kv': { + 'name': 'FP8 W8A16 (e4m3) + BF16 KV', + 'linear_attn_weight_dtype': 'fp8_e4m3', + 'linear_mlp_weight_dtype': 'fp8_e4m3', + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'bf16', + 'kv_cache_dtype': 'bf16', + }, + 'fp8_w8a16_e4m3_fp8kv': { + 'name': 'FP8 W8A16 (e4m3) + FP8 KV', + 'linear_attn_weight_dtype': 'fp8_e4m3', + 'linear_mlp_weight_dtype': 'fp8_e4m3', + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'bf16', + 'kv_cache_dtype': 'fp8', + }, + 'fp8_w8a16_e5m2_bf16kv': { + 'name': 'FP8 W8A16 (e5m2) + BF16 KV', + 'linear_attn_weight_dtype': 'fp8_e5m2', + 'linear_mlp_weight_dtype': 'fp8_e5m2', + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'bf16', + 'kv_cache_dtype': 'bf16', + }, + 'fp8_w8a16_e5m2_fp8kv': { + 'name': 'FP8 W8A16 (e5m2) + FP8 KV', + 'linear_attn_weight_dtype': 'fp8_e5m2', + 'linear_mlp_weight_dtype': 'fp8_e5m2', + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'bf16', + 'kv_cache_dtype': 'fp8', + }, + # FP8 W8A8 strategies + 'fp8_w8a8_e4m3_bf16kv': { + 'name': 'FP8 W8A8 (e4m3) + BF16 KV', + 'linear_attn_weight_dtype': 'fp8_e4m3', + 'linear_mlp_weight_dtype': 'fp8_e4m3', + 'linear_attn_act_dtype': 'fp8_e4m3', + 'linear_mlp_act_dtype': 'fp8_e4m3', + 'kv_cache_dtype': 'bf16', + }, + 'fp8_w8a8_e4m3_fp8kv': { + 'name': 'FP8 W8A8 (e4m3) + FP8 KV', + 'linear_attn_weight_dtype': 'fp8_e4m3', + 'linear_mlp_weight_dtype': 'fp8_e4m3', + 'linear_attn_act_dtype': 'fp8_e4m3', + 'linear_mlp_act_dtype': 'fp8_e4m3', + 'kv_cache_dtype': 'fp8', + }, + 'fp8_w8a8_e5m2_bf16kv': { + 'name': 'FP8 W8A8 (e5m2) + BF16 KV', + 'linear_attn_weight_dtype': 'fp8_e5m2', + 'linear_mlp_weight_dtype': 'fp8_e5m2', + 'linear_attn_act_dtype': 'fp8_e5m2', + 'linear_mlp_act_dtype': 'fp8_e5m2', + 'kv_cache_dtype': 'bf16', + }, + 'fp8_w8a8_e5m2_fp8kv': { + 'name': 'FP8 W8A8 (e5m2) + FP8 KV', + 'linear_attn_weight_dtype': 'fp8_e5m2', + 'linear_mlp_weight_dtype': 'fp8_e5m2', + 'linear_attn_act_dtype': 'fp8_e5m2', + 'linear_mlp_act_dtype': 'fp8_e5m2', + 'kv_cache_dtype': 'fp8', + }, + # GPTQ W4A16 strategies (offline quantized) + 'gptq_w4a16_bf16kv': { + 'name': 'GPTQ W4A16 (离线量化) + BF16 KV', + 'linear_attn_weight_dtype': 'gptq', + 'linear_mlp_weight_dtype': 'gptq', + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'bf16', + 'kv_cache_dtype': 'bf16', + }, + 'gptq_w4a16_fp8kv': { + 'name': 'GPTQ W4A16 (离线量化) + FP8 KV', + 'linear_attn_weight_dtype': 'gptq', + 'linear_mlp_weight_dtype': 'gptq', + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'bf16', + 'kv_cache_dtype': 'fp8', + }, + # AWQ W4A16 strategies (offline quantized) + 'awq_w4a16_bf16kv': { + 'name': 'AWQ W4A16 (离线量化) + BF16 KV', + 'linear_attn_weight_dtype': 'awq', + 'linear_mlp_weight_dtype': 'awq', + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'bf16', + 'kv_cache_dtype': 'bf16', + }, + 'awq_w4a16_fp8kv': { + 'name': 'AWQ W4A16 (离线量化) + FP8 KV', + 'linear_attn_weight_dtype': 'awq', + 'linear_mlp_weight_dtype': 'awq', + 'linear_attn_act_dtype': 'bf16', + 'linear_mlp_act_dtype': 'bf16', + 'kv_cache_dtype': 'fp8', + }, +} + +# 策略组定义 +STRATEGY_GROUPS = { + 'bf16': ['bf16_bf16kv', 'bf16_fp8kv'], + 'w8a16': ['w8a16_bf16kv', 'w8a16_fp8kv'], + 'w4a16': ['w4a16_bf16kv', 'w4a16_fp8kv'], + 'w8a8': ['w8a8_bf16kv', 'w8a8_fp8kv'], + 'w4a8': ['w4a8_bf16kv', 'w4a8_fp8kv'], + 'fp8_w8a16': [ + 'fp8_w8a16_e4m3_bf16kv', + 'fp8_w8a16_e4m3_fp8kv', + 'fp8_w8a16_e5m2_bf16kv', + 'fp8_w8a16_e5m2_fp8kv', + ], + 'fp8_w8a8': [ + 'fp8_w8a8_e4m3_bf16kv', + 'fp8_w8a8_e4m3_fp8kv', + 'fp8_w8a8_e5m2_bf16kv', + 'fp8_w8a8_e5m2_fp8kv', + ], + 'gptq': [ + 'gptq_w4a16_bf16kv', + 'gptq_w4a16_fp8kv', + ], + 'awq': [ + 'awq_w4a16_bf16kv', + 'awq_w4a16_fp8kv', + ], + 'all': list(STRATEGY_CONFIGS.keys()), +} + + +def test_generation( + llm: Diffulex, + tokenizer: AutoTokenizer, + test_name: str, + prompts: List[str], + warmup: bool = False, + max_tokens: int = 30, +) -> Optional[Dict[str, float]]: + """运行文本生成测试 + + Args: + llm: Diffulex 模型实例 + tokenizer: Tokenizer 实例 + test_name: 测试名称 + prompts: 输入 prompts 列表 + warmup: 如果为 True,只运行 warmup,不报告详细结果 + max_tokens: 最大生成 token 数 + + Returns: + 如果是 warmup,返回 True/False + 如果不是 warmup,返回包含性能指标的字典,或 None(如果失败) + """ + if not warmup: + print("\n" + "=" * 70) + print(f"测试: {test_name}") + print("=" * 70) + else: + print("\n" + "=" * 70) + print(f"Warmup: {test_name} (排除 kernel 编译影响)") + print("=" * 70) + + sampling_params = SamplingParams(temperature=0.7, max_tokens=max_tokens) + + # 添加 BOS token(如果需要) + prompts_with_bos = [] + for p in prompts: + if tokenizer.bos_token and not p.startswith(tokenizer.bos_token): + prompts_with_bos.append(tokenizer.bos_token + p) + else: + prompts_with_bos.append(p) + + if not warmup: + print(f"输入 prompts ({len(prompts_with_bos)} 个):") + for i, p in enumerate(prompts_with_bos, 1): + print(f" {i}. {p[:60]}...") + print(f"\n开始生成...") + else: + print(f"运行 warmup 生成(kernel 编译中,不报告速度)...") + + start_time = time.time() + + try: + outputs = llm.generate(prompts_with_bos, sampling_params) + end_time = time.time() + + total_time = end_time - start_time + total_tokens = sum(len(o.get('token_ids', [])) for o in outputs) + + if warmup: + print(f"✓ Warmup 完成 (kernel 已编译,耗时 {total_time:.2f} 秒)") + return True + + avg_tps = total_tokens / total_time if total_time > 0 else 0 + + print(f"\n✓ 生成成功!") + print(f" - 总时间: {total_time:.2f} 秒") + print(f" - 总 token 数: {total_tokens}") + print(f" - 平均 TPS: {avg_tps:.2f} tok/s") + + print(f"\n生成结果:") + for i, output in enumerate(outputs, 1): + generated_text = output.get('text', '') + token_ids = output.get('token_ids', []) + print(f"\n [{i}] 输入: {prompts[i-1][:50]}...") + print(f" 输出: {generated_text[:150]}...") + print(f" Token数: {len(token_ids)}") + + return { + 'total_time': total_time, + 'total_tokens': total_tokens, + 'avg_tps': avg_tps, + } + except Exception as e: + print(f"\n✗ 生成失败: {e}") + import traceback + traceback.print_exc() + return None + + +def _cleanup_llm(llm: Optional[Diffulex], force_cleanup: bool = False): + """Best-effort cleanup to release GPU memory and NCCL resources even on exceptions. + + Args: + llm: Diffulex instance to cleanup + force_cleanup: If True, performs more aggressive cleanup including delays + """ + try: + if llm is not None: + llm.exit() + except Exception: + pass + + try: + import torch + import torch.distributed as dist + if dist.is_initialized(): + dist.destroy_process_group() + torch.cuda.empty_cache() + if force_cleanup: + # Force synchronization to ensure cleanup is complete + torch.cuda.synchronize() + except Exception: + pass + + # Clear quantization strategy caches if available + if force_cleanup: + try: + from diffulex.utils.quantization.context import get_quantization_context + ctx = get_quantization_context() + # QuantizationContext stores strategies in ctx._strategies (linear_attn/linear_mlp/linear_other/...). + if hasattr(ctx, "_strategies") and isinstance(ctx._strategies, dict): + for strategy in ctx._strategies.values(): + if strategy is not None and hasattr(strategy, "_weight_cache"): + strategy._weight_cache.clear() + except Exception: + pass + + try: + gc.collect() + if force_cleanup: + # Additional cleanup pass + gc.collect() + except Exception: + pass + + if force_cleanup: + # Small delay to allow resources to be released + import time + time.sleep(0.5) + + +def run_strategy( + strategy_key: str, + model_path: str, + tokenizer: AutoTokenizer, + prompts: List[str], + common_kwargs: Dict, + max_tokens: int = 30, +) -> Tuple[str, Optional[Dict[str, float]]]: + """运行单个策略的测试 + + Returns: + (strategy_name, result_dict) 或 (strategy_name, None) 如果失败 + """ + if strategy_key not in STRATEGY_CONFIGS: + print(f"✗ 未知策略: {strategy_key}") + return (strategy_key, None) + + config = STRATEGY_CONFIGS[strategy_key] + strategy_name = config['name'] + is_w4a16 = 'w4a16' in strategy_key.lower() + is_w4a8 = 'w4a8' in strategy_key.lower() + is_gptq = 'gptq' in strategy_key.lower() + is_awq = 'awq' in strategy_key.lower() + needs_special_cleanup = is_w4a16 or is_w4a8 or is_gptq or is_awq # W4A16/W4A8/GPTQ/AWQ may need extra cleanup + + print("\n" + "=" * 70) + print(f"测试: {strategy_name}") + print("=" * 70) + + # For W4A16/W4A8 strategies, add a delay before starting to ensure previous strategy is fully cleaned up + if needs_special_cleanup: + import time + print("等待资源清理...") + # Additional cleanup before W4A16/W4A8 + _cleanup_llm(None, force_cleanup=True) + time.sleep(2.0) + + llm = None + try: + # 构建 Diffulex 配置 + llm_kwargs = { + **common_kwargs, + 'kv_cache_dtype': config['kv_cache_dtype'], + 'kv_cache_layout': 'unified', # FP8 kernel 只支持 unified layout + 'linear_attn_weight_dtype': config['linear_attn_weight_dtype'], + 'linear_mlp_weight_dtype': config['linear_mlp_weight_dtype'], + 'linear_attn_act_dtype': config['linear_attn_act_dtype'], + 'linear_mlp_act_dtype': config['linear_mlp_act_dtype'], + } + + llm = Diffulex(model_path, **llm_kwargs) + print(f"✓ {strategy_name} 模型初始化成功") + + # 第一轮:Warmup(排除 kernel 编译影响) + test_generation(llm, tokenizer, strategy_name, prompts, warmup=True, max_tokens=max_tokens) + + # 第二轮:实际测试(kernel 已编译,看稳态性能) + result = test_generation(llm, tokenizer, strategy_name, prompts, warmup=False, max_tokens=max_tokens) + return (strategy_name, result) + + except Exception as e: + print(f"✗ {strategy_name} 路径测试失败: {e}") + import traceback + traceback.print_exc() + + # For W4A16/W4A8 strategies, provide more detailed error information + if needs_special_cleanup and 'shape' in str(e).lower(): + strategy_type = "W4A16/W4A8" + print(f"\n提示: {strategy_type} 策略失败可能是由于资源清理不彻底导致的。") + print(" 建议:") + print(" 1. 单独运行测试脚本") + print(" 2. 或者增加策略之间的清理延迟时间") + + return (strategy_name, None) + finally: + # Use force_cleanup=True for W4A16/W4A8 strategies to ensure complete cleanup + _cleanup_llm(llm, force_cleanup=needs_special_cleanup) + llm = None + # Additional cleanup delay for W4A16/W4A8 to ensure resources are fully released + if needs_special_cleanup: + import time + time.sleep(2.0) # Increased delay for W4A16/W4A8 + + +def _run_strategy_in_subprocess( + strategy_key: str, + *, + model_path: str, + max_tokens: int, + gpu_memory_utilization: float, +) -> Tuple[str, Optional[Dict[str, float]]]: + """Run a single strategy in a fresh subprocess to avoid cross-strategy state (CUDA/NCCL/cache/fragmentation).""" + cmd = [ + sys.executable, + "-u", # unbuffered stdout/stderr so parent can stream logs in real time + str(Path(__file__).resolve()), + "--strategies", + strategy_key, + "--max-tokens", + str(max_tokens), + "--model-path", + model_path, + "--gpu-memory-utilization", + str(gpu_memory_utilization), + "--_emit-json", + ] + # NOTE: don't use capture_output=True here, otherwise the parent appears to "hang" + # during long model init/compilation because no logs are printed until the subprocess exits. + print(f"\n[INFO] 启动子进程运行策略: {strategy_key}") + # Ensure CUDA env is present *before Python starts* in the subprocess. + # This matters because TileLang caches CUDA_HOME at import time (and can be imported very early). + child_env = os.environ.copy() + if _CUDA_12_2_PATH.exists(): + child_env["CUDA_HOME"] = str(_CUDA_12_2_PATH) + child_env["CUDA_PATH"] = str(_CUDA_12_2_PATH) + child_env["PATH"] = f"{_CUDA_12_2_PATH}/bin:{child_env.get('PATH', '')}" + child_env["LD_LIBRARY_PATH"] = f"{_CUDA_12_2_PATH}/lib64:{child_env.get('LD_LIBRARY_PATH', '')}" + child_env["LIBRARY_PATH"] = f"{_CUDA_12_2_PATH}/lib64:{child_env.get('LIBRARY_PATH', '')}" + child_env["CPATH"] = f"{_CUDA_12_2_PATH}/include:{child_env.get('CPATH', '')}" + child_env["CUDACXX"] = str(_CUDA_12_2_PATH / "bin" / "nvcc") + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True, + env=child_env, + ) + + marker = "__RESULT_JSON__:" + captured_lines: List[str] = [] + try: + assert proc.stdout is not None + for line in proc.stdout: + # Stream logs live so the user can see progress. + print(line, end="") + captured_lines.append(line.rstrip("\n")) + finally: + # Ensure process termination is observed. + returncode = proc.wait() + + # Parse the result marker from captured stdout. + for line in reversed(captured_lines): + if line.startswith(marker): + payload = json.loads(line[len(marker):]) + return payload["strategy_name"], payload["result"] + + # If we can't find the marker, treat as failure. + print(f"✗ 子进程未返回结果标记(strategy={strategy_key}, returncode={returncode})") + return STRATEGY_CONFIGS.get(strategy_key, {}).get("name", strategy_key), None + + +def print_summary(results: Dict[str, Dict[str, float]]): + """打印汇总结果表格""" + if not results: + print("\n⚠ 没有成功完成的测试") + return + + print("\n" + "=" * 90) + print("性能汇总(第二轮,kernel 已编译)") + print("=" * 90) + print(f"{'策略':<25} {'总时间 (秒)':<15} {'总 Token 数':<15} {'平均 TPS (tok/s)':<20}") + print("-" * 90) + + # 按策略名称排序 + sorted_results = sorted(results.items()) + for name, result in sorted_results: + print(f"{name:<25} {result['total_time']:<15.2f} {result['total_tokens']:<15} {result['avg_tps']:<20.2f}") + + # 计算性能对比(如果有多个结果) + if len(results) > 1: + print("\n" + "-" * 90) + print("性能对比(相对于第一个策略):") + print("-" * 90) + + baseline_name = sorted_results[0][0] + baseline_result = sorted_results[0][1] + baseline_tps = baseline_result['avg_tps'] + + for name, result in sorted_results[1:]: + tps_diff = ((result['avg_tps'] - baseline_tps) / baseline_tps) * 100 + time_diff = ((result['total_time'] - baseline_result['total_time']) / baseline_result['total_time']) * 100 + + tps_indicator = "↑" if tps_diff > 0 else "↓" if tps_diff < 0 else "≈" + time_indicator = "↓" if time_diff < 0 else "↑" if time_diff > 0 else "≈" + + print(f" {name:<25} TPS: {tps_diff:+.1f}% {tps_indicator} 时间: {time_diff:+.1f}% {time_indicator}") + + +def parse_strategies(args) -> List[str]: + """解析命令行参数,返回要运行的策略列表""" + strategies = [] + + if args.all: + strategies = STRATEGY_GROUPS['all'] + elif args.bf16: + strategies = STRATEGY_GROUPS['bf16'] + elif args.w8a16: + strategies = STRATEGY_GROUPS['w8a16'] + elif args.w4a16: + strategies = STRATEGY_GROUPS['w4a16'] + elif args.w8a8: + strategies = STRATEGY_GROUPS['w8a8'] + elif args.w4a8: + strategies = STRATEGY_GROUPS['w4a8'] + elif args.fp8_w8a16: + strategies = STRATEGY_GROUPS['fp8_w8a16'] + elif args.fp8_w8a8: + strategies = STRATEGY_GROUPS['fp8_w8a8'] + elif args.gptq: + strategies = STRATEGY_GROUPS['gptq'] + elif args.awq: + strategies = STRATEGY_GROUPS['awq'] + elif args.strategies: + # 手动指定策略,支持逗号分隔 + strategies = [s.strip() for s in args.strategies.split(',')] + # 验证策略是否有效 + invalid = [s for s in strategies if s not in STRATEGY_CONFIGS] + if invalid: + print(f"✗ 无效的策略: {invalid}") + print(f" 支持的策略: {', '.join(STRATEGY_CONFIGS.keys())}") + sys.exit(1) + else: + # 默认运行所有策略 + print("未指定策略,默认运行所有策略(使用 --all 显式指定)") + strategies = STRATEGY_GROUPS['all'] + + return strategies + + +def main(): + parser = argparse.ArgumentParser( + description='Diffulex 量化策略文本生成测试', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +示例用法: + %(prog)s --all # 运行所有策略 + %(prog)s --bf16 # 只运行 BF16 相关策略 + %(prog)s --w8a16 # 只运行 W8A16 相关策略 + %(prog)s --w4a16 # 只运行 W4A16 相关策略 + %(prog)s --w8a8 # 只运行 W8A8 相关策略 + %(prog)s --w4a8 # 只运行 W4A8 相关策略 + %(prog)s --fp8_w8a16 # 只运行 FP8 W8A16 相关策略 + %(prog)s --fp8_w8a8 # 只运行 FP8 W8A8 相关策略 + %(prog)s --gptq # 只运行 GPTQ W4A16 相关策略(需要先运行量化脚本) + %(prog)s --awq # 只运行 AWQ W4A16 相关策略(需要先运行量化脚本) + %(prog)s --strategies bf16_bf16kv,w8a16_bf16kv # 自定义选择 + %(prog)s --strategies w4a16_fp8kv --max-tokens 50 # 指定策略和参数 + %(prog)s --gptq --model-path /path/to/quantized/model # 使用量化后的模型路径 + """ + ) + + # 策略选择选项(互斥) + strategy_group = parser.add_mutually_exclusive_group() + strategy_group.add_argument('--all', action='store_true', help='运行所有策略') + strategy_group.add_argument('--bf16', action='store_true', help='只运行 BF16 相关策略') + strategy_group.add_argument('--w8a16', action='store_true', help='只运行 W8A16 相关策略') + strategy_group.add_argument('--w4a16', action='store_true', help='只运行 W4A16 相关策略') + strategy_group.add_argument('--w8a8', action='store_true', help='只运行 W8A8 相关策略') + strategy_group.add_argument('--w4a8', action='store_true', help='只运行 W4A8 相关策略') + strategy_group.add_argument('--fp8_w8a16', action='store_true', help='只运行 FP8 W8A16 相关策略') + strategy_group.add_argument('--fp8_w8a8', action='store_true', help='只运行 FP8 W8A8 相关策略') + strategy_group.add_argument('--gptq', action='store_true', help='只运行 GPTQ W4A16 相关策略(需要先运行量化脚本生成离线权重)') + strategy_group.add_argument('--awq', action='store_true', help='只运行 AWQ W4A16 相关策略(需要先运行量化脚本生成离线权重)') + strategy_group.add_argument('--strategies', type=str, help='手动指定策略(逗号分隔),例如: bf16_bf16kv,w8a16_fp8kv') + + # 其他选项 + parser.add_argument('--max-tokens', type=int, default=30, help='最大生成 token 数(默认: 30)') + parser.add_argument('--model-path', type=str, required=True, help='模型路径(必填)') + parser.add_argument('--lora-path', type=str, default="", help='LoRA 路径(可选)') + parser.add_argument('--use-lora', action='store_true', help='启用 LoRA(需同时提供 --lora-path)') + parser.add_argument('--gpu-memory-utilization', type=float, default=0.3, help='GPU 内存利用率(默认: 0.3)') + parser.add_argument('--no-isolate', action='store_true', help='多策略运行时不使用子进程隔离(调试用,可能导致状态串扰/性能波动)') + # Internal: emit a single JSON result line for parent process parsing. + parser.add_argument('--_emit-json', action='store_true', help=argparse.SUPPRESS) + + args = parser.parse_args() + + # 确定模型路径 + model_path = args.model_path + if not os.path.exists(model_path): + print(f"错误: 模型路径不存在: {model_path}") + print("请使用 --model-path 指向有效的模型路径") + return + + # 解析要运行的策略 + strategies = parse_strategies(args) + + print("=" * 90) + print("Diffulex 量化策略文本生成测试") + print("=" * 90) + print(f"模型路径: {model_path}") + print(f"要测试的策略 ({len(strategies)} 个): {', '.join(STRATEGY_CONFIGS[s]['name'] for s in strategies)}") + print(f"最大生成 token 数: {args.max_tokens}") + print("=" * 90) + + # 测试 prompts (10个样例) + test_prompts = [ + "The capital of France is", + "Python is a programming language", + "The largest planet in our solar system is", + "Machine learning is a subset of", + "The speed of light is approximately", + "Artificial intelligence has applications in", + "The Great Wall of China was built", + "Quantum computing uses principles from", + "The human brain contains approximately", + "Climate change is caused by", + ] + + # 加载 tokenizer + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + print(f"✓ Tokenizer 加载成功") + except Exception as e: + print(f"✗ Tokenizer 加载失败: {e}") + return + + # 通用 Diffulex 配置 + common_kwargs = { + 'lora_path': args.lora_path, + 'use_lora': bool(args.use_lora and args.lora_path), + 'model_name': 'dream', + 'enforce_eager': True, + 'data_parallel_size': 1, + 'tensor_parallel_size': 1, + 'gpu_memory_utilization': args.gpu_memory_utilization, + 'max_num_batched_tokens': 1024, + 'max_num_seqs': 4, + 'max_model_len': 1024, + 'decoding_strategy': 'd2f', + 'decode_mode': 'varlen', # 统一设置为 varlen 模式 + } + + # 运行所有选定的策略 + # 对于 W4A16/W4A8/GPTQ/AWQ 策略,调整运行顺序:先运行其他策略,再运行这些策略 + # 这样可以避免在运行其他策略后资源状态不一致导致的问题 + w4a16_strategies = [s for s in strategies if 'w4a16' in s.lower() and 'gptq' not in s.lower() and 'awq' not in s.lower()] + w4a8_strategies = [s for s in strategies if 'w4a8' in s.lower()] + gptq_strategies = [s for s in strategies if 'gptq' in s.lower()] + awq_strategies = [s for s in strategies if 'awq' in s.lower()] + other_strategies = [s for s in strategies if 'w4a16' not in s.lower() and 'w4a8' not in s.lower() and 'gptq' not in s.lower() and 'awq' not in s.lower()] + # 先运行其他策略,再运行 W4A16 策略,然后 W4A8,最后 GPTQ/AWQ 策略(如果存在) + ordered_strategies = other_strategies + w4a16_strategies + w4a8_strategies + gptq_strategies + awq_strategies + + results = {} + isolate = (len(ordered_strategies) > 1) and (not args.no_isolate) and (not args._emit_json) + for strategy_key in ordered_strategies: + if isolate: + strategy_name, result = _run_strategy_in_subprocess( + strategy_key, + model_path=model_path, + max_tokens=args.max_tokens, + gpu_memory_utilization=args.gpu_memory_utilization, + ) + else: + strategy_name, result = run_strategy( + strategy_key, + model_path, + tokenizer, + test_prompts, + common_kwargs, + max_tokens=args.max_tokens, + ) + if result: + results[strategy_name] = result + + # 打印汇总结果 + if args._emit_json: + # In emit-json mode we should have exactly one strategy; return it as a single machine-readable line. + # If multiple are present for any reason, pick the first. + if results: + name, result = next(iter(results.items())) + print("__RESULT_JSON__:" + json.dumps({"strategy_name": name, "result": result}, ensure_ascii=False)) + else: + # Fallback: map key to display name if possible + only_key = ordered_strategies[0] if ordered_strategies else "unknown" + only_name = STRATEGY_CONFIGS.get(only_key, {}).get("name", only_key) + print("__RESULT_JSON__:" + json.dumps({"strategy_name": only_name, "result": None}, ensure_ascii=False)) + return + + print_summary(results) + + print("\n" + "=" * 90) + print("测试完成") + print("=" * 90) + + +if __name__ == "__main__": + main() + diff --git a/examples/test_sdar_diffulex_gsm8k.py b/examples/test_sdar_diffulex_gsm8k.py new file mode 100755 index 0000000..5d9efe7 --- /dev/null +++ b/examples/test_sdar_diffulex_gsm8k.py @@ -0,0 +1,68 @@ +import os +import csv +import time + +import pandas as pd + +from tqdm import tqdm +from datasets import load_dataset +from viztracer import VizTracer +from transformers import AutoTokenizer + +from diffulex import Diffulex, SamplingParams + + +FEW_SHOTS = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nQuestion: Jen and Tyler are gymnasts practicing flips. Jen is practicing the triple-flip while Tyler is practicing the double-flip. Jen did sixteen triple-flips during practice. Tyler flipped in the air half the number of times Jen did. How many double-flips did Tyler do?\nAnswer:<|im_end|>\n<|im_start|>assistant\nJen did 16 triple-flips, so she did 16 * 3 = <<16*3=48>>48 flips.\nTyler did half the number of flips, so he did 48 / 2 = <<48/2=24>>24 flips.\nA double flip has two flips, so Tyler did 24 / 2 = <<24/2=12>>12 double-flips.\n#### 12<|im_end|>\n<|im_start|>user\nQuestion: Four people in a law firm are planning a party. Mary will buy a platter of pasta for $20 and a loaf of bread for $2. Elle and Andrea will split the cost for buying 4 cans of soda which cost $1.50 each, and chicken wings for $10. Joe will buy a cake that costs $5. How much more will Mary spend than the rest of the firm put together?\nAnswer:<|im_end|>\n<|im_start|>assistant\nMary will spend $20 + $2 = $<<20+2=22>>22.\nElle and Andrea will spend $1.5 x 4 = $<<1.5*4=6>>6 for the soda.\nElle and Andrea will spend $6 + $10 = $<<6+10=16>>16 for the soda and chicken wings.\nElle, Andrea, and Joe together will spend $16 + $5 = $<<16+5=21>>21.\nSo, Mary will spend $22 - $21 = $<<22-21=1>>1 more than all of them combined.\n#### 1<|im_end|>\n<|im_start|>user\nQuestion: A charcoal grill burns fifteen coals to ash every twenty minutes of grilling. The grill ran for long enough to burn three bags of coals. Each bag of coal contains 60 coals. How long did the grill run?\nAnswer:<|im_end|>\n<|im_start|>assistant\nThe grill burned 3 * 60 = <<3*60=180>>180 coals.\nIt takes 20 minutes to burn 15 coals, so the grill ran for 180 / 15 * 20 = <<180/15*20=240>>240 minutes.\n#### 240<|im_end|>\n<|im_start|>user\nQuestion: A bear is preparing to hibernate for the winter and needs to gain 1000 pounds. At the end of summer, the bear feasts on berries and small woodland animals. During autumn, it devours acorns and salmon. It gained a fifth of the weight it needed from berries during summer, and during autumn, it gained twice that amount from acorns. Salmon made up half of the remaining weight it had needed to gain. How many pounds did it gain eating small animals?\nAnswer:<|im_end|>\n<|im_start|>assistant\nThe bear gained 1 / 5 * 1000 = <<1/5*1000=200>>200 pounds from berries.\nIt gained 2 * 200 = <<2*200=400>>400 pounds from acorns.\nIt still needed 1000 - 200 - 400 = <<1000-200-400=400>>400 pounds.\nThus, it gained 400 / 2 = <<400/2=200>>200 pounds from salmon.\nTherefore, the bear gained 400 - 200 = <<400-200=200>>200 pounds from small animals.\n#### 200<|im_end|>\n<|im_start|>user\nQuestion: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nAnswer:<|im_end|>\n<|im_start|>assistant\n" +# FEW_SHOTS = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + +if __name__ == "__main__": + PROFILE = True + # model = "/root/data/ckpts/JetLM/SDAR-1.7B-Chat-b32" + model = "/data1/ckpts/JetLM/SDAR-1.7B-Chat-b32" + dataset = load_dataset("gsm8k", "main", split="test")["question"][:1] + LLM = Diffulex( + model, + use_lora=False, + model_name="sdar", + enforce_eager=True, + data_parallel_size=1, + tensor_parallel_size=1, + gpu_memory_utilization=0.3, + max_num_batched_tokens=2048, + max_num_seqs=20, + max_model_len=2048, + kv_cache_layout="unified", + decoding_strategy="block_diffusion", + mask_token_id=151669, + ) + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + sampling_params = SamplingParams(temperature=0.0, max_tokens=256) + + prompts = [ + FEW_SHOTS + f"<|im_start|>user\nQuestion: {question}\nAnswer:<|im_end|>\n<|im_start|>assistant\n" + for question in tqdm(dataset) + ] + s = time.time() + if PROFILE: + output_file = "log/profiles/perf_dvllm_dream_7B.json" + if os.path.exists(output_file): + os.remove(output_file) + + with VizTracer(output_file=output_file, file_info=True) as tracer: + outputs = LLM.generate(prompts, sampling_params) + else: + outputs = LLM.generate(prompts, sampling_params) + e = time.time() + print("=*=" * 30, + "\nProfiling Results\n", + "=*=" * 30, "\n" + f"Generated {len(outputs)} outputs.\n" + f"Total tokens: {sum(len(o['token_ids']) for o in outputs)}\n" + f"Total time: {e - s:.2f} seconds.\n" + f"Avg TPS: {sum(len(o['token_ids']) for o in outputs) / (e - s):.2f} tok/s.\n" + f"AVG Number of Diffusion Steps: {sum(o['n_diff_steps'] for o in outputs) / len(outputs):.2f}\n", + "=*=" * 30) + for idx, o in enumerate(outputs): + print("\n", "=*=" * 30) + resp = prompts[idx] + "\n----------\n" + o["text"] + print(f"[Prompt {idx} Result]\n{resp}\n") \ No newline at end of file diff --git a/examples/test_sdar_dvllm.py b/examples/test_sdar_dvllm.py index 28aded5..c5c00ba 100644 --- a/examples/test_sdar_dvllm.py +++ b/examples/test_sdar_dvllm.py @@ -97,14 +97,14 @@ def main() -> None: parser.add_argument( "--model", type=str, - default="/data1/ckpts/SDAR/SDAR-1.7B-Chat", + required=True, help="SDAR HF model directory (contains config.json + model.safetensors).", ) parser.add_argument("--device", type=int, default=0) parser.add_argument( "--converted-dir", type=str, - default="/home/ljp/tmp/diffulex_sdar_converted", + default="tmp/diffulex_sdar_converted", help="Output directory for converted checkpoint keys (Diffulex-native).", ) parser.add_argument("--prompt", type=str, default="你好,请用一句话介绍 SDAR。") diff --git a/profile/analyze_trace_bottlenecks.py b/profile/analyze_trace_bottlenecks.py new file mode 100644 index 0000000..41821d3 --- /dev/null +++ b/profile/analyze_trace_bottlenecks.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python3 +""" +Analyze huge torch chrome trace (streaming) to locate non-GEMM bottlenecks. + +Outputs: +- duration of user_annotation "diffulex.generate(profiled)" (wall-ish) +- GPU active time (union of kernel/memcpy/memset intervals) to estimate GPU idle gaps +- top CUDA runtime/driver API calls by CPU time + +Designed to work without loading the >2GB JSON into memory. +""" + +from __future__ import annotations + +import argparse +import json +from collections import Counter, defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple + + +def _extract_str_after_key(line: str, key: str) -> Optional[str]: + k = f"\"{key}\"" + pos = line.find(k) + if pos < 0: + return None + colon = line.find(":", pos + len(k)) + if colon < 0: + return None + q1 = line.find('"', colon) + if q1 < 0: + return None + q2 = line.find('"', q1 + 1) + if q2 < 0: + return None + return line[q1 + 1 : q2] + + +def _extract_num_after_key(line: str, key: str) -> Optional[float]: + k = f"\"{key}\"" + pos = line.find(k) + if pos < 0: + return None + colon = line.find(":", pos + len(k)) + if colon < 0: + return None + frag = line[colon + 1 :].strip() + comma = frag.find(",") + if comma >= 0: + frag = frag[:comma] + try: + return float(frag.strip()) + except Exception: + return None + + +def _extract_json_object_value(line: str, key: str) -> Optional[Any]: + """ + Extract JSON object/array value following `"key":` on the same line. + Assumes the value is a JSON object {...} or array [...] and is fully contained in the line. + """ + k = f"\"{key}\"" + pos = line.find(k) + if pos < 0: + return None + colon = line.find(":", pos + len(k)) + if colon < 0: + return None + # find first '{' or '[' after colon + start = None + for i in range(colon, len(line)): + if line[i] == "{": + start = i + open_ch, close_ch = "{", "}" + break + if line[i] == "[": + start = i + open_ch, close_ch = "[", "]" + break + if start is None: + return None + depth = 0 + end = None + for i in range(start, len(line)): + ch = line[i] + if ch == open_ch: + depth += 1 + elif ch == close_ch: + depth -= 1 + if depth == 0: + end = i + 1 + break + if end is None: + return None + frag = line[start:end] + try: + return json.loads(frag) + except Exception: + return None + + +@dataclass +class Interval: + start: float + end: float + + +def _merge_intervals(intervals: List[Interval]) -> List[Interval]: + if not intervals: + return [] + intervals.sort(key=lambda x: x.start) + merged: List[Interval] = [intervals[0]] + for it in intervals[1:]: + last = merged[-1] + if it.start <= last.end: + if it.end > last.end: + last.end = it.end + else: + merged.append(it) + return merged + + +def analyze(trace_path: Path) -> Dict[str, Any]: + # union intervals for GPU activity across all streams + gpu_intervals: List[Interval] = [] + gpu_min_ts: Optional[float] = None + gpu_max_end: Optional[float] = None + + # also per stream, to detect if one stream is idle most of the time + gpu_intervals_by_stream: Dict[int, List[Interval]] = defaultdict(list) + + # user annotation + generate_dur_us: Optional[float] = None + + # runtime/driver api durations (cpu-side) + cuda_runtime: Counter[str] = Counter() + cuda_driver: Counter[str] = Counter() + + in_events = False + in_obj = False + depth = 0 + buf: List[str] = [] + + def _consume_event(text: str) -> None: + nonlocal generate_dur_us, gpu_min_ts, gpu_max_end + # quick checks without json parsing + if '"cat"' not in text or '"name"' not in text: + return + cat = None + name = None + # extract cat/name + # cat and name appear on first line typically, but safe on full text. + for line in text.splitlines(): + if cat is None and '"cat"' in line: + v = _extract_str_after_key(line, "cat") + if v: + cat = v + if name is None and '"name"' in line: + v = _extract_str_after_key(line, "name") + if v: + name = v + if cat is not None and name is not None: + break + if cat is None or name is None: + return + + if cat == "user_annotation" and name == "diffulex.generate(profiled)": + # duration in us + for line in text.splitlines(): + if '"dur"' in line: + d = _extract_num_after_key(line, "dur") + if d is not None: + generate_dur_us = d + break + return + + # cuda runtime/driver (CPU) + if cat == "cuda_runtime": + d = None + for line in text.splitlines(): + if '"dur"' in line: + d = _extract_num_after_key(line, "dur") + break + if d is not None: + cuda_runtime[name] += d + return + if cat == "cuda_driver": + d = None + for line in text.splitlines(): + if '"dur"' in line: + d = _extract_num_after_key(line, "dur") + break + if d is not None: + cuda_driver[name] += d + return + + # GPU activity events + if cat in ("kernel", "gpu_memcpy", "gpu_memset"): + ts = None + dur = None + stream = None + for line in text.splitlines(): + if ts is None and '"ts"' in line: + ts = _extract_num_after_key(line, "ts") + if dur is None and '"dur"' in line: + dur = _extract_num_after_key(line, "dur") + if stream is None and '"args"' in line and "stream" in line: + # args is often multi-line; rely on json fragment extraction when seen + pass + # extract args object to fetch stream quickly (safe, small) + args_obj = None + for line in text.splitlines(): + if '"args"' in line: + args_obj = _extract_json_object_value(line, "args") + break + if isinstance(args_obj, dict): + try: + stream = int(args_obj.get("stream", -1)) + except Exception: + stream = None + if ts is None or dur is None: + return + start = ts + end = ts + dur + gpu_intervals.append(Interval(start, end)) + if stream is not None and stream >= 0: + gpu_intervals_by_stream[stream].append(Interval(start, end)) + gpu_min_ts = start if gpu_min_ts is None else min(gpu_min_ts, start) + gpu_max_end = end if gpu_max_end is None else max(gpu_max_end, end) + return + + with trace_path.open("r", encoding="utf-8", errors="replace") as f: + for line in f: + if not in_events: + if '"traceEvents"' in line and "[" in line: + in_events = True + continue + if not in_obj: + if line.lstrip().startswith("{"): + in_obj = True + buf = [line] + depth = line.count("{") - line.count("}") + else: + if line.lstrip().startswith("]"): + break + continue + else: + buf.append(line) + depth += line.count("{") - line.count("}") + if in_obj and depth <= 0: + _consume_event("".join(buf)) + in_obj = False + + merged = _merge_intervals(gpu_intervals) + active_us = sum(it.end - it.start for it in merged) + span_us = (gpu_max_end - gpu_min_ts) if (gpu_min_ts is not None and gpu_max_end is not None) else 0.0 + + per_stream_active: Dict[int, float] = {} + for s, ints in gpu_intervals_by_stream.items(): + m = _merge_intervals(ints) + per_stream_active[s] = sum(it.end - it.start for it in m) + + top_runtime = cuda_runtime.most_common(30) + top_driver = cuda_driver.most_common(30) + + return { + "trace": str(trace_path), + "generate_dur_us": generate_dur_us, + "gpu_active_union_us": active_us, + "gpu_span_us": span_us, + "gpu_active_ratio_union_over_span": (active_us / span_us) if span_us > 0 else None, + "gpu_active_ratio_union_over_generate": (active_us / generate_dur_us) if (generate_dur_us and generate_dur_us > 0) else None, + "gpu_span_over_generate": (span_us / generate_dur_us) if (generate_dur_us and generate_dur_us > 0) else None, + "gpu_event_count": len(gpu_intervals), + "gpu_stream_count": len(per_stream_active), + "top_cuda_runtime_us": top_runtime, + "top_cuda_driver_us": top_driver, + "top_stream_active_us": sorted(per_stream_active.items(), key=lambda kv: kv[1], reverse=True)[:10], + } + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--trace", type=str, required=True) + ap.add_argument("--out", type=str, required=True) + args = ap.parse_args() + + res = analyze(Path(args.trace)) + out_path = Path(args.out) + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(res, indent=2), encoding="utf-8") + print(f"[OK] wrote: {out_path}") + + +if __name__ == "__main__": + main() + diff --git a/profile/analyze_trace_cpu_ops.py b/profile/analyze_trace_cpu_ops.py new file mode 100644 index 0000000..c08b05c --- /dev/null +++ b/profile/analyze_trace_cpu_ops.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +""" +Stream-aggregate CPU-side durations from huge torch chrome traces. + +We aggregate: +- cat=cpu_op +- cat=python_function +- cat=user_annotation + +This helps answer: where is the extra walltime coming from (outside CUDA kernels)? +""" + +from __future__ import annotations + +import argparse +import json +from collections import Counter, defaultdict +from pathlib import Path +from typing import Dict, Optional, Tuple + + +def _extract_str_after_key(s: str, key: str) -> Optional[str]: + k = f"\"{key}\"" + pos = s.find(k) + if pos < 0: + return None + colon = s.find(":", pos + len(k)) + if colon < 0: + return None + q1 = s.find('"', colon) + if q1 < 0: + return None + q2 = s.find('"', q1 + 1) + if q2 < 0: + return None + return s[q1 + 1 : q2] + + +def _extract_num_after_key(s: str, key: str) -> Optional[float]: + k = f"\"{key}\"" + pos = s.find(k) + if pos < 0: + return None + colon = s.find(":", pos + len(k)) + if colon < 0: + return None + frag = s[colon + 1 :].strip() + comma = frag.find(",") + if comma >= 0: + frag = frag[:comma] + try: + return float(frag.strip()) + except Exception: + return None + + +def analyze(trace_path: Path, cats: Tuple[str, ...]) -> Dict[str, Dict[str, Dict[str, float]]]: + # cat -> name -> (dur_us_sum, calls) + dur: Dict[str, Counter[str]] = {c: Counter() for c in cats} + calls: Dict[str, Counter[str]] = {c: Counter() for c in cats} + + in_events = False + in_obj = False + depth = 0 + buf = [] + + def consume(text: str) -> None: + if '"cat"' not in text or '"name"' not in text: + return + cat = None + name = None + d = None + for line in text.splitlines(): + if cat is None and '"cat"' in line: + cat = _extract_str_after_key(line, "cat") + if name is None and '"name"' in line: + name = _extract_str_after_key(line, "name") + if d is None and '"dur"' in line: + d = _extract_num_after_key(line, "dur") + if cat and name and d is not None: + break + if cat not in cats or name is None: + return + calls[cat][name] += 1 + if d is not None: + dur[cat][name] += d + + with trace_path.open("r", encoding="utf-8", errors="replace") as f: + for line in f: + if not in_events: + if '"traceEvents"' in line and "[" in line: + in_events = True + continue + if not in_obj: + if line.lstrip().startswith("{"): + in_obj = True + buf = [line] + depth = line.count("{") - line.count("}") + else: + if line.lstrip().startswith("]"): + break + continue + else: + buf.append(line) + depth += line.count("{") - line.count("}") + if in_obj and depth <= 0: + consume("".join(buf)) + in_obj = False + + out: Dict[str, Dict[str, Dict[str, float]]] = {} + for c in cats: + out[c] = {} + for name, total in dur[c].items(): + out[c][name] = { + "dur_us": float(total), + "calls": float(calls[c][name]), + "avg_us": float(total) / float(calls[c][name]) if calls[c][name] else 0.0, + } + return out + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--trace", type=str, required=True) + ap.add_argument("--out", type=str, required=True) + ap.add_argument("--topk", type=int, default=50) + args = ap.parse_args() + + cats = ("cpu_op", "python_function", "user_annotation") + res = analyze(Path(args.trace), cats) + + # Write a compact report: per-cat topk by dur. + lines = [] + lines.append(f"Trace: {args.trace}") + lines.append("") + for c in cats: + items = sorted(res[c].items(), key=lambda kv: kv[1]["dur_us"], reverse=True)[: args.topk] + lines.append(f"== {c} top {args.topk} by dur_us ==") + for name, st in items: + lines.append(f"{st['dur_us']:.3f} us calls={int(st['calls'])} avg={st['avg_us']:.3f} us {name}") + lines.append("") + + Path(args.out).write_text("\n".join(lines), encoding="utf-8") + print(f"[OK] wrote: {args.out}") + + +if __name__ == "__main__": + main() + diff --git a/profile/analyze_trace_gemm_shapes.py b/profile/analyze_trace_gemm_shapes.py new file mode 100644 index 0000000..98a0a0d --- /dev/null +++ b/profile/analyze_trace_gemm_shapes.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +""" +Stream-parse PyTorch chrome trace JSON (very large) and aggregate GEMM shape +distributions for selected ops. + +This script is designed for traces exported with record_shapes=True, where op +events contain args["Input Dims"]. + +Example: + python profile/analyze_trace_gemm_shapes.py \ + --trace log/torch_profiles/20260125_023133/pytorch_trace_diffulex.generate(profiled).json \ + --out log/torch_profiles/20260125_023133/gemm_shapes_bf16.txt \ + --ops aten::mm aten::addmm +""" + +from __future__ import annotations + +import argparse +import json +import math +from collections import Counter, defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple + + +def _parse_json_value_fragment(fragment: str) -> Any: + # fragment: after ':' in a JSON line, possibly ending with ',' and newline. + frag = fragment.strip() + if frag.endswith(","): + frag = frag[:-1] + return json.loads(frag) + +def _extract_json_array_after_key(line: str, key: str) -> Optional[Any]: + """ + Extract and json-load the array value after `"key":` from a possibly + multi-field JSON line, e.g. + ..."Input Dims": [[1,2],[3,4]], "Ev Idx": 5 + """ + k = f"\"{key}\"" + pos = line.find(k) + if pos < 0: + return None + colon = line.find(":", pos + len(k)) + if colon < 0: + return None + # Find the first '[' after the colon. + start = line.find("[", colon) + if start < 0: + return None + depth = 0 + end = -1 + for i in range(start, len(line)): + ch = line[i] + if ch == "[": + depth += 1 + elif ch == "]": + depth -= 1 + if depth == 0: + end = i + 1 + break + if end < 0: + return None + frag = line[start:end] + try: + return json.loads(frag) + except Exception: + return None + + +def _extract_quoted_value(line: str) -> Optional[str]: + # very small helper: extract first "...". + i = line.find('"') + if i < 0: + return None + j = line.find('"', i + 1) + if j < 0: + return None + return line[i + 1 : j] + + +def _extract_number_after_colon(line: str) -> Optional[float]: + # e.g. "dur": 123.0, + if ":" not in line: + return None + frag = line.split(":", 1)[1].strip() + if frag.endswith(","): + frag = frag[:-1] + try: + return float(frag) + except Exception: + return None + +def _extract_number_after_key(line: str, key: str) -> Optional[float]: + """ + Extract a numeric value after `"key":` from a possibly multi-field JSON line, e.g. + "ts": 123.0, "dur": 34.5, + """ + k = f"\"{key}\"" + pos = line.find(k) + if pos < 0: + return None + colon = line.find(":", pos + len(k)) + if colon < 0: + return None + frag = line[colon + 1 :].strip() + # Cut at next comma if present. + comma = frag.find(",") + if comma >= 0: + frag = frag[:comma] + try: + return float(frag.strip()) + except Exception: + return None + + +def _dims_to_mnk(input_dims: Any) -> Optional[Tuple[int, int, int]]: + """ + Convert args["Input Dims"] into a best-effort (M,N,K). + input_dims is typically a list where each element is [] (non-tensor) or + a list[int] (tensor dims). + """ + if not isinstance(input_dims, list): + return None + + tensor_dims: List[List[int]] = [] + for d in input_dims: + if isinstance(d, list) and len(d) >= 2 and all(isinstance(x, (int, float)) for x in d): + tensor_dims.append([int(x) for x in d]) + if len(tensor_dims) < 2: + return None + + a = tensor_dims[0] + b = tensor_dims[1] + a_m, a_k = a[-2], a[-1] + # b could be [k, n] or [n, k] depending on transpose convention. + if len(b) >= 2 and a_k == b[-2]: + b_k, b_n = b[-2], b[-1] + return (a_m, b_n, a_k) + if len(b) >= 2 and a_k == b[-1]: + # b is [n, k] + b_n, b_k = b[-2], b[-1] + return (a_m, b_n, a_k) + + # fallback: assume [k, n] + return (a_m, b[-1], a_k) + + +@dataclass +class ShapeStats: + calls: int = 0 + dur_us: float = 0.0 + + +def iter_op_events(trace_path: Path, target_ops: set[str]) -> Iterable[Tuple[str, Optional[float], Any]]: + """ + Yields (op_name, dur_us, input_dims) for events whose "name" is in target_ops. + Streaming + brace-depth parsing to avoid loading giant JSON into memory. + """ + in_trace_events = False + in_event = False + depth = 0 + + name: Optional[str] = None + dur: Optional[float] = None + input_dims: Any = None + want = False + + with trace_path.open("r", encoding="utf-8", errors="replace") as f: + for line in f: + if not in_trace_events: + if '"traceEvents"' in line and "[" in line: + in_trace_events = True + continue + + # Start of a JSON object event in traceEvents list. + if not in_event: + stripped = line.lstrip() + if stripped.startswith("{"): + in_event = True + depth = stripped.count("{") - stripped.count("}") + name = None + dur = None + input_dims = None + want = False + else: + # End of traceEvents list. + if line.lstrip().startswith("]"): + break + continue + else: + depth += line.count("{") - line.count("}") + + # Parse fields we care about. + if '"name"' in line: + # Some traces put multiple fields on one line: + # "ph": "X", "cat": "cpu_op", "name": "aten::mm", ... + key = '"name":' + pos = line.find(key) + if pos >= 0: + q1 = line.find('"', pos + len(key)) + if q1 >= 0: + q2 = line.find('"', q1 + 1) + if q2 >= 0: + name = line[q1 + 1 : q2] + want = name in target_ops + + if want and dur is None and '"dur"' in line: + dur = _extract_number_after_key(line, "dur") + + if want and input_dims is None and "Input Dims" in line: + input_dims = _extract_json_array_after_key(line, "Input Dims") + + # End of current event object (also works for single-line events). + if in_event and depth <= 0: + if want and name is not None: + yield (name, dur, input_dims) + in_event = False + + +def _human_int(n: float) -> str: + if n >= 1e9: + return f"{n/1e9:.3f}B" + if n >= 1e6: + return f"{n/1e6:.3f}M" + if n >= 1e3: + return f"{n/1e3:.3f}K" + return f"{int(n)}" + + +def main() -> None: + ap = argparse.ArgumentParser("Aggregate GEMM shapes from huge torch chrome trace") + ap.add_argument("--trace", type=str, required=True, help="Path to pytorch_trace_*.json") + ap.add_argument("--out", type=str, required=True, help="Output report path") + ap.add_argument("--ops", type=str, nargs="+", default=["aten::mm", "aten::addmm"], help="Op names to aggregate") + ap.add_argument("--topk", type=int, default=30) + args = ap.parse_args() + + trace_path = Path(args.trace) + out_path = Path(args.out) + target_ops = set(args.ops) + + # op -> (mnk -> stats) + agg: Dict[str, Dict[Tuple[int, int, int], ShapeStats]] = defaultdict(dict) + op_totals: Dict[str, ShapeStats] = defaultdict(ShapeStats) + op_unknown: Counter[str] = Counter() + + for op, dur_us, input_dims in iter_op_events(trace_path, target_ops): + op_totals[op].calls += 1 + if dur_us is not None: + op_totals[op].dur_us += dur_us + + mnk = _dims_to_mnk(input_dims) + if mnk is None: + op_unknown[op] += 1 + continue + + st = agg[op].get(mnk) + if st is None: + st = ShapeStats() + agg[op][mnk] = st + st.calls += 1 + if dur_us is not None: + st.dur_us += dur_us + + lines: List[str] = [] + lines.append(f"Trace: {trace_path}") + lines.append(f"Ops: {', '.join(sorted(target_ops))}") + lines.append("") + + for op in sorted(target_ops): + tot = op_totals.get(op, ShapeStats()) + lines.append(f"== {op} ==") + lines.append(f"total calls: {tot.calls}") + lines.append(f"total dur(us): {tot.dur_us:.3f}") + lines.append(f"unknown shapes: {op_unknown.get(op, 0)}") + lines.append("") + + if op not in agg or not agg[op]: + lines.append("(no shape stats)\n") + continue + + # Top by total dur + items = list(agg[op].items()) + items_by_dur = sorted(items, key=lambda kv: kv[1].dur_us, reverse=True)[: args.topk] + lines.append(f"-- top {args.topk} shapes by total dur(us) --") + lines.append("M,N,K calls total_dur(us) approx_GFLOP") + for (m, n, k), st in items_by_dur: + gflop = 2.0 * m * n * k / 1e9 + lines.append(f"{m},{n},{k} {st.calls} {st.dur_us:.3f} {gflop:.3f}") + lines.append("") + + # Top by calls + items_by_calls = sorted(items, key=lambda kv: kv[1].calls, reverse=True)[: args.topk] + lines.append(f"-- top {args.topk} shapes by calls --") + lines.append("M,N,K calls total_dur(us) avg_dur(us)") + for (m, n, k), st in items_by_calls: + avg = st.dur_us / st.calls if st.calls else 0.0 + lines.append(f"{m},{n},{k} {st.calls} {st.dur_us:.3f} {avg:.3f}") + lines.append("") + + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text("\n".join(lines), encoding="utf-8") + print(f"[OK] wrote: {out_path}") + + +if __name__ == "__main__": + main() + diff --git a/profile/d2f_dream_profile.py b/profile/d2f_dream_profile.py new file mode 100644 index 0000000..750fe4f --- /dev/null +++ b/profile/d2f_dream_profile.py @@ -0,0 +1,87 @@ +""" +D2F Dream Model Profiling Example + +This example demonstrates how to profile the performance +of Dream model with D2F decoding strategy using nsys. +""" +import os +import time +from pathlib import Path +from diffulex import Diffulex, SamplingParams +from transformers import AutoTokenizer + + +def main(): + model_path = "/data1/ckpts/Dream-org/Dream-v0-Base-7B" + lora_path = "/data1/ckpts/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora" + + output_dir = Path("log/profiles") + output_dir.mkdir(parents=True, exist_ok=True) + + print("Loading model...") + model_load_start = time.time() + llm = Diffulex( + model_path, + lora_path=lora_path, + use_lora=True, + model_name="dream", + enforce_eager=True, + tensor_parallel_size=1, + data_parallel_size=1, + gpu_memory_utilization=0.25, + max_model_len=2048, + decoding_strategy="d2f", + mask_token_id=151666, + diffusion_block_size=32, + accept_threshold=0.95, + complete_threshold=0.9, + add_new_block_threshold=0.1, + ) + model_load_time = time.time() - model_load_start + print(f"Model loaded in {model_load_time:.2f} seconds") + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + sampling_params = SamplingParams(temperature=0.0, max_tokens=256) + + prompts = [ + "What is 2+2?", + "Explain quantum computing in simple terms.", + "Write a Python function to calculate factorial.", + ] + + print(f"\nStarting inference profiling...") + + inference_start = time.time() + outputs = llm.generate(prompts, sampling_params) + inference_time = time.time() - inference_start + + total_tokens = sum(len(o.get('token_ids', [])) for o in outputs) + num_outputs = len(outputs) + avg_diff_steps = sum(o.get('n_diff_steps', 0) for o in outputs) / num_outputs if outputs else 0 + throughput = total_tokens / inference_time if inference_time > 0 else 0 + + print("\n" + "=" * 80) + print("Profiling Summary") + print("=" * 80) + print(f"Model Loading Time: {model_load_time:.2f} seconds") + print(f"Inference Time: {inference_time:.2f} seconds") + print(f"Total Duration: {model_load_time + inference_time:.2f} seconds") + print(f"\nInference Metrics:") + print(f" Number of Prompts: {num_outputs}") + print(f" Total Tokens: {total_tokens}") + print(f" Average Throughput: {throughput:.2f} tokens/sec") + print(f" Average Diffusion Steps: {avg_diff_steps:.2f}") + print("=" * 80) + + print("\nGenerated Output Preview:") + for idx, output in enumerate(outputs): + print(f"\n[Prompt {idx + 1}]") + print(f"Input: {prompts[idx]}") + print(f"Output: {output.get('text', 'N/A')[:200]}...") + print(f"Token Count: {len(output.get('token_ids', []))}") + if 'n_diff_steps' in output: + print(f"Diffusion Steps: {output['n_diff_steps']}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/profile/torch_d2f_profiler.py b/profile/torch_d2f_profiler.py new file mode 100644 index 0000000..8dfcf18 --- /dev/null +++ b/profile/torch_d2f_profiler.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3 +""" +用 torch.profiler 跑 Diffulex(D2F/Dream) 的性能剖析,并导出 flamegraph 所需 stacks。 + +设计目标: +- 直接复用 Diffulex 的配置入口(kv_cache_dtype / linear_*_dtype / decode_mode 等) +- 默认强制 TP=1/DP=1,避免 tp_worker 的 spawn 子进程导致 profiler 采不到 CUDA kernel +- 两阶段:先编译/初始化 warmup(不计入 profile),再进入 torch.profiler 采集窗口 + +输出: +- Chrome trace: *.json (可用 chrome://tracing 或 Perfetto 打开) +- Stacks: *.stacks (用于生成火焰图,格式兼容 Brendan Gregg flamegraph 工具链) + +示例: + # BF16 基线 + python profile/torch_d2f_profiler.py --tag bf16 --kv-cache-dtype bf16 + + # FP8 KV + W8A16(对比量化为何更慢) + python profile/torch_d2f_profiler.py --tag w8a16_fp8kv --kv-cache-dtype fp8_e4m3 \ + --linear-attn-weight-dtype int8 --linear-mlp-weight-dtype int8 + + # 指定 decode_mode(auto/varlen/static) + python profile/torch_d2f_profiler.py --tag fp8kv_static --kv-cache-dtype fp8_e4m3 --decode-mode static +""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +import time +from pathlib import Path +from typing import List + +# Make stdout/stderr line-buffered so progress logs are visible even when redirected/captured. +try: + sys.stdout.reconfigure(line_buffering=True) + sys.stderr.reconfigure(line_buffering=True) +except Exception: + pass + +# Ensure import from current repo. +_REPO_ROOT = Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +import torch +from diffulex import Diffulex, SamplingParams +from diffulex_profiler import DiffulexProfiler, ProfilerConfig + + +def _default_prompts() -> List[str]: + return [ + "What is 2+2?", + "Explain quantum computing in simple terms.", + "Write a Python function to calculate factorial.", + ] + + +def _load_prompts(args: argparse.Namespace) -> List[str]: + if args.prompts_file: + p = Path(args.prompts_file) + data = json.loads(p.read_text(encoding="utf-8")) + if not isinstance(data, list) or not all(isinstance(x, str) for x in data): + raise ValueError("--prompts-file 必须是 JSON list[str]") + return data + if args.prompt: + return args.prompt + return _default_prompts() + + +def _mkdir(p: Path) -> Path: + p.mkdir(parents=True, exist_ok=True) + return p + + +def main() -> None: + parser = argparse.ArgumentParser("Diffulex torch.profiler flamegraph (D2F/Dream)") + + parser.add_argument("--model-path", type=str, required=True, help="模型路径(必填)") + parser.add_argument("--lora-path", type=str, default="", help="LoRA 路径(可选)") + parser.add_argument("--use-lora", action="store_true", help="启用 LoRA(需同时提供 --lora-path)") + parser.add_argument("--cuda-home", type=str, default="", help="(可选)设置 CUDA_HOME/CUDA_PATH 并更新 PATH/LD_LIBRARY_PATH") + + parser.add_argument("--tag", type=str, default="torch_profile", help="输出文件名前缀") + parser.add_argument("--out-dir", type=str, default="log/torch_profiles", help="输出目录(相对仓库根)") + + # Quantization / KV settings + parser.add_argument("--kv-cache-dtype", type=str, default="bf16", help="bf16/fp8_e4m3/fp8_e5m2 (也支持别名 fp8/e4m3/e5m2)") + parser.add_argument("--kv-cache-layout", type=str, default="unified", choices=["unified", "distinct"]) + parser.add_argument("--decode-mode", type=str, default="auto", choices=["auto", "varlen", "static"]) + + parser.add_argument("--linear-attn-weight-dtype", type=str, default="bf16") + parser.add_argument("--linear-mlp-weight-dtype", type=str, default="bf16") + parser.add_argument("--linear-attn-act-dtype", type=str, default="bf16") + parser.add_argument("--linear-mlp-act-dtype", type=str, default="bf16") + + # CUDA Graph + parser.add_argument( + "--use-cudagraph", + action="store_true", + help="启用 CUDA Graph(仅 decode_mode=static 且 shape 稳定时有意义);默认关闭以避免 capture 成本影响分析。", + ) + + # Engine settings (force single-process profiling by default) + parser.add_argument("--tensor-parallel-size", type=int, default=1, help="建议保持 1,否则会 spawn 子进程导致采集不到 CUDA") + parser.add_argument("--data-parallel-size", type=int, default=1) + # Distributed comm (avoid port conflicts with other local runs) + parser.add_argument("--master-addr", type=str, default="localhost") + parser.add_argument("--master-port", type=int, default=2333) + parser.add_argument("--gpu-memory-utilization", type=float, default=0.30) + parser.add_argument("--max-model-len", type=int, default=1024) + + # Prompts / decode + parser.add_argument("--max-tokens", type=int, default=256) + parser.add_argument("--prompt", type=str, action="append", help="可多次传入,作为 prompts 列表;不传则用内置默认 prompts") + parser.add_argument("--prompts-file", type=str, default="", help="JSON list[str] 文件路径") + + # Warmup + profiler schedule + parser.add_argument("--compile-warmup-iters", type=int, default=1, help="用于 kernel 编译/缓存的 warmup 次数(不进入 profiler)") + parser.add_argument("--profile-wait", type=int, default=0) + parser.add_argument("--profile-warmup", type=int, default=1) + parser.add_argument("--profile-active", type=int, default=1) + parser.add_argument("--profile-repeat", type=int, default=1) + parser.add_argument( + "--use-diffulex-profiler", + action="store_true", + help="改用 diffulex_profiler 的 PyTorchProfilerBackend(会导出 trace/stacks/top,并额外导出 summary/json)", + ) + parser.add_argument( + "--no-torch-profiler", + action="store_true", + help="仅运行一次稳态 generate(包含 compile warmup),不启用 torch.profiler。用于配合 ncu 等外部 profiler,避免 CUPTI 冲突。", + ) + parser.add_argument( + "--nvtx-range", + type=str, + default="", + help="(可选)用 NVTX 把 profiled generate 包起来,便于 ncu 用 --nvtx-include 精准过滤。示例:--nvtx-range d2f_generate", + ) + + args = parser.parse_args() + + if args.cuda_home: + cuda_home = Path(args.cuda_home) + if not cuda_home.exists(): + raise FileNotFoundError(f"--cuda-home 不存在: {cuda_home}") + os.environ["CUDA_HOME"] = str(cuda_home) + os.environ["CUDA_PATH"] = str(cuda_home) + os.environ["PATH"] = f"{cuda_home}/bin:{os.environ.get('PATH', '')}" + os.environ["LD_LIBRARY_PATH"] = f"{cuda_home}/lib64:{os.environ.get('LD_LIBRARY_PATH', '')}" + os.environ["LIBRARY_PATH"] = f"{cuda_home}/lib64:{os.environ.get('LIBRARY_PATH', '')}" + os.environ["CPATH"] = f"{cuda_home}/include:{os.environ.get('CPATH', '')}" + os.environ["CUDACXX"] = str(cuda_home / "bin" / "nvcc") + + model_path = Path(args.model_path) + if not model_path.exists(): + raise FileNotFoundError(f"模型路径不存在: {model_path}") + + if args.tensor_parallel_size != 1 or args.data_parallel_size != 1: + print( + "[WARN] 你设置了 TP/DP != 1。Diffulex 会 spawn 子进程运行模型," + "torch.profiler 在父进程里通常采不到子进程里的 CUDA kernel。" + "建议用 TP=1/DP=1 跑 profile。" + ) + + prompts = _load_prompts(args) + sampling_params = SamplingParams(temperature=0.0, max_tokens=args.max_tokens) + + out_root = _mkdir(_REPO_ROOT / args.out_dir) + run_dir = _mkdir(out_root / time.strftime("%Y%m%d_%H%M%S")) + print(f"[INFO] 输出目录: {run_dir}") + + # Build Diffulex + use_lora = args.use_lora or bool(args.lora_path) + llm = Diffulex( + str(model_path), + lora_path=args.lora_path, + use_lora=use_lora, + model_name="dream", + decoding_strategy="d2f", + enforce_eager=not args.use_cudagraph, + tensor_parallel_size=args.tensor_parallel_size, + data_parallel_size=args.data_parallel_size, + master_addr=args.master_addr, + master_port=args.master_port, + gpu_memory_utilization=args.gpu_memory_utilization, + max_model_len=args.max_model_len, + max_num_batched_tokens=max(1024, args.max_model_len), + max_num_seqs=min(4, len(prompts)), + kv_cache_dtype=args.kv_cache_dtype, + kv_cache_layout=args.kv_cache_layout, + decode_mode=None if args.decode_mode == "auto" else args.decode_mode, + linear_attn_weight_dtype=args.linear_attn_weight_dtype, + linear_mlp_weight_dtype=args.linear_mlp_weight_dtype, + linear_attn_act_dtype=args.linear_attn_act_dtype, + linear_mlp_act_dtype=args.linear_mlp_act_dtype, + ) + + try: + # Compile / cache warmup (exclude from profile) + for i in range(max(0, args.compile_warmup_iters)): + print(f"[INFO] compile warmup {i+1}/{args.compile_warmup_iters} ...") + with torch.profiler.record_function("diffulex.generate(warmup)"): + _ = llm.generate(prompts, sampling_params, use_tqdm=False) + torch.cuda.synchronize() + + # For external profilers (e.g., ncu). Avoid enabling torch.profiler (CUPTI) here. + if args.no_torch_profiler: + print("[INFO] --no-torch-profiler: 运行一次稳态 generate(不启用 torch.profiler)...") + nvtx_handle = None + nvtx_pushed = False + if args.nvtx_range and torch.cuda.is_available(): + # Nsight Compute CLI --nvtx-include matches start/end ranges (not push/pop ranges). + # Prefer range_start/range_end if available; fallback to push/pop for other tools. + try: + nvtx_handle = torch.cuda.nvtx.range_start(args.nvtx_range) + except Exception: + try: + torch.cuda.nvtx.range_push(args.nvtx_range) + nvtx_pushed = True + except Exception: + pass + try: + with torch.profiler.record_function("diffulex.generate(profiled)"): + _ = llm.generate(prompts, sampling_params, use_tqdm=False) + torch.cuda.synchronize() + finally: + if args.nvtx_range and torch.cuda.is_available(): + if nvtx_handle is not None: + try: + torch.cuda.nvtx.range_end(nvtx_handle) + except Exception: + pass + elif nvtx_pushed: + try: + torch.cuda.nvtx.range_pop() + except Exception: + pass + print(f"[INFO] 完成(无 torch.profiler 输出)。输出目录: {run_dir}") + return + + # Option A: use Diffulex built-in profiler framework. + if args.use_diffulex_profiler: + profiler = DiffulexProfiler( + config=ProfilerConfig( + enabled=True, + backend="pytorch", + output_dir=str(run_dir), + export_formats=["json", "summary"], + pytorch_profiler_config={ + # Ensure artifacts are written into the same run_dir. + "output_dir": str(run_dir), + "record_shapes": True, + "profile_memory": True, + "with_stack": True, + # Also export stacks/top table for flamegraph + quick inspection. + "export_stacks": True, + "stacks_metric": "self_cuda_time_total", + "export_table": True, + "table_row_limit": 80, + }, + ) + ) + + # In this mode, we don't use torch.profiler schedule; we just profile the steady-state generate. + print("[INFO] 使用 diffulex_profiler(pytorch backend) 采集一次稳态 generate ...") + with profiler.profile( + "diffulex.generate(profiled)", + metadata={ + "tag": args.tag, + "decode_mode": args.decode_mode, + "kv_cache_dtype": args.kv_cache_dtype, + "linear_attn_weight_dtype": args.linear_attn_weight_dtype, + "linear_mlp_weight_dtype": args.linear_mlp_weight_dtype, + "linear_attn_act_dtype": args.linear_attn_act_dtype, + "linear_mlp_act_dtype": args.linear_mlp_act_dtype, + }, + ): + _ = llm.generate(prompts, sampling_params, use_tqdm=False) + torch.cuda.synchronize() + print("[INFO] diffulex_profiler 采集完成(trace/stacks/top 已导出到输出目录)。") + profiler.export(str(run_dir / f"{args.tag}")) + print(f"[INFO] 输出目录: {run_dir}") + return + + # Option B: raw torch.profiler with schedule (more controllable / multi-step). + activities = [torch.profiler.ProfilerActivity.CPU] + if torch.cuda.is_available(): + activities.append(torch.profiler.ProfilerActivity.CUDA) + + def _trace_handler(prof: torch.profiler.profile) -> None: + # One trace per active window. + step = getattr(prof, "step_num", None) + suffix = f"_step{step}" if step is not None else "" + trace_path = run_dir / f"{args.tag}{suffix}.trace.json" + stacks_path = run_dir / f"{args.tag}{suffix}.stacks" + summary_path = run_dir / f"{args.tag}{suffix}.top.txt" + prof.export_chrome_trace(str(trace_path)) + # 用 self_cuda_time_total 更聚焦 kernel 开销;若只关心 CPU 改成 self_cpu_time_total + try: + prof.export_stacks(str(stacks_path), "self_cuda_time_total") + except Exception: + # CUDA 不可用/未编译 kineto 时可能失败,仍保留 trace + pass + try: + top = prof.key_averages().table( + sort_by="self_cuda_time_total" if torch.cuda.is_available() else "self_cpu_time_total", + row_limit=50, + ) + summary_path.write_text(top, encoding="utf-8") + except Exception: + pass + + schedule = torch.profiler.schedule( + wait=max(0, args.profile_wait), + warmup=max(0, args.profile_warmup), + active=max(1, args.profile_active), + repeat=max(1, args.profile_repeat), + ) + total_steps = args.profile_wait + args.profile_warmup + args.profile_active * args.profile_repeat + print( + f"[INFO] profiler schedule: wait={args.profile_wait}, warmup={args.profile_warmup}, " + f"active={args.profile_active}, repeat={args.profile_repeat} -> total_steps={total_steps}" + ) + + with torch.profiler.profile( + activities=activities, + schedule=schedule, + on_trace_ready=_trace_handler, + record_shapes=True, + profile_memory=True, + with_stack=True, + ) as prof: + for step in range(total_steps): + print(f"[INFO] profiled generate step {step+1}/{total_steps} ...") + with torch.profiler.record_function("diffulex.generate(profiled)"): + _ = llm.generate(prompts, sampling_params, use_tqdm=False) + torch.cuda.synchronize() + prof.step() + + print("[INFO] 采集完成。你可以用 trace.json 打开时间线,用 .stacks 生成火焰图。") + print(f"[INFO] 输出目录: {run_dir}") + finally: + try: + llm.exit() + except Exception: + pass + + +if __name__ == "__main__": + main() + diff --git a/pyproject.toml b/pyproject.toml index f2e2607..30a6222 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,11 @@ dependencies = [ "matplotlib>=3.10.5", "fastapi>=0.115.0", "uvicorn>=0.30.0", + "pandas>=2.3.3", + "tilelang>=0.1.7.post1", + "rich>=13.0.0", + "colorama>=0.4.6", + "lm-eval" ] [project.urls] @@ -39,13 +44,16 @@ Homepage = "https://github.com/zhijie-group/D2fEngine" Repository = "https://zhijie-group.github.io/D2fEngine" "Organization" = "https://github.com/zhijie-group" +[[tool.uv.index]] +url = "https://mirrors.aliyun.com/pypi/simple" +default = true + [tool.setuptools.packages.find] include = [ "diffulex", + "diffulex_bench", "diffulex_kernel", "diffulex_legacy", + "diffulex_profiler", + "test" ] - -[[tool.uv.index]] -url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" -default = true \ No newline at end of file diff --git a/scripts/build_docs.sh b/script/build_docs.sh similarity index 100% rename from scripts/build_docs.sh rename to script/build_docs.sh diff --git a/script/d2f_dream_eval_gsm8k.sh b/script/d2f_dream_eval_gsm8k.sh new file mode 100755 index 0000000..7cece76 --- /dev/null +++ b/script/d2f_dream_eval_gsm8k.sh @@ -0,0 +1,16 @@ +#!/usr/bin/zsh + +export HF_HUB_OFFLINE=1 +export HF_DATASETS_OFFLINE=1 +export HF_EVALUATE_OFFLINE=1 +export TRANSFORMERS_OFFLINE=1 +export WANDB_DISABLED=true + +export HF_HOME="$(pwd)/cache" +export HF_DATASETS_CACHE="$HF_HOME/datasets" +export HF_METRICS_CACHE="$HF_HOME/metrics" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +python -m diffulex_bench.main \ + --config custom_configs/d2f_dream_eval_gsm8k.yml \ + 2>&1 | tee log/d2f_dream_eval_gsm8k.log \ No newline at end of file diff --git a/scripts/launch_server.sh b/script/launch_server.sh similarity index 100% rename from scripts/launch_server.sh rename to script/launch_server.sh diff --git a/scripts/profile_dvllm_dream.sh b/script/profile_dvllm_dream.sh similarity index 100% rename from scripts/profile_dvllm_dream.sh rename to script/profile_dvllm_dream.sh diff --git a/scripts/test_dvllm_dllm_decoding_kernel.sh b/script/test_dvllm_dllm_decoding_kernel.sh similarity index 100% rename from scripts/test_dvllm_dllm_decoding_kernel.sh rename to script/test_dvllm_dllm_decoding_kernel.sh diff --git a/scripts/test_dvllm_dream_gsm8k.sh b/script/test_dvllm_dream_gsm8k.sh similarity index 100% rename from scripts/test_dvllm_dream_gsm8k.sh rename to script/test_dvllm_dream_gsm8k.sh diff --git a/scripts/test_dvllm_dream_human_eval.sh b/script/test_dvllm_dream_human_eval.sh similarity index 100% rename from scripts/test_dvllm_dream_human_eval.sh rename to script/test_dvllm_dream_human_eval.sh diff --git a/scripts/test_dvllm_qwen.sh b/script/test_dvllm_qwen.sh similarity index 100% rename from scripts/test_dvllm_qwen.sh rename to script/test_dvllm_qwen.sh diff --git a/tests/.gitkeep b/test/.gitkeep similarity index 100% rename from tests/.gitkeep rename to test/.gitkeep diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..a4b72a6 --- /dev/null +++ b/test/__init__.py @@ -0,0 +1,2 @@ +# test package + diff --git a/test/python/__init__.py b/test/python/__init__.py new file mode 100644 index 0000000..da0260f --- /dev/null +++ b/test/python/__init__.py @@ -0,0 +1,2 @@ +# test.python package + diff --git a/tests/python/kernel/test_dllm_flash_attn_decode_kernel.py b/test/python/kernel/test_dllm_flash_attn_decode_kernel.py similarity index 95% rename from tests/python/kernel/test_dllm_flash_attn_decode_kernel.py rename to test/python/kernel/test_dllm_flash_attn_decode_kernel.py index 29200be..01ca7ef 100644 --- a/tests/python/kernel/test_dllm_flash_attn_decode_kernel.py +++ b/test/python/kernel/test_dllm_flash_attn_decode_kernel.py @@ -7,8 +7,7 @@ import torch.nn.functional as F from einops import rearrange -# from diffulex_kernel.python.dllm_flash_attn import dllm_flash_attn_decode_kernel -from diffulex_kernel.python.dllm_flash_attn import dllm_flash_attn_decode_kernel_legacy as dllm_flash_attn_decode_kernel +from diffulex_kernel.python.dllm_flash_attn_kernels import dllm_flash_attn_decode_kernel def naive_sdpa_with_kvcache( @@ -184,7 +183,7 @@ def run_dllm_flash_attn_decode( kernel_source = decode_kernel.get_kernel_source() - cuda_cache_dir = os.getenv("CUDA_CACHE_DIR", "/data1/jyj/Diffulex/cuda_cache") + cuda_cache_dir = os.getenv("CUDA_CACHE_DIR", "./cuda_cache") cache_root = Path(cuda_cache_dir) / "test_dllm_flash_attn_decode_kernel" case_dir = cache_root / ( f"seq{num_seqs}_heads{num_heads}_kv{num_kv_heads}_hd{head_dim}_" @@ -254,6 +253,23 @@ def test_decode_bf16_multi_seq(): ) +def test_decode_bf16_multi_seq_long_context(): + """Test with multiple sequences, bfloat16.""" + run_dllm_flash_attn_decode( + num_seqs=4, + num_heads=32, + num_kv_heads=8, + head_dim=128, + max_q_len=64, + max_kv_len=64, + context_len=1024, + page_block_size=32, + diffusion_block_size=32, + is_block_attn=False, + dtype="bfloat16", + ) + + def test_decode_bf16_block_attn(): """Test with block attention enabled.""" run_dllm_flash_attn_decode( diff --git a/test/python/kernel/test_dllm_flash_attn_decode_kernel_multiround.py b/test/python/kernel/test_dllm_flash_attn_decode_kernel_multiround.py new file mode 100644 index 0000000..09e5b8c --- /dev/null +++ b/test/python/kernel/test_dllm_flash_attn_decode_kernel_multiround.py @@ -0,0 +1,567 @@ +import os +import time +from pathlib import Path + +import torch + +from diffulex_kernel.python.dllm_flash_attn_kernels import dllm_flash_attn_decode_kernel +from test.python.kernel.test_dllm_flash_attn_decode_kernel import naive_sdpa_with_kvcache + + +def test_decode_multiround_context_len(): + """ + Test inference time and compilation behavior across different context_len values and num_seqs. + This test verifies: + 1. Inference time for different context lengths and sequence counts + 2. Whether kernels are recompiled for different context_len values + 3. Block table configurations with trailing -1 entries + """ + # Common parameters (same as test_decode_bf16_multi_seq) + base_params = { + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "max_q_len": 64, + "max_kv_len": 64, + "page_block_size": 32, + "diffusion_block_size": 32, + "is_block_attn": False, + "dtype": "bfloat16", + } + + # Different sequence counts to test + num_seqs_list = [1, 4, 8, 13, 14, 15, 16] + + # Different context lengths to test + max_context_len = 2048 + context_lens = list(range(128, max_context_len + 1, 32)) + + # Track compilation times and inference times + # Key format: (num_seqs, context_len) + compilation_times = {} + inference_times = {} + kernel_paths = {} + kernel_instances = {} + correctness_results = {} # Track correctness verification results + + cuda_cache_dir = os.getenv("CUDA_CACHE_DIR", "./cuda_cache") + cache_root = Path(cuda_cache_dir) / "test_dllm_flash_attn_decode_kernel_multiround" + + print("\n" + "=" * 80) + print("Testing multiple num_seqs and context_len values") + print(f"Testing num_seqs: {num_seqs_list}") + print(f"Testing context_lens: {len(context_lens)} values from {context_lens[0]} to {context_lens[-1]}") + print("=" * 80) + + # Test all combinations of num_seqs and context_len + for num_seqs in num_seqs_list: + # Calculate KV cache size based on max_context_len to ensure consistent allocation + # across all tests for this num_seqs + max_blocks_per_seq = (max_context_len + base_params["page_block_size"] - 1) // base_params["page_block_size"] + max_seq_num_blocks = max_blocks_per_seq + num_page_blocks = num_seqs * max_blocks_per_seq + + print(f"\n{'=' * 80}") + print(f"Testing with num_seqs={num_seqs}") + print(f"KV cache: max_seq_num_blocks={max_seq_num_blocks}, num_page_blocks={num_page_blocks}") + print(f"{'=' * 80}") + + for context_len in context_lens: + print(f"\n--- Testing num_seqs={num_seqs}, context_len={context_len} ---") + + # Check if kernel file already exists (indicates potential cache hit) + case_dir = cache_root / ( + f"seq{num_seqs}_heads{base_params['num_heads']}_" + f"kv{base_params['num_kv_heads']}_hd{base_params['head_dim']}_" + f"ctx{context_len}_pbs{base_params['page_block_size']}_" + f"dbs{base_params['diffusion_block_size']}_" + f"block{int(base_params['is_block_attn'])}_dtype{base_params['dtype']}_" + f"bm64_bn64_stg1_thr128_mq{base_params['max_q_len']}_mk{base_params['max_kv_len']}" + ) + kernel_path = case_dir / "kernel.cu" + + kernel_existed_before = kernel_path.exists() + kernel_mtime_before = kernel_path.stat().st_mtime if kernel_existed_before else None + + # Measure compilation + first inference time + start_time = time.time() + + # Run the test (this includes kernel compilation if needed) + # We'll create the kernel and run it to measure compilation time + torch_dtype = getattr(torch, base_params["dtype"]) + device = "cuda" + num_groups = base_params["num_heads"] // base_params["num_kv_heads"] + total_q_len = num_seqs * base_params["diffusion_block_size"] + total_kv_len = num_seqs * base_params["diffusion_block_size"] + + # Create kernel (this may trigger compilation) + decode_kernel = dllm_flash_attn_decode_kernel( + num_seqs, + num_groups, + num_page_blocks, + total_q_len, + total_kv_len, + base_params["num_heads"], + base_params["head_dim"], + base_params["is_block_attn"], + base_params["diffusion_block_size"], + max_seq_num_blocks, + base_params["page_block_size"], + 64, # block_m + 64, # block_n + 1, # num_stages + 128, # num_threads + ) + + # Save kernel source + kernel_source = decode_kernel.get_kernel_source() + case_dir.mkdir(parents=True, exist_ok=True) + kernel_path.write_text(kernel_source) + + # Prepare input tensors for first run + q = torch.randn(total_q_len, base_params["num_heads"], base_params["head_dim"], + dtype=torch_dtype, device=device) + k = torch.randn(total_kv_len, base_params["num_kv_heads"], base_params["head_dim"], + dtype=torch_dtype, device=device) + v = torch.randn(total_kv_len, base_params["num_kv_heads"], base_params["head_dim"], + dtype=torch_dtype, device=device) + k_cache = torch.randn(num_page_blocks, base_params["page_block_size"], + base_params["num_kv_heads"], base_params["head_dim"], + dtype=torch_dtype, device=device) + v_cache = torch.randn(num_page_blocks, base_params["page_block_size"], + base_params["num_kv_heads"], base_params["head_dim"], + dtype=torch_dtype, device=device) + + # Create block_tables with varying configurations + # Some sequences will have trailing -1 entries even when context_len is sufficient + block_tables = torch.zeros(num_seqs, max_seq_num_blocks, + dtype=torch.int32, device=device) + # Calculate actual blocks needed for current context_len + num_blocks_per_seq = (context_len + base_params["page_block_size"] - 1) // base_params["page_block_size"] + + for seq_idx in range(num_seqs): + # Determine how many blocks to actually use for this sequence + # For some sequences, use fewer blocks to create trailing -1 entries + # Pattern: alternate between full blocks and partial blocks + if seq_idx % 2 == 0: + # Even-indexed sequences: use all blocks needed + blocks_to_use = num_blocks_per_seq + else: + # Odd-indexed sequences: use fewer blocks (leave some trailing -1) + # Use at least 1 block, but leave at least 1 trailing -1 if possible + blocks_to_use = max(1, num_blocks_per_seq - 1) + + # Fill in the blocks + for block_idx in range(blocks_to_use): + block_tables[seq_idx, block_idx] = seq_idx * max_blocks_per_seq + block_idx + + # Set remaining blocks to -1 (invalid) + for block_idx in range(blocks_to_use, max_seq_num_blocks): + block_tables[seq_idx, block_idx] = -1 + + context_lens_tensor = torch.full((num_seqs,), context_len, + dtype=torch.int32, device=device) + cu_seqlens_q = torch.arange(0, (num_seqs + 1) * base_params["diffusion_block_size"], + base_params["diffusion_block_size"], dtype=torch.int32, device=device) + cu_seqlens_k = torch.arange(0, (num_seqs + 1) * base_params["diffusion_block_size"], + base_params["diffusion_block_size"], dtype=torch.int32, device=device) + + # First run (includes compilation if needed) + _ = decode_kernel( + q, k, v, k_cache, v_cache, + block_tables, + context_lens_tensor, + cu_seqlens_q, + cu_seqlens_k, + base_params["max_q_len"], + ) + torch.cuda.synchronize() + + compilation_time = time.time() - start_time + key = (num_seqs, context_len) + compilation_times[key] = compilation_time + + # Check if kernel was compiled (file was created, not just loaded from cache) + # Note: This is a heuristic - the actual compilation happens when the kernel + # is first called, and tilelang may have its own caching mechanism + was_compiled = not kernel_existed_before + + kernel_paths[key] = str(kernel_path) + + # Count trailing -1 entries in block_tables + trailing_neg_ones = 0 + for seq_idx in range(num_seqs): + for block_idx in range(max_seq_num_blocks - 1, -1, -1): + if block_tables[seq_idx, block_idx].item() == -1: + trailing_neg_ones += 1 + else: + break + + print(f" Kernel path: {kernel_path}") + print(f" Kernel existed before: {kernel_existed_before}") + print(f" Was compiled: {was_compiled}") + print(f" Compilation + first inference time: {compilation_time:.4f}s") + print(f" Block table trailing -1 entries: {trailing_neg_ones}") + + # Measure pure inference time (warmup + actual measurement) + # Warmup + _ = decode_kernel( + q, k, v, k_cache, v_cache, + block_tables, + context_lens_tensor, + cu_seqlens_q, + cu_seqlens_k, + base_params["max_q_len"], + ) + torch.cuda.synchronize() + + # Measure inference time + num_iterations = 10 + start_time = time.time() + for _ in range(num_iterations): + _ = decode_kernel( + q, k, v, k_cache, v_cache, + block_tables, + context_lens_tensor, + cu_seqlens_q, + cu_seqlens_k, + base_params["max_q_len"], + ) + torch.cuda.synchronize() + inference_time = (time.time() - start_time) / num_iterations + inference_times[key] = inference_time + + print(f" Average inference time ({num_iterations} iterations): {inference_time*1000:.4f}ms") + + # Verify correctness by comparing with reference implementation + print(f" Verifying correctness...") + # Run kernel once more to get output for correctness verification + output = decode_kernel( + q, k, v, k_cache, v_cache, + block_tables, + context_lens_tensor, + cu_seqlens_q, + cu_seqlens_k, + base_params["max_q_len"], + ) + torch.cuda.synchronize() + + scale = 1.0 / (base_params["head_dim"] ** 0.5) + ref_output = naive_sdpa_with_kvcache( + q, k, v, k_cache, v_cache, + block_tables, context_lens_tensor, + cu_seqlens_q, cu_seqlens_k, + scale, num_groups, base_params["page_block_size"], + ) + + try: + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2) + correctness_results[key] = True + print(f" ✓ Correctness check passed") + except AssertionError as e: + correctness_results[key] = False + print(f" ✗ Correctness check FAILED: {e}") + + # Store kernel instance for later use + kernel_instances[key] = decode_kernel + + # Print summary + print("\n" + "=" * 80) + print("Summary") + print("=" * 80) + print(f"{'Num Seqs':<12} {'Context Len':<15} {'Compiled':<10} {'Correct':<10} {'Compilation Time (s)':<20} {'Inference Time (ms)':<20}") + print("-" * 100) + for num_seqs in num_seqs_list: + for context_len in context_lens: + key = (num_seqs, context_len) + if key in kernel_paths: + was_compiled = kernel_paths[key] and Path(kernel_paths[key]).exists() + is_correct = correctness_results.get(key, False) + correct_str = "✓" if is_correct else "✗" + print(f"{num_seqs:<12} {context_len:<15} {str(was_compiled):<10} {correct_str:<10} {compilation_times[key]:<20.4f} {inference_times[key]*1000:<20.4f}") + + print("\n" + "=" * 80) + print("Analysis") + print("=" * 80) + + # Check if kernels were recompiled for different (num_seqs, context_len) combinations + unique_kernel_paths = set(kernel_paths.values()) + total_combinations = len(num_seqs_list) * len(context_lens) + print(f"Number of unique kernel paths: {len(unique_kernel_paths)}") + print(f"Number of (num_seqs, context_len) combinations tested: {total_combinations}") + + if len(unique_kernel_paths) == total_combinations: + print("✓ Each (num_seqs, context_len) combination resulted in a unique kernel (expected behavior)") + else: + print(f"⚠ Some combinations shared the same kernel ({len(unique_kernel_paths)} unique kernels for {total_combinations} combinations)") + + # Check inference time scaling by num_seqs + print(f"\nInference time scaling by num_seqs:") + for num_seqs in num_seqs_list: + seq_times = [inference_times[(num_seqs, ctx)] for ctx in context_lens if (num_seqs, ctx) in inference_times] + if seq_times: + base_time = seq_times[0] + print(f" num_seqs={num_seqs}:") + for i, context_len in enumerate(context_lens): + key = (num_seqs, context_len) + if key in inference_times: + ratio = inference_times[key] / base_time + print(f" context_len={context_len}: {ratio:.2f}x (vs context_len={context_lens[0]})") + + # Check inference time scaling by context_len + print(f"\nInference time scaling by context_len:") + for context_len in context_lens[::4]: # Sample every 4th context_len to avoid too much output + ctx_times = [inference_times[(ns, context_len)] for ns in num_seqs_list if (ns, context_len) in inference_times] + if ctx_times: + base_time = ctx_times[0] + print(f" context_len={context_len}:") + for num_seqs in num_seqs_list: + key = (num_seqs, context_len) + if key in inference_times: + ratio = inference_times[key] / base_time + print(f" num_seqs={num_seqs}: {ratio:.2f}x (vs num_seqs={num_seqs_list[0]})") + + # Check correctness summary + print(f"\nCorrectness verification summary:") + passed = sum(1 for v in correctness_results.values() if v) + total = len(correctness_results) + print(f" Passed: {passed}/{total}") + if passed < total: + print(f" Failed (num_seqs, context_len) combinations:") + for key, is_correct in correctness_results.items(): + if not is_correct: + num_seqs, context_len = key + print(f" - num_seqs={num_seqs}, context_len={context_len}") + else: + print(" ✓ All correctness checks passed!") + + +def test_decode_engine_like_scenarios(): + """ + Test decode kernel with scenarios that more closely match engine usage. + This test simulates: + 1. Non-contiguous block_tables (like engine's prepare_block_tables) + 2. Variable cu_seqlens_k based on actual sequence lengths + 3. Memory reuse scenarios + 4. Different block_table patterns (some sequences with fewer blocks) + """ + base_params = { + "num_heads": 32, + "num_kv_heads": 8, + "head_dim": 128, + "max_q_len": 64, + "max_kv_len": 64, + "page_block_size": 32, + "diffusion_block_size": 32, + "is_block_attn": False, + "dtype": "bfloat16", + } + + num_seqs_list = [1, 4, 8, 13, 14, 15, 16] + context_lens_list = [128, 256, 512, 1024, 2048] + + torch_dtype = getattr(torch, base_params["dtype"]) + device = "cuda" + num_groups = base_params["num_heads"] // base_params["num_kv_heads"] + + # Calculate maximum KV cache size to avoid recompilation + max_num_seqs = max(num_seqs_list) + max_context_len = max(context_lens_list) + max_blocks_per_seq = (max_context_len + base_params["page_block_size"] - 1) // base_params["page_block_size"] + max_num_page_blocks = max_num_seqs * max_blocks_per_seq + + # Setup cache directory for saving kernel sources + cuda_cache_dir = os.getenv("CUDA_CACHE_DIR", "./cuda_cache") + cache_root = Path(cuda_cache_dir) / "test_dllm_flash_attn_decode_kernel_multiround" + + # Create fixed-size KV cache (static allocation) + print("\n" + "=" * 80) + print("Testing engine-like scenarios") + print(f"Using fixed large KV cache: num_page_blocks={max_num_page_blocks}") + print("=" * 80) + + k_cache = torch.randn(max_num_page_blocks, base_params["page_block_size"], + base_params["num_kv_heads"], base_params["head_dim"], + dtype=torch_dtype, device=device) + v_cache = torch.randn(max_num_page_blocks, base_params["page_block_size"], + base_params["num_kv_heads"], base_params["head_dim"], + dtype=torch_dtype, device=device) + + correctness_results = {} + + for num_seqs in num_seqs_list: + print(f"\n{'=' * 80}") + print(f"Testing with num_seqs={num_seqs}") + print(f"{'=' * 80}") + + for context_len in context_lens_list: + print(f"\n--- Testing num_seqs={num_seqs}, context_len={context_len} ---") + + # Simulate engine's prepare_block_tables behavior + # Each sequence may have different number of blocks + max_blocks_per_seq = (context_len + base_params["page_block_size"] - 1) // base_params["page_block_size"] + max_seq_num_blocks = max_blocks_per_seq + num_page_blocks = num_seqs * max_blocks_per_seq + + # Create block_tables like engine does: each seq may have different lengths + block_tables_list = [] + for seq_idx in range(num_seqs): + # Simulate variable block counts per sequence + # Some sequences use fewer blocks (like engine scenarios) + if seq_idx % 3 == 0: + # Every 3rd sequence uses all blocks + num_blocks = max_blocks_per_seq + elif seq_idx % 3 == 1: + # Use 1 less block + num_blocks = max(1, max_blocks_per_seq - 1) + else: + # Use 2 less blocks + num_blocks = max(1, max_blocks_per_seq - 2) + + seq_block_table = [] + for block_idx in range(num_blocks): + seq_block_table.append(seq_idx * max_blocks_per_seq + block_idx) + # Engine pads with -1 to max_len + seq_block_table.extend([-1] * (max_seq_num_blocks - num_blocks)) + block_tables_list.append(seq_block_table) + + block_tables = torch.tensor(block_tables_list, dtype=torch.int32, device=device) + + # Simulate engine's cu_seqlens calculation + # In engine, cu_seqlens_k is based on actual sequence lengths (total_seqlen) + # cu_seqlens_q is based on query lengths (total_seqlen - cached_num_tokens) + total_q_len = num_seqs * base_params["diffusion_block_size"] + total_kv_len = num_seqs * base_params["diffusion_block_size"] + + cu_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int32, device=device) + cu_seqlens_k = torch.zeros(num_seqs + 1, dtype=torch.int32, device=device) + + # Simulate variable sequence lengths (like in engine) + for seq_idx in range(num_seqs): + seqlen_q = base_params["diffusion_block_size"] # Query length + # KV length = context_len + seqlen_q (simulating cached + new tokens) + seqlen_k = seqlen_q + cu_seqlens_q[seq_idx + 1] = cu_seqlens_q[seq_idx] + seqlen_q + cu_seqlens_k[seq_idx + 1] = cu_seqlens_k[seq_idx] + seqlen_k + + # Adjust total lengths based on actual cu_seqlens + total_q_len = cu_seqlens_q[-1].item() + total_kv_len = cu_seqlens_k[-1].item() + + # Prepare tensors + q = torch.randn(total_q_len, base_params["num_heads"], base_params["head_dim"], + dtype=torch_dtype, device=device) + k = torch.randn(total_kv_len, base_params["num_kv_heads"], base_params["head_dim"], + dtype=torch_dtype, device=device) + v = torch.randn(total_kv_len, base_params["num_kv_heads"], base_params["head_dim"], + dtype=torch_dtype, device=device) + # Use the fixed-size KV cache (already allocated above) + + context_lens_tensor = torch.full((num_seqs,), context_len, + dtype=torch.int32, device=device) + + # Create kernel (use max_num_page_blocks for KV cache size) + decode_kernel = dllm_flash_attn_decode_kernel( + num_seqs, + num_groups, + max_num_page_blocks, # Use fixed max size + total_q_len, + total_kv_len, + base_params["num_heads"], + base_params["head_dim"], + base_params["is_block_attn"], + base_params["diffusion_block_size"], + max_seq_num_blocks, + base_params["page_block_size"], + 64, # block_m + 64, # block_n + 1, # num_stages + 128, # num_threads + ) + + # Save kernel source + case_dir = cache_root / ( + f"seq{num_seqs}_heads{base_params['num_heads']}_" + f"kv{base_params['num_kv_heads']}_hd{base_params['head_dim']}_" + f"ctx{context_len}_pbs{base_params['page_block_size']}_" + f"dbs{base_params['diffusion_block_size']}_" + f"block{int(base_params['is_block_attn'])}_dtype{base_params['dtype']}_" + f"bm64_bn64_stg1_thr128_mq{base_params['max_q_len']}_mk{base_params['max_kv_len']}" + ) + kernel_path = case_dir / "kernel.cu" + kernel_source = decode_kernel.get_kernel_source() + case_dir.mkdir(parents=True, exist_ok=True) + kernel_path.write_text(kernel_source) + print(f" Kernel saved to: {kernel_path}") + + # Test with memory reuse (simulate engine's behavior) + # Run multiple times to check for memory corruption + outputs = [] + for run_idx in range(3): + output = decode_kernel( + q, k, v, k_cache, v_cache, + block_tables, + context_lens_tensor, + cu_seqlens_q, + cu_seqlens_k, + base_params["max_q_len"], + ) + torch.cuda.synchronize() + outputs.append(output.clone()) + + # Verify consistency across runs + consistent = True + for i in range(1, len(outputs)): + if not torch.allclose(outputs[0], outputs[i], atol=1e-5, rtol=1e-5): + consistent = False + max_diff = (outputs[0] - outputs[i]).abs().max().item() + print(f" ✗ Output inconsistency detected in run {i}: max_diff={max_diff:.6f}") + break + + if not consistent: + correctness_results[(num_seqs, context_len)] = False + continue + + # Verify correctness against reference + scale = 1.0 / (base_params["head_dim"] ** 0.5) + ref_output = naive_sdpa_with_kvcache( + q, k, v, k_cache, v_cache, + block_tables, context_lens_tensor, + cu_seqlens_q, cu_seqlens_k, + scale, num_groups, base_params["page_block_size"], + ) + + try: + torch.testing.assert_close(outputs[0], ref_output, atol=1e-2, rtol=1e-2) + correctness_results[(num_seqs, context_len)] = True + print(f" ✓ Correctness check passed") + except AssertionError as e: + correctness_results[(num_seqs, context_len)] = False + abs_diff = (outputs[0] - ref_output).abs() + max_diff = abs_diff.max().item() + mean_diff = abs_diff.mean().item() + print(f" ✗ Correctness check FAILED: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") + print(f" Error: {str(e)[:200]}") + + # Print summary + print("\n" + "=" * 80) + print("Engine-like Test Summary") + print("=" * 80) + passed = sum(1 for v in correctness_results.values() if v) + total = len(correctness_results) + print(f" Passed: {passed}/{total}") + if passed < total: + print(f" Failed (num_seqs, context_len) combinations:") + for key, is_correct in correctness_results.items(): + if not is_correct: + num_seqs, context_len = key + print(f" - num_seqs={num_seqs}, context_len={context_len}") + else: + print(" ✓ All correctness checks passed!") + + +if __name__ == "__main__": + # tilelang.testing.main() + # test_decode_multiround_context_len() + # print("\n\n") + test_decode_engine_like_scenarios() \ No newline at end of file diff --git a/test/python/kernel/test_dllm_flash_attn_decode_specified_case.py b/test/python/kernel/test_dllm_flash_attn_decode_specified_case.py new file mode 100644 index 0000000..46756c3 --- /dev/null +++ b/test/python/kernel/test_dllm_flash_attn_decode_specified_case.py @@ -0,0 +1,188 @@ +import os +import pickle +from pathlib import Path + +import torch +import tilelang +import tilelang.testing + +from diffulex_kernel.python.dllm_flash_attn_kernels import dllm_flash_attn_decode_kernel +from test.python.utils.checker import CHECK_FLASH_ATTN_DECODE + + +def get_failed_test_cases_dir(): + """Get the directory containing failed test cases.""" + default_dir = Path(__file__).parent.parent.parent.parent / "failed_test_cases" + return Path(os.getenv("TEST_CASE_SAVE_DIR", str(default_dir))) + + +def find_failed_test_cases(): + """Find all failed test case directories.""" + test_cases_dir = get_failed_test_cases_dir() + if not test_cases_dir.exists(): + return [] + + test_cases = [] + for case_dir in test_cases_dir.iterdir(): + if case_dir.is_dir() and case_dir.name.startswith("decode_kernel_failure_"): + test_data_path = case_dir / "test_data.pkl" + if test_data_path.exists(): + test_cases.append(case_dir) + + return sorted(test_cases) + + +def load_test_case(case_dir: Path): + """Load a test case from directory.""" + test_data_path = case_dir / "test_data.pkl" + if not test_data_path.exists(): + raise FileNotFoundError(f"test_data.pkl not found in {case_dir}") + + with open(test_data_path, "rb") as f: + test_data = pickle.load(f) + + return test_data + + +def run_test_case_from_saved_data(case_dir: Path): + """Run a test case from saved data.""" + # Load test data + test_data = load_test_case(case_dir) + + # Extract inputs and move to device + device = "cuda" + q = test_data['inputs']['q'].to(device) + k = test_data['inputs']['k'].to(device) + v = test_data['inputs']['v'].to(device) + k_cache = test_data['inputs']['k_cache'].to(device) + v_cache = test_data['inputs']['v_cache'].to(device) + block_tables = test_data['inputs']['block_tables'].to(device) + context_lens = test_data['inputs']['context_lens'].to(device) + cu_seqlens_q = test_data['inputs']['cu_seqlens_q'].to(device) + cu_seqlens_k = test_data['inputs']['cu_seqlens_k'].to(device) + + # Extract parameters + params = test_data['parameters'] + max_seqlen_q = params['max_seqlen_q'] + scale = params['scale'] + num_groups = params['num_groups'] + page_block_size = params['page_block_size'] + diffusion_block_size = params['diffusion_block_size'] + is_block_attn = params['is_block_attn'] + + # Extract shapes to infer kernel parameters + q_shape = test_data['shapes']['q_shape'] + k_shape = test_data['shapes']['k_shape'] + k_cache_shape = test_data['shapes']['k_cache_shape'] + block_tables_shape = test_data['shapes']['block_tables_shape'] + + # Infer kernel parameters from shapes + total_q_len = q_shape[0] + total_kv_len = k_shape[0] + num_heads = q_shape[1] + num_kv_heads = k_shape[1] + head_dim = q_shape[2] + num_seqs = len(cu_seqlens_q) - 1 + num_page_blocks = k_cache_shape[0] + max_seq_num_blocks = block_tables_shape[1] + + # Default kernel tuning parameters (can be overridden if saved in test_data) + block_m = 64 + block_n = 64 + num_stages = 1 + num_threads = 128 + + # Build kernel + decode_kernel = dllm_flash_attn_decode_kernel( + num_seqs, + num_groups, + num_page_blocks, + total_q_len, + total_kv_len, + num_heads, + head_dim, + is_block_attn, + diffusion_block_size, + max_seq_num_blocks, + page_block_size, + block_m, + block_n, + num_stages, + num_threads, + ) + + # Verify using CHECK_FLASH_ATTN_DECODE (it will run the kernel and verify) + CHECK_FLASH_ATTN_DECODE( + q, k, v, + k_cache, v_cache, + block_tables, + context_lens, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + decode_kernel, + scale, + num_groups, + page_block_size, + diffusion_block_size, + is_block_attn, + ) + + print(f"Test case {case_dir.name} passed! Shape: {q.shape}") + + +def test_all_failed_cases(): + """Test all failed test cases found in the failed_test_cases directory.""" + test_cases = find_failed_test_cases() + + if not test_cases: + print("No failed test cases found. Skipping test.") + return + + print(f"Found {len(test_cases)} failed test case(s) to verify:") + for case_dir in test_cases: + print(f" - {case_dir.name}") + + # Run each test case + for case_dir in test_cases: + print(f"\n{'='*80}") + print(f"Testing case: {case_dir.name}") + print(f"{'='*80}") + + try: + run_test_case_from_saved_data(case_dir) + except Exception as e: + print(f"Test case {case_dir.name} FAILED with error:") + print(f" {type(e).__name__}: {str(e)}") + raise + + +# Generate individual test functions for each failed test case +def generate_test_functions(): + """Dynamically generate test functions for each failed test case.""" + test_cases = find_failed_test_cases() + + for idx, case_dir in enumerate(test_cases): + case_name = case_dir.name.replace("decode_kernel_failure_", "").replace("-", "_").replace(".", "_") + test_func_name = f"test_case_{case_name}" + + # Create a closure with the case_dir captured + def make_test_func(case_path): + def test_func(): + run_test_case_from_saved_data(case_path) + return test_func + + # Create and register the test function + test_func = make_test_func(case_dir) + test_func.__name__ = test_func_name + test_func.__doc__ = f"Test case from {case_dir.name}" + globals()[test_func_name] = test_func + + +# Generate test functions at module load time +generate_test_functions() + + +if __name__ == "__main__": + tilelang.testing.main() + diff --git a/tests/python/kernel/test_dllm_flash_attn_prefill_kernel.py b/test/python/kernel/test_dllm_flash_attn_prefill_kernel.py similarity index 98% rename from tests/python/kernel/test_dllm_flash_attn_prefill_kernel.py rename to test/python/kernel/test_dllm_flash_attn_prefill_kernel.py index 6bc9ba8..b69b014 100644 --- a/tests/python/kernel/test_dllm_flash_attn_prefill_kernel.py +++ b/test/python/kernel/test_dllm_flash_attn_prefill_kernel.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from einops import rearrange -from diffulex_kernel.python.dllm_flash_attn import dllm_flash_attn_prefill_kernel +from diffulex_kernel.python.dllm_flash_attn_kernels import dllm_flash_attn_prefill_kernel def naive_sdpa_prefill( @@ -124,7 +124,7 @@ def run_dllm_flash_attn_prefill( ) kernel_source = prefill_kernel.get_kernel_source() - cuda_cache_dir = os.getenv("CUDA_CACHE_DIR", "/data1/jyj/Diffulex/cuda_cache") + cuda_cache_dir = os.getenv("CUDA_CACHE_DIR", "./cuda_cache") cache_root = Path(cuda_cache_dir) / "test_dllm_flash_attn_prefill_kernel" case_dir = cache_root / ( f"seq{num_seqs}_heads{num_heads}_kv{num_kv_heads}_hd{head_dim}_" diff --git a/test/python/kernel/test_paged_attn_decode_triton.py b/test/python/kernel/test_paged_attn_decode_triton.py new file mode 100644 index 0000000..055dece --- /dev/null +++ b/test/python/kernel/test_paged_attn_decode_triton.py @@ -0,0 +1,240 @@ +import pytest +import torch +import torch.nn.functional as F + +from einops import rearrange +from types import SimpleNamespace + +from diffulex_kernel.python.paged_attn_decode_triton import paged_attn_decode_unified_triton + + +def _has_fp8() -> bool: + return hasattr(torch, "float8_e4m3fn") or hasattr(torch, "float8_e4m3fnuz") or hasattr(torch, "float8_e5m2") + + +def _build_cu_seqlens(lengths: torch.Tensor) -> torch.Tensor: + # lengths: [num_seqs] int32 on cuda + return torch.tensor( + [0] + list(torch.cumsum(lengths, dim=0).cpu().numpy()), + dtype=torch.int32, + device=lengths.device, + ) + + +def naive_sdpa_with_kvcache( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + scale: float, + num_groups: int, + page_block_size: int, +) -> torch.Tensor: + num_seqs = len(cu_seqlens_q) - 1 + output = torch.zeros_like(q) + for seq_idx in range(num_seqs): + q_start = int(cu_seqlens_q[seq_idx].item()) + q_end = int(cu_seqlens_q[seq_idx + 1].item()) + kv_start = int(cu_seqlens_k[seq_idx].item()) + kv_end = int(cu_seqlens_k[seq_idx + 1].item()) + + q_seq = q[q_start:q_end] # [q_len, Hq, D] + k_seq = k[kv_start:kv_end] # [new_len, Hkv, D] + v_seq = v[kv_start:kv_end] + + ctx = int(context_lens[seq_idx].item()) + k_cache_seq_list = [] + v_cache_seq_list = [] + for blk in range(block_tables.shape[1]): + page = int(block_tables[seq_idx, blk].item()) + if page < 0: + continue + blk_start = blk * page_block_size + if blk_start >= ctx: + continue + blk_end = min(blk_start + page_block_size, ctx) + n = blk_end - blk_start + k_cache_seq_list.append(k_cache[page, :n]) + v_cache_seq_list.append(v_cache[page, :n]) + + if k_cache_seq_list: + k_ctx = torch.cat(k_cache_seq_list, dim=0) + v_ctx = torch.cat(v_cache_seq_list, dim=0) + k_comb = torch.cat([k_ctx, k_seq], dim=0) + v_comb = torch.cat([v_ctx, v_seq], dim=0) + else: + k_comb = k_seq + v_comb = v_seq + + q_sdpa = rearrange(q_seq, "s h d -> 1 h s d") + k_sdpa = rearrange(k_comb, "s h d -> 1 h s d") + v_sdpa = rearrange(v_comb, "s h d -> 1 h s d") + attn_out = F.scaled_dot_product_attention( + q_sdpa, + k_sdpa, + v_sdpa, + dropout_p=0.0, + is_causal=False, + scale=scale, + enable_gqa=True, + ) + output[q_start:q_end] = rearrange(attn_out, "1 h s d -> s h d").to(output.dtype) + + return output + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for Triton paged-attention kernel") +def test_paged_decode_triton_bf16_cache_matches_reference(): + torch.manual_seed(0) + device = torch.device("cuda") + + num_seqs = 4 + num_heads = 32 + num_kv_heads = 8 + head_dim = 128 + page_size = 32 + diffusion_block_size = 32 + + num_groups = num_heads // num_kv_heads + + # Per-seq query/new KV length (decode step) + q_lens = torch.full((num_seqs,), diffusion_block_size, dtype=torch.int32, device=device) + cu_q = _build_cu_seqlens(q_lens) + cu_k = cu_q.clone() + total_q = int(cu_q[-1].item()) + + # Context lengths (vary per seq) + context_lens = torch.tensor([0, 17, 63, 128], dtype=torch.int32, device=device) + max_ctx = int(context_lens.max().item()) + max_seq_blocks = (max_ctx + page_size - 1) // page_size + num_page_blocks = num_seqs * max_seq_blocks + + # Assign each seq its own contiguous pages + block_tables = torch.full((num_seqs, max_seq_blocks), -1, dtype=torch.int32, device=device) + for s in range(num_seqs): + for b in range(max_seq_blocks): + block_tables[s, b] = s * max_seq_blocks + b + + q = torch.randn((total_q, num_heads, head_dim), device=device, dtype=torch.bfloat16) + k = torch.randn((total_q, num_kv_heads, head_dim), device=device, dtype=torch.bfloat16) + v = torch.randn_like(k) + + k_cache = torch.randn((num_page_blocks, page_size, num_kv_heads, head_dim), device=device, dtype=torch.bfloat16) + v_cache = torch.randn_like(k_cache) + + md = SimpleNamespace( + kv_cache_layout="unified", + block_tables=block_tables, + context_lens=context_lens, + cu_seqlens_q=cu_q, + max_seqlen_q=int(q_lens.max().item()), + page_block_size=page_size, + ) + scale = 1.0 / (head_dim**0.5) + + out = paged_attn_decode_unified_triton(q, k, v, k_cache, v_cache, md, softmax_scale=scale, fp8_cache=False) + ref = naive_sdpa_with_kvcache( + q, + k, + v, + k_cache, + v_cache, + block_tables, + context_lens, + cu_q, + cu_k, + scale, + num_groups, + page_size, + ) + + torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for Triton paged-attention kernel") +@pytest.mark.skipif(not _has_fp8(), reason="This torch build does not expose FP8 dtypes") +def test_paged_decode_triton_fp8_cache_matches_reference(): + torch.manual_seed(0) + device = torch.device("cuda") + + fp8_dtype = torch.float8_e4m3fn if hasattr(torch, "float8_e4m3fn") else torch.float8_e5m2 + + num_seqs = 2 + num_heads = 16 + num_kv_heads = 4 + head_dim = 128 + page_size = 32 + diffusion_block_size = 32 + num_groups = num_heads // num_kv_heads + + q_lens = torch.full((num_seqs,), diffusion_block_size, dtype=torch.int32, device=device) + cu_q = _build_cu_seqlens(q_lens) + cu_k = cu_q.clone() + total_q = int(cu_q[-1].item()) + + context_lens = torch.tensor([37, 55], dtype=torch.int32, device=device) + max_ctx = int(context_lens.max().item()) + max_seq_blocks = (max_ctx + page_size - 1) // page_size + num_page_blocks = num_seqs * max_seq_blocks + block_tables = torch.full((num_seqs, max_seq_blocks), -1, dtype=torch.int32, device=device) + for s in range(num_seqs): + for b in range(max_seq_blocks): + block_tables[s, b] = s * max_seq_blocks + b + + q = torch.randn((total_q, num_heads, head_dim), device=device, dtype=torch.bfloat16) + k = torch.randn((total_q, num_kv_heads, head_dim), device=device, dtype=torch.bfloat16) + v = torch.randn_like(k) + + # Build BF16 "true" cache values, then quantize to FP8 as (x / scale) -> fp8, with per-head scales. + k_cache_true = torch.randn((num_page_blocks, page_size, num_kv_heads, head_dim), device=device, dtype=torch.bfloat16) * 0.5 + v_cache_true = torch.randn_like(k_cache_true) * 0.5 + + eps = 1e-6 + k_absmax = k_cache_true.to(torch.float32).abs().amax(dim=(0, 1, 3)) + v_absmax = v_cache_true.to(torch.float32).abs().amax(dim=(0, 1, 3)) + fp8_max = 448.0 if fp8_dtype == torch.float8_e4m3fn else 57344.0 + k_scale = (k_absmax / fp8_max).clamp_min(eps).to(torch.float32) + v_scale = (v_absmax / fp8_max).clamp_min(eps).to(torch.float32) + + k_cache_fp8 = (k_cache_true.to(torch.float32) / k_scale.view(1, 1, -1, 1)).to(fp8_dtype) + v_cache_fp8 = (v_cache_true.to(torch.float32) / v_scale.view(1, 1, -1, 1)).to(fp8_dtype) + + md = SimpleNamespace( + kv_cache_layout="unified", + block_tables=block_tables, + context_lens=context_lens, + cu_seqlens_q=cu_q, + max_seqlen_q=int(q_lens.max().item()), + page_block_size=page_size, + k_scale=k_scale, + v_scale=v_scale, + ) + scale = 1.0 / (head_dim**0.5) + + out = paged_attn_decode_unified_triton(q, k, v, k_cache_fp8, v_cache_fp8, md, softmax_scale=scale, fp8_cache=True) + + # Reference uses dequantized cache. + k_cache_deq = (k_cache_fp8.float() * k_scale.view(1, 1, -1, 1)).to(torch.bfloat16) + v_cache_deq = (v_cache_fp8.float() * v_scale.view(1, 1, -1, 1)).to(torch.bfloat16) + ref = naive_sdpa_with_kvcache( + q, + k, + v, + k_cache_deq, + v_cache_deq, + block_tables, + context_lens, + cu_q, + cu_k, + scale, + num_groups, + page_size, + ) + + torch.testing.assert_close(out, ref, atol=2e-2, rtol=2e-2) + diff --git a/test/python/test_kv_cache_fp8_distinct_load.py b/test/python/test_kv_cache_fp8_distinct_load.py new file mode 100644 index 0000000..4dabc75 --- /dev/null +++ b/test/python/test_kv_cache_fp8_distinct_load.py @@ -0,0 +1,143 @@ +import pytest +import torch + +from types import SimpleNamespace + +from diffulex.utils.quantization.factory import QuantizationStrategyFactory +from diffulex_kernel import store_kvcache_distinct_layout, load_kvcache + + +def _has_fp8() -> bool: + return hasattr(torch, "float8_e4m3fn") or hasattr(torch, "float8_e4m3fnuz") or hasattr(torch, "float8_e5m2") + + +def _build_cu_seqlens(x: torch.Tensor) -> torch.Tensor: + # x: [num_seqs] int32 on cuda + return torch.tensor( + [0] + list(torch.cumsum(x, dim=0).cpu().numpy()), + dtype=torch.int32, + device=x.device, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for Triton KV-cache kernels") +@pytest.mark.skipif(not _has_fp8(), reason="This torch build does not expose FP8 dtypes") +def test_fp8_kv_cache_distinct_store_and_load(): + """ + Regression test for FP8 KV cache distinct layout: + - store: quantize+store context into distinct cache (uint8 storage) + - load: fused gather+dequant+scale from distinct cache into BF16 output, + and append active KV (k_new/v_new) exactly. + """ + torch.manual_seed(1234) + device = torch.device("cuda") + + # Enable FP8 KV quantization strategy in the global quantization context. + QuantizationStrategyFactory.create_from_config(SimpleNamespace(kv_cache_dtype="fp8_e4m3")) + + num_seqs = 2 + blk_sz = 64 + num_kv_heads = 4 + head_dim = 128 + x = 8 + diffusion_block_size = 32 + + # ctx/new lengths (make new divisible by diffusion_block_size to match kernel loop) + ctx_lens = torch.tensor([37, 55], dtype=torch.int32, device=device) + seq_lens = torch.tensor([32, 32], dtype=torch.int32, device=device) + total_lens = ctx_lens + seq_lens + + # Build concatenated [sum(total_lens), H, D] for store reference. + k_all = torch.randn((int(total_lens.sum().item()), num_kv_heads, head_dim), device=device, dtype=torch.bfloat16) + v_all = torch.randn_like(k_all) + + # slot_mapping: context tokens map to their block slots; new tokens use -1 (not stored). + slot_mapping: list[int] = [] + start = 0 + for seq_idx in range(num_seqs): + ctx = int(ctx_lens[seq_idx].item()) + new = int(seq_lens[seq_idx].item()) + slot_mapping.extend(list(range(seq_idx * blk_sz, seq_idx * blk_sz + ctx))) + slot_mapping.extend([-1] * new) + start += ctx + new + slot_mapping_ts = torch.tensor(slot_mapping, dtype=torch.int64, device=device) + + # Distinct caches (uint8 storage for FP8). + k_cache_u8 = torch.zeros((num_seqs, num_kv_heads, head_dim // x, blk_sz, x), device=device, dtype=torch.uint8) + v_cache_u8 = torch.zeros((num_seqs, num_kv_heads, head_dim, blk_sz), device=device, dtype=torch.uint8) + + # Scales: per-head absmax / fp8_max (same convention as strategy). + from diffulex.utils.quantization.kv_cache_dtype import parse_kv_cache_dtype + + spec = parse_kv_cache_dtype("fp8_e4m3") + assert spec.is_fp8 and spec.fp8_max is not None + fp8_max = float(spec.fp8_max) + eps = 1e-6 + k_absmax = k_all.to(torch.float32).abs().amax(dim=(0, 2)) + v_absmax = v_all.to(torch.float32).abs().amax(dim=(0, 2)) + k_scale = (k_absmax / fp8_max).clamp_min(eps).to(torch.float32) + v_scale = (v_absmax / fp8_max).clamp_min(eps).to(torch.float32) + + # Minimal metadata required by store/load. + block_tables = torch.arange(num_seqs, dtype=torch.int32, device=device).view(num_seqs, 1) + md = SimpleNamespace( + kv_cache_layout="distinct", + need_kv_cache_store=True, + slot_mapping=slot_mapping_ts, + context_lens=ctx_lens, + seq_lens_ts=seq_lens, + block_tables=block_tables, + cu_seqlens_q=_build_cu_seqlens(seq_lens), + cu_seqlens_k=_build_cu_seqlens(total_lens), + max_seqlen_q=int(seq_lens.max().item()), + max_seqlen_k=int(total_lens.max().item()), + seqs=[SimpleNamespace(diffusion_block_size=diffusion_block_size)], + k_scale=k_scale, + v_scale=v_scale, + ) + + # Store context into cache. + store_kvcache_distinct_layout(k_all, v_all, k_cache_u8, v_cache_u8, slot_mapping_ts, md) + + # Build k_new/v_new (only active tokens, concatenated over sequences). + k_new_list = [] + v_new_list = [] + start = 0 + for seq_idx in range(num_seqs): + ctx = int(ctx_lens[seq_idx].item()) + new = int(seq_lens[seq_idx].item()) + k_new_list.append(k_all[start + ctx : start + ctx + new]) + v_new_list.append(v_all[start + ctx : start + ctx + new]) + start += ctx + new + k_new = torch.cat(k_new_list, dim=0).contiguous() + v_new = torch.cat(v_new_list, dim=0).contiguous() + + # Load (fused dequant + gather) and append new tokens. + k_out, v_out = load_kvcache(k_cache_u8, v_cache_u8, md, k_new, v_new) + + # Split outputs per sequence to check ctx/new portions. + out_splits_k = torch.split(k_out, total_lens.tolist(), dim=0) + out_splits_v = torch.split(v_out, total_lens.tolist(), dim=0) + new_splits_k = torch.split(k_new, seq_lens.tolist(), dim=0) + new_splits_v = torch.split(v_new, seq_lens.tolist(), dim=0) + + start = 0 + for seq_idx in range(num_seqs): + ctx = int(ctx_lens[seq_idx].item()) + new = int(seq_lens[seq_idx].item()) + + k_ctx_ref = k_all[start : start + ctx].to(torch.float32) + v_ctx_ref = v_all[start : start + ctx].to(torch.float32) + k_ctx_got = out_splits_k[seq_idx][:ctx].to(torch.float32) + v_ctx_got = out_splits_v[seq_idx][:ctx].to(torch.float32) + + # Quantization error tolerance (FP8). + assert torch.allclose(k_ctx_got, k_ctx_ref, atol=2e-1, rtol=2e-1) + assert torch.allclose(v_ctx_got, v_ctx_ref, atol=2e-1, rtol=2e-1) + + # New tokens should be appended exactly (no quantization). + assert torch.equal(out_splits_k[seq_idx][ctx : ctx + new], new_splits_k[seq_idx]) + assert torch.equal(out_splits_v[seq_idx][ctx : ctx + new], new_splits_v[seq_idx]) + + start += ctx + new + diff --git a/test/python/test_linear_fp8.py b/test/python/test_linear_fp8.py new file mode 100644 index 0000000..9eaa71f --- /dev/null +++ b/test/python/test_linear_fp8.py @@ -0,0 +1,347 @@ +""" +Unit tests for FP8 Linear quantization strategies. +""" + +import pytest +import torch +import torch.nn.functional as F + +from diffulex.utils.quantization.registry import create_linear_strategy +from diffulex.utils.quantization.context import get_quantization_context + + +def test_linear_strategy_registry_fp8_e4m3_w8a16(): + """Test that fp8_e4m3+bf16 returns the real FP8 W8A16 strategy.""" + s = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="bf16") + assert s.name == "linear_fp8_fp8_e4m3_w8a16" + assert s.linear_weight_format == "fp8_e4m3" + assert s.linear_act_format == "bf16" + + +def test_linear_strategy_registry_fp8_e5m2_w8a16(): + """Test that fp8_e5m2+bf16 returns the real FP8 W8A16 strategy.""" + s = create_linear_strategy(weight_dtype="fp8_e5m2", act_dtype="bf16") + assert s.name == "linear_fp8_fp8_e5m2_w8a16" + assert s.linear_weight_format == "fp8_e5m2" + assert s.linear_act_format == "bf16" + + +def test_linear_strategy_registry_fp8_e4m3_w8a8(): + """Test that fp8_e4m3+fp8_e4m3 returns the real FP8 W8A8 strategy.""" + s = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="fp8_e4m3") + assert s.name == "linear_fp8_fp8_e4m3_w8a8" + assert s.linear_weight_format == "fp8_e4m3" + assert s.linear_act_format == "fp8_e4m3" + + +def test_linear_strategy_registry_fp8_e5m2_w8a8(): + """Test that fp8_e5m2+fp8_e5m2 returns the real FP8 W8A8 strategy.""" + s = create_linear_strategy(weight_dtype="fp8_e5m2", act_dtype="fp8_e5m2") + assert s.name == "linear_fp8_fp8_e5m2_w8a8" + assert s.linear_weight_format == "fp8_e5m2" + assert s.linear_act_format == "fp8_e5m2" + + +def test_fp8_w8a16_quantize_dequantize_roundtrip(): + """Test FP8 W8A16 quantization and dequantization roundtrip.""" + strategy = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="bf16") + + # Create a test weight tensor + weight = torch.randn(128, 256, dtype=torch.bfloat16) + + # Quantize + quantized, scales = strategy.quantize(weight) + + # Check output types and shapes + assert quantized.dtype == torch.uint8 + assert quantized.shape == weight.shape + assert scales.dtype == torch.float32 + assert scales.shape == (weight.shape[0],) + + # Dequantize + dequantized = strategy.dequantize(quantized, scales) + + # Check output type and shape + assert dequantized.dtype == torch.bfloat16 + assert dequantized.shape == weight.shape + + # Check approximate recovery (FP8 has limited precision) + # Use relaxed tolerance for FP8 + max_error = torch.abs(dequantized - weight).max() + relative_error = torch.abs((dequantized - weight) / (weight.abs() + 1e-8)).max() + # FP8 has ~3-4 bits of precision, so we expect some error + assert max_error < 0.5 # Relaxed tolerance + assert relative_error < 0.3 # 30% relative error is acceptable for FP8 + + +def test_fp8_w8a16_forward(): + """Test FP8 W8A16 forward pass.""" + strategy = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="bf16") + + # Create test tensors + M, K, N = 4, 256, 128 + x = torch.randn(M, K, dtype=torch.bfloat16) + weight = torch.randn(N, K, dtype=torch.bfloat16) + bias = torch.randn(N, dtype=torch.bfloat16) + + # Compute reference output (bf16) + ref_out = F.linear(x, weight, bias) + + # Compute FP8 quantized output + fp8_out = strategy.linear_forward(x, weight, bias, quant_kind="attn") + + # Check output shape + assert fp8_out.shape == ref_out.shape + assert fp8_out.dtype == torch.bfloat16 + + # Check approximate correctness (FP8 has limited precision) + max_error = torch.abs(fp8_out - ref_out).max() + # FP8 quantization introduces error, but output should be reasonable + # FP8 has ~3-4 bits of precision, so we use more relaxed tolerance + # Only check absolute error to avoid issues with near-zero values + assert max_error < 2.0 # Relaxed tolerance for FP8 + # Check that outputs are in similar range (not completely broken) + assert fp8_out.abs().max() < ref_out.abs().max() * 3 # Output shouldn't be 3x larger + + +def test_fp8_w8a16_lazy_cache(): + """Test FP8 W8A16 lazy cache behavior.""" + strategy = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="bf16") + + # Create test tensors + M, K, N = 4, 256, 128 + x = torch.randn(M, K, dtype=torch.bfloat16) + weight = torch.randn(N, K, dtype=torch.bfloat16) + bias = torch.randn(N, dtype=torch.bfloat16) + + # First forward pass should quantize and cache + out1 = strategy.linear_forward(x, weight, bias, quant_kind="attn") + assert len(strategy._weight_cache) == 1 + + # Second forward pass should use cached quantized weight + out2 = strategy.linear_forward(x, weight, bias, quant_kind="attn") + assert len(strategy._weight_cache) == 1 # Cache size unchanged + + # Outputs should be identical (same quantization) + assert torch.allclose(out1, out2, atol=1e-5, rtol=1e-5) + + # Clear cache + strategy.clear_cache() + assert len(strategy._weight_cache) == 0 + + +def test_fp8_w8a8_quantize_dequantize_roundtrip(): + """Test FP8 W8A8 quantization and dequantization roundtrip.""" + strategy = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="fp8_e4m3") + + # Test weight quantization + weight = torch.randn(128, 256, dtype=torch.bfloat16) + quantized_weight, w_scales = strategy.quantize(weight) + + assert quantized_weight.dtype == torch.uint8 + assert quantized_weight.shape == weight.shape + assert w_scales.dtype == torch.float16 + assert w_scales.shape == (weight.shape[0],) + + dequantized_weight = strategy.dequantize(quantized_weight, w_scales) + assert dequantized_weight.dtype == torch.bfloat16 + assert dequantized_weight.shape == weight.shape + + # Test activation quantization + x = torch.randn(4, 256, dtype=torch.bfloat16) + quantized_x, x_scales = strategy.quantize_act_for_kernel(x) + + assert quantized_x.dtype == torch.uint8 + assert quantized_x.shape == x.shape + assert x_scales.dtype == torch.float32 + assert x_scales.shape == (x.shape[0],) + + # Dequantize activation + dequantized_x = strategy._dequantize_act(quantized_x, x_scales) + assert dequantized_x.dtype == torch.bfloat16 + assert dequantized_x.shape == x.shape + + +def test_fp8_w8a8_forward(): + """Test FP8 W8A8 forward pass.""" + strategy = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="fp8_e4m3") + + # Create test tensors + M, K, N = 4, 256, 128 + x = torch.randn(M, K, dtype=torch.bfloat16) + weight = torch.randn(N, K, dtype=torch.bfloat16) + bias = torch.randn(N, dtype=torch.bfloat16) + + # Compute reference output (bf16) + ref_out = F.linear(x, weight, bias) + + # Compute FP8 quantized output + fp8_out = strategy.linear_forward(x, weight, bias, quant_kind="attn") + + # Check output shape + assert fp8_out.shape == ref_out.shape + assert fp8_out.dtype == torch.bfloat16 + + # Check approximate correctness (FP8 has limited precision) + max_error = torch.abs(fp8_out - ref_out).max() + # FP8 W8A8 quantization introduces larger error since both weights and activations are quantized + # FP8 has ~3-4 bits of precision, so we use more relaxed tolerance for W8A8 + # Only check absolute error to avoid issues with near-zero values + assert max_error < 3.0 # More relaxed tolerance for FP8 W8A8 (both W and A quantized) + # Check that outputs are in similar range (not completely broken) + assert fp8_out.abs().max() < ref_out.abs().max() * 3 # Output shouldn't be 3x larger + + +def test_fp8_w8a8_lazy_cache(): + """Test FP8 W8A8 lazy cache behavior.""" + strategy = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="fp8_e4m3") + + # Create test tensors + M, K, N = 4, 256, 128 + x = torch.randn(M, K, dtype=torch.bfloat16) + weight = torch.randn(N, K, dtype=torch.bfloat16) + bias = torch.randn(N, dtype=torch.bfloat16) + + # First forward pass should quantize and cache weight + out1 = strategy.linear_forward(x, weight, bias, quant_kind="attn") + assert len(strategy._weight_cache) == 1 + + # Second forward pass should use cached quantized weight + out2 = strategy.linear_forward(x, weight, bias, quant_kind="attn") + assert len(strategy._weight_cache) == 1 # Cache size unchanged + + # Outputs should be identical (same quantization) + assert torch.allclose(out1, out2, atol=1e-5, rtol=1e-5) + + # Clear cache + strategy.clear_cache() + assert len(strategy._weight_cache) == 0 + + +def test_fp8_w8a16_load_time_quantization(monkeypatch): + """Test FP8 W8A16 load-time quantization (quantized weight buffer).""" + import torch.distributed as dist + monkeypatch.setattr(dist, "get_rank", lambda: 0) + monkeypatch.setattr(dist, "get_world_size", lambda: 1) + + from diffulex.layer.linear import ReplicatedLinear + from diffulex.utils.quantization.context import get_quantization_context + + # Set up FP8 W8A16 strategy + ctx = get_quantization_context() + strategy = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="bf16") + ctx.set_linear_strategy("attn", strategy) + + # Create Linear layer + linear = ReplicatedLinear(256, 128, bias=False, quant_kind="attn") + + # Load weight (should trigger quantization) + weight = torch.randn(128, 256, dtype=torch.bfloat16) + linear.weight.data.copy_(weight) + linear.weight_loader(linear.weight, weight) + + # Check that bf16 weight Parameter is removed + assert linear.weight is None or not hasattr(linear.weight, "data") + + # Check that quantized weight buffer is set + assert linear.has_quantized_weight() + assert linear.quant_weight_int8.dtype == torch.uint8 + assert linear.quant_weight_int8.shape == weight.shape + assert linear.quant_scales.dtype == torch.float32 + assert linear.quant_scales.shape == (weight.shape[0],) + + # Test forward with quantized weight + x = torch.randn(4, 256, dtype=torch.bfloat16) + out = linear(x) + assert out.shape == (4, 128) + assert out.dtype == torch.bfloat16 + + +def test_fp8_w8a8_load_time_quantization(monkeypatch): + """Test FP8 W8A8 load-time quantization (quantized weight buffer).""" + import torch.distributed as dist + monkeypatch.setattr(dist, "get_rank", lambda: 0) + monkeypatch.setattr(dist, "get_world_size", lambda: 1) + + from diffulex.layer.linear import ReplicatedLinear + from diffulex.utils.quantization.context import get_quantization_context + + # Set up FP8 W8A8 strategy + ctx = get_quantization_context() + strategy = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="fp8_e4m3") + ctx.set_linear_strategy("attn", strategy) + + # Create Linear layer + linear = ReplicatedLinear(256, 128, bias=False, quant_kind="attn") + + # Load weight (should trigger quantization) + weight = torch.randn(128, 256, dtype=torch.bfloat16) + linear.weight.data.copy_(weight) + linear.weight_loader(linear.weight, weight) + + # Check that bf16 weight Parameter is removed + assert linear.weight is None or not hasattr(linear.weight, "data") + + # Check that quantized weight buffer is set + assert linear.has_quantized_weight() + assert linear.quant_weight_int8.dtype == torch.uint8 + assert linear.quant_weight_int8.shape == weight.shape + assert linear.quant_scales.dtype == torch.float16 # FP8 W8A8 uses float16 scales + assert linear.quant_scales.shape == (weight.shape[0],) + + # Test forward with quantized weight + x = torch.randn(4, 256, dtype=torch.bfloat16) + out = linear(x) + assert out.shape == (4, 128) + assert out.dtype == torch.bfloat16 + + +def test_fp8_different_shapes(): + """Test FP8 strategies with different tensor shapes.""" + strategy_w8a16 = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="bf16") + strategy_w8a8 = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="fp8_e4m3") + + # Test various shapes + shapes = [ + (1, 64, 32), # Small decode + (4, 128, 64), # Small batch + (16, 256, 128), # Medium batch + (32, 512, 256), # Large batch + ] + + for M, K, N in shapes: + x = torch.randn(M, K, dtype=torch.bfloat16) + weight = torch.randn(N, K, dtype=torch.bfloat16) + bias = torch.randn(N, dtype=torch.bfloat16) + + # Test W8A16 + out_w8a16 = strategy_w8a16.linear_forward(x, weight, bias, quant_kind="attn") + assert out_w8a16.shape == (M, N) + assert out_w8a16.dtype == torch.bfloat16 + + # Test W8A8 + out_w8a8 = strategy_w8a8.linear_forward(x, weight, bias, quant_kind="attn") + assert out_w8a8.shape == (M, N) + assert out_w8a8.dtype == torch.bfloat16 + + +def test_fp8_e5m2_vs_e4m3(): + """Test both FP8 formats (e4m3 and e5m2).""" + # Test W8A16 with both formats + strategy_e4m3 = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="bf16") + strategy_e5m2 = create_linear_strategy(weight_dtype="fp8_e5m2", act_dtype="bf16") + + M, K, N = 4, 256, 128 + x = torch.randn(M, K, dtype=torch.bfloat16) + weight = torch.randn(N, K, dtype=torch.bfloat16) + bias = torch.randn(N, dtype=torch.bfloat16) + + out_e4m3 = strategy_e4m3.linear_forward(x, weight, bias, quant_kind="attn") + out_e5m2 = strategy_e5m2.linear_forward(x, weight, bias, quant_kind="attn") + + # Both should produce valid outputs + assert out_e4m3.shape == (M, N) + assert out_e5m2.shape == (M, N) + assert out_e4m3.dtype == torch.bfloat16 + assert out_e5m2.dtype == torch.bfloat16 + diff --git a/test/python/test_linear_quantization_module.py b/test/python/test_linear_quantization_module.py new file mode 100644 index 0000000..b76c558 --- /dev/null +++ b/test/python/test_linear_quantization_module.py @@ -0,0 +1,488 @@ +import pytest + + +def test_linear_strategy_registry_bf16_pair(): + from diffulex.utils.quantization.registry import create_linear_strategy + + s = create_linear_strategy(weight_dtype="bf16", act_dtype="bf16") + assert s.linear_weight_format == "bf16" + assert s.linear_act_format == "bf16" + + +def test_linear_strategy_registry_int8_w8a16(): + """Test that int8+bf16 returns the real W8A16 strategy (not stub).""" + from diffulex.utils.quantization.registry import create_linear_strategy + + s = create_linear_strategy(weight_dtype="int8", act_dtype="bf16") + assert s.name == "linear_int8_w8a16" + assert s.linear_weight_format == "int8" + assert s.linear_act_format == "bf16" + + +def test_linear_strategy_registry_int4_w4a16(): + """Test that int4+bf16 returns the real W4A16 strategy (not stub).""" + from diffulex.utils.quantization.registry import create_linear_strategy + + s = create_linear_strategy(weight_dtype="int4", act_dtype="bf16") + assert s.name == "linear_int4_w4a16" + assert s.linear_weight_format == "int4" + assert s.linear_act_format == "bf16" + + +def test_linear_strategy_registry_int8_w8a8(): + """Test that int8+int8 returns the real W8A8 strategy (not stub).""" + from diffulex.utils.quantization.registry import create_linear_strategy + + s = create_linear_strategy(weight_dtype="int8", act_dtype="int8") + assert s.name == "linear_int8_w8a8" + assert s.linear_weight_format == "int8" + assert s.linear_act_format == "int8" + + +def test_linear_strategy_registry_int4_w4a8(): + """Test that int4+int8 returns the real W4A8 strategy (not stub).""" + from diffulex.utils.quantization.registry import create_linear_strategy + + s = create_linear_strategy(weight_dtype="int4", act_dtype="int8") + assert s.name == "linear_int4_w4a8" + assert s.linear_weight_format == "int4" + assert s.linear_act_format == "int8" + + +def test_linear_strategy_registry_non_bf16_returns_stub(): + """Test that unimplemented combinations (e.g., fp8) return stub.""" + from diffulex.utils.quantization.registry import create_linear_strategy + + s = create_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="bf16") + assert s.name.startswith("linear_stub") + assert s.linear_weight_format == "fp8_e4m3" + assert s.linear_act_format == "bf16" + + +def test_factory_injects_linear_strategies_into_context(): + from dataclasses import dataclass + + from diffulex.utils.quantization.factory import QuantizationStrategyFactory + from diffulex.utils.quantization.context import get_quantization_context + + @dataclass + class DummyConfig: + kv_cache_dtype: str = "bf16" + attn_q_dtype: str = "bf16" + linear_attn_weight_dtype: str = "bf16" + linear_mlp_weight_dtype: str = "bf16" + linear_attn_act_dtype: str = "bf16" + linear_mlp_act_dtype: str = "bf16" + + ctx = QuantizationStrategyFactory.create_from_config(DummyConfig()) + assert ctx is get_quantization_context() + assert ctx.get_linear_strategy("attn") is not None + assert ctx.get_linear_strategy("mlp") is not None + + +def test_linear_forward_raises_on_stub(monkeypatch): + # Avoid requiring torch.distributed process group init in unit tests. + import torch + import torch.nn.functional as F + import torch.distributed as dist + + monkeypatch.setattr(dist, "get_rank", lambda: 0) + monkeypatch.setattr(dist, "get_world_size", lambda: 1) + + from diffulex.layer.linear import ColumnParallelLinear + from diffulex.utils.quantization.registry import create_linear_strategy + from diffulex.utils.quantization.context import get_quantization_context + + # Install a stub strategy for attention linears (use int4, not implemented yet). + ctx = get_quantization_context() + ctx.set_linear_strategy("attn", create_linear_strategy(weight_dtype="int4", act_dtype="bf16")) + + lin = ColumnParallelLinear(4, 8, bias=False, quant_kind="attn") + # NOTE: default Linear weights are float32 unless a checkpoint loader overwrites them. + # Keep dtypes consistent for this unit test. + x = torch.randn(2, 4, dtype=torch.float32) + + with pytest.raises(NotImplementedError): + _ = lin(x) + + # Ensure bf16 path still works for other kinds. + lin2 = ColumnParallelLinear(4, 8, bias=False, quant_kind="other") + y = lin2(x) + ref = F.linear(x, lin2.weight, None) + assert torch.allclose(y, ref) + + +def test_linear_int8_w8a16_quantization(): + """Test that int8+bf16 strategy correctly quantizes and dequantizes weights.""" + from diffulex.utils.quantization.registry import create_linear_strategy + import torch + + strategy = create_linear_strategy(weight_dtype="int8", act_dtype="bf16") + assert strategy.name == "linear_int8_w8a16" + assert strategy.linear_weight_format == "int8" + assert strategy.linear_act_format == "bf16" + + # Test quantization/dequantization + weight = torch.randn(8, 4, dtype=torch.bfloat16) + quantized, scales = strategy.quantize(weight) + assert quantized.dtype == torch.int8 + assert quantized.shape == weight.shape + assert scales.shape == (weight.shape[0],) # Per-output-channel scales + + dequantized = strategy.dequantize(quantized, scales) + assert dequantized.dtype == torch.bfloat16 + assert dequantized.shape == weight.shape + + # Quantization error should be reasonable (int8 quantization introduces error) + error = (weight - dequantized).abs().max() + assert error.item() < 0.1, f"Quantization error too large: {error.item()}" + + +def test_linear_int8_w8a16_forward(): + """Test that int8+bf16 strategy's linear_forward produces reasonable outputs.""" + from diffulex.utils.quantization.registry import create_linear_strategy + import torch + import torch.nn.functional as F + + strategy = create_linear_strategy(weight_dtype="int8", act_dtype="bf16") + + x = torch.randn(2, 4, dtype=torch.bfloat16) + weight = torch.randn(8, 4, dtype=torch.bfloat16) + bias = torch.randn(8, dtype=torch.bfloat16) + + # Forward with quantized strategy + y_quant = strategy.linear_forward(x, weight, bias, quant_kind="test") + + # Reference forward (should be close but not exact due to quantization) + y_ref = F.linear(x, weight, bias) + + assert y_quant.shape == y_ref.shape + assert y_quant.dtype == torch.bfloat16 + + # Error should be reasonable (quantization introduces some error) + error = (y_quant - y_ref).abs().max() + assert error.item() < 0.5, f"Forward error too large: {error.item()}" + + +def test_linear_int8_w8a16_lazy_cache(): + """Test that W8A16 strategy caches quantized weights to avoid re-quantization.""" + from diffulex.utils.quantization.registry import create_linear_strategy + import torch + + strategy = create_linear_strategy(weight_dtype="int8", act_dtype="bf16") + + # Initial cache should be empty + assert len(strategy._weight_cache) == 0 + + weight = torch.randn(8, 4, dtype=torch.bfloat16) + x = torch.randn(2, 4, dtype=torch.bfloat16) + + # First forward - should cache + y1 = strategy.linear_forward(x, weight, None, quant_kind="test") + assert len(strategy._weight_cache) == 1 + assert id(weight) in strategy._weight_cache + + # Second forward with same weight - should use cache (same output) + y2 = strategy.linear_forward(x, weight, None, quant_kind="test") + assert len(strategy._weight_cache) == 1 # Cache size unchanged + assert torch.allclose(y1, y2), "Cached forward should produce same output" + + # Different weight - should cache new entry + weight2 = torch.randn(8, 4, dtype=torch.bfloat16) + y3 = strategy.linear_forward(x, weight2, None, quant_kind="test") + assert len(strategy._weight_cache) == 2 # New entry cached + + # Clear cache + strategy.clear_cache() + assert len(strategy._weight_cache) == 0 + + +def test_w8a16_tilelang_kernel_correctness(): + """Test that W8A16 TileLang kernel produces correct results (if available).""" + from diffulex.utils.quantization.registry import create_linear_strategy + import torch + + strategy = create_linear_strategy(weight_dtype="int8", act_dtype="bf16") + + # Skip test if TileLang kernel is not available + try: + from diffulex_kernel.python.linear_kernels import w8a16_gemm + tilelang_available = True + except ImportError: + tilelang_available = False + import pytest + pytest.skip("TileLang kernel not available") + + if not tilelang_available: + return + + # Create test data + M, N, K = 128, 256, 512 + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + weight = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") + + # Quantize weight + quantized_weight, scales = strategy.quantize(weight) + quantized_weight = quantized_weight.to(device="cuda") + scales = scales.to(device="cuda") + + # Compute reference output (Python implementation) + ref_output = strategy._fallback_python_forward(x, quantized_weight, scales, None) + + # Compute output using strategy (kernel when available; may fall back if kernel unavailable). + out = strategy.linear_forward(x, weight, None, quant_kind="test") + + # Compare results + error = (out - ref_output).abs().max() + # Relative error can explode when ref_output is very close to 0. + # Use a masked relative error that only considers reasonably-sized reference values. + rel_mask = ref_output.abs() > 1.0 + if rel_mask.any(): + relative_error = (out - ref_output).abs() / (ref_output.abs() + 1e-8) + max_relative_error = relative_error[rel_mask].max() + else: + max_relative_error = None + + # Allow some numerical error (quantization + kernel precision) + assert error.item() < 1.0, f"Absolute error too large: {error.item()}" + if max_relative_error is not None: + assert max_relative_error.item() < 0.15, f"Relative error too large: {max_relative_error.item()}" + + +def test_w8a16_tilelang_kernel_tail_sizes_correctness(): + """Tail sizes (non-multiple M/N/K) should be handled without needing K%128==0.""" + from diffulex.utils.quantization.registry import create_linear_strategy + import torch + + # Skip test if TileLang kernel is not available + try: + from diffulex_kernel.python.linear_kernels import w8a16_gemm # noqa: F401 + tilelang_available = True + except ImportError: + tilelang_available = False + import pytest + pytest.skip("TileLang kernel not available") + + if not tilelang_available: + return + + strategy = create_linear_strategy(weight_dtype="int8", act_dtype="bf16") + + if not torch.cuda.is_available(): + import pytest + pytest.skip("CUDA not available") + + # Intentionally choose tail sizes (not multiples of block_M/N=64 and block_K=128). + M, N, K = 127, 255, 130 + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + weight = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") + + # Strategy output (kernel when available; may fall back if kernel unavailable). + out = strategy.linear_forward(x, weight, None, quant_kind="test") + + # Reference (same as fallback implementation) + qweight, scales = strategy.quantize_weight_for_kernel(weight, device=x.device) + ref = strategy._fallback_python_forward(x, qweight, scales, None) + + assert out.shape == ref.shape + assert torch.allclose(out, ref, rtol=7e-2, atol=7e-2) + + +def test_w8a16_load_time_quantized_linear_saves_weight_memory(monkeypatch): + """Ensure load-time quantized Linear does not keep bf16 weight Parameter on CUDA.""" + import torch + import torch.distributed as dist + + if not torch.cuda.is_available(): + import pytest + pytest.skip("CUDA not available") + + # Avoid requiring torch.distributed process group init in unit tests. + monkeypatch.setattr(dist, "get_rank", lambda: 0) + monkeypatch.setattr(dist, "get_world_size", lambda: 1) + + from diffulex.layer.linear import ReplicatedLinear + from diffulex.utils.quantization.registry import create_linear_strategy + from diffulex.utils.quantization.context import get_quantization_context + + ctx = get_quantization_context() + strategy = create_linear_strategy(weight_dtype="int8", act_dtype="bf16") + ctx.set_linear_strategy("attn", strategy) + + lin = ReplicatedLinear(4096, 11008, bias=False, quant_kind="attn").cuda().to(dtype=torch.bfloat16) + + # Simulate checkpoint load: call weight_loader on the original Parameter. + param = lin._parameters["weight"] + loaded_weight = torch.randn_like(param, device=param.device, dtype=torch.bfloat16) + lin.weight_loader(param, loaded_weight) + + # Weight Parameter should be dropped and replaced by quant buffers. + assert lin.has_quantized_weight() + assert lin.weight is None + assert "weight" not in dict(lin.named_parameters()) + assert lin.quant_weight_int8.dtype == torch.int8 + assert lin.quant_scales.dtype == torch.bfloat16 + assert lin.quant_weight_int8.device.type == "cuda" + assert lin.quant_scales.device.type == "cuda" + + # Quant buffers should be significantly smaller than bf16 weight. + bf16_bytes = loaded_weight.numel() * loaded_weight.element_size() + q_bytes = lin.quant_weight_int8.numel() * lin.quant_weight_int8.element_size() + s_bytes = lin.quant_scales.numel() * lin.quant_scales.element_size() + assert (q_bytes + s_bytes) < bf16_bytes * 0.7 # conservative threshold + + # Forward should run and NOT populate the lazy cache (to avoid double-storage). + x = torch.randn(8, 4096, device="cuda", dtype=torch.bfloat16) + before_cache = len(strategy._weight_cache) + + +# ========== W4A16 Tests ========== + +def test_linear_int4_w4a16_quantization(): + """Test W4A16 quantization and dequantization.""" + from diffulex.utils.quantization.registry import create_linear_strategy + import torch + torch.manual_seed(0) + + strategy = create_linear_strategy(weight_dtype="int4", act_dtype="bf16") + assert strategy.name == "linear_int4_w4a16" + assert strategy.linear_weight_format == "int4" + assert strategy.linear_act_format == "bf16" + + # Test quantization/dequantization + # Use a bounded distribution to make the quantization error check stable. + # With int4 per-channel quantization, very large random values can cause the max error + # to occasionally exceed a tight threshold. + weight = (torch.randn(8, 4, dtype=torch.float32) * 0.5).to(torch.bfloat16) + packed_weight, scales = strategy.quantize(weight) + assert packed_weight.dtype == torch.int8 + # Packed shape: [out_features, (in_features + 1) // 2] + assert packed_weight.shape == (weight.shape[0], (weight.shape[1] + 1) // 2) + assert scales.shape == (weight.shape[0],) # Per-output-channel scales + + dequantized = strategy.dequantize(packed_weight, scales, original_in_features=weight.shape[1]) + assert dequantized.dtype == torch.bfloat16 + assert dequantized.shape == weight.shape + + # Quantization error should be reasonable (int4 quantization introduces more error than int8) + error = (weight - dequantized).abs().max() + assert error.item() < 0.2, f"Quantization error too large: {error.item()}" + + +def test_linear_int4_w4a16_forward(): + """Test that int4+bf16 strategy's linear_forward produces reasonable outputs.""" + from diffulex.utils.quantization.registry import create_linear_strategy + import torch + import torch.nn.functional as F + + strategy = create_linear_strategy(weight_dtype="int4", act_dtype="bf16") + + x = torch.randn(2, 4, dtype=torch.bfloat16) + weight = torch.randn(8, 4, dtype=torch.bfloat16) + bias = torch.randn(8, dtype=torch.bfloat16) + + # Forward with quantized strategy + y_quant = strategy.linear_forward(x, weight, bias, quant_kind="test") + + # Reference forward (should be close but not exact due to quantization) + y_ref = F.linear(x, weight, bias) + + assert y_quant.shape == y_ref.shape + assert y_quant.dtype == torch.bfloat16 + + # Error should be reasonable (int4 quantization introduces more error than int8) + error = (y_quant - y_ref).abs().max() + assert error.item() < 1.0, f"Forward error too large: {error.item()}" + + +def test_linear_int4_w4a16_lazy_cache(): + """Test that W4A16 strategy caches quantized weights to avoid re-quantization.""" + from diffulex.utils.quantization.registry import create_linear_strategy + import torch + + strategy = create_linear_strategy(weight_dtype="int4", act_dtype="bf16") + + # Initial cache should be empty + assert len(strategy._weight_cache) == 0 + + weight = torch.randn(8, 4, dtype=torch.bfloat16) + x = torch.randn(2, 4, dtype=torch.bfloat16) + + # First forward - should cache + y1 = strategy.linear_forward(x, weight, None, quant_kind="test") + assert len(strategy._weight_cache) == 1 + assert id(weight) in strategy._weight_cache + + # Second forward with same weight - should use cache (same output) + y2 = strategy.linear_forward(x, weight, None, quant_kind="test") + assert len(strategy._weight_cache) == 1 # Cache size unchanged + assert torch.allclose(y1, y2, rtol=1e-3, atol=1e-3), "Cached forward should produce same output" + + # Different weight - should cache new entry + weight2 = torch.randn(8, 4, dtype=torch.bfloat16) + y3 = strategy.linear_forward(x, weight2, None, quant_kind="test") + assert len(strategy._weight_cache) == 2 # New entry cached + + # Clear cache + strategy.clear_cache() + assert len(strategy._weight_cache) == 0 + + +def test_w4a16_load_time_quantized_linear_saves_weight_memory(monkeypatch): + """Ensure load-time quantized W4A16 Linear does not keep bf16 weight Parameter on CUDA.""" + import torch + import torch.distributed as dist + from diffulex.layer.linear import ReplicatedLinear + from diffulex.utils.quantization.registry import create_linear_strategy + from diffulex.utils.quantization.context import get_quantization_context + + if not torch.cuda.is_available(): + import pytest + pytest.skip("CUDA not available") + + # Avoid requiring torch.distributed process group init in unit tests. + monkeypatch.setattr(dist, "get_rank", lambda: 0) + monkeypatch.setattr(dist, "get_world_size", lambda: 1) + + ctx = get_quantization_context() + strategy = create_linear_strategy(weight_dtype="int4", act_dtype="bf16") + ctx.set_linear_strategy("attn", strategy) + + lin = ReplicatedLinear(4096, 11008, bias=False, quant_kind="attn").cuda().to(dtype=torch.bfloat16) + + # Simulate checkpoint load: call weight_loader on the original Parameter. + param = lin._parameters["weight"] + loaded_weight = torch.randn_like(param, device=param.device, dtype=torch.bfloat16) + lin.weight_loader(param, loaded_weight) + + # Weight Parameter should be dropped and replaced by quant buffers. + assert lin.has_quantized_weight() + assert lin.weight is None + assert "weight" not in dict(lin.named_parameters()) + assert lin.quant_weight_int8.dtype == torch.int8 + assert lin.quant_scales.dtype == torch.bfloat16 + assert lin.quant_weight_int8.device.type == "cuda" + assert lin.quant_scales.device.type == "cuda" + + # Quant buffers should be significantly smaller than bf16 weight. + # For int4: packed shape is [out_features, (in_features + 1) // 2] + bf16_bytes = loaded_weight.numel() * loaded_weight.element_size() + q_bytes = lin.quant_weight_int8.numel() * lin.quant_weight_int8.element_size() + s_bytes = lin.quant_scales.numel() * lin.quant_scales.element_size() + # int4 packed should be ~50% of bf16 (plus small scales overhead) + assert (q_bytes + s_bytes) < bf16_bytes * 0.6 # conservative threshold + + # Forward should run and NOT populate the lazy cache (to avoid double-storage). + x = torch.randn(8, 4096, device="cuda", dtype=torch.bfloat16) + before_cache = len(strategy._weight_cache) + out = lin(x) + after_cache = len(strategy._weight_cache) + assert after_cache == before_cache, "Load-time quantized forward should not populate lazy cache" + assert out.shape == (8, 11008) + assert out.dtype == torch.bfloat16 + y = lin(x) + after_cache = len(strategy._weight_cache) + assert y.shape == (8, 11008) + assert after_cache == before_cache + + diff --git a/test/python/test_quantization_e2e.py b/test/python/test_quantization_e2e.py new file mode 100644 index 0000000..3c23d7c --- /dev/null +++ b/test/python/test_quantization_e2e.py @@ -0,0 +1,188 @@ +""" +端到端测试:验证量化策略在实际使用场景中的集成 +""" +import torch +from types import SimpleNamespace + +from diffulex.utils.quantization.factory import QuantizationStrategyFactory +from diffulex.utils.quantization.context import get_kv_cache_strategy +from diffulex.attention.metadata import AttnMetaDataBase + + +def test_bf16_e2e(): + """端到端测试:BF16路径的完整流程""" + print("\n=== BF16 端到端测试 ===") + + # 1. 配置初始化 + config = SimpleNamespace( + kv_cache_dtype="bf16", + attn_q_dtype="bf16", + ) + ctx = QuantizationStrategyFactory.create_from_config(config) + strategy = get_kv_cache_strategy() + + # 2. 验证存储dtype + storage_dtype, itemsize = strategy.get_storage_dtype() + assert storage_dtype == torch.bfloat16 + print(f"✓ 存储dtype: {storage_dtype}, itemsize: {itemsize}") + + # 3. 模拟KV cache分配(类似ModelRunner.allocate_kv_cache) + num_layers = 2 + num_blocks = 4 + block_size = 32 + num_kv_heads = 8 + head_dim = 128 + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # 分配KV cache(unified layout) + kv_cache = torch.zeros( + 2, num_layers, num_blocks, block_size, num_kv_heads, head_dim, + dtype=storage_dtype, device=device + ) + print(f"✓ KV cache分配: shape={kv_cache.shape}, dtype={kv_cache.dtype}") + + # 4. 验证不需要scales + k_scale_init, v_scale_init = strategy.init_scales(num_kv_heads, device) + assert k_scale_init is None or v_scale_init is None + print("✓ BF16不需要scales") + + # 5. 模拟attention forward(类似Attention.forward) + seq_len = 16 + k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.bfloat16, device=device) + v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.bfloat16, device=device) + + # 模拟scale更新(应该被跳过,因为BF16不需要) + k_scale, v_scale = None, None + if strategy.requires_kv_cache_scales: + k_scale, v_scale = strategy.update_scales(k, v, k_scale, v_scale, num_kv_heads, device) + assert k_scale is None or v_scale is None + print("✓ Scale更新被正确跳过") + + # 6. 模拟metadata设置 + attn_metadata = AttnMetaDataBase() + strategy.maybe_set_attn_metadata_scales(attn_metadata, k_scale=k_scale, v_scale=v_scale) + assert attn_metadata.k_scale is None + assert attn_metadata.v_scale is None + print("✓ Metadata scales未设置(符合预期)") + + # 7. 验证cache view(应该直接返回原cache) + cache_view = strategy.view_kv_cache_for_kernels(kv_cache[0, 0, 0]) + assert cache_view is kv_cache[0, 0, 0] or torch.equal(cache_view, kv_cache[0, 0, 0]) + print("✓ Cache view正确(直接返回原cache)") + + print("✅ BF16端到端测试通过") + + +def test_fp8_e2e(): + """端到端测试:FP8路径的完整流程""" + print("\n=== FP8 端到端测试 ===") + + # 检查FP8支持 + has_fp8 = hasattr(torch, "float8_e4m3fn") or hasattr(torch, "float8_e4m3fnuz") + if not has_fp8: + print("⚠ 当前PyTorch版本不支持FP8,跳过FP8端到端测试") + return True + + # 1. 配置初始化 + config = SimpleNamespace( + kv_cache_dtype="fp8", + attn_q_dtype="bf16", + ) + ctx = QuantizationStrategyFactory.create_from_config(config) + strategy = get_kv_cache_strategy() + + # 2. 验证存储dtype + storage_dtype, itemsize = strategy.get_storage_dtype() + assert storage_dtype == torch.uint8 + assert itemsize == 1 + print(f"✓ 存储dtype: {storage_dtype}, itemsize: {itemsize}") + + # 3. 模拟KV cache分配 + num_layers = 2 + num_blocks = 4 + block_size = 32 + num_kv_heads = 8 + head_dim = 128 + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # 分配KV cache(unified layout,uint8存储) + kv_cache = torch.zeros( + 2, num_layers, num_blocks, block_size, num_kv_heads, head_dim, + dtype=storage_dtype, device=device + ) + print(f"✓ KV cache分配: shape={kv_cache.shape}, dtype={kv_cache.dtype}") + + # 4. 分配scales(类似ModelRunner.allocate_kv_cache) + k_scale_init, v_scale_init = strategy.init_scales(num_kv_heads, device) + assert k_scale_init is not None and v_scale_init is not None + + k_scale = torch.zeros(num_layers, num_kv_heads, dtype=torch.float32, device=device) + v_scale = torch.zeros(num_layers, num_kv_heads, dtype=torch.float32, device=device) + k_scale[:] = k_scale_init[None, :] + v_scale[:] = v_scale_init[None, :] + print(f"✓ Scales分配: k_scale={k_scale.shape}, v_scale={v_scale.shape}") + + # 5. 模拟attention forward + seq_len = 16 + k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.bfloat16, device=device) + v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.bfloat16, device=device) + + # 模拟scale更新(类似Attention.forward中的逻辑) + layer_id = 0 + k_scale_layer = k_scale[layer_id] + v_scale_layer = v_scale[layer_id] + + k_scale_updated, v_scale_updated = strategy.update_scales( + k, v, k_scale_layer, v_scale_layer, num_kv_heads, device + ) + assert k_scale_updated is not None and v_scale_updated is not None + assert k_scale_updated.shape == (num_kv_heads,) + assert v_scale_updated.shape == (num_kv_heads,) + print(f"✓ Scale更新: k_scale范围=[{k_scale_updated.min():.4f}, {k_scale_updated.max():.4f}]") + + # 更新全局scales + k_scale[layer_id] = k_scale_updated + v_scale[layer_id] = v_scale_updated + + # 6. 模拟metadata设置(类似Attention.forward) + attn_metadata = AttnMetaDataBase() + strategy.maybe_set_attn_metadata_scales( + attn_metadata, k_scale=k_scale_layer, v_scale=v_scale_layer + ) + assert attn_metadata.k_scale is not None + assert attn_metadata.v_scale is not None + print("✓ Metadata scales已设置") + + # 7. 验证cache view(应该返回float8 view) + cache_view = strategy.view_kv_cache_for_kernels(kv_cache[0, 0, 0]) + assert cache_view.dtype != torch.uint8 + print(f"✓ Cache view dtype: {cache_view.dtype}") + + # 8. 模拟quantize_kv_for_store(类似store_kvcache中的逻辑) + k_quantized, v_quantized = strategy.quantize_kv_for_store( + k, v, k_scale=k_scale_layer, v_scale=v_scale_layer + ) + assert k_quantized.dtype == torch.uint8 + assert v_quantized.dtype == torch.uint8 + assert k_quantized.shape == k.shape + assert v_quantized.shape == v.shape + print(f"✓ KV量化: k={k_quantized.shape}, v={v_quantized.shape}") + + print("✅ FP8端到端测试通过") + + +if __name__ == "__main__": + print("开始端到端测试...") + + try: + test_bf16_e2e() + test_fp8_e2e() + print("\n✅ 所有端到端测试通过!") + except Exception as e: + print(f"\n❌ 测试失败: {e}") + import traceback + traceback.print_exc() + exit(1) + diff --git a/test/python/test_quantization_module.py b/test/python/test_quantization_module.py new file mode 100644 index 0000000..10ede0b --- /dev/null +++ b/test/python/test_quantization_module.py @@ -0,0 +1,70 @@ +import pytest +import torch + + +def test_kv_cache_strategy_registry_bf16_aliases(): + from diffulex.utils.quantization import create_kv_cache_strategy + + s1 = create_kv_cache_strategy("bf16") + assert getattr(s1, "kv_cache_format", None) == "bf16" + assert s1.requires_kv_cache_scales is False + + # Currently routed to BF16 kernels for compatibility. + s2 = create_kv_cache_strategy("fp16") + assert getattr(s2, "kv_cache_format", None) == "bf16" + + s3 = create_kv_cache_strategy("fp32") + assert getattr(s3, "kv_cache_format", None) == "bf16" + + +def test_attn_q_strategy_registry_and_factory(): + from types import SimpleNamespace + from diffulex.utils.quantization import ( + QuantizationStrategyFactory, + get_attn_q_strategy, + create_attn_q_strategy, + ) + + # Registry creation works + s_bf16 = create_attn_q_strategy("bf16") + assert s_bf16.attn_q_format == "bf16" + + s_fp8 = create_attn_q_strategy("fp8") + assert s_fp8.attn_q_format == "fp8" + + # Factory wiring: enable fp8 Q and ensure it lands in context + cfg = SimpleNamespace(kv_cache_dtype="bf16", attn_q_dtype="fp8") + QuantizationStrategyFactory.create_from_config(cfg) + active = get_attn_q_strategy() + assert active is not None + assert active.attn_q_format == "fp8" + + +@pytest.mark.skipif( + not (hasattr(torch, "float8_e4m3fn") or hasattr(torch, "float8_e4m3fnuz")), + reason="This torch build does not expose float8 dtypes required by FP8 strategy.", +) +def test_kv_cache_fp8_strategy_metadata_and_views(): + from diffulex.utils.quantization import create_kv_cache_strategy + from diffulex.attention.metadata import AttnMetaDataBase + + s = create_kv_cache_strategy("fp8") + assert s.kv_cache_format == "fp8" + assert s.requires_kv_cache_scales is True + + md = AttnMetaDataBase() + k_scale = torch.ones((8,), dtype=torch.float32) + v_scale = torch.ones((8,), dtype=torch.float32) * 2 + s.maybe_set_attn_metadata_scales(md, k_scale=k_scale, v_scale=v_scale) + assert md.k_scale is k_scale + assert md.v_scale is v_scale + + with pytest.raises(ValueError): + s.maybe_set_attn_metadata_scales(AttnMetaDataBase(), k_scale=None, v_scale=None) + + # uint8 storage -> float8 view for kernels + cache_u8 = torch.empty((16,), dtype=torch.uint8) + cache_view = s.view_kv_cache_for_kernels(cache_u8) + assert cache_view.dtype != torch.uint8 + + diff --git a/test/python/test_quantization_paths.py b/test/python/test_quantization_paths.py new file mode 100644 index 0000000..3e5b065 --- /dev/null +++ b/test/python/test_quantization_paths.py @@ -0,0 +1,199 @@ +""" +验证量化路径:bf16路径和bf16+fp8 kv路径 +""" +import torch +from types import SimpleNamespace + +from diffulex.utils.quantization.factory import QuantizationStrategyFactory +from diffulex.utils.quantization.context import ( + get_kv_cache_strategy, + get_attn_q_strategy, + QuantizationContext, +) + + +def test_bf16_path(): + """测试bf16路径(默认,无量化)""" + print("\n=== 测试 BF16 路径 ===") + + # 创建配置 + config = SimpleNamespace( + kv_cache_dtype="bf16", + attn_q_dtype="bf16", + ) + + # 初始化量化上下文 + ctx = QuantizationStrategyFactory.create_from_config(config) + + # 获取策略 + kv_strategy = get_kv_cache_strategy() + attn_q_strategy = get_attn_q_strategy() + + assert kv_strategy is not None, "KV cache策略应该被创建" + assert attn_q_strategy is not None, "Attn-Q策略应该被创建" + + print(f"KV Cache策略: {kv_strategy.name}") + print(f"KV Cache格式: {kv_strategy.kv_cache_format}") + print(f"需要scales: {kv_strategy.requires_kv_cache_scales}") + + print(f"Attn-Q策略: {attn_q_strategy.name}") + print(f"Attn-Q格式: {attn_q_strategy.attn_q_format}") + + # 验证存储dtype + storage_dtype, itemsize = kv_strategy.get_storage_dtype() + assert storage_dtype == torch.bfloat16, f"期望bfloat16,得到{storage_dtype}" + assert itemsize == 2, f"期望itemsize=2,得到{itemsize}" + print(f"存储dtype: {storage_dtype}, itemsize: {itemsize}") + + # 验证不需要scales + assert not kv_strategy.requires_kv_cache_scales, "BF16不应该需要scales" + + # 验证scale初始化 + num_kv_heads = 8 + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + k_scale, v_scale = kv_strategy.init_scales(num_kv_heads, device) + assert k_scale is None or v_scale is None, "BF16策略应该返回None scales" + print("✓ BF16路径验证通过") + + +def test_bf16_with_fp8_kv_path(): + """测试bf16 + fp8 kv路径""" + print("\n=== 测试 BF16 + FP8 KV 路径 ===") + + # 检查是否支持FP8 + has_fp8 = hasattr(torch, "float8_e4m3fn") or hasattr(torch, "float8_e4m3fnuz") + if not has_fp8: + print("⚠ 当前PyTorch版本不支持FP8,跳过FP8测试") + return True + + # 创建配置 + config = SimpleNamespace( + kv_cache_dtype="fp8", # 或 "fp8_e4m3" + attn_q_dtype="bf16", + ) + + # 初始化量化上下文 + ctx = QuantizationStrategyFactory.create_from_config(config) + + # 获取策略 + kv_strategy = get_kv_cache_strategy() + attn_q_strategy = get_attn_q_strategy() + + assert kv_strategy is not None, "KV cache策略应该被创建" + assert attn_q_strategy is not None, "Attn-Q策略应该被创建" + + print(f"KV Cache策略: {kv_strategy.name}") + print(f"KV Cache格式: {kv_strategy.kv_cache_format}") + print(f"需要scales: {kv_strategy.requires_kv_cache_scales}") + + print(f"Attn-Q策略: {attn_q_strategy.name}") + print(f"Attn-Q格式: {attn_q_strategy.attn_q_format}") + + # 验证存储dtype(FP8应该用uint8存储) + storage_dtype, itemsize = kv_strategy.get_storage_dtype() + assert storage_dtype == torch.uint8, f"期望uint8,得到{storage_dtype}" + assert itemsize == 1, f"期望itemsize=1,得到{itemsize}" + print(f"存储dtype: {storage_dtype}, itemsize: {itemsize}") + + # 验证需要scales + assert kv_strategy.requires_kv_cache_scales, "FP8应该需要scales" + + # 验证scale初始化 + num_kv_heads = 8 + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + k_scale, v_scale = kv_strategy.init_scales(num_kv_heads, device) + assert k_scale is not None and v_scale is not None, "FP8策略应该返回非None scales" + assert k_scale.shape == (num_kv_heads,), f"k_scale形状应该是({num_kv_heads},),得到{k_scale.shape}" + assert v_scale.shape == (num_kv_heads,), f"v_scale形状应该是({num_kv_heads},),得到{v_scale.shape}" + print(f"初始scales形状: k_scale={k_scale.shape}, v_scale={v_scale.shape}") + + # 验证scale更新逻辑 + seq_len = 32 + head_dim = 128 + k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.bfloat16, device=device) + v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.bfloat16, device=device) + + # 第一次更新(从None开始) + k_scale_new, v_scale_new = kv_strategy.update_scales( + k, v, None, None, num_kv_heads, device + ) + assert k_scale_new is not None and v_scale_new is not None + assert k_scale_new.shape == (num_kv_heads,) + assert v_scale_new.shape == (num_kv_heads,) + print(f"第一次更新scales: k_scale范围=[{k_scale_new.min():.4f}, {k_scale_new.max():.4f}]") + + # 第二次更新(使用已有scales) + k_scale_updated, v_scale_updated = kv_strategy.update_scales( + k, v, k_scale_new, v_scale_new, num_kv_heads, device + ) + assert k_scale_updated is not None and v_scale_updated is not None + print(f"第二次更新scales: k_scale范围=[{k_scale_updated.min():.4f}, {k_scale_updated.max():.4f}]") + + # 验证view_kv_cache_for_kernels + cache_u8 = torch.empty((16,), dtype=torch.uint8, device=device) + cache_view = kv_strategy.view_kv_cache_for_kernels(cache_u8) + assert cache_view.dtype != torch.uint8, "view应该返回非uint8的dtype" + print(f"Cache view dtype: {cache_view.dtype}") + + # 验证quantize_kv_for_store + k_quantized, v_quantized = kv_strategy.quantize_kv_for_store( + k, v, k_scale=k_scale_new, v_scale=v_scale_new + ) + assert k_quantized.dtype == torch.uint8, f"量化后的K应该是uint8,得到{k_quantized.dtype}" + assert v_quantized.dtype == torch.uint8, f"量化后的V应该是uint8,得到{v_quantized.dtype}" + assert k_quantized.shape == k.shape, f"量化后的K形状应该保持不变" + assert v_quantized.shape == v.shape, f"量化后的V形状应该保持不变" + print(f"量化后形状: k={k_quantized.shape}, v={v_quantized.shape}") + + print("✓ BF16 + FP8 KV路径验证通过") + + +def test_metadata_integration(): + """测试与AttnMetaData的集成""" + print("\n=== 测试 Metadata 集成 ===") + + from diffulex.attention.metadata import AttnMetaDataBase + + # BF16路径 + config_bf16 = SimpleNamespace(kv_cache_dtype="bf16", attn_q_dtype="bf16") + QuantizationStrategyFactory.create_from_config(config_bf16) + kv_strategy_bf16 = get_kv_cache_strategy() + + md_bf16 = AttnMetaDataBase() + kv_strategy_bf16.maybe_set_attn_metadata_scales(md_bf16, k_scale=None, v_scale=None) + assert md_bf16.k_scale is None, "BF16不应该设置scales" + assert md_bf16.v_scale is None, "BF16不应该设置scales" + print("✓ BF16 metadata集成正常") + + # FP8路径(如果支持) + has_fp8 = hasattr(torch, "float8_e4m3fn") or hasattr(torch, "float8_e4m3fnuz") + if has_fp8: + config_fp8 = SimpleNamespace(kv_cache_dtype="fp8", attn_q_dtype="bf16") + QuantizationStrategyFactory.create_from_config(config_fp8) + kv_strategy_fp8 = get_kv_cache_strategy() + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + k_scale = torch.ones((8,), dtype=torch.float32, device=device) + v_scale = torch.ones((8,), dtype=torch.float32, device=device) * 2 + + md_fp8 = AttnMetaDataBase() + kv_strategy_fp8.maybe_set_attn_metadata_scales(md_fp8, k_scale=k_scale, v_scale=v_scale) + assert md_fp8.k_scale is k_scale, "FP8应该设置k_scale" + assert md_fp8.v_scale is v_scale, "FP8应该设置v_scale" + print("✓ FP8 metadata集成正常") + + +if __name__ == "__main__": + print("开始验证量化路径...") + + try: + test_bf16_path() + test_bf16_with_fp8_kv_path() + test_metadata_integration() + print("\n✅ 所有路径验证通过!") + except Exception as e: + print(f"\n❌ 验证失败: {e}") + import traceback + traceback.print_exc() + exit(1) + diff --git a/test/python/utils/__init__.py b/test/python/utils/__init__.py new file mode 100644 index 0000000..36585dd --- /dev/null +++ b/test/python/utils/__init__.py @@ -0,0 +1,2 @@ +# test.python.utils package + diff --git a/test/python/utils/checker.py b/test/python/utils/checker.py new file mode 100755 index 0000000..479ea05 --- /dev/null +++ b/test/python/utils/checker.py @@ -0,0 +1,520 @@ +def CHECK_D2F_SLOT_MAPPING(seqs, slot_mapping): + # check slot mapping layout + start_idx = 0 + for seq in seqs: + cur_ref_slot_mapping = [] + for idx in range(seq.num_diffusion_blocks): + if seq.active_blocks[idx]: + padding_num_tokens = (seq.num_diffusion_blocks - idx) * seq.diffusion_block_size + cur_ref_slot_mapping.extend([-1] * padding_num_tokens) + break + elif seq.to_cache_blocks[idx]: + cur_ref_slot_mapping.extend([0] * seq.diffusion_block_size) + cur_slot_mapping = slot_mapping[start_idx:start_idx + len(cur_ref_slot_mapping)] + for slot, ref_slot in zip(cur_slot_mapping, cur_ref_slot_mapping): + try: + if ref_slot == -1: + assert slot == -1 + elif ref_slot == 0: + assert slot != -1 + elif ref_slot is not None: + assert slot is not None + except AssertionError: + raise ValueError(f"Slot mapping mismatch: {slot} != {ref_slot}. " + f"Check the implementation of prepare_decode.\n" + f"slot_mapping: {cur_slot_mapping}\n" + f"ref_slot_mapping: {cur_ref_slot_mapping}\n" + f"diff: {[s - r for s, r in zip(cur_slot_mapping, cur_ref_slot_mapping)]}") + start_idx += len(cur_ref_slot_mapping) + + +def CHECK_FLASH_ATTN_PREFILL( + q, k, v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + prefill_kernel, + diffusion_block_size: int = 32, + is_block_attn: bool = False, +): + """ + Verify prefill kernel correctness by comparing with PyTorch's scaled_dot_product_attention. + + Args: + q: Query tensor [total_q_len, num_heads, head_dim] + k: Key tensor [total_kv_len, num_kv_heads, head_dim] + v: Value tensor [total_kv_len, num_kv_heads, head_dim] + cu_seqlens_q: Cumulative sequence lengths for queries + cu_seqlens_k: Cumulative sequence lengths for keys/values + max_seqlen_q: Maximum sequence length for queries + prefill_kernel: The kernel function to test + diffusion_block_size: Size of diffusion blocks for block attention + is_block_attn: Whether this is block attention mode + """ + import torch + import torch.nn.functional as F + from einops import rearrange + + # Run kernel + kernel_output = prefill_kernel(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) + + # Compute reference output using PyTorch's SDPA + head_dim = q.shape[2] + scale = 1.0 / (head_dim ** 0.5) + num_seqs = len(cu_seqlens_q) - 1 + + gt_output = torch.zeros_like(q) + for seq_idx in range(num_seqs): + q_start = cu_seqlens_q[seq_idx].item() + q_end = cu_seqlens_q[seq_idx + 1].item() + kv_start = cu_seqlens_k[seq_idx].item() + kv_end = cu_seqlens_k[seq_idx + 1].item() + + q_seq = q[q_start:q_end] + k_seq = k[kv_start:kv_end] + v_seq = v[kv_start:kv_end] + + q_len = q_seq.shape[0] + kv_len = k_seq.shape[0] + + # Reshape for SDPA: [1, num_heads, seq_len, head_dim] + q_sdpa = rearrange(q_seq, 's h d -> 1 h s d') + k_sdpa = rearrange(k_seq, 's h d -> 1 h s d') + v_sdpa = rearrange(v_seq, 's h d -> 1 h s d') + + if not is_block_attn: + # Standard attention + attn_out = F.scaled_dot_product_attention( + q_sdpa, + k_sdpa, + v_sdpa, + dropout_p=0.0, + is_causal=False, + scale=scale, + enable_gqa=True, + ) + else: + # Block attention with mask + block_mask = torch.zeros((1, 1, q_len, kv_len), dtype=q.dtype, device=q.device).bool() + num_diffusion_blocks = (kv_len + diffusion_block_size - 1) // diffusion_block_size + for block_idx in range(num_diffusion_blocks): + block_start = block_idx * diffusion_block_size + block_end = min(block_start + diffusion_block_size, kv_len) + block_mask[..., block_start:block_end, :block_end] = True + + attn_out = F.scaled_dot_product_attention( + q_sdpa, + k_sdpa, + v_sdpa, + attn_mask=block_mask, + dropout_p=0.0, + is_causal=False, + scale=scale, + enable_gqa=True, + ) + + gt_output[q_start:q_end] = rearrange(attn_out, '1 h s d -> s h d').to(gt_output.dtype) + + # Compare results + atol = 1e-2 + rtol = 1e-2 + try: + torch.testing.assert_close( + kernel_output, + gt_output, + atol=atol, + rtol=rtol, + msg="Kernel output does not match reference implementation" + ) + except AssertionError as e: + # Compute error statistics for debugging + abs_diff = torch.abs(kernel_output - gt_output) + max_diff = torch.max(abs_diff).item() + mean_diff = torch.mean(abs_diff).item() + rel_diff = torch.abs((kernel_output - gt_output) / (gt_output + 1e-8)) + max_rel_diff = torch.max(rel_diff).item() + mean_rel_diff = torch.mean(rel_diff).item() + + # Count elements that exceed tolerance + total_elements = kernel_output.numel() + # Elements that exceed absolute tolerance + exceeds_atol = (abs_diff > atol) + num_exceeds_atol = exceeds_atol.sum().item() + # Elements that exceed relative tolerance + exceeds_rtol = (rel_diff > rtol) + num_exceeds_rtol = exceeds_rtol.sum().item() + # Elements that exceed either tolerance + exceeds_tolerance = exceeds_atol | exceeds_rtol + num_exceeds_tolerance = exceeds_tolerance.sum().item() + pct_exceeds_tolerance = (num_exceeds_tolerance / total_elements * 100) if total_elements > 0 else 0 + + raise AssertionError( + f"Prefill kernel verification failed!\n" + f"Max absolute difference: {max_diff:.6f}\n" + f"Mean absolute difference: {mean_diff:.6f}\n" + f"Max relative difference: {max_rel_diff:.6f}\n" + f"Mean relative difference: {mean_rel_diff:.6f}\n" + f"Total elements: {total_elements}\n" + f"Elements exceeding absolute tolerance (atol={atol}): {num_exceeds_atol} ({num_exceeds_atol/total_elements*100:.2f}%)\n" + f"Elements exceeding relative tolerance (rtol={rtol}): {num_exceeds_rtol} ({num_exceeds_rtol/total_elements*100:.2f}%)\n" + f"Elements exceeding either tolerance: {num_exceeds_tolerance} ({pct_exceeds_tolerance:.2f}%)\n" + f"Kernel output shape: {kernel_output.shape}\n" + f"Reference output shape: {gt_output.shape}\n" + f"Original error: {str(e)}" + ) + + +def CHECK_FLASH_ATTN_DECODE( + q, k, v, + k_cache, v_cache, + block_tables, + context_lens, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + decode_kernel, + scale: float, + num_groups: int, + page_block_size: int, + diffusion_block_size: int = 32, + is_block_attn: bool = False, +): + """ + Verify decode kernel correctness by comparing with reference implementation. + This function mimics engine-like scenarios with memory reuse testing. + + Args: + q: Query tensor [total_q_len, num_heads, head_dim] + k: Key tensor [total_kv_len, num_kv_heads, head_dim] + v: Value tensor [total_kv_len, num_kv_heads, head_dim] + k_cache: KV cache for keys [num_page_blocks, page_block_size, num_kv_heads, head_dim] + v_cache: KV cache for values [num_page_blocks, page_block_size, num_kv_heads, head_dim] + block_tables: Block tables [num_seqs, max_seq_num_blocks] + context_lens: Context lengths for each sequence [num_seqs] + cu_seqlens_q: Cumulative sequence lengths for queries + cu_seqlens_k: Cumulative sequence lengths for keys/values + max_seqlen_q: Maximum sequence length for queries + decode_kernel: The kernel function to test + scale: Attention scale factor + num_groups: Number of GQA groups (num_heads // num_kv_heads) + page_block_size: Size of page blocks in KV cache + diffusion_block_size: Size of diffusion blocks for block attention + is_block_attn: Whether this is block attention mode + """ + import torch + from test.python.kernel.test_dllm_flash_attn_decode_kernel import naive_sdpa_with_kvcache + + # Test with memory reuse (simulate engine's behavior) + # Run multiple times to check for memory corruption + outputs = [] + for run_idx in range(3): + output = decode_kernel( + q, k, v, k_cache, v_cache, + block_tables, + context_lens, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + ) + torch.cuda.synchronize() + outputs.append(output.clone()) + + # Verify consistency across runs + consistent = True + for i in range(1, len(outputs)): + if not torch.allclose(outputs[0], outputs[i], atol=1e-5, rtol=1e-5): + consistent = False + max_diff = (outputs[0] - outputs[i]).abs().max().item() + raise AssertionError( + f"Output inconsistency detected in run {i}: max_diff={max_diff:.6f}. " + f"This indicates potential memory corruption or non-deterministic behavior." + ) + + # Use the first output for comparison + kernel_output = outputs[0] + + # Compute reference output using naive_sdpa_with_kvcache (same as test file) + gt_output = naive_sdpa_with_kvcache( + q, k, v, k_cache, v_cache, + block_tables, context_lens, + cu_seqlens_q, cu_seqlens_k, + scale, num_groups, page_block_size, + ) + + # Compare results (using same tolerance as test file) + atol = 1e-2 + rtol = 1e-2 + try: + torch.testing.assert_close( + kernel_output, + gt_output, + atol=atol, + rtol=rtol, + msg="Decode kernel output does not match reference implementation" + ) + except AssertionError as e: + # Compute error statistics for debugging + abs_diff = torch.abs(kernel_output - gt_output) + max_diff = torch.max(abs_diff).item() + mean_diff = torch.mean(abs_diff).item() + rel_diff = torch.abs((kernel_output - gt_output) / (gt_output + 1e-8)) + max_rel_diff = torch.max(rel_diff).item() + mean_rel_diff = torch.mean(rel_diff).item() + + # Count elements that exceed tolerance + total_elements = kernel_output.numel() + # Elements that exceed absolute tolerance + exceeds_atol = (abs_diff > atol) + num_exceeds_atol = exceeds_atol.sum().item() + # Elements that exceed relative tolerance + exceeds_rtol = (rel_diff > rtol) + num_exceeds_rtol = exceeds_rtol.sum().item() + # Elements that exceed either tolerance + exceeds_tolerance = exceeds_atol | exceeds_rtol + num_exceeds_tolerance = exceeds_tolerance.sum().item() + pct_exceeds_tolerance = (num_exceeds_tolerance / total_elements * 100) if total_elements > 0 else 0 + + # Save test case data for debugging + import os + from pathlib import Path + import pickle + from datetime import datetime + + save_dir = Path(os.getenv("TEST_CASE_SAVE_DIR", "./failed_test_cases")) + save_dir.mkdir(parents=True, exist_ok=True) + + # Generate unique filename with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + case_name = f"decode_kernel_failure_{timestamp}" + case_dir = save_dir / case_name + case_dir.mkdir(parents=True, exist_ok=True) + + # Save all input and output tensors + test_data = { + 'inputs': { + 'q': q.cpu(), + 'k': k.cpu(), + 'v': v.cpu(), + 'k_cache': k_cache.cpu(), + 'v_cache': v_cache.cpu(), + 'block_tables': block_tables.cpu(), + 'context_lens': context_lens.cpu(), + 'cu_seqlens_q': cu_seqlens_q.cpu(), + 'cu_seqlens_k': cu_seqlens_k.cpu(), + }, + 'outputs': { + 'kernel_output': kernel_output.cpu(), + 'gt_output': gt_output.cpu(), + 'abs_diff': abs_diff.cpu(), + 'rel_diff': rel_diff.cpu(), + }, + 'parameters': { + 'max_seqlen_q': max_seqlen_q, + 'scale': scale, + 'num_groups': num_groups, + 'page_block_size': page_block_size, + 'diffusion_block_size': diffusion_block_size, + 'is_block_attn': is_block_attn, + 'atol': atol, + 'rtol': rtol, + }, + 'statistics': { + 'max_diff': max_diff, + 'mean_diff': mean_diff, + 'max_rel_diff': max_rel_diff, + 'mean_rel_diff': mean_rel_diff, + 'total_elements': total_elements, + 'num_exceeds_atol': num_exceeds_atol, + 'num_exceeds_rtol': num_exceeds_rtol, + 'num_exceeds_tolerance': num_exceeds_tolerance, + 'pct_exceeds_tolerance': pct_exceeds_tolerance, + }, + 'shapes': { + 'q_shape': list(q.shape), + 'k_shape': list(k.shape), + 'v_shape': list(v.shape), + 'k_cache_shape': list(k_cache.shape), + 'v_cache_shape': list(v_cache.shape), + 'block_tables_shape': list(block_tables.shape), + 'kernel_output_shape': list(kernel_output.shape), + 'gt_output_shape': list(gt_output.shape), + }, + } + + # Save as pickle + with open(case_dir / "test_data.pkl", "wb") as f: + pickle.dump(test_data, f) + + # Save kernel source (same as test file) + kernel_path = None + try: + kernel_source = decode_kernel.get_kernel_source() + kernel_path = case_dir / "kernel.cu" + kernel_path.write_text(kernel_source) + except Exception as kernel_err: + # If kernel source is not available, log but don't fail + pass + + # Generate a Python script to reproduce the test case + timestamp_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + repro_script = f'''""" +Auto-generated test case from failed CHECK_FLASH_ATTN_DECODE. +Generated at: {timestamp_str} + +To use this test case: +1. Load the data: test_data = pickle.load(open("test_data.pkl", "rb")) +2. Move tensors to device: q = test_data['inputs']['q'].to(device), etc. +3. Call your kernel with the loaded inputs +""" +import torch +import pickle +from pathlib import Path + +# Load test data +case_dir = Path(__file__).parent +with open(case_dir / "test_data.pkl", "rb") as f: + test_data = pickle.load(f) + +# Extract inputs +q = test_data['inputs']['q'] +k = test_data['inputs']['k'] +v = test_data['inputs']['v'] +k_cache = test_data['inputs']['k_cache'] +v_cache = test_data['inputs']['v_cache'] +block_tables = test_data['inputs']['block_tables'] +context_lens = test_data['inputs']['context_lens'] +cu_seqlens_q = test_data['inputs']['cu_seqlens_q'] +cu_seqlens_k = test_data['inputs']['cu_seqlens_k'] + +# Extract parameters +params = test_data['parameters'] +max_seqlen_q = params['max_seqlen_q'] +scale = params['scale'] +num_groups = params['num_groups'] +page_block_size = params['page_block_size'] +diffusion_block_size = params['diffusion_block_size'] +is_block_attn = params['is_block_attn'] + +# Extract expected outputs for comparison +gt_output = test_data['outputs']['gt_output'] + +# Print test case info +print("Test Case Information:") +q_shape = test_data['shapes']['q_shape'] +k_shape = test_data['shapes']['k_shape'] +v_shape = test_data['shapes']['v_shape'] +print(f" Shapes: q={{q_shape}}, k={{k_shape}}, v={{v_shape}}") +print(f" Parameters: scale={{scale}}, num_groups={{num_groups}}, page_block_size={{page_block_size}}") +max_diff_val = test_data['statistics']['max_diff'] +num_mismatches = test_data['statistics']['num_exceeds_tolerance'] +print(f" Statistics: max_diff={{max_diff_val:.6f}}, num_mismatches={{num_mismatches}}") + +# TODO: Add your kernel call here +# kernel_output = your_kernel(q, k, v, k_cache, v_cache, block_tables, context_lens, +# cu_seqlens_q, cu_seqlens_k, max_seqlen_q) +# torch.testing.assert_close(kernel_output, gt_output, atol=params['atol'], rtol=params['rtol']) +''' + + with open(case_dir / "reproduce_test.py", "w") as f: + f.write(repro_script) + + # Save error summary + error_summary = f"""Test Case Failure Summary +Generated at: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")} + +Shapes: + q: {test_data['shapes']['q_shape']} + k: {test_data['shapes']['k_shape']} + v: {test_data['shapes']['v_shape']} + k_cache: {test_data['shapes']['k_cache_shape']} + v_cache: {test_data['shapes']['v_cache_shape']} + block_tables: {test_data['shapes']['block_tables_shape']} + kernel_output: {test_data['shapes']['kernel_output_shape']} + gt_output: {test_data['shapes']['gt_output_shape']} + +Parameters: + max_seqlen_q: {max_seqlen_q} + scale: {scale} + num_groups: {num_groups} + page_block_size: {page_block_size} + diffusion_block_size: {diffusion_block_size} + is_block_attn: {is_block_attn} + atol: {atol} + rtol: {rtol} + +Statistics: + Max absolute difference: {max_diff:.6f} + Mean absolute difference: {mean_diff:.6f} + Max relative difference: {max_rel_diff:.6f} + Mean relative difference: {mean_rel_diff:.6f} + Total elements: {total_elements} + Elements exceeding absolute tolerance: {num_exceeds_atol} ({num_exceeds_atol/total_elements*100:.2f}%) + Elements exceeding relative tolerance: {num_exceeds_rtol} ({num_exceeds_rtol/total_elements*100:.2f}%) + Elements exceeding either tolerance: {num_exceeds_tolerance} ({pct_exceeds_tolerance:.2f}%) +""" + + with open(case_dir / "error_summary.txt", "w") as f: + f.write(error_summary) + + save_info = f"\n\nTest case data saved to: {case_dir}\n" + save_info += f" - test_data.pkl: All input/output tensors and metadata\n" + save_info += f" - reproduce_test.py: Script to reproduce the test case\n" + save_info += f" - error_summary.txt: Summary of the failure\n" + if kernel_path is not None: + save_info += f" - kernel.cu: CUDA kernel source code\n" + + # Show mismatched elements layout + mismatch_info = "" + if num_exceeds_tolerance > 0: + # Get indices of mismatched elements + mismatch_indices = torch.nonzero(exceeds_tolerance, as_tuple=False) + num_to_show = min(50, num_exceeds_tolerance) # Show at most 50 mismatches + + mismatch_info = f"\n\nMismatched elements (showing first {num_to_show} of {num_exceeds_tolerance}):\n" + mismatch_info += "-" * 100 + "\n" + mismatch_info += f"{'Index':<30} {'Kernel Value':<20} {'Ref Value':<20} {'Abs Diff':<15} {'Rel Diff':<15}\n" + mismatch_info += "-" * 100 + "\n" + + for i in range(num_to_show): + idx = mismatch_indices[i] + idx_tuple = tuple(idx.tolist()) + + kernel_val = kernel_output[idx_tuple].item() + gt_val = gt_output[idx_tuple].item() + abs_err = abs_diff[idx_tuple].item() + rel_err = rel_diff[idx_tuple].item() + + mismatch_info += ( + f"{str(idx_tuple):<30} " + f"{kernel_val:>19.6f} " + f"{gt_val:>19.6f} " + f"{abs_err:>14.6f} " + f"{rel_err:>14.6f}\n" + ) + + if num_exceeds_tolerance > num_to_show: + mismatch_info += f"\n... and {num_exceeds_tolerance - num_to_show} more mismatches\n" + + # Show distribution of mismatches by dimension + if len(kernel_output.shape) >= 2: + mismatch_info += f"\nMismatch distribution by dimensions:\n" + for dim_idx in range(len(kernel_output.shape)): + dim_mismatches = exceeds_tolerance.sum(dim=tuple(j for j in range(len(kernel_output.shape)) if j != dim_idx)) + mismatch_info += f" Dim {dim_idx} (size {kernel_output.shape[dim_idx]}): {dim_mismatches.tolist()}\n" + + raise AssertionError( + f"Decode kernel verification failed!\n" + f"Max absolute difference: {max_diff:.6f}\n" + f"Mean absolute difference: {mean_diff:.6f}\n" + f"Max relative difference: {max_rel_diff:.6f}\n" + f"Mean relative difference: {mean_rel_diff:.6f}\n" + f"Total elements: {total_elements}\n" + f"Elements exceeding absolute tolerance (atol={atol}): {num_exceeds_atol} ({num_exceeds_atol/total_elements*100:.2f}%)\n" + f"Elements exceeding relative tolerance (rtol={rtol}): {num_exceeds_rtol} ({num_exceeds_rtol/total_elements*100:.2f}%)\n" + f"Elements exceeding either tolerance: {num_exceeds_tolerance} ({pct_exceeds_tolerance:.2f}%)\n" + f"Kernel output shape: {kernel_output.shape}\n" + f"Reference output shape: {gt_output.shape}\n" + f"{mismatch_info}" + f"{save_info}" + f"Original error: {str(e)}" + ) \ No newline at end of file diff --git a/test/test_gptq_awq_strategies.py b/test/test_gptq_awq_strategies.py new file mode 100644 index 0000000..7d5d12b --- /dev/null +++ b/test/test_gptq_awq_strategies.py @@ -0,0 +1,328 @@ +""" +Unit tests for GPTQ/AWQ quantization strategies. + +These tests verify the dequantization correctness for GPTQ and AWQ formats. +""" + +import pytest +import torch +import torch.nn.functional as F + +from diffulex.utils.quantization.strategies.linear_gptq_w4a16 import ( + LinearGPTQW4A16Strategy, + _dequantize_gptq, + _unpack_gptq_int4, +) +from diffulex.utils.quantization.strategies.linear_awq_w4a16 import ( + LinearAWQW4A16Strategy, + _dequantize_awq, + _unpack_awq_int4, +) + + +def _pack_int4_to_int8(int4_tensor: torch.Tensor) -> torch.Tensor: + """Pack int4 tensor into int8 format for testing. + + This matches the unpack implementation in _unpack_gptq_int4: + - Lower 4 bits: even columns (0, 2, 4, ...) + - Upper 4 bits: odd columns (1, 3, 5, ...) + """ + out_features, in_features = int4_tensor.shape + + # Clamp to int4 range [-8, 7] + int4_tensor = int4_tensor.clamp(-8, 7) + + # Pad in_features to even number if needed + if in_features % 2 != 0: + pad_size = 1 + padding = torch.zeros(out_features, pad_size, dtype=int4_tensor.dtype, device=int4_tensor.device) + int4_tensor = torch.cat([int4_tensor, padding], dim=1) + padded_in_features = in_features + pad_size + else: + padded_in_features = in_features + + # Convert to uint8 for bit manipulation + # Map [-8, 7] to [0, 15] by adding 8 + uint8_tensor = (int4_tensor + 8).to(torch.uint8) + + # Reshape to [out_features, in_features // 2, 2] + reshaped = uint8_tensor.view(out_features, padded_in_features // 2, 2) + + # Pack: even columns (reshaped[:, :, 0]) in lower 4 bits, odd columns (reshaped[:, :, 1]) in upper 4 bits + # This matches unpack: low = p_u8 & 0x0F (even), high = (p_u8 >> 4) & 0x0F (odd) + packed = reshaped[:, :, 0] | (reshaped[:, :, 1] << 4) + + # Convert back to int8 + return packed.to(torch.int8) + + +@pytest.mark.parametrize("out_features,in_features,group_size", [ + (128, 256, 128), + (256, 512, 128), + (128, 128, 128), +]) +def test_gptq_unpack_pack_roundtrip(out_features, in_features, group_size): + """Test that unpack and pack operations are inverse.""" + # Create random int4 weights + weight_int4 = torch.randint(-8, 8, (out_features, in_features), dtype=torch.int8) + + # Pack to int8 + packed = _pack_int4_to_int8(weight_int4) + + # Unpack back + unpacked = _unpack_gptq_int4(packed, out_features=out_features, in_features=in_features) + + # Verify roundtrip + assert unpacked.shape == weight_int4.shape + torch.testing.assert_close(unpacked, weight_int4, rtol=0, atol=0) + + +@pytest.mark.parametrize("out_features,in_features,group_size", [ + (128, 256, 128), + (256, 512, 128), + (128, 128, 128), +]) +def test_awq_unpack_pack_roundtrip(out_features, in_features, group_size): + """Test that unpack and pack operations are inverse.""" + # Create random int4 weights + weight_int4 = torch.randint(-8, 8, (out_features, in_features), dtype=torch.int8) + + # Pack to int8 + packed = _pack_int4_to_int8(weight_int4) + + # Unpack back + unpacked = _unpack_awq_int4(packed, out_features=out_features, in_features=in_features) + + # Verify roundtrip + assert unpacked.shape == weight_int4.shape + torch.testing.assert_close(unpacked, weight_int4, rtol=0, atol=0) + + +@pytest.mark.parametrize("out_features,in_features,group_size", [ + (128, 256, 128), + (256, 512, 128), + (128, 128, 128), +]) +def test_gptq_dequantize_correctness(out_features, in_features, group_size): + """Test GPTQ dequantization correctness.""" + device = torch.device("cpu") + + # Create reference float weights + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32) + + # Simulate GPTQ quantization + num_groups = (out_features + group_size - 1) // group_size + + # Quantize per group + qweight_list = [] + qzeros_list = [] + scales_list = [] + + for g in range(num_groups): + start_idx = g * group_size + end_idx = min((g + 1) * group_size, out_features) + group_weight = weight_fp32[start_idx:end_idx] # [group_size, in_features] + + # Compute scale per group (per input channel for GPTQ/AWQ) + # GPTQ/AWQ typically uses per-channel scales: [in_features] + abs_max_per_channel = torch.abs(group_weight).max(dim=0, keepdim=False)[0] # [in_features] + scales_per_channel = (abs_max_per_channel.clamp(min=1e-8) / 7.0).to(torch.float32) # [in_features] + + # Per-group zero point (typically zero for symmetric quantization) + zeros_per_channel = torch.zeros(in_features, dtype=torch.float32) + + # Quantize weight for this group + qweight_group = torch.round(group_weight / scales_per_channel.unsqueeze(0)).clamp(-8, 7).to(torch.int8) + # Quantize zeros (should be zero, but compute for consistency) + qzeros_per_channel = torch.round(zeros_per_channel / scales_per_channel).clamp(-8, 7).to(torch.int8) + + qweight_list.append(qweight_group) + qzeros_list.append(qzeros_per_channel.unsqueeze(0)) # [1, in_features] + scales_list.append(scales_per_channel.unsqueeze(0)) # [1, in_features] + + # Concatenate groups + qweight = torch.cat(qweight_list, dim=0) # [out_features, in_features] + qzeros = torch.cat(qzeros_list, dim=0) # [num_groups, in_features] + scales = torch.cat(scales_list, dim=0) # [num_groups, in_features] + + # Ensure shapes are correct + assert qzeros.shape == (num_groups, in_features), f"qzeros shape mismatch: got {qzeros.shape}, expected ({num_groups}, {in_features})" + assert scales.shape == (num_groups, in_features), f"scales shape mismatch: got {scales.shape}, expected ({num_groups}, {in_features})" + + # Pack to int8 + qweight_packed = _pack_int4_to_int8(qweight) + qzeros_packed = _pack_int4_to_int8(qzeros) + + # Dequantize + dequantized = _dequantize_gptq( + qweight=qweight_packed, + qzeros=qzeros_packed, + scales=scales, + out_features=out_features, + in_features=in_features, + group_size=group_size, + g_idx=None, + ) + + # Verify approximate correctness (allow small quantization error) + assert dequantized.shape == weight_fp32.shape + # Note: Exact match is not expected due to quantization, but should be close + error = torch.abs(dequantized.float() - weight_fp32) + max_error = error.max().item() + mean_error = error.mean().item() + + # Allow reasonable quantization error + assert max_error < 1.0, f"Max quantization error too large: {max_error}" + assert mean_error < 0.5, f"Mean quantization error too large: {mean_error}" + + +@pytest.mark.parametrize("out_features,in_features,group_size", [ + (128, 256, 128), + (256, 512, 128), + (128, 128, 128), +]) +def test_awq_dequantize_correctness(out_features, in_features, group_size): + """Test AWQ dequantization correctness.""" + device = torch.device("cpu") + + # Create reference float weights + weight_fp32 = torch.randn(out_features, in_features, dtype=torch.float32) + + # Simulate AWQ quantization + num_groups = (out_features + group_size - 1) // group_size + + # Quantize per group (sequential grouping) + qweight_list = [] + qzeros_list = [] + scales_list = [] + + for g in range(num_groups): + start_idx = g * group_size + end_idx = min((g + 1) * group_size, out_features) + group_weight = weight_fp32[start_idx:end_idx] # [group_size, in_features] + + # Compute scale per group (per input channel for AWQ) + # AWQ typically uses per-channel scales: [in_features] + abs_max_per_channel = torch.abs(group_weight).max(dim=0, keepdim=False)[0] # [in_features] + scales_per_channel = (abs_max_per_channel.clamp(min=1e-8) / 7.0).to(torch.float32) # [in_features] + + # Per-group zero point (typically zero for symmetric quantization) + zeros_per_channel = torch.zeros(in_features, dtype=torch.float32) + + # Quantize weight for this group + qweight_group = torch.round(group_weight / scales_per_channel.unsqueeze(0)).clamp(-8, 7).to(torch.int8) + # Quantize zeros (should be zero, but compute for consistency) + qzeros_per_channel = torch.round(zeros_per_channel / scales_per_channel).clamp(-8, 7).to(torch.int8) + + qweight_list.append(qweight_group) + qzeros_list.append(qzeros_per_channel.unsqueeze(0)) # [1, in_features] + scales_list.append(scales_per_channel.unsqueeze(0)) # [1, in_features] + + # Concatenate groups + qweight = torch.cat(qweight_list, dim=0) # [out_features, in_features] + qzeros = torch.cat(qzeros_list, dim=0) # [num_groups, in_features] + scales = torch.cat(scales_list, dim=0) # [num_groups, in_features] + + # Ensure shapes are correct + assert qzeros.shape == (num_groups, in_features), f"qzeros shape mismatch: got {qzeros.shape}, expected ({num_groups}, {in_features})" + assert scales.shape == (num_groups, in_features), f"scales shape mismatch: got {scales.shape}, expected ({num_groups}, {in_features})" + + # Pack to int8 + qweight_packed = _pack_int4_to_int8(qweight) + qzeros_packed = _pack_int4_to_int8(qzeros) + + # Dequantize + dequantized = _dequantize_awq( + qweight=qweight_packed, + qzeros=qzeros_packed, + scales=scales, + out_features=out_features, + in_features=in_features, + group_size=group_size, + ) + + # Verify approximate correctness + assert dequantized.shape == weight_fp32.shape + error = torch.abs(dequantized.float() - weight_fp32) + max_error = error.max().item() + mean_error = error.mean().item() + + # Allow reasonable quantization error + assert max_error < 1.0, f"Max quantization error too large: {max_error}" + assert mean_error < 0.5, f"Mean quantization error too large: {mean_error}" + + +def test_gptq_strategy_linear_forward(): + """Test GPTQ strategy linear forward pass.""" + strategy = LinearGPTQW4A16Strategy() + + out_features, in_features = 128, 256 + group_size = 128 + num_groups = (out_features + group_size - 1) // group_size + + # Create mock GPTQ tensors + qweight = torch.randint(-128, 127, (out_features, (in_features + 1) // 2), dtype=torch.int8) + qzeros = torch.randint(-128, 127, (num_groups, (in_features + 1) // 2), dtype=torch.int8) + scales = torch.randn(num_groups, in_features, dtype=torch.float32).abs() + 0.1 + + # Create input + batch_size = 4 + x = torch.randn(batch_size, in_features, dtype=torch.bfloat16) + + # Forward pass + output = strategy.linear_forward( + x=x, + weight=None, + bias=None, + quant_kind="other", + gptq_qweight=qweight, + gptq_qzeros=qzeros, + gptq_scales=scales, + gptq_group_size=group_size, + out_features=out_features, + in_features=in_features, + ) + + # Verify output shape + assert output.shape == (batch_size, out_features) + assert output.dtype == torch.bfloat16 + + +def test_awq_strategy_linear_forward(): + """Test AWQ strategy linear forward pass.""" + strategy = LinearAWQW4A16Strategy() + + out_features, in_features = 128, 256 + group_size = 128 + num_groups = (out_features + group_size - 1) // group_size + + # Create mock AWQ tensors + qweight = torch.randint(-128, 127, (out_features, (in_features + 1) // 2), dtype=torch.int8) + qzeros = torch.randint(-128, 127, (num_groups, (in_features + 1) // 2), dtype=torch.int8) + scales = torch.randn(num_groups, in_features, dtype=torch.float32).abs() + 0.1 + + # Create input + batch_size = 4 + x = torch.randn(batch_size, in_features, dtype=torch.bfloat16) + + # Forward pass + output = strategy.linear_forward( + x=x, + weight=None, + bias=None, + quant_kind="other", + awq_qweight=qweight, + awq_qzeros=qzeros, + awq_scales=scales, + awq_group_size=group_size, + out_features=out_features, + in_features=in_features, + ) + + # Verify output shape + assert output.shape == (batch_size, out_features) + assert output.dtype == torch.bfloat16 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])