diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 00000000..adb2692e
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "angelslim/compressor/sparsity/vecattention/ops/vllm-flash-attention"]
+ path = angelslim/compressor/sparsity/vecattention/ops/vllm-flash-attention
+ url = git@github.com:anminliu/vllm-flash-attention.git
diff --git a/README.md b/README.md
index 7d6d7a43..de3c3f1e 100644
--- a/README.md
+++ b/README.md
@@ -135,6 +135,12 @@ A more accessible, comprehensive, and efficient toolkit for large model compress
+ -
+ Sparse Attention
+
+
-
Token Pruning
diff --git a/README_cn.md b/README_cn.md
index df58cf5b..079ab305 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -136,6 +136,12 @@
|
+ -
+ 稀疏注意力
+
+
-
Token剪枝
diff --git a/angelslim/compressor/sparsity/__init__.py b/angelslim/compressor/sparsity/__init__.py
index eaa4c7a2..52c1a740 100644
--- a/angelslim/compressor/sparsity/__init__.py
+++ b/angelslim/compressor/sparsity/__init__.py
@@ -13,5 +13,6 @@
# limitations under the License.
from .stem import StemInference # noqa: F401
+from .vecattention import VecAttentionInference # noqa: F401
-__all__ = ["StemInference"]
+__all__ = ["StemInference", "VecAttentionInference"]
diff --git a/angelslim/compressor/sparsity/vecattention/__init__.py b/angelslim/compressor/sparsity/vecattention/__init__.py
new file mode 100644
index 00000000..632c2880
--- /dev/null
+++ b/angelslim/compressor/sparsity/vecattention/__init__.py
@@ -0,0 +1,17 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .vecattention import VecAttentionInference # noqa: F401
+
+__all__ = ["VecAttentionInference"]
diff --git a/angelslim/compressor/sparsity/vecattention/modules/__init__.py b/angelslim/compressor/sparsity/vecattention/modules/__init__.py
new file mode 100644
index 00000000..a5de244f
--- /dev/null
+++ b/angelslim/compressor/sparsity/vecattention/modules/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""VecAttention-patched attention forward methods for VLM."""
+
+from .forward import qwen_vl_attn_forward
+
+__all__ = ["qwen_vl_attn_forward"]
diff --git a/angelslim/compressor/sparsity/vecattention/modules/forward.py b/angelslim/compressor/sparsity/vecattention/modules/forward.py
new file mode 100644
index 00000000..6ecf6988
--- /dev/null
+++ b/angelslim/compressor/sparsity/vecattention/modules/forward.py
@@ -0,0 +1,381 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""VecAttention-patched attention forward pass for Vision-Language Models.
+
+This module provides the replacement ``forward`` method that is bound to each
+attention layer by the patch function. During **prefill** (``q_len > 1``) it
+delegates to the VecAttention sparse backend; during **decode** (``q_len == 1``)
+it falls back to the model's original attention implementation.
+
+VecAttention applies sparse attention only to the vision token region;
+text tokens before/after use standard full attention.
+"""
+
+from __future__ import annotations
+
+import math
+import os
+import sys
+from typing import Union
+
+import torch
+from torch import nn
+from transformers.cache_utils import Cache
+from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
+
+from ..ops.vecattention_kernel import average_vector, fuse_qk_softmax_minp_wo_causal
+
+# Ensure VecAttention's custom vllm_flash_attn is importable.
+# The vllm-flash-attention source lives as a standalone submodule under ops/.
+_OPS_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ops")
+_VLLM_FA_DIR = os.path.join(_OPS_DIR, "vllm-flash-attention")
+if os.path.isdir(_VLLM_FA_DIR) and _VLLM_FA_DIR not in sys.path:
+ sys.path.insert(0, _VLLM_FA_DIR)
+
+try:
+ from vllm_flash_attn import sparse_attn_func
+except ImportError:
+ raise ImportError(
+ "vllm_flash_attn with sparse_attn_func not found. "
+ "Please init the submodule and build:\n"
+ " git submodule update --init --recursive\n"
+ " cd angelslim/compressor/sparsity/vecattention/ops/vllm-flash-attention\n"
+ " pip install -e . --no-build-isolation"
+ )
+
+
+# ---------------------------------------------------------------------------
+# Helper functions
+# ---------------------------------------------------------------------------
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """Repeat KV heads to match the number of query heads (GQA support)."""
+ if n_rep == 1:
+ return hidden_states
+ batch, num_kv_heads, slen, head_dim = hidden_states.shape
+ hidden_states = hidden_states[:, :, None, :, :].expand(
+ batch, num_kv_heads, n_rep, slen, head_dim
+ )
+ return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Eager scaled dot-product attention for decode fallback."""
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+def _full_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ causal: bool = True,
+) -> torch.Tensor:
+ """Standard scaled dot-product attention (no sparsity).
+
+ Used for text token segments where VecAttention is not applied.
+ """
+ scaling = query.shape[-1] ** -0.5
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+
+ if causal:
+ Lq, Lk = query.shape[2], key.shape[2]
+ causal_mask = torch.ones(Lq, Lk, dtype=torch.bool, device=query.device).triu(
+ diagonal=Lk - Lq + 1
+ )
+ attn_weights = attn_weights.masked_fill(causal_mask[None, None, :, :], float("-inf"))
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ return torch.matmul(attn_weights, value)
+
+
+# ---------------------------------------------------------------------------
+# VecAttention prefill core
+# ---------------------------------------------------------------------------
+
+
+def vecattention_forward(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ threshold: Union[float, torch.Tensor] = None,
+ q_pooling_size: int = 128,
+ k_local_size: int = 128,
+ group_k_block: int = 1,
+ causal: bool = True,
+ chunk_size: int = 16 * 1024,
+) -> torch.Tensor:
+ """Vectorized sparse attention prefill (VecAttention).
+
+ Selects important key-value blocks per query using a fused Triton kernel
+ that applies a per-head MinP threshold on averaged Q*K^T scores, then
+ runs sparse attention via vllm_flash_attn.sparse_attn_func.
+
+ Args:
+ query_states: (batch, num_heads, seq_len, head_dim)
+ key_states: (batch, num_kv_heads, seq_len, head_dim)
+ value_states: (batch, num_kv_heads, seq_len, head_dim)
+ threshold: MinP threshold.
+ q_pooling_size: Query block size for vector pooling. Must be 64 or 128.
+ k_local_size: Key local block size for column selection.
+ group_k_block: Number of k-blocks processed together.
+ causal: Whether to apply causal masking.
+ chunk_size: Prefill chunk size (tokens).
+
+ Returns:
+ attn_output: (batch, num_heads, seq_len, head_dim)
+ """
+ assert chunk_size % q_pooling_size == 0, "chunk_size must be a multiple of q_pooling_size"
+ assert q_pooling_size in [64, 128], "q_pooling_size must be 64 or 128"
+ SPATTN_BLOCK_SIZE_K = 64
+
+ if isinstance(threshold, torch.Tensor):
+ gap = -torch.log(threshold + 1e-9)
+ else:
+ gap = -math.log(threshold + 1e-9, math.e)
+
+ batch_size, num_heads, seq_len, head_dim = query_states.shape
+ num_q_blocks = math.ceil(seq_len / q_pooling_size)
+
+ attn_output = torch.empty(
+ batch_size,
+ num_heads,
+ seq_len,
+ head_dim,
+ dtype=query_states.dtype,
+ device=query_states.device,
+ )
+
+ if causal:
+ n = q_pooling_size // SPATTN_BLOCK_SIZE_K
+ blk_count = torch.full(
+ (batch_size, num_heads, num_q_blocks),
+ 2 * n,
+ dtype=torch.int32,
+ device=query_states.device,
+ )
+ blk_count[..., 0] = math.ceil(min(seq_len, q_pooling_size) / SPATTN_BLOCK_SIZE_K)
+ if seq_len > q_pooling_size and seq_len % q_pooling_size != 0:
+ blk_count[..., -1] = n + math.ceil(
+ (seq_len - math.floor(seq_len / q_pooling_size) * q_pooling_size)
+ / SPATTN_BLOCK_SIZE_K
+ )
+ blk_idx = torch.zeros(
+ batch_size,
+ num_heads,
+ num_q_blocks,
+ 2 * n,
+ dtype=torch.int32,
+ device=query_states.device,
+ )
+ offsets = (
+ torch.arange(n, device=query_states.device, dtype=torch.int32) * SPATTN_BLOCK_SIZE_K
+ )
+ blk_idx[..., :n] = offsets
+ base = (
+ torch.arange(0, num_q_blocks, device=query_states.device, dtype=torch.int32)
+ * q_pooling_size
+ ).unsqueeze(-1)
+ blk_idx[..., n:] = base + offsets
+ else:
+ blk_count = torch.zeros(
+ batch_size, num_heads, num_q_blocks, dtype=torch.int32, device=query_states.device
+ )
+ blk_idx = torch.zeros(
+ batch_size, num_heads, num_q_blocks, 1, dtype=torch.int32, device=query_states.device
+ )
+
+ for i in range(0, seq_len, chunk_size):
+ q_chunk = query_states[:, :, i : i + chunk_size, :]
+
+ avg_q_chunk = average_vector(q_chunk, q_pooling_size, use_triton=True)
+ col_count, col_idx = fuse_qk_softmax_minp_wo_causal(
+ avg_q_chunk,
+ key_states,
+ i // q_pooling_size,
+ gap,
+ causal,
+ q_pooling_size,
+ k_local_size,
+ wo_initial=causal,
+ group_k_block=group_k_block,
+ )
+
+ blk_count_chunk = blk_count[:, :, i // q_pooling_size : (i + chunk_size) // q_pooling_size]
+ blk_idx_chunk = blk_idx[:, :, i // q_pooling_size : (i + chunk_size) // q_pooling_size, :]
+
+ k_chunk = key_states[:, :, : i + chunk_size, :] if causal else key_states
+ v_chunk = value_states[:, :, : i + chunk_size, :] if causal else value_states
+
+ attn_output[:, :, i : i + chunk_size, :] = sparse_attn_func(
+ q_chunk.transpose(1, 2).contiguous(),
+ k_chunk.transpose(1, 2).contiguous(),
+ v_chunk.transpose(1, 2).contiguous(),
+ q_pooling_size,
+ blk_count_chunk.contiguous(),
+ blk_idx_chunk.contiguous(),
+ col_count,
+ col_idx,
+ return_softmax_lse=False,
+ causal=causal,
+ ).transpose(1, 2)
+
+ return attn_output
+
+
+# ---------------------------------------------------------------------------
+# Patched forward for Qwen2.5-VL models
+# ---------------------------------------------------------------------------
+
+
+def qwen_vl_attn_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: torch.LongTensor | None = None,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
+ **kwargs,
+) -> tuple[torch.Tensor, torch.Tensor | None]:
+ """VecAttention-patched attention forward for Qwen2.5-VL.
+
+ During prefill, applies VecAttention sparse attention to the vision token
+ region and full attention to surrounding text tokens.
+ During decode, falls back to the model's original attention.
+ """
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
+ apply_multimodal_rotary_pos_emb,
+ )
+
+ rope_scaling = getattr(self, "rope_scaling", None) or self.config.rope_scaling
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
+ query_states, key_states, cos, sin, rope_scaling["mrope_section"]
+ )
+
+ if past_key_values is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ config = self.vecattention_config
+ vision_start = config.attn_kwargs.get("vision_start_position", None)
+ vision_end = config.attn_kwargs.get("vision_end_position", None)
+
+ if q_len > 1 and q_len >= config.block_size_q:
+ if vision_start is None and vision_end is None:
+ attn_output = vecattention_forward(
+ query_states,
+ key_states,
+ value_states,
+ threshold=config.threshold,
+ q_pooling_size=config.block_size_q,
+ k_local_size=config.block_size_k,
+ group_k_block=config.group_k_block,
+ causal=True,
+ chunk_size=config.chunk_size,
+ )
+ else:
+ if vision_start is not None and vision_start > 0:
+ attn_output_before = _full_attention(
+ query_states[:, :, :vision_start, :],
+ key_states[:, :, :vision_start, :],
+ value_states[:, :, :vision_start, :],
+ causal=True,
+ )
+
+ ve = vision_end if vision_end is not None else q_len
+ attn_output_va = vecattention_forward(
+ query_states[:, :, :ve, :],
+ key_states[:, :, :ve, :],
+ value_states[:, :, :ve, :],
+ threshold=config.threshold,
+ q_pooling_size=config.block_size_q,
+ k_local_size=config.block_size_k,
+ group_k_block=config.group_k_block,
+ causal=True,
+ chunk_size=config.chunk_size,
+ )
+
+ if vision_start is not None and vision_start > 0:
+ attn_output_va[:, :, :vision_start, :] = attn_output_before
+
+ if vision_end is not None and vision_end < q_len:
+ attn_output_after = _full_attention(
+ query_states[:, :, vision_end:, :],
+ key_states,
+ value_states,
+ causal=True,
+ )
+ attn_output = torch.cat([attn_output_va, attn_output_after], dim=2)
+ else:
+ attn_output = attn_output_va
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_weights = None
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS.get(
+ getattr(self.config, "_attn_implementation", None),
+ eager_attention_forward,
+ )
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else getattr(self, "attention_dropout", 0.0),
+ scaling=self.head_dim**-0.5,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
diff --git a/angelslim/compressor/sparsity/vecattention/ops/__init__.py b/angelslim/compressor/sparsity/vecattention/ops/__init__.py
new file mode 100644
index 00000000..14a3dceb
--- /dev/null
+++ b/angelslim/compressor/sparsity/vecattention/ops/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""VecAttention ops: Triton kernels and utility functions."""
diff --git a/angelslim/compressor/sparsity/vecattention/ops/cache/dit_vecattention_kernels_best_eff_configs.pkl b/angelslim/compressor/sparsity/vecattention/ops/cache/dit_vecattention_kernels_best_eff_configs.pkl
new file mode 100644
index 00000000..dda3b713
Binary files /dev/null and b/angelslim/compressor/sparsity/vecattention/ops/cache/dit_vecattention_kernels_best_eff_configs.pkl differ
diff --git a/angelslim/compressor/sparsity/vecattention/ops/cache/vlm_vecattention_kernels_best_eff_configs.pkl b/angelslim/compressor/sparsity/vecattention/ops/cache/vlm_vecattention_kernels_best_eff_configs.pkl
new file mode 100644
index 00000000..d57abf11
Binary files /dev/null and b/angelslim/compressor/sparsity/vecattention/ops/cache/vlm_vecattention_kernels_best_eff_configs.pkl differ
diff --git a/angelslim/compressor/sparsity/vecattention/ops/vecattention_kernel.py b/angelslim/compressor/sparsity/vecattention/ops/vecattention_kernel.py
new file mode 100644
index 00000000..cd8d89f5
--- /dev/null
+++ b/angelslim/compressor/sparsity/vecattention/ops/vecattention_kernel.py
@@ -0,0 +1,797 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import os
+import pickle
+
+import torch
+import triton
+import triton.language as tl
+
+try:
+ from eval.check_env import get_env_name
+except ImportError:
+
+ def get_env_name():
+ """Fallback: assume 'vlm' environment when eval module is unavailable."""
+ return "vlm"
+
+
+# Set to True to enable Triton autotuning (benchmarks kernel configs on your hardware).
+# When False (default), pre-computed configs are loaded from the cache file.
+# To run autotuning: set this to True in the source (or module attribute) before the first import.
+USE_TRITON_AUTOTUNE: bool = False
+BEST_EFF_CONFIGS_CACHE = {}
+env_prefix = get_env_name()
+
+_KERNEL_DIR = os.path.dirname(os.path.abspath(__file__))
+CONFIG_PATH = os.path.join(
+ _KERNEL_DIR, "cache", f"{env_prefix}_vecattention_kernels_best_eff_configs.pkl"
+)
+
+
+def load_best_eff_configs(acc_config, file_path=CONFIG_PATH):
+ """Load the best efficiency kernel config for a given accuracy config tuple.
+
+ Caches results in memory to avoid repeated disk reads.
+
+ Args:
+ acc_config: Tuple of (q_pooling_size, k_local_size, group_k_block, head_dim, causal).
+ file_path: Path to the pickled config cache file.
+
+ Returns:
+ dict of kernel hyperparameters (BLOCK_SIZE_Q, num_warps, num_stages).
+ """
+ global BEST_EFF_CONFIGS_CACHE
+ if acc_config in BEST_EFF_CONFIGS_CACHE:
+ return BEST_EFF_CONFIGS_CACHE[acc_config]
+
+ try:
+ with open(file_path, "rb") as f:
+ all_configs = pickle.load(f)
+ except FileNotFoundError:
+ all_configs = {}
+
+ eff_config = None
+ for k, v in all_configs.items():
+ if acc_config == k[: len(acc_config)]:
+ eff_config = v
+ break
+ if eff_config is None:
+ # No cached config: fall back to safe defaults
+ eff_config = {"BLOCK_SIZE_Q": 64, "num_warps": 4, "num_stages": 3}
+
+ BEST_EFF_CONFIGS_CACHE[acc_config] = eff_config
+ return eff_config
+
+
+def fuse_qk_softmax_minp_wo_causal(
+ avg_q_chunk,
+ k,
+ avg_q_chunk_offset,
+ gap,
+ causal=True,
+ q_pooling_size=128,
+ k_local_size=128,
+ wo_initial=False,
+ group_k_block=1,
+):
+ """Fused Triton kernel for Q·K^T scoring with MinP column selection.
+
+ Computes block-level attention scores between pooled query vectors and
+ full key vectors, then selects key columns whose score exceeds a
+ per-row running maximum minus ``gap`` (i.e., MinP-style thresholding).
+
+ Args:
+ avg_q_chunk: Pooled query blocks, shape (B, H, q_num_blocks, head_dim).
+ k: Full key tensor, shape (B, H, k_len, head_dim).
+ avg_q_chunk_offset: Starting Q-block index in the full sequence (for chunked prefill).
+ gap: Logit gap threshold (= -log(min_p)). Float or (B, H) Tensor for per-head thresholds.
+ causal: Whether to apply causal masking.
+ q_pooling_size: Number of tokens per query block.
+ k_local_size: Key block size for column selection (BLOCK_SIZE_K in the kernel).
+ wo_initial: If True, skip the initial (sink) k-block to avoid double-counting.
+ group_k_block: Number of consecutive k-blocks merged into one Triton program.
+
+ Returns:
+ column_count: (B, H, q_num_blocks) int32 — number of selected columns per Q-block.
+ column_index: (B, H, q_num_blocks, max_count) int32 — indices of selected key tokens.
+ """
+ batch_size, num_heads, q_num_blocks, head_dim = avg_q_chunk.shape
+ assert avg_q_chunk.shape[1] == k.shape[1]
+ k_len = k.shape[-2]
+ sm_scale = 1.0 / math.sqrt(head_dim)
+ BLOCK_SIZE_K = k_local_size
+
+ if USE_TRITON_AUTOTUNE:
+ BLOCK_SIZE_Q = 128
+ kernel_autotune_args = {}
+ else:
+ kernel_autotune_args = load_best_eff_configs(
+ (q_pooling_size, k_local_size, group_k_block, head_dim, causal)
+ )
+ BLOCK_SIZE_Q = kernel_autotune_args.get("BLOCK_SIZE_Q", 64)
+
+ padded_q_num_blocks = math.ceil(q_num_blocks / BLOCK_SIZE_Q) * BLOCK_SIZE_Q
+ avg_q_chunk = torch.cat(
+ [
+ avg_q_chunk,
+ torch.zeros(
+ (batch_size, num_heads, padded_q_num_blocks - q_num_blocks, head_dim),
+ dtype=avg_q_chunk.dtype,
+ device=avg_q_chunk.device,
+ ),
+ ],
+ dim=2,
+ )
+ avg_q_chunk_real_length = q_num_blocks
+
+ padded_k_len = math.ceil(k_len / BLOCK_SIZE_K) * BLOCK_SIZE_K
+ k = torch.cat(
+ [
+ k,
+ torch.full(
+ (batch_size, num_heads, padded_k_len - k_len, head_dim),
+ torch.nan,
+ dtype=k.dtype,
+ device=k.device,
+ ),
+ ],
+ dim=2,
+ )
+
+ assert not causal or q_pooling_size % BLOCK_SIZE_K == 0
+
+ column_count = torch.zeros(
+ (batch_size, num_heads, padded_q_num_blocks), dtype=torch.int32, device=avg_q_chunk.device
+ )
+ column_index = torch.zeros(
+ (batch_size, num_heads, padded_q_num_blocks, padded_k_len),
+ dtype=torch.int32,
+ device=avg_q_chunk.device,
+ )
+
+ def grid(META):
+ return (
+ padded_q_num_blocks // META["BLOCK_SIZE_Q"],
+ math.ceil(padded_k_len / (k_local_size * group_k_block)),
+ batch_size * num_heads,
+ )
+
+ if isinstance(gap, float):
+ kernel_args = [
+ avg_q_chunk,
+ k,
+ sm_scale,
+ column_count,
+ column_index,
+ avg_q_chunk.stride(0),
+ avg_q_chunk.stride(1),
+ avg_q_chunk.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ column_count.stride(0),
+ column_count.stride(1),
+ column_index.stride(0),
+ column_index.stride(1),
+ column_index.stride(2),
+ avg_q_chunk_offset,
+ avg_q_chunk_real_length,
+ k_len,
+ num_heads,
+ head_dim,
+ gap,
+ causal,
+ q_pooling_size,
+ k_local_size,
+ group_k_block,
+ wo_initial,
+ ]
+ _causal_fuse_qk_cutoff_wo_causal_kernel[grid](*kernel_args, **kernel_autotune_args)
+ elif isinstance(gap, torch.Tensor):
+ kernel_args = [
+ avg_q_chunk,
+ k,
+ sm_scale,
+ gap,
+ column_count,
+ column_index,
+ avg_q_chunk.stride(0),
+ avg_q_chunk.stride(1),
+ avg_q_chunk.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ column_count.stride(0),
+ column_count.stride(1),
+ column_index.stride(0),
+ column_index.stride(1),
+ column_index.stride(2),
+ avg_q_chunk_offset,
+ avg_q_chunk_real_length,
+ k_len,
+ num_heads,
+ head_dim,
+ causal,
+ q_pooling_size,
+ k_local_size,
+ group_k_block,
+ wo_initial,
+ ]
+ _causal_fuse_qk_cutoff_wo_causal_perHead_kernel[grid](*kernel_args, **kernel_autotune_args)
+
+ column_count = column_count[:, :, :q_num_blocks].contiguous()
+ max_column_count = torch.max(column_count)
+ column_index = column_index[
+ :, :, :q_num_blocks, : max(min(k_len, max_column_count), 1)
+ ].contiguous()
+ return column_count, column_index
+
+
+def clean_count(nargs):
+ nargs["column_count"].fill_(0)
+
+
+configs = [
+ triton.Config({"BLOCK_SIZE_Q": BSQ}, num_warps=w, num_stages=s, pre_hook=clean_count)
+ for BSQ in [16, 32, 64, 128]
+ for s in [1, 2, 3]
+ for w in [4, 8]
+]
+
+if USE_TRITON_AUTOTUNE:
+ decorator = triton.autotune(
+ configs=configs,
+ key=["q_pooling_size", "k_local_size", "group_k_block", "HEAD_DIM", "CAUSAL"],
+ cache_results=True,
+ )
+else:
+
+ def decorator(func):
+ return func
+
+
+@decorator
+@triton.jit
+def _causal_fuse_qk_cutoff_wo_causal_kernel(
+ q,
+ k,
+ sm_scale,
+ column_count,
+ column_index,
+ stride_qb,
+ stride_qh,
+ stride_qn,
+ stride_kb,
+ stride_kh,
+ stride_kn,
+ stride_ccb,
+ stride_cch,
+ stride_cib,
+ stride_cih,
+ stride_cin,
+ q_chunk_offset,
+ q_length,
+ k_length,
+ NUM_HEADS: tl.constexpr,
+ HEAD_DIM: tl.constexpr,
+ gap: tl.constexpr,
+ CAUSAL: tl.constexpr,
+ q_pooling_size: tl.constexpr,
+ k_local_size: tl.constexpr,
+ group_k_block: tl.constexpr,
+ wo_init: tl.constexpr,
+ BLOCK_SIZE_Q: tl.constexpr,
+):
+ BLOCK_SIZE_K: tl.constexpr = k_local_size
+
+ qblock_id = tl.program_id(0).to(tl.int64)
+ kblock_id = tl.program_id(1).to(tl.int64) * group_k_block
+ batch_id = tl.program_id(2).to(tl.int64) // NUM_HEADS
+ head_id = tl.program_id(2).to(tl.int64) % NUM_HEADS
+
+ k_offset_start = (
+ kblock_id * BLOCK_SIZE_K if not wo_init else max(q_pooling_size, kblock_id * BLOCK_SIZE_K)
+ )
+ max_valid_offset = (
+ k_length
+ if not CAUSAL
+ else min(
+ k_length,
+ (q_chunk_offset + (qblock_id + 1) * BLOCK_SIZE_Q) * q_pooling_size - q_pooling_size,
+ )
+ )
+ k_offset_end = min(max_valid_offset, kblock_id * BLOCK_SIZE_K + k_local_size * group_k_block)
+
+ if k_offset_end <= k_offset_start:
+ return
+
+ offset_at_causal = q_chunk_offset * q_pooling_size + qblock_id * BLOCK_SIZE_Q * q_pooling_size
+ k_ptrs = k + batch_id * stride_kb + head_id * stride_kh + k_offset_start * stride_kn
+ k_ptrs = (
+ k_ptrs + tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_kn + tl.arange(0, HEAD_DIM)[:, None]
+ )
+
+ q_ptrs = q + batch_id * stride_qb + head_id * stride_qh + qblock_id * BLOCK_SIZE_Q * stride_qn
+ q_ptrs = (
+ q_ptrs + tl.arange(0, BLOCK_SIZE_Q)[:, None] * stride_qn + tl.arange(0, HEAD_DIM)[None, :]
+ )
+ qblock = tl.load(q_ptrs)
+ qblock = (qblock * sm_scale).to(qblock.type.element_ty)
+
+ column_count_ptr = (
+ column_count + batch_id * stride_ccb + head_id * stride_cch + qblock_id * BLOCK_SIZE_Q
+ )
+ column_count_ptr = column_count_ptr + tl.arange(0, BLOCK_SIZE_Q)[:, None]
+ column_index_ptr = (
+ column_index
+ + batch_id * stride_cib.to(tl.int64)
+ + head_id * stride_cih
+ + qblock_id * BLOCK_SIZE_Q * stride_cin
+ )
+ column_index_ptr = (
+ column_index_ptr
+ + tl.arange(0, BLOCK_SIZE_Q)[:, None] * stride_cin
+ + tl.arange(0, BLOCK_SIZE_K)[None, :]
+ )
+
+ qk_max = tl.zeros((BLOCK_SIZE_Q, 1), dtype=tl.float32) - float("inf")
+
+ for k_offset in tl.range(k_offset_start, k_offset_end, BLOCK_SIZE_K):
+ kblock = tl.load(k_ptrs)
+ qk = tl.dot(qblock, kblock)
+ qk_max = tl.maximum(qk_max, tl.max(qk, axis=1, keep_dims=True))
+ qk_mask = (qk + gap) >= qk_max
+ if CAUSAL:
+ qk_mask = qk_mask & (
+ tl.arange(0, BLOCK_SIZE_Q)[:, None]
+ > ((k_offset - offset_at_causal) // q_pooling_size)
+ )
+
+ row_counts = tl.sum(qk_mask.to(tl.int32), axis=1, keep_dims=True)
+ idx = tl.arange(0, BLOCK_SIZE_K)[None, :] + k_offset
+ idx = tl.where(qk_mask, idx, k_length)
+ idx = tl.sort(idx, dim=1, descending=False)
+
+ # Requires triton > 3.4.0 for correct atomic_add broadcast behaviour.
+ # See: https://github.com/triton-lang/triton/issues/7402
+ column_index_offset = tl.atomic_add(column_count_ptr, row_counts, sem="relaxed")
+ tl.store(column_index_ptr + column_index_offset, idx, mask=idx < k_length)
+
+ k_ptrs = k_ptrs + BLOCK_SIZE_K * stride_kn
+
+
+@decorator
+@triton.jit
+def _causal_fuse_qk_cutoff_wo_causal_perHead_kernel(
+ q,
+ k,
+ sm_scale,
+ gaps,
+ column_count,
+ column_index,
+ stride_qb,
+ stride_qh,
+ stride_qn,
+ stride_kb,
+ stride_kh,
+ stride_kn,
+ stride_ccb,
+ stride_cch,
+ stride_cib,
+ stride_cih,
+ stride_cin,
+ q_chunk_offset,
+ q_length,
+ k_length,
+ NUM_HEADS: tl.constexpr,
+ HEAD_DIM: tl.constexpr,
+ CAUSAL: tl.constexpr,
+ q_pooling_size: tl.constexpr,
+ k_local_size: tl.constexpr,
+ group_k_block: tl.constexpr,
+ wo_init: tl.constexpr,
+ BLOCK_SIZE_Q: tl.constexpr,
+):
+ """Per-head variant: each head reads its own gap value from the ``gaps`` tensor."""
+ BLOCK_SIZE_K: tl.constexpr = k_local_size
+
+ qblock_id = tl.program_id(0).to(tl.int64)
+ kblock_id = tl.program_id(1).to(tl.int64) * group_k_block
+ batch_id = tl.program_id(2).to(tl.int64) // NUM_HEADS
+ head_id = tl.program_id(2).to(tl.int64) % NUM_HEADS
+
+ gap_ptr = gaps + batch_id * NUM_HEADS + head_id
+ gap = tl.load(gap_ptr)
+
+ k_offset_start = (
+ kblock_id * BLOCK_SIZE_K if not wo_init else max(q_pooling_size, kblock_id * BLOCK_SIZE_K)
+ )
+ max_valid_offset = (
+ k_length
+ if not CAUSAL
+ else min(
+ k_length,
+ (q_chunk_offset + (qblock_id + 1) * BLOCK_SIZE_Q) * q_pooling_size - q_pooling_size,
+ )
+ )
+ k_offset_end = min(max_valid_offset, kblock_id * BLOCK_SIZE_K + k_local_size * group_k_block)
+
+ if k_offset_end <= k_offset_start:
+ return
+
+ offset_at_causal = q_chunk_offset * q_pooling_size + qblock_id * BLOCK_SIZE_Q * q_pooling_size
+ k_ptrs = k + batch_id * stride_kb + head_id * stride_kh + k_offset_start * stride_kn
+ k_ptrs = (
+ k_ptrs + tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_kn + tl.arange(0, HEAD_DIM)[:, None]
+ )
+
+ q_ptrs = q + batch_id * stride_qb + head_id * stride_qh + qblock_id * BLOCK_SIZE_Q * stride_qn
+ q_ptrs = (
+ q_ptrs + tl.arange(0, BLOCK_SIZE_Q)[:, None] * stride_qn + tl.arange(0, HEAD_DIM)[None, :]
+ )
+ qblock = tl.load(q_ptrs)
+ qblock = (qblock * sm_scale).to(qblock.type.element_ty)
+
+ column_count_ptr = (
+ column_count + batch_id * stride_ccb + head_id * stride_cch + qblock_id * BLOCK_SIZE_Q
+ )
+ column_count_ptr = column_count_ptr + tl.arange(0, BLOCK_SIZE_Q)[:, None]
+ column_index_ptr = (
+ column_index
+ + batch_id * stride_cib.to(tl.int64)
+ + head_id * stride_cih
+ + qblock_id * BLOCK_SIZE_Q * stride_cin
+ )
+ column_index_ptr = (
+ column_index_ptr
+ + tl.arange(0, BLOCK_SIZE_Q)[:, None] * stride_cin
+ + tl.arange(0, BLOCK_SIZE_K)[None, :]
+ )
+
+ qk_max = tl.zeros((BLOCK_SIZE_Q, 1), dtype=tl.float32) - float("inf")
+
+ for k_offset in tl.range(k_offset_start, k_offset_end, BLOCK_SIZE_K):
+ kblock = tl.load(k_ptrs)
+ qk = tl.dot(qblock, kblock)
+ qk_max = tl.maximum(qk_max, tl.max(qk, axis=1, keep_dims=True))
+ qk_mask = (qk + gap) >= qk_max
+ if CAUSAL:
+ qk_mask = qk_mask & (
+ tl.arange(0, BLOCK_SIZE_Q)[:, None]
+ > ((k_offset - offset_at_causal) // q_pooling_size)
+ )
+
+ row_counts = tl.sum(qk_mask.to(tl.int32), axis=1, keep_dims=True)
+ idx = tl.arange(0, BLOCK_SIZE_K)[None, :] + k_offset
+ idx = tl.where(qk_mask, idx, k_length)
+ idx = tl.sort(idx, dim=1, descending=False)
+
+ # Requires triton > 3.4.0 for correct atomic_add broadcast behaviour.
+ # See: https://github.com/triton-lang/triton/issues/7402
+ column_index_offset = tl.atomic_add(column_count_ptr, row_counts, sem="relaxed")
+ tl.store(column_index_ptr + column_index_offset, idx, mask=idx < k_length)
+
+ k_ptrs = k_ptrs + BLOCK_SIZE_K * stride_kn
+
+
+# ====================== Autotuning utilities =========================
+
+
+def find_best_eff_config(
+ q_pooling_size: int,
+ k_local_size: int,
+ group_k_block: int,
+ seq_len: int = 64 * 1024,
+ head_dim: int = 128,
+ causal: bool = True,
+):
+ """Run Triton autotune for one (q_pooling_size, k_local_size, group_k_block) combination.
+
+ Must be called with USE_TRITON_AUTOTUNE=1.
+
+ Returns:
+ The best Triton Config object, or None if not found in cache.
+ """
+ device = "cuda"
+ dtype = torch.float16
+ batch_size, num_heads = 1, 32
+ gap = 1e6
+
+ k_len = ((seq_len + q_pooling_size - 1) // q_pooling_size) * q_pooling_size
+ q_num_blocks = k_len // q_pooling_size
+
+ avg_q = torch.randn(batch_size, num_heads, q_num_blocks, head_dim, dtype=dtype, device=device)
+ k = torch.randn(batch_size, num_heads, k_len, head_dim, dtype=dtype, device=device)
+ fuse_qk_softmax_minp_wo_causal(
+ avg_q, k, 0, gap, causal, q_pooling_size, k_local_size, causal, group_k_block
+ )
+ torch.cuda.synchronize()
+ key = (q_pooling_size, k_local_size, group_k_block, head_dim, causal)
+ for k, v in _causal_fuse_qk_cutoff_wo_causal_kernel.cache.items():
+ if key == k[: len(key)]:
+ return v
+ return None
+
+
+def find_all_config(
+ q_pooling_size_list=None,
+ k_local_size_list=None,
+ group_k_block_list=None,
+ seq_len=64 * 1024,
+ head_dim=128,
+ causal=True,
+ save=False,
+ save_path=CONFIG_PATH,
+ load_existing_results=True,
+):
+ """Sweep all (q_pooling_size, k_local_size, group_k_block) combinations.
+
+ Find the best kernel config.
+
+ Args:
+ q_pooling_size_list: List of Q pooling sizes to sweep (default [64, 128]).
+ k_local_size_list: List of K local sizes to sweep (default [16, 32, 64, 128]).
+ group_k_block_list: List of group_k_block values to sweep (default [1, 2, 4, 8]).
+ seq_len: Sequence length used for benchmarking.
+ head_dim: Head dimension.
+ causal: Whether to benchmark causal attention.
+ save: If True, pickle results to ``save_path``.
+ save_path: Output path for the pickled config cache.
+ load_existing_results: If True, skip configurations already in the saved file.
+
+ Returns:
+ dict mapping config tuples to their best kernel hyperparameters.
+ """
+ if q_pooling_size_list is None:
+ q_pooling_size_list = [64, 128]
+ if k_local_size_list is None:
+ k_local_size_list = [16, 32, 64, 128]
+ if group_k_block_list is None:
+ group_k_block_list = [1, 2, 4, 8]
+
+ if load_existing_results and os.path.exists(save_path):
+ with open(save_path, "rb") as f:
+ best_configs = pickle.load(f)
+ print(f"Loaded existing best configs from {save_path}")
+ else:
+ best_configs = {}
+
+ for q_pooling_size in q_pooling_size_list:
+ for k_local_size in k_local_size_list:
+ if k_local_size >= q_pooling_size:
+ # k_local_size == q_pooling_size causes register spilling
+ continue
+ for group_k_block in group_k_block_list:
+ key = (q_pooling_size, k_local_size, group_k_block, head_dim, causal)
+ if key in best_configs:
+ continue
+ print(f"Testing config: {key}")
+ best_eff_config = find_best_eff_config(
+ q_pooling_size, k_local_size, group_k_block, seq_len, head_dim, causal
+ )
+ if best_eff_config is None:
+ raise ValueError(f"No best config found for: {key}")
+ best_eff_config = best_eff_config.__dict__
+ best_configs[key] = {
+ **best_eff_config["kwargs"],
+ **{
+ k: v
+ for k, v in best_eff_config.items()
+ if k not in ["kwargs", "pre_hook", "ir_override"]
+ },
+ }
+ print(f"Best config for {key}: {best_configs[key]}")
+
+ if save:
+ with open(save_path, "wb") as f:
+ pickle.dump(best_configs, f)
+ print(f"Saved best configs to {save_path}")
+
+ return best_configs
+
+
+def autotune_main():
+ """Entry point for generating the kernel efficiency config cache.
+
+ Sets ``USE_TRITON_AUTOTUNE = True`` automatically so that the Triton autotune
+ decorator is active for the benchmarking sweep.
+ """
+ global USE_TRITON_AUTOTUNE
+ USE_TRITON_AUTOTUNE = True
+ q_pooling_size_list = [64, 128]
+ k_local_size_list = [16, 32, 64, 128]
+ group_k_block_list = [1, 2, 4, 8, 16, 32, 64, 128]
+
+ find_all_config(
+ q_pooling_size_list,
+ k_local_size_list,
+ group_k_block_list,
+ seq_len=64 * 1024,
+ save=True,
+ save_path=CONFIG_PATH,
+ load_existing_results=False,
+ causal=True if env_prefix == "vlm" else False,
+ )
+
+
+if __name__ == "__main__":
+ autotune_main()
+
+
+# ---------------------------------------------------------------------------
+# Utility functions: pooling kernel and helpers
+# ---------------------------------------------------------------------------
+
+
+@triton.jit
+def bnhd_pool_kernel(
+ x_ptr,
+ y_ptr,
+ pool_type: tl.constexpr,
+ batch_size,
+ seq_len,
+ num_heads,
+ head_dim: tl.constexpr,
+ stride_xb,
+ stride_xn,
+ stride_xh,
+ stride_xd,
+ stride_yb,
+ stride_yn,
+ stride_yh,
+ stride_yd,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_H: tl.constexpr,
+ BLOCK_SIZE_D: tl.constexpr,
+):
+ pid_b = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ pid_h = tl.program_id(2)
+
+ x_ptr = (
+ x_ptr
+ + pid_b * stride_xb
+ + pid_n * BLOCK_SIZE_N * stride_xn
+ + pid_h * BLOCK_SIZE_H * stride_xh
+ )
+
+ off_n = tl.arange(0, BLOCK_SIZE_N)
+ off_h = tl.arange(0, BLOCK_SIZE_H)
+ off_d = tl.arange(0, BLOCK_SIZE_D)
+
+ cur_block_size_n = min(seq_len - pid_n * BLOCK_SIZE_N, BLOCK_SIZE_N)
+
+ x_mask = (
+ (off_n < seq_len - pid_n * BLOCK_SIZE_N)[:, None, None]
+ & (off_h < num_heads - pid_h * BLOCK_SIZE_H)[None, :, None]
+ & (off_d < head_dim)[None, None, :]
+ )
+ x = tl.load(
+ x_ptr
+ + off_n[:, None, None] * stride_xn
+ + off_h[None, :, None] * stride_xh
+ + off_d[None, None, :] * stride_xd,
+ mask=x_mask,
+ other=0,
+ )
+ if pool_type == 0:
+ y = tl.sum(x, axis=0) / cur_block_size_n
+ elif pool_type == 1:
+ y = tl.max(x, axis=0)
+ elif pool_type == 2:
+ y = tl.min(x, axis=0)
+ elif pool_type == 3:
+ y = tl.max(tl.abs(x), axis=0)
+ elif pool_type == 4:
+ y = tl.sum(x, axis=0)
+ else:
+ y = tl.sum(x, axis=0) / cur_block_size_n
+
+ y_ptr = y_ptr + pid_b * stride_yb + pid_n * stride_yn + pid_h * BLOCK_SIZE_H * stride_yh
+ y_mask = (off_h < num_heads - pid_h * BLOCK_SIZE_H)[:, None] & (off_d < head_dim)[None, :]
+ tl.store(y_ptr + off_h[:, None] * stride_yh + off_d[None, :] * stride_yd, y, mask=y_mask)
+
+
+def triton_bnhd_pool(x: torch.Tensor, kernel_size: int, pool_type: str = "avg"):
+ b, n, h, d = x.shape
+ assert d in {16, 32, 64, 128}
+ assert kernel_size in {1, 16, 32, 64, 128, 256, 512}
+ if kernel_size == 1:
+ return x
+ m = triton.cdiv(n, kernel_size)
+ y = torch.zeros(b, m, h, d, device=x.device, dtype=x.dtype)
+
+ if pool_type == "last":
+ if n % kernel_size == 0:
+ return x[:, kernel_size - 1 :: kernel_size, ...]
+ else:
+ return torch.cat((x[:, kernel_size - 1 :: kernel_size, ...], x[:, -1:, ...]), dim=1)
+
+ block_size_h = triton.next_power_of_2(h)
+ while kernel_size * block_size_h * d > 128 * 128 * 128:
+ block_size_h = block_size_h // 2
+ assert block_size_h != 0
+
+ block_size_d = triton.next_power_of_2(d)
+ pool_str_to_type = {"avg": 0, "max": 1, "min": 2, "maxabs": 3, "sum": 4}
+ pool_type = pool_str_to_type[pool_type]
+
+ def grid(META):
+ return (
+ b,
+ triton.cdiv(n, META["BLOCK_SIZE_N"]),
+ triton.cdiv(h, META["BLOCK_SIZE_H"]),
+ )
+
+ bnhd_pool_kernel[grid](
+ x,
+ y,
+ pool_type,
+ b,
+ n,
+ h,
+ d,
+ x.stride(0),
+ x.stride(1),
+ x.stride(2),
+ x.stride(3),
+ y.stride(0),
+ y.stride(1),
+ y.stride(2),
+ y.stride(3),
+ BLOCK_SIZE_N=kernel_size,
+ BLOCK_SIZE_H=block_size_h,
+ BLOCK_SIZE_D=block_size_d,
+ )
+ return y
+
+
+def causal_mask_in_uneqal_block(seqlen, block_size_q, block_size_k, device):
+ num_q_blocks = math.ceil(seqlen / block_size_q)
+ num_k_blocks = math.ceil(seqlen / block_size_k)
+ q_start = torch.arange(num_q_blocks, device=device) * block_size_q
+ q_end = q_start + block_size_q - 1
+ k_start = torch.arange(num_k_blocks, device=device) * block_size_k
+ k_end = k_start + block_size_k - 1
+ overlap = (q_start.unsqueeze(1) <= k_end.unsqueeze(0)) & (
+ k_start.unsqueeze(0) <= q_end.unsqueeze(1)
+ )
+ assert overlap.any(dim=-1).all()
+ last_true_idx = overlap.float().cumsum(dim=-1).argmax(dim=-1)
+ col_idx = torch.arange(num_k_blocks, device=overlap.device).unsqueeze(0)
+ mask = col_idx <= last_true_idx.unsqueeze(1)
+ return overlap, mask
+
+
+def average_vector(q, block_size, use_triton=True):
+ """Average pool query vectors: (B, H, L, D) -> (B, H, num_blocks, D)."""
+ batch_size, num_heads, seq_len, head_dim = q.shape
+ dtype = q.dtype
+ q = q.float()
+ num_blocks = math.ceil(seq_len / block_size)
+ if use_triton:
+ q = q.transpose(1, 2)
+ return triton_bnhd_pool(q, block_size).transpose(1, 2).to(dtype)
+ else:
+ pad_q = num_blocks * block_size - seq_len
+ avg_q = (
+ torch.nn.functional.pad(q, (0, 0, 0, pad_q), value=0)
+ .view(batch_size, num_heads, num_blocks, block_size, head_dim)
+ .mean(-2)
+ )
+ avg_q[:, :, -1, :] = avg_q[:, :, -1, :] * block_size / (block_size - pad_q)
+ return avg_q.to(dtype)
diff --git a/angelslim/compressor/sparsity/vecattention/ops/vllm-flash-attention b/angelslim/compressor/sparsity/vecattention/ops/vllm-flash-attention
new file mode 160000
index 00000000..4a6393c1
--- /dev/null
+++ b/angelslim/compressor/sparsity/vecattention/ops/vllm-flash-attention
@@ -0,0 +1 @@
+Subproject commit 4a6393c19c953bab156d763d5091ece18c40817b
diff --git a/angelslim/compressor/sparsity/vecattention/patch.py b/angelslim/compressor/sparsity/vecattention/patch.py
new file mode 100644
index 00000000..b3d59463
--- /dev/null
+++ b/angelslim/compressor/sparsity/vecattention/patch.py
@@ -0,0 +1,77 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Model-patching logic: replace the standard attention forward with VecAttention's."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from .modules.forward import qwen_vl_attn_forward
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+
+ from .vecattention_configuration import VecAttentionConfig
+
+
+def vecattention_patch(
+ model: "PreTrainedModel", config: "VecAttentionConfig"
+) -> "PreTrainedModel":
+ """Replace each attention layer's ``forward`` with VecAttention sparse prefill.
+
+ Supports Qwen2.5-VL and Qwen2-VL models.
+
+ Args:
+ model: A HuggingFace VLM model (e.g. Qwen2.5-VL).
+ config: VecAttention runtime configuration.
+
+ Returns:
+ The same *model* object, mutated in-place with VecAttention.
+
+ Raises:
+ ValueError: If the model's ``model_type`` is not supported.
+ """
+ model_type = model.config.model_type.lower()
+
+ if "qwen2_5_vl" in model_type or "qwen2_vl" in model_type:
+ _patch_qwen_vl(model, config)
+ else:
+ raise ValueError(
+ f"VecAttention does not support model_type={model_type!r}. "
+ f"Supported: qwen2_5_vl, qwen2_vl."
+ )
+
+ return model
+
+
+def _patch_qwen_vl(model: "PreTrainedModel", config: "VecAttentionConfig") -> None:
+ """Apply VecAttention patch to Qwen2.5-VL / Qwen2-VL models."""
+ if hasattr(model.model, "language_model"):
+ layers = model.model.language_model.layers
+ elif hasattr(model.model, "layers"):
+ layers = model.model.layers
+ else:
+ raise AttributeError(
+ "Cannot find attention layers in Qwen VL model. "
+ "Expected model.model.language_model.layers or model.model.layers."
+ )
+
+ AttentionClass = layers[0].self_attn.__class__
+
+ for i, layer in enumerate(layers):
+ attn = layer.self_attn
+ attn.layer_idx = i
+ attn.vecattention_config = config
+ attn.forward = qwen_vl_attn_forward.__get__(attn, AttentionClass)
diff --git a/angelslim/compressor/sparsity/vecattention/vecattention.py b/angelslim/compressor/sparsity/vecattention/vecattention.py
new file mode 100644
index 00000000..81e1a5ac
--- /dev/null
+++ b/angelslim/compressor/sparsity/vecattention/vecattention.py
@@ -0,0 +1,53 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""High-level entry point for applying VecAttention to a VLM model."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from .patch import vecattention_patch
+from .vecattention_configuration import VecAttentionConfig
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+
+
+class VecAttentionInference:
+ """Callable object that patches a model to use VecAttention sparse prefill.
+
+ Only the prefill phase (q_len > 1) uses VecAttention; decode falls back
+ to the model's original attention implementation.
+
+ Usage::
+
+ vec = VecAttentionInference(attn_kwargs={"threshold": 0.9})
+ model = vec(model)
+
+ Args:
+ attn_kwargs: Forwarded to ``VecAttentionConfig``. See its docstring
+ for valid keys (threshold, block_size_q, block_size_k,
+ group_k_block, chunk_size).
+ """
+
+ def __init__(self, attn_kwargs: dict | None = None) -> None:
+ self.config = VecAttentionConfig(attn_kwargs=attn_kwargs)
+
+ def __call__(self, model: "PreTrainedModel") -> "PreTrainedModel":
+ """Apply the VecAttention patch and return the modified model."""
+ return vecattention_patch(model, self.config)
+
+ def __repr__(self) -> str:
+ return f"VecAttentionInference(config={self.config!r})"
diff --git a/angelslim/compressor/sparsity/vecattention/vecattention_configuration.py b/angelslim/compressor/sparsity/vecattention/vecattention_configuration.py
new file mode 100644
index 00000000..06d3b819
--- /dev/null
+++ b/angelslim/compressor/sparsity/vecattention/vecattention_configuration.py
@@ -0,0 +1,83 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Configuration class for VecAttention sparse prefill."""
+
+from __future__ import annotations
+
+
+class VecAttentionConfig:
+ """Configuration container for VecAttention sparse attention.
+
+ Args:
+ attn_kwargs: Dictionary of keyword arguments. Recognised keys:
+
+ - ``threshold`` (float): MinP threshold for key column selection.
+ Higher = more aggressive sparsity. Default 0.9.
+ - ``block_size_q`` (int): Query pooling block size. Must be 64 or 128.
+ Default 64.
+ - ``block_size_k`` (int): Key local block size. Default 16.
+ - ``group_k_block`` (int): Number of k-blocks processed together.
+ Default 1.
+ - ``chunk_size`` (int): Prefill chunk size (tokens). Must be a
+ multiple of block_size_q. Default 65536.
+
+ Raises:
+ ValueError: If block_size_q is not 64 or 128, or chunk_size is not
+ a multiple of block_size_q.
+ """
+
+ def __init__(self, attn_kwargs: dict | None = None) -> None:
+ self.attn_kwargs: dict = dict(attn_kwargs or {})
+ # Set defaults
+ self.attn_kwargs.setdefault("threshold", 0.1)
+ self.attn_kwargs.setdefault("block_size_q", 64)
+ self.attn_kwargs.setdefault("block_size_k", 16)
+ self.attn_kwargs.setdefault("group_k_block", 16)
+ self.attn_kwargs.setdefault("chunk_size", 64 * 1024)
+
+ # Validate
+ block_size_q = self.attn_kwargs["block_size_q"]
+ chunk_size = self.attn_kwargs["chunk_size"]
+
+ if block_size_q not in (64, 128):
+ raise ValueError(f"block_size_q must be 64 or 128, got {block_size_q}")
+ if chunk_size % block_size_q != 0:
+ raise ValueError(
+ f"chunk_size ({chunk_size}) must be a multiple of "
+ f"block_size_q ({block_size_q})"
+ )
+
+ @property
+ def threshold(self):
+ return self.attn_kwargs["threshold"]
+
+ @property
+ def block_size_q(self):
+ return self.attn_kwargs["block_size_q"]
+
+ @property
+ def block_size_k(self):
+ return self.attn_kwargs["block_size_k"]
+
+ @property
+ def group_k_block(self):
+ return self.attn_kwargs["group_k_block"]
+
+ @property
+ def chunk_size(self):
+ return self.attn_kwargs["chunk_size"]
+
+ def __repr__(self) -> str:
+ return f"VecAttentionConfig(attn_kwargs={self.attn_kwargs!r})"
diff --git a/tools/run_vecattention.py b/tools/run_vecattention.py
new file mode 100644
index 00000000..8e329200
--- /dev/null
+++ b/tools/run_vecattention.py
@@ -0,0 +1,324 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+VecAttention sparse attention inference script for Vision-Language Models.
+
+Usage:
+ # Text-only query on Qwen2.5-VL
+ python tools/run_vecattention.py --mode vecattention \
+ --model-path Qwen/Qwen2.5-VL-3B-Instruct \
+ --prompt "What is the capital of France?"
+
+ # Image (URL) + text query
+ python tools/run_vecattention.py --mode vecattention \
+ --model-path Qwen/Qwen2.5-VL-3B-Instruct \
+ --prompt "Describe this image in detail." \
+ --image https://inews.gtimg.com/news_bt/OQSQBp_mW8TxXv7UsR55mi2DMfWW4D2aJJ-jsFphE5YD8AA/1000
+
+ # Image (local file) + text query
+ python tools/run_vecattention.py --mode vecattention \
+ --model-path Qwen/Qwen2.5-VL-3B-Instruct \
+ --prompt "What's in this photo?" \
+ --image /path/to/local/image.jpg
+
+ # Video (local file) + text query
+ python tools/run_vecattention.py --mode vecattention \
+ --model-path Qwen/Qwen2.5-VL-3B-Instruct \
+ --prompt "Summarize what happens in this video." \
+ --video /path/to/video.mp4 --nframes 24
+
+ # Video (URL) + text query
+ python tools/run_vecattention.py --mode vecattention \
+ --model-path Qwen/Qwen2.5-VL-3B-Instruct \
+ --prompt "Describe the key events in this video." \
+ --video https://example.com/video.mp4 --nframes 16
+
+ # Dense baseline for comparison
+ python tools/run_vecattention.py --mode dense \
+ --model-path Qwen/Qwen2.5-VL-3B-Instruct \
+ --prompt "Describe this image." \
+ --image https://example.com/image.jpg
+"""
+
+import argparse
+import sys
+import time
+from io import BytesIO
+
+import requests
+import torch
+from PIL import Image
+from transformers import AutoProcessor
+
+from angelslim.compressor.sparsity.vecattention import VecAttentionInference
+
+# Default test image
+DEFAULT_IMAGE_URL = (
+ "https://inews.gtimg.com/news_bt/" "OQSQBp_mW8TxXv7UsR55mi2DMfWW4D2aJJ-jsFphE5YD8AA/1000"
+)
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+
+
+def load_image(image_source: str) -> Image.Image:
+ """Load an image from a local path or HTTP(S) URL.
+
+ Args:
+ image_source: Local file path or URL starting with http:// or https://
+
+ Returns:
+ PIL Image in RGB mode.
+ """
+ if image_source.startswith(("http://", "https://")):
+ response = requests.get(image_source, timeout=15)
+ response.raise_for_status()
+ image = Image.open(BytesIO(response.content))
+ else:
+ image = Image.open(image_source)
+ return image.convert("RGB")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="VecAttention VLM inference: Dense vs VecAttention on Qwen2.5-VL.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--mode",
+ type=str,
+ default="vecattention",
+ choices=["dense", "vecattention"],
+ help="Attention mode: 'dense' (no patch) or 'vecattention' (sparse prefill).",
+ )
+ parser.add_argument(
+ "--model-path",
+ type=str,
+ default="Qwen/Qwen2.5-VL-3B-Instruct",
+ help="Path to the VLM model directory or HuggingFace model ID.",
+ )
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ default="Describe the content of this image in one short sentence.",
+ help="Text prompt for the model.",
+ )
+ parser.add_argument(
+ "--image",
+ type=str,
+ default=None,
+ help="Image source: local file path or HTTP(S) URL. "
+ "If not provided, runs text-only inference.",
+ )
+ parser.add_argument(
+ "--video",
+ type=str,
+ default=None,
+ help="Video source: local file path or HTTP(S) URL.",
+ )
+ parser.add_argument(
+ "--nframes",
+ type=int,
+ default=24,
+ help="Number of frames to sample from video.",
+ )
+ parser.add_argument(
+ "--max-new-tokens",
+ type=int,
+ default=128,
+ help="Maximum number of new tokens to generate.",
+ )
+ # VecAttention parameters
+ parser.add_argument("--threshold", type=float, default=0.1, help="MinP threshold.")
+ parser.add_argument(
+ "--block-size-q", type=int, default=64, choices=[64, 128], help="Q pooling block size."
+ )
+ parser.add_argument("--block-size-k", type=int, default=16, help="K local block size.")
+ parser.add_argument(
+ "--group-k-block", type=int, default=16, help="K block grouping (default 16 for VLM)."
+ )
+ parser.add_argument("--chunk-size", type=int, default=64 * 1024, help="Prefill chunk size.")
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ print(f"[Env] Python: {sys.executable}, Device: {DEVICE}")
+ print(f"[Config] mode={args.mode}, model={args.model_path}")
+ if args.image:
+ print(f"[Config] image={args.image}")
+ if args.video:
+ print(f"[Config] video={args.video}, nframes={args.nframes}")
+ if args.mode == "vecattention":
+ print(
+ f"[Config] threshold={args.threshold}, block_size_q={args.block_size_q}, "
+ f"block_size_k={args.block_size_k}, group_k_block={args.group_k_block}, "
+ f"chunk_size={args.chunk_size}"
+ )
+
+ # --- 1. Load model and processor ---
+ print("Loading model...")
+ # Use the same loading pattern as test_token_pruning.py
+ from transformers import Qwen2_5_VLForConditionalGeneration
+
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ args.model_path,
+ torch_dtype=torch.bfloat16 if DEVICE != "cpu" else torch.float32,
+ device_map=DEVICE,
+ trust_remote_code=True,
+ ).eval()
+
+ processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
+ num_layers = getattr(
+ model.config,
+ "num_hidden_layers",
+ getattr(model.config, "text_config", model.config).num_hidden_layers,
+ )
+ print(f"Model: {model.config.model_type}, {num_layers} layers")
+
+ # --- 2. Prepare inputs (before patch, so we can detect vision positions) ---
+ print("Preparing inputs...")
+ image = None
+ if args.video:
+ from qwen_vl_utils import process_vision_info
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "video", "video": args.video, "nframes": args.nframes},
+ {"type": "text", "text": args.prompt},
+ ],
+ }
+ ]
+ image_inputs, video_inputs = process_vision_info(messages)
+ text_prompt = processor.apply_chat_template(
+ messages, add_generation_prompt=True, tokenize=False
+ )
+ inputs = processor(
+ text=[text_prompt], videos=video_inputs, padding=True, return_tensors="pt"
+ ).to(model.device)
+ elif args.image:
+ image = load_image(args.image)
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": args.prompt},
+ ],
+ }
+ ]
+ text_prompt = processor.apply_chat_template(
+ messages, add_generation_prompt=True, tokenize=False
+ )
+ inputs = processor(
+ text=[text_prompt], images=[image], padding=True, return_tensors="pt"
+ ).to(model.device)
+ else:
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": args.prompt},
+ ],
+ }
+ ]
+ text_prompt = processor.apply_chat_template(
+ messages, add_generation_prompt=True, tokenize=False
+ )
+ inputs = processor(text=[text_prompt], padding=True, return_tensors="pt").to(model.device)
+
+ input_len = inputs.input_ids.shape[1]
+ print(f"[Input] token_length={input_len}")
+
+ # --- 3. Detect vision token positions from input_ids ---
+ vision_start_position = None
+ vision_end_position = None
+ if args.image or args.video:
+ input_ids = inputs.input_ids[0]
+ VISION_START_TOKEN_ID = 151652 # <|vision_start|>
+ VISION_END_TOKEN_ID = 151653 # <|vision_end|>
+
+ vision_start_indices = (input_ids == VISION_START_TOKEN_ID).nonzero(as_tuple=True)[0]
+ vision_end_indices = (input_ids == VISION_END_TOKEN_ID).nonzero(as_tuple=True)[0]
+
+ if len(vision_start_indices) > 0:
+ vision_start_position = int(vision_start_indices[0].item())
+ if len(vision_end_indices) > 0:
+ # vision_end_position is one past the last vision token
+ vision_end_position = int(vision_end_indices[-1].item()) + 1
+
+ print(f"[Vision] start={vision_start_position}, end={vision_end_position}")
+
+ # --- 4. Apply VecAttention patch ---
+ if args.mode == "vecattention":
+ print("Applying VecAttention patch...")
+ attn_kwargs = {
+ "threshold": args.threshold,
+ "block_size_q": args.block_size_q,
+ "block_size_k": args.block_size_k,
+ "group_k_block": args.group_k_block,
+ "chunk_size": args.chunk_size,
+ }
+ # Pass vision positions so VecAttention only applies to vision region
+ if vision_start_position is not None:
+ attn_kwargs["vision_start_position"] = vision_start_position
+ if vision_end_position is not None:
+ attn_kwargs["vision_end_position"] = vision_end_position
+
+ vec = VecAttentionInference(attn_kwargs=attn_kwargs)
+ model = vec(model)
+ print(f"[VecAttention] Patched {num_layers} attention layers.")
+ else:
+ print("[Dense] No patch applied. Using standard attention.")
+
+ # --- 5. Generate ---
+ print("Generating...")
+ torch.cuda.synchronize()
+ start = time.time()
+ with torch.no_grad():
+ generated_ids = model.generate(
+ **inputs,
+ max_new_tokens=args.max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+ torch.cuda.synchronize()
+ elapsed = time.time() - start
+
+ # --- 6. Decode and display ---
+ generated_ids_trimmed = generated_ids[:, input_len:]
+ output_text = processor.batch_decode(
+ generated_ids_trimmed,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ )[0]
+
+ print("=" * 80)
+ print(f"Mode: {args.mode}")
+ print(f"Input tokens: {input_len}")
+ print(f"Generated tokens: {generated_ids_trimmed.shape[1]}")
+ print(f"Total time: {elapsed:.3f}s")
+ print(f"Tokens/sec: {generated_ids_trimmed.shape[1] / (elapsed + 1e-9):.1f}")
+ print("-" * 80)
+ print("Output:")
+ print(output_text.strip())
+ print("=" * 80)
+
+ torch.cuda.empty_cache()
+
+
+if __name__ == "__main__":
+ main()
|