From d16cbaa693bc8e8ac16c76968270cda08a0b0792 Mon Sep 17 00:00:00 2001 From: Karthikeya Date: Tue, 2 Jun 2026 11:38:16 +0530 Subject: [PATCH 01/14] [WIP] Fix for acc issue in Qwen3 VL moe (#1010) Issue: Generated text is not accurate (unable to identify object in given image) 5/25: Found a workaround, seems like compiler issue - debugging further --------- Signed-off-by: vtirumal Signed-off-by: Vaibhav Verma --- .../models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 9 +++++---- .../models/qwen3_vl_moe/qwen3_vl_disagg_mode.py | 1 + 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 4a6259bf8..91b529b3c 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -642,6 +642,7 @@ def _deepstack_process( ): visual_pos_masks = visual_pos_masks.unsqueeze(-1).expand(-1, -1, self.config.hidden_size) visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + hidden_states = hidden_states.clone() visual_mask = visual_pos_masks.to(hidden_states.dtype) return hidden_states + (visual_embeds * visual_mask) @@ -653,10 +654,10 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens x = hidden_states.view(T, H) act = getattr(self.experts, "act_fn", F.silu) - router_hidden_states = x.reshape(-1, self.gate.hidden_dim) - router_logits = F.linear(router_hidden_states, self.gate.weight) - top_w, top_i = torch.topk(router_logits, self.gate.top_k, dim=-1) - top_w = F.softmax(top_w, dim=-1, dtype=torch.float) + router_logits = self.gate(x) # [T, E] + prob = F.softmax(router_logits, dim=-1, dtype=torch.float32) + top_w, top_i = torch.topk(prob, self.top_k, dim=-1) + top_w = top_w / torch.einsum("bi->b", top_w)[:, None] top_w = top_w.to(hidden_states.dtype) num_experts = getattr(self, "num_experts", self.gate.num_experts) routing_weights = torch.zeros((T, num_experts), dtype=x.dtype) diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py index 585a532fa..38ca6b478 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py @@ -217,6 +217,7 @@ decode_inputs[f"past_key.{i}"] = outputs[f"past_key.{i}_RetainedState"] decode_inputs[f"past_value.{i}"] = outputs[f"past_value.{i}_RetainedState"] + st = perf_counter() decode_out = lang_decode_session.run(decode_inputs) print(f"time for first run of decode with KV as input = {perf_counter() - st} sec\n") From e90c41075233f2cdc001ba710dde72a6439d4621 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Tue, 2 Jun 2026 11:43:02 +0530 Subject: [PATCH 02/14] Revert "[WIP] Fix for acc issue in Qwen3 VL moe" (#1019) Reverts quic/efficient-transformers#1010 Signed-off-by: Rishin Raj Signed-off-by: Vaibhav Verma --- .../models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 9 ++++++--- .../models/qwen3_vl_moe/qwen3_vl_disagg_mode.py | 14 ++++++++++---- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 91b529b3c..cc6899fc1 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -643,8 +643,11 @@ def _deepstack_process( visual_pos_masks = visual_pos_masks.unsqueeze(-1).expand(-1, -1, self.config.hidden_size) visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) hidden_states = hidden_states.clone() - visual_mask = visual_pos_masks.to(hidden_states.dtype) - return hidden_states + (visual_embeds * visual_mask) + mixed_embeds = hidden_states + visual_embeds + + local_this = torch.where(visual_pos_masks, mixed_embeds, hidden_states) + + return local_this class QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock(Qwen3VLMoeTextSparseMoeBlock): @@ -655,7 +658,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens act = getattr(self.experts, "act_fn", F.silu) router_logits = self.gate(x) # [T, E] - prob = F.softmax(router_logits, dim=-1, dtype=torch.float32) + prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype) top_w, top_i = torch.topk(prob, self.top_k, dim=-1) top_w = top_w / torch.einsum("bi->b", top_w)[:, None] top_w = top_w.to(hidden_states.dtype) diff --git a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py index 38ca6b478..6e3c43951 100644 --- a/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py +++ b/examples/image_text_to_text/models/qwen3_vl_moe/qwen3_vl_disagg_mode.py @@ -50,7 +50,7 @@ mxfp6_matmul=True, aic_enable_depth_first=True, skip_vision=skip_vision, - split_model_io=True, + split_retained_state_io=True, skip_lang=True, use_onnx_subfunctions=True, ) @@ -66,7 +66,7 @@ mxfp6_matmul=True, mxint8_kv_cache=True, retain_full_kv=True, - split_model_io=True, # This should be used for disagg serving via VLLM + split_retained_state_io=True, # This should be used for disagg serving via VLLM mos=1, aic_enable_depth_first=True, prefill_only=True, @@ -86,7 +86,8 @@ num_devices=1, mxfp6_matmul=True, mxint8_kv_cache=True, - split_model_io=True, # This should be used for disagg serving via VLLM + retain_full_kv=True, + split_retained_state_io=True, # This should be used for disagg serving via VLLM mos=1, aic_enable_depth_first=True, prefill_only=False, @@ -117,7 +118,6 @@ "content": [ {"type": "image", "image": image}, {"type": "text", "text": "Describe all the colors seen in the image."}, - # {"type": "text", "text": "Can you describe the image in detail?"}, ], }, ] @@ -217,6 +217,9 @@ decode_inputs[f"past_key.{i}"] = outputs[f"past_key.{i}_RetainedState"] decode_inputs[f"past_value.{i}"] = outputs[f"past_value.{i}_RetainedState"] +decode_inputs["image_idx"] = outputs["image_idx_output"] +decode_inputs["vision_embeds"] = outputs["vision_embeds_RetainedState"] +decode_inputs["deepstack_features"] = outputs["deepstack_features_RetainedState"] st = perf_counter() decode_out = lang_decode_session.run(decode_inputs) @@ -232,6 +235,9 @@ for i in range(config.text_config.num_hidden_layers): loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] +loop_decode_inputs["image_idx"] = decode_out["image_idx_output"] +loop_decode_inputs["vision_embeds"] = decode_out["vision_embeds_RetainedState"] +loop_decode_inputs["deepstack_features"] = decode_out["deepstack_features_RetainedState"] st = perf_counter() From f2195996d856cced9b9ccf05a1091832cb7e7fa7 Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Fri, 8 May 2026 18:24:58 -0500 Subject: [PATCH 03/14] Rebased PagedAttention support with latest Qeff for PR Signed-off-by: Vaibhav Verma --- QEfficient/blocking/attention_blocking.py | 39 +- .../blocking/blocked_attention_forwards.py | 626 +++++++++++++++++- QEfficient/blocking/blocking_configurator.py | 38 +- QEfficient/customop/__init__.py | 4 + QEfficient/customop/ctx_scatter_gather.py | 92 +++ QEfficient/transformers/cache_utils.py | 130 ++++ .../transformers/models/modeling_auto.py | 35 +- 7 files changed, 933 insertions(+), 31 deletions(-) diff --git a/QEfficient/blocking/attention_blocking.py b/QEfficient/blocking/attention_blocking.py index b75342013..34d027eb4 100644 --- a/QEfficient/blocking/attention_blocking.py +++ b/QEfficient/blocking/attention_blocking.py @@ -16,26 +16,35 @@ from QEfficient.blocking.blocked_attention_forwards import ( blocked_bhqkv_attention_forward, + blocked_bhqkv_paged_attention_forward, blocked_h_attention_forward, blocked_h_mla_attention_forward, blocked_hqkv_attention_forward, + blocked_hqkv_paged_attention_forward, blocked_kv_attention_forward, + blocked_kv_paged_attention_forward, blocked_kv_mla_attention_forward, blocked_q_attention_forward, blocked_qkv_attention_forward, + blocked_qkv_paged_attention_forward, ) class BlockingMode(str, Enum): NONE = "" KV = "kv" + KV_PAGED = "kv_paged" Q = "q" H = "h" QKV = "qkv" + QKV_PAGED = "qkv_paged" HQ = "hq" HKV = "hkv" + HKV = "hkv_paged" HQKV = "hqkv" + HQKV_PAGED = "hqkv_paged" BHQKV = "bhqkv" + BHQKV_PAGED = "bhqkv_paged" @dataclass @@ -52,15 +61,24 @@ def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool: return past_key_value is not None and hasattr(past_key_value, "read_only_blockedKV") +def supports_paged_attention_blocked_kv(past_key_value: Optional[Cache]) -> bool: + return past_key_value is not None and hasattr(past_key_value, "read_only_pagedAttention") + + _STRATEGIES: Dict[BlockingMode, Callable] = { BlockingMode.KV: blocked_kv_attention_forward, + BlockingMode.KV_PAGED: blocked_kv_paged_attention_forward, BlockingMode.Q: blocked_q_attention_forward, BlockingMode.H: blocked_h_attention_forward, BlockingMode.QKV: blocked_qkv_attention_forward, + BlockingMode.QKV_PAGED: blocked_qkv_paged_attention_forward, BlockingMode.HQ: blocked_hqkv_attention_forward, BlockingMode.HKV: blocked_hqkv_attention_forward, + BlockingMode.HKV_PAGED: blocked_hqkv_paged_attention_forward, BlockingMode.HQKV: blocked_hqkv_attention_forward, + BlockingMode.HQKV_PAGED: blocked_hqkv_paged_attention_forward, BlockingMode.BHQKV: blocked_bhqkv_attention_forward, + BlockingMode.BHQKV_PAGED: blocked_bhqkv_paged_attention_forward, } _STRATEGIES_MLA: Dict[BlockingMode, Callable] = { @@ -122,8 +140,27 @@ def generic_blocked_attention_interface( blocking_config is not None and "kv" in blocking_config.mode and supports_blocked_kv(past_key_value) ) + use_paged_kv_blocked = ( + blocking_config is not None and "paged" in blocking_config.mode and supports_paged_attention_blocked_kv(past_key_value) + ) + if past_key_value is not None: - if use_kv_blocked and sliding_window is None: + if use_paged_kv_blocked and sliding_window is None: + cache_kwargs = { + "batch_index": batch_index, + "position_ids": position_ids, + "block_table": block_table, + "slot_id": slot_id, + } + if sliding_window is not None: + cache_kwargs.update( + { + "is_sliding": sliding_window is not None, + "sliding_window": past_key_value.sliding_window_len, + } + ) + past_key_value.write_only_pagedAttention(key, value, module.layer_idx, cache_kwargs) + elif use_kv_blocked and sliding_window is None: cache_kwargs = { "batch_index": batch_index, "position_ids": position_ids, diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index b5533ee33..80f40200f 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -208,6 +208,121 @@ def blocked_kv_attention_forward( return attn_output, attn_weights +def blocked_kv_paged_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + num_kv_blocks: int, + cache_kwargs: Dict[str, Any], + layer_idx: int, + past_key_value: Cache, + *, + use_causal_mask: bool = False, + sliding_window: Optional[int] = None, + skip_kv: bool = False, + position_bias: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Initialize result tensor + output = torch.zeros_like(query) + + # Initialize Running Maximum and Denominator + batch_size, num_heads, seq_len, _ = query.shape + current_max = torch.full( + (batch_size, num_heads, seq_len), + float(MIN_MASKED_ATTENTION_VALUE), + device=query.device, + ) + current_denominator = torch.zeros(batch_size, num_heads, seq_len, device=query.device) + + if torch.onnx.is_in_onnx_export(): + attention_mask = None + use_causal_mask = True + position_ids = cache_kwargs.get("position_ids") + block_table = cache_kwargs.get("block_table") # [BS, num_kv_blocks] -> each entry is block_id value + kv_block_size = past_key_value.get_seq_length() if past_key_value is not None else 0 + past_seen_tokens = kv_block_size * num_kv_blocks + + if hasattr(module, "config"): + mask_dtype = module.config.torch_dtype + else: + mask_dtype = value.dtype + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=mask_dtype, device=query.device) + current_position = position_ids.max(dim=-1).values + # needed for GPT-OSS + if sinks is not None: + sinks = sinks.reshape(1, -1, 1, 1).expand(batch_size, -1, seq_len, -1) + + for j in range(num_kv_blocks): + start_index = j * kv_block_size + if j == num_kv_blocks - 1: + kv_len_block = past_seen_tokens - start_index + else: + kv_len_block = kv_block_size + end_index = start_index + kv_len_block + + skip_future = None + if skip_kv: + skip_future = (torch.tensor(start_index, device=query.device) > current_position).all() + # Eager mode Only + if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing(): + if skip_future.item(): + break + + block_index = block_table[:, i] + updated = (position_ids.max(1, keepdim=True).values // kv_block_size) == i + k_block, v_block = past_key_value.read_only_pagedAttention(block_index, updated, layer_idx, cache_kwargs) + k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) + + attn_weights_block = torch.matmul(query, k_block_states.transpose(2, 3)) * scaling + # position bias needed for mpt model + if position_bias is not None: + attn_weights_block = attn_weights_block + position_bias[:, :, start_index:end_index] + + mask_block = None + if attention_mask is not None: + mask_block = attention_mask[..., start_index:end_index] + if mask_block.shape[-1] != attn_weights_block.shape[-1]: + mask_block = None + + if use_causal_mask or mask_block is None: + target_length = torch.where( + torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int), + past_seen_tokens, + end_index, + ) + causal_mask_block = _create_causal_mask( + position_ids=position_ids, + target_length=target_length, + sliding_window=sliding_window, + start_index=start_index, + ) + if mask_block is None: + mask_block = causal_mask_block + else: + mask_block = mask_block.to(torch.bool) | causal_mask_block + + if mask_block is not None: + attn_weights_block = torch.where(mask_block, masked_tensor, attn_weights_block) + + current_max, current_denominator, output = update_running_softmax( + current_max, attn_weights_block, current_denominator, output, v_block_states, skip_kv, skip_future + ) + + # If present, apply Attention Sinks, needed for GPT-OSS + if sinks is not None: + _, _, output = update_running_softmax(current_max, sinks, current_denominator, output, None) + + attn_output = output.transpose(1, 2).contiguous() + attn_weights = None + + return attn_output, attn_weights + + def blocked_qkv_attention_forward( module: nn.Module, query: torch.Tensor, @@ -350,6 +465,145 @@ def blocked_qkv_attention_forward( return attn_output, attn_weights +def blocked_qkv_paged_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + num_kv_blocks: int, + num_q_blocks: int, + cache_kwargs: Dict[str, Any], + layer_idx: int, + past_key_value: Cache, + *, + use_causal_mask: bool = False, + sliding_window: Optional[int] = None, + skip_kv: bool = False, + position_bias: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Initialize Running Maximum and Denominator + batch_size, num_heads, seq_len, DH = query.shape + + if torch.onnx.is_in_onnx_export(): + attention_mask = None + use_causal_mask = True + position_ids = cache_kwargs.get("position_ids") + block_table = cache_kwargs.get("block_table") # [BS, num_kv_blocks] -> each entry is block_id value + kv_block_size = past_key_value.get_seq_length() if past_key_value is not None else 0 + past_seen_tokens = kv_block_size * num_kv_blocks + + num_q_blocks = max(1, num_q_blocks) + q_block_positions = [-(-i * seq_len) // num_q_blocks for i in range(num_q_blocks)] + + q_output_blocks = [] + q_attn_blocks = [] + if hasattr(module, "config"): + mask_dtype = module.config.torch_dtype + else: + mask_dtype = value.dtype + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=mask_dtype, device=query.device) + current_position = position_ids.max(dim=-1).values + # needed for GPT-OSS + if sinks is not None: + sinks = sinks.reshape(1, -1, 1, 1).expand(batch_size, -1, seq_len, -1) + + for q_block_idx in range(num_q_blocks): + q_start = q_block_positions[q_block_idx] + if q_block_idx == num_q_blocks - 1: + q_len_block = seq_len - q_start + else: + q_len_block = q_block_positions[q_block_idx + 1] - q_start + + q_block = query[:, :, q_start : q_start + q_len_block, :] + + current_max = torch.full( + (batch_size, num_heads, q_len_block), + float(MIN_MASKED_ATTENTION_VALUE), + device=query.device, + ) + current_denominator = torch.zeros(batch_size, num_heads, q_len_block, device=query.device) + output_blocks = torch.zeros((batch_size, num_heads, q_len_block, DH), device=query.device, dtype=query.dtype) + + for j in range(num_kv_blocks): + start_index = j * kv_block_size + if j == num_kv_blocks - 1: + kv_len_block = past_seen_tokens - start_index + else: + kv_len_block = kv_block_size + end_index = start_index + kv_len_block + + skip_future = None + if skip_kv: + skip_future = (torch.tensor(start_index, device=query.device) > current_position).all() + # Eager mode Only + if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing(): + if skip_future.item(): + break + + block_index = block_table[:, i] + updated = (position_ids.max(1, keepdim=True).values // kv_block_size) == i + k_block, v_block = past_key_value.read_only_pagedAttention(block_index, updated, layer_idx, cache_kwargs) + k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) + + attn_weights_block = torch.matmul(q_block, k_block_states.transpose(2, 3)) * scaling + # position bias needed for mpt model + if position_bias is not None: + attn_weights_block = attn_weights_block + position_bias[:, :, start_index:end_index] + + mask_block = None + if attention_mask is not None: + mask_block = attention_mask[..., start_index:end_index] + if mask_block.shape[-1] != attn_weights_block.shape[-1]: + mask_block = None + + if use_causal_mask or mask_block is None: + # target_length = min(total_seen_tokens, end_index) + target_length = torch.where( + torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int), + past_seen_tokens, + end_index, + ) + causal_mask_block = _create_causal_mask( + position_ids=position_ids, + target_length=target_length, + sliding_window=sliding_window, + start_index=start_index, + ) + if mask_block is None: + mask_block = causal_mask_block + else: + mask_block = mask_block.to(torch.bool) | causal_mask_block + + if mask_block is not None: + attn_mask_block = mask_block[:, :, q_start : q_start + q_len_block, :] + attn_weights_block = torch.where(attn_mask_block, masked_tensor, attn_weights_block) + + current_max, current_denominator, output_blocks = update_running_softmax( + current_max, + attn_weights_block, + current_denominator, + output_blocks, + v_block_states, + skip_kv, + skip_future, + ) + + # If present, apply Attention Sinks, needed for GPT-OSS + if sinks is not None: + _, _, output_blocks = update_running_softmax(current_max, sinks, current_denominator, output_blocks, None) + q_output_blocks.append(output_blocks) + q_attn_blocks.append(attn_weights_block) + + attn_output = torch.cat(q_output_blocks, dim=2).transpose(1, 2).contiguous() + attn_weights = torch.cat(q_attn_blocks, dim=2) + + return attn_output, attn_weights + + def blocked_hqkv_attention_forward( module: nn.Module, query: torch.Tensor, @@ -506,7 +760,7 @@ def blocked_hqkv_attention_forward( return attn_output, attn_weights -def blocked_bhqkv_attention_forward( +def blocked_hqkv_paged_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -515,13 +769,11 @@ def blocked_bhqkv_attention_forward( scaling: float, num_kv_blocks: int, num_q_blocks: int, - num_batch_blocks: int, head_block_size: int, cache_kwargs: Dict[str, Any], layer_idx: int, past_key_value: Cache, *, - score_mod: Optional[Callable[[torch.Tensor, int, int], torch.Tensor]] = None, use_causal_mask: bool = False, sliding_window: Optional[int] = None, skip_kv: bool = False, @@ -532,7 +784,6 @@ def blocked_bhqkv_attention_forward( # Initialize Running Maximum and Denominator batch_size, num_heads, seq_len, DH = query.shape - past_seen_tokens = cache_kwargs.get("past_seen_tokens") if torch.onnx.is_in_onnx_export(): attention_mask = None use_causal_mask = True @@ -540,25 +791,19 @@ def blocked_bhqkv_attention_forward( if head_block_size <= 0: head_block_size = num_heads num_head_blocks = math.ceil(num_heads / head_block_size) - num_q_blocks = max(1, _normalize_int(num_q_blocks)) + num_q_blocks = max(1, num_q_blocks) q_block_positions = [-(-i * seq_len) // num_q_blocks for i in range(num_q_blocks)] - num_kv_blocks = max(1, num_kv_blocks) - kv_block_size = -(-past_seen_tokens // num_kv_blocks) + block_table = cache_kwargs.get("block_table") # [BS, num_kv_blocks] -> each entry is block_id value + kv_block_size = past_key_value.get_seq_length() if past_key_value is not None else 0 + past_seen_tokens = kv_block_size * num_kv_blocks h_output_blocks = [] h_attn_blocks = [] - - num_batch_blocks = max( - 1, min(batch_size, _normalize_int(num_batch_blocks)) - ) # default to batch size for number of batch blocks - batch_block_positions = [(i * batch_size) // num_batch_blocks for i in range(num_batch_blocks)] - if hasattr(module, "config"): mask_dtype = module.config.torch_dtype else: mask_dtype = value.dtype masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=mask_dtype, device=query.device) - current_position = position_ids.max(dim=-1).values # needed for GPT-OSS if sinks is not None: @@ -582,7 +827,174 @@ def blocked_bhqkv_attention_forward( else: q_len_block = q_block_positions[q_block_idx + 1] - q_start - q_block_head = q_g[:, :, q_start : q_start + q_len_block, :] + q_block = q_g[:, :, q_start : q_start + q_len_block, :] + + current_max = torch.full( + (batch_size, h_end - h_start, q_len_block), + float(MIN_MASKED_ATTENTION_VALUE), + device=query.device, + ) + current_denominator = torch.zeros(batch_size, h_end - h_start, q_len_block, device=query.device) + output_blocks = torch.zeros( + (batch_size, h_end - h_start, q_len_block, DH), device=query.device, dtype=query.dtype + ) + + for j in range(num_kv_blocks): + start_index = j * kv_block_size + if j == num_kv_blocks - 1: + kv_len_block = past_seen_tokens - start_index + else: + kv_len_block = kv_block_size + end_index = start_index + kv_len_block + + skip_future = None + if skip_kv: + skip_future = (torch.tensor(start_index, device=query.device) > current_position).all() + # Eager mode Only + if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing(): + if skip_future.item(): + break + + block_index = block_table[:, i] + updated = (position_ids.max(1, keepdim=True).values // kv_block_size) == i + k_block, v_block = past_key_value.read_only_pagedAttention(block_index, updated, layer_idx, cache_kwargs) + k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) + + k_g = k_block_states[:, h_start:h_end, :, :] + v_g = v_block_states[:, h_start:h_end, :, :] + + attn_weights_block = torch.matmul(q_block, k_g.transpose(2, 3)) * scaling + # position bias needed for mpt model + if position_bias is not None: + attn_weights_block = attn_weights_block + position_bias[h_start:h_end, :, start_index:end_index] + + mask_block = None + if attention_mask is not None: + mask_block = attention_mask[..., start_index:end_index] + if mask_block.shape[-1] != attn_weights_block.shape[-1]: + mask_block = None + + if use_causal_mask or mask_block is None: + # target_length = min(total_seen_tokens, end_index) + target_length = torch.where( + torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int), + past_seen_tokens, + end_index, + ) + causal_mask_block = _create_causal_mask( + position_ids=position_ids, + target_length=target_length, + sliding_window=sliding_window, + start_index=start_index, + ) + if mask_block is None: + mask_block = causal_mask_block + else: + mask_block = mask_block.to(torch.bool) | causal_mask_block + + if mask_block is not None: + mask_block_g = mask_block[:, :, q_start : q_start + q_len_block, :] + attn_weights_block = torch.where(mask_block_g, masked_tensor, attn_weights_block) + + current_max, current_denominator, output_blocks = update_running_softmax( + current_max, attn_weights_block, current_denominator, output_blocks, v_g, skip_kv, skip_future + ) + # If present, apply Attention Sinks, needed for GPT-OSS + if sinks is not None: + _, _, output_blocks = update_running_softmax( + current_max, sinks, current_denominator, output_blocks, None + ) + q_output_blocks.append(output_blocks) + q_attn_blocks.append(attn_weights_block) + + head_output = torch.cat(q_output_blocks, dim=2) + head_attn_weights = torch.cat(q_attn_blocks, dim=2) + h_output_blocks.append(head_output) + h_attn_blocks.append(head_attn_weights) + + attn_output = torch.cat(h_output_blocks, dim=1).transpose(1, 2).contiguous() + attn_weights = torch.cat(h_attn_blocks, dim=1) + + return attn_output, attn_weights + + +def blocked_bhqkv_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + num_kv_blocks: int, + num_q_blocks: int, + num_batch_blocks: int, + head_block_size: int, + cache_kwargs: Dict[str, Any], + layer_idx: int, + past_key_value: Cache, + *, + score_mod: Optional[Callable[[torch.Tensor, int, int], torch.Tensor]] = None, + use_causal_mask: bool = False, + sliding_window: Optional[int] = None, + skip_kv: bool = False, + position_bias: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Initialize Running Maximum and Denominator + batch_size, num_heads, seq_len, DH = query.shape + + past_seen_tokens = cache_kwargs.get("past_seen_tokens") + if torch.onnx.is_in_onnx_export(): + attention_mask = None + use_causal_mask = True + position_ids = cache_kwargs.get("position_ids") + if head_block_size <= 0: + head_block_size = num_heads + num_head_blocks = math.ceil(num_heads / head_block_size) + num_q_blocks = max(1, _normalize_int(num_q_blocks)) + q_block_positions = [-(-i * seq_len) // num_q_blocks for i in range(num_q_blocks)] + num_kv_blocks = max(1, num_kv_blocks) + kv_block_size = -(-past_seen_tokens // num_kv_blocks) + + h_output_blocks = [] + h_attn_blocks = [] + + num_batch_blocks = max( + 1, min(batch_size, _normalize_int(num_batch_blocks)) + ) # default to batch size for number of batch blocks + batch_block_positions = [(i * batch_size) // num_batch_blocks for i in range(num_batch_blocks)] + + if hasattr(module, "config"): + mask_dtype = module.config.torch_dtype + else: + mask_dtype = value.dtype + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=mask_dtype, device=query.device) + + current_position = position_ids.max(dim=-1).values + # needed for GPT-OSS + if sinks is not None: + sinks = sinks.reshape(1, -1, 1, 1).expand(batch_size, -1, seq_len, -1) + + # Process each head block independently + for head_block_idx in range(num_head_blocks): + h_start = head_block_idx * head_block_size + h_end = min(h_start + head_block_size, num_heads) + + # Extract head blocks + q_g = query[:, h_start:h_end, :, :] + + q_output_blocks = [] + q_attn_blocks = [] + + for q_block_idx in range(num_q_blocks): + q_start = q_block_positions[q_block_idx] + if q_block_idx == num_q_blocks - 1: + q_len_block = seq_len - q_start + else: + q_len_block = q_block_positions[q_block_idx + 1] - q_start + + q_block_head = q_g[:, :, q_start : q_start + q_len_block, :] batch_output_blocks = [] batch_attn_blocks = [] @@ -689,6 +1101,190 @@ def blocked_bhqkv_attention_forward( return attn_output, attn_weights +def blocked_bhqkv_paged_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + num_kv_blocks: int, + num_q_blocks: int, + num_batch_blocks: int, + head_block_size: int, + cache_kwargs: Dict[str, Any], + layer_idx: int, + past_key_value: Cache, + *, + score_mod: Optional[Callable[[torch.Tensor, int, int], torch.Tensor]] = None, + use_causal_mask: bool = False, + sliding_window: Optional[int] = None, + skip_kv: bool = False, + position_bias: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Initialize Running Maximum and Denominator + batch_size, num_heads, seq_len, DH = query.shape + + if torch.onnx.is_in_onnx_export(): + attention_mask = None + use_causal_mask = True + position_ids = cache_kwargs.get("position_ids") + if head_block_size <= 0: + head_block_size = num_heads + num_head_blocks = math.ceil(num_heads / head_block_size) + num_q_blocks = max(1, _normalize_int(num_q_blocks)) + q_block_positions = [-(-i * seq_len) // num_q_blocks for i in range(num_q_blocks)] + block_table = cache_kwargs.get("block_table") # [BS, num_kv_blocks] -> each entry is block_id value + kv_block_size = past_key_value.get_seq_length() if past_key_value is not None else 0 + past_seen_tokens = kv_block_size * num_kv_blocks + + + h_output_blocks = [] + h_attn_blocks = [] + + num_batch_blocks = max( + 1, min(batch_size, _normalize_int(num_batch_blocks)) + ) # default to batch size for number of batch blocks + batch_block_positions = [(i * batch_size) // num_batch_blocks for i in range(num_batch_blocks)] + + if hasattr(module, "config"): + mask_dtype = module.config.torch_dtype + else: + mask_dtype = value.dtype + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=mask_dtype, device=query.device) + + current_position = position_ids.max(dim=-1).values + # needed for GPT-OSS + if sinks is not None: + sinks = sinks.reshape(1, -1, 1, 1).expand(batch_size, -1, seq_len, -1) + + # Process each head block independently + for head_block_idx in range(num_head_blocks): + h_start = head_block_idx * head_block_size + h_end = min(h_start + head_block_size, num_heads) + + # Extract head blocks + q_g = query[:, h_start:h_end, :, :] + + q_output_blocks = [] + q_attn_blocks = [] + + for q_block_idx in range(num_q_blocks): + q_start = q_block_positions[q_block_idx] + if q_block_idx == num_q_blocks - 1: + q_len_block = seq_len - q_start + else: + q_len_block = q_block_positions[q_block_idx + 1] - q_start + + q_block_head = q_g[:, :, q_start : q_start + q_len_block, :] + + batch_output_blocks = [] + batch_attn_blocks = [] + + for b_block_idx in range(num_batch_blocks): + batch_start = batch_block_positions[b_block_idx] + if b_block_idx == num_batch_blocks - 1: + batch_len = batch_size - batch_start + else: + batch_len = batch_block_positions[b_block_idx + 1] - batch_start + + q_block = q_block_head[batch_start : batch_start + batch_len, :, :, :] + + current_max = torch.full( + (batch_len, h_end - h_start, q_len_block), + float(MIN_MASKED_ATTENTION_VALUE), + device=query.device, + ) + current_denominator = torch.zeros(batch_len, h_end - h_start, q_len_block, device=query.device) + output_blocks = torch.zeros( + (batch_len, h_end - h_start, q_len_block, DH), device=query.device, dtype=query.dtype + ) + + for j in range(num_kv_blocks): + start_index = j * kv_block_size + if j == num_kv_blocks - 1: + kv_len_block = past_seen_tokens - start_index + else: + kv_len_block = kv_block_size + end_index = start_index + kv_len_block + + skip_future = None + if skip_kv: + skip_future = (torch.tensor(start_index, device=query.device) > current_position).all() + # Eager mode Only + if not torch.onnx.is_in_onnx_export() and not torch.jit.is_tracing(): + if skip_future.item(): + break + + block_index = block_table[:, i] + updated = (position_ids.max(1, keepdim=True).values // kv_block_size) == i + k_block, v_block = past_key_value.read_only_pagedAttention(block_index, updated, layer_idx, cache_kwargs) + k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) + + k_g = k_block_states[batch_start : batch_start + batch_len, h_start:h_end, :, :] + v_g = v_block_states[batch_start : batch_start + batch_len, h_start:h_end, :, :] + + attn_weights_block = torch.matmul(q_block, k_g.transpose(2, 3)) * scaling + # position bias needed for mpt model + if position_bias is not None: + attn_weights_block = attn_weights_block + position_bias[h_start:h_end, :, start_index:end_index] + + mask_block = None + if attention_mask is not None: + mask_block = attention_mask[..., start_index:end_index] + if mask_block.shape[-1] != attn_weights_block.shape[-1]: + mask_block = None + + if use_causal_mask or mask_block is None: + # target_length = min(total_seen_tokens, end_index) + target_length = torch.where( + torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int), + past_seen_tokens, + end_index, + ) + causal_mask_block = _create_causal_mask( + position_ids=position_ids, + target_length=target_length, + sliding_window=sliding_window, + start_index=start_index, + ) + if mask_block is None: + mask_block = causal_mask_block + else: + mask_block = mask_block.to(torch.bool) | causal_mask_block + + if mask_block is not None: + mask_block_g = mask_block[ + batch_start : batch_start + batch_len, :, q_start : q_start + q_len_block, : + ] + attn_weights_block = torch.where(mask_block_g, masked_tensor, attn_weights_block) + + current_max, current_denominator, output_blocks = update_running_softmax( + current_max, attn_weights_block, current_denominator, output_blocks, v_g, skip_kv, skip_future + ) + batch_output_blocks.append(output_blocks) + batch_attn_blocks.append(attn_weights_block) + # If present, apply Attention Sinks, needed for GPT-OSS + if sinks is not None: + _, _, batch_output_blocks = update_running_softmax( + current_max, sinks, current_denominator, batch_output_blocks, None + ) + q_output_blocks.append(torch.cat(batch_output_blocks, dim=0)) + q_attn_blocks.append(torch.cat(batch_attn_blocks, dim=0)) + + head_output = torch.cat(q_output_blocks, dim=2) + head_attn_weights = torch.cat(q_attn_blocks, dim=2) + h_output_blocks.append(head_output) + h_attn_blocks.append(head_attn_weights) + + attn_output = torch.cat(h_output_blocks, dim=1).transpose(1, 2).contiguous() + attn_weights = torch.cat(h_attn_blocks, dim=1) + + return attn_output, attn_weights + + def blocked_h_attention_forward( module: nn.Module, query: torch.Tensor, diff --git a/QEfficient/blocking/blocking_configurator.py b/QEfficient/blocking/blocking_configurator.py index eaa256611..05b0d21b2 100644 --- a/QEfficient/blocking/blocking_configurator.py +++ b/QEfficient/blocking/blocking_configurator.py @@ -42,17 +42,25 @@ def _infer_data_bytes(compile_config: Dict[str, Any]) -> int: def _normalize_attention_mode(raw_mode: str) -> str: mode = raw_mode.lower() - if "h" in mode and "q" in mode and "kv" in mode: + if "h" in mode and "q" in mode and "kv" in mode and "paged" in mode: + return "hqkv_paged" + if "h" in mode and "q" in mode and "kv" in mode and "paged" not in mode: return "hqkv" if "h" in mode and "q" in mode: return "hq" - if "h" in mode and "kv" in mode: + if "h" in mode and "kv" in mode and "paged" in mode: + return "hkv_paged" + if "h" in mode and "kv" in mode and "paged" not in mode: return "hkv" if "h" in mode: return "h" - if "q" in mode and "kv" in mode: + if "q" in mode and "kv" in mode and "paged" in mode: + return "qkv_paged" + if "q" in mode and "kv" in mode and "paged" not in mode: return "qkv" - if "kv" in mode: + if "kv" in mode and "paged" in mode: + return "kv_paged" + if "kv" in mode and "paged" not in mode: return "kv" if "q" in mode: return "q" @@ -61,8 +69,8 @@ def _normalize_attention_mode(raw_mode: str) -> str: def _resolve_effective_blocking_mode(attention_cfg: Dict[str, Any], requested_mode: str) -> str: mode = _normalize_attention_mode(requested_mode) - if mode == "": - return "" + #if mode == "": #this should not be here since it will prevent 'h' and 'b' modes + #return "" num_q_blocks = attention_cfg.get("num_q_blocks") or 1 num_kv_blocks = attention_cfg.get("num_kv_blocks") or 1 head_block_size = (attention_cfg.get("head_block_size") or 1) if attention_cfg.get("head_blocking_enabled") else 1 @@ -74,13 +82,13 @@ def _resolve_effective_blocking_mode(attention_cfg: Dict[str, Any], requested_mo if head_block_size > 1 and num_kv_blocks > 1: return "hkv" if head_block_size > 1: - return "hqkv" + return "h" + mode if num_q_blocks > 1 and num_kv_blocks > 1: - return "qkv" + return mode if num_q_blocks > 1: - return "q" + return mode if num_kv_blocks > 1: - return "kv" + return mode return "" @@ -239,8 +247,9 @@ def update_best_config(num_q_blocks: int, num_kv_blocks: int, q_kv_ratio: float, for num_kv_blocks, kv_cl_per_nsp, kv_size_per_nsp in kv_metrics: qk_size_per_nsp = num_heads_per_iter * bs * q_sl_per_nsp * kv_cl_per_nsp * data_bytes - vtcm_footprint = q_size_per_nsp + kv_size_per_nsp + qk_size_per_nsp - + qkv_size_per_nsp = q_size_per_nsp + # For KV Blocking for loop, q input and qkv output should be persistent in VTCM + vtcm_footprint = q_size_per_nsp + kv_size_per_nsp + qk_size_per_nsp + qkv_size_per_nsp q_kv_ratio = max(q_size_per_nsp / kv_size_per_nsp, kv_size_per_nsp / q_size_per_nsp) num_total_blocks = num_q_blocks * num_kv_blocks @@ -347,9 +356,12 @@ def build_transformer_blocking_config_for_transform( else: blocking_config = AttentionBlockingConfig() mode_from_config = "" - if qaic_config.get("num_kv_blocks", False) and enable_blocking and "kv" in blocking_mode: + if qaic_config.get("num_kv_blocks", False) and enable_blocking and "kv" in blocking_mode and "paged" not in blocking_mode: mode_from_config = "kv" + mode_from_config blocking_config.num_kv_blocks = _get_valid_num_blocks(qaic_config, "num_kv_blocks") + if qaic_config.get("num_kv_blocks", False) and enable_blocking and "kv" in blocking_mode and "paged" in blocking_mode: + mode_from_config = "kv_paged" + mode_from_config + blocking_config.num_kv_blocks = _get_valid_num_blocks(qaic_config, "num_kv_blocks") if qaic_config.get("num_q_blocks", False) and enable_blocking and "q" in blocking_mode: mode_from_config = "q" + mode_from_config blocking_config.num_q_blocks = _get_valid_num_blocks(qaic_config, "num_q_blocks") diff --git a/QEfficient/customop/__init__.py b/QEfficient/customop/__init__.py index dcf5662fb..9380feaa0 100644 --- a/QEfficient/customop/__init__.py +++ b/QEfficient/customop/__init__.py @@ -10,10 +10,12 @@ CtxGatherFunc3D, CtxGatherFunc3DGeneralized, CtxGatherFuncBlockedKV, + CtxGatherFuncPagedAttention, CtxScatterFunc, CtxScatterFunc3D, CtxScatterFunc3DGeneralized, CtxScatterFunc3DInt, + CtxScatterFuncPagedAttention, ) from QEfficient.customop.ctx_scatter_gather_cb import ( CtxGatherFuncBlockedKVCB, @@ -27,7 +29,9 @@ __all__ = [ "CtxGatherFunc", "CtxGatherFuncBlockedKV", + "CtxGatherFuncPagedAttention", "CtxScatterFunc", + "CtxScatterFuncPagedAttention", "CtxGatherFunc3D", "CtxScatterFunc3D", "CtxGatherFunc3DGeneralized", diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index aedddb186..75d6b0bba 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -56,6 +56,53 @@ def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updat return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data) +@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +def CtxScatterPagedAttention( + data: onnxscript.FLOAT, block_index: onnxscript.INT32, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT +) -> onnxscript.FLOAT: + # Find dims + num_blocks = ops.Gather(ops.Shape(block_index), [0]) + num_heads = ops.Gather(ops.Shape(data), [1]) + seq_len = ops.Gather(ops.Shape(position_ids), [1]) + + # Expanded shape to create indices + zero = ops.Constant(value_ints=[0]) + one = ops.Constant(value_ints=[1]) + exp_shape = ops.Concat(num_blocks, num_heads, seq_len, one, axis=0) + + # Create indices + block_idx = ops.Expand(ops.Unsqueeze(block_index, [3]), exp_shape) + head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape) + ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [1, 3]), exp_shape) + indices = ops.Concat(block_idx, head_idx, ctx_idx, axis=3) + + return ops.ScatterND(data, indices, updates) + + +class CtxScatterFuncPagedAttention(torch.autograd.Function): + """ + Function to scatter the current key values into KV-cache. + """ + + @staticmethod + def forward(data: torch.Tensor, block_index: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + block_index = block_index.view(-1, 1, 1) + head_idx = torch.arange(data.shape[1]).view(1, -1, 1) + ctx_idx = position_ids.unsqueeze(1) + data[block_index, head_idx, ctx_idx] = updates + return data + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic( + g: torch.Graph, data: torch.Value, block_index: torch.Value, position_ids: torch.Value, updates: torch.Value + ) -> torch.Value: + return g.onnxscript_op(CtxScatterPagedAttention, data, block_index, position_ids, updates).setTypeAs(data) + + @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT: # Find dims @@ -274,3 +321,48 @@ def setup_context(ctx, inputs, outputs): @staticmethod def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: return g.onnxscript_op(CtxGatherBlockedKV, data, ctx_indices).setTypeAs(data) + + +@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +def CtxGatherPagedAttention( + data: onnxscript.FLOAT, block_indices: onnxscript.INT32, ctx_indices: onnxscript.INT32 +) -> onnxscript.FLOAT: + num_kv_blocks = ops.Gather(ops.Shape(block_indices), [0]) + num_heads = ops.Gather(ops.Shape(data), [1]) + seq_len = ops.Gather(ops.Shape(ctx_indices), [1]) + + # Expanded shape to create indices + zero = ops.Constant(value_ints=[0]) + one = ops.Constant(value_ints=[1]) + exp_shape = ops.Concat(num_kv_blocks, num_heads, seq_len, one, axis=0) + + # Create indices + block_idx = ops.Expand(ops.Unsqueeze(block_indices, [1, 3]), exp_shape) + head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape) + ctx_idx = ops.Expand(ops.Unsqueeze(ctx_indices, [1, 3]), exp_shape) + indices = ops.Concat(block_idx, head_idx, ctx_idx, axis=3) + + return ops.GatherND(data, indices) + + +class CtxGatherFuncPagedAttention(torch.autograd.Function): + """ + Function to gather only the valid key values from KV-cache. + """ + + @staticmethod + def forward(data: torch.Tensor, block_indices: torch.Tensor, ctx_indices: torch.Tensor): + block_indices = block_indices.view(-1, 1, 1) + head_indices = torch.arange(data.shape[1]).view(1, -1, 1) + ctx_indices = ctx_indices.unsqueeze(1) + return data[block_indices, head_indices, ctx_indices] + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic( + g: torch.Graph, data: torch.Value, block_indices: torch.Value, ctx_indices: torch.Value + ) -> torch.Value: + return g.onnxscript_op(CtxGatherPagedAttention, data, block_indices, ctx_indices).setTypeAs(data) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index f6c2b6128..12754aab2 100755 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -19,10 +19,12 @@ CtxGatherFuncBlockedKVCB, CtxGatherFuncCB, CtxGatherFuncCB3D, + CtxGatherFuncPagedAttention, CtxScatterFunc, CtxScatterFunc3D, CtxScatterFuncCB, CtxScatterFuncCB3D, + CtxScatterFuncPagedAttention, ) @@ -246,6 +248,87 @@ def read_only_blockedKV(self, start_index, end_index, cache_kwargs): v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out + def read_only_pagedAttention(self, block_index, updated, cache_kwargs): + """ + Reads the `key_states` and `value_states` for the layer for each KV block. + + Parameters: + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + block_index : + Block index of the K/V block to read from the block table + + updated: + Was the current block updated during the current write? If yes, then read slot_id + seq_len worth of entires from current block + + Return: + A tuple containing the updated key and value states. + """ + # Gather + k_out, v_out = self.keys, self.values + position_ids = cache_kwargs.get("position_ids") + batch, seq_len = position_ids.shape + slot_id = cache_kwargs.get("slot_id", None) + num_kv_blocks, num_kv_heads, block_size, dh = k_out.shape + ctx_indices = torch.arange(block_size)[None, ...] + gather_limit = torch.where(updated, slot_id.unsqueeze(-1) + seq_len, block_size) + block_indices = block_index.unsqueeze(-1) + invalid_mask = torch.ones_like(position_ids, dtype=torch.bool) + invalid_mask = torch.where(block_indices < 0, invalid_mask, ctx_indices >= gather_limit) + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + k_out = CtxGatherFuncPagedAttention.apply(k_out, block_indices, ctx_indices) + v_out = CtxGatherFuncPagedAttention.apply(v_out, block_indices, ctx_indices) + + v_out = torch.where((invalid_mask.unsqueeze(1)).unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out + + def write_only_pagedAttention(self, key_states, value_states, cache_kwargs): + """ + Write in the cache with the new `key_states` and `value_states` for the layer. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + """ + # Update the cache + if self.keys is None: + self.keys = key_states + self.values = value_states + else: + position_ids = cache_kwargs.get("position_ids") + block_table = cache_kwargs.get("block_table") # [BS, num_kv_blocks/BS] -> each entry is block_id value + slot_id = cache_kwargs.get("slot_id") + batch, num_kv_heads, seq_len, dh = key_states.shape + num_kv_blocks, num_kv_heads, block_size, dh = self.keys.shape + block_index = position_ids.max(1).values // block_size # Assuming only 1 block is written at max + invalid_scatter_index = torch.iinfo(torch.int32).max + ctx_indices = torch.arange(seq_len) + slot_id.unsqueeze(-1) + + rows = torch.arange(batch) + block_id = block_table[rows, block_index].unsqueeze(-1) + ctx_indices = torch.where(block_id < 0, invalid_scatter_index, ctx_indices) + block_id = block_id.unsqueeze(-1) + self.keys = CtxScatterFuncPagedAttention.apply(self.keys, block_id, ctx_indices, key_states) + self.values = CtxScatterFuncPagedAttention.apply(self.values, block_id, ctx_indices, value_states) + + def get_seq_lengthPagedAttention(self, cache_position=None) -> int: + """Returns the sequence length of the cached states for pagedAttention.""" + if self.keys is None or self.keys.numel() == 0: + return 0 + return self.keys.shape[-2] * self.keys.shape[0] + def write_only(self, key_states, value_states, cache_kwargs): """ Write in the cache with the new `key_states` and `value_states` for the layer. @@ -644,6 +727,53 @@ def read_only_blockedKV(self, start_index, end_index, layer_idx, cache_kwargs): """ return self.layers[layer_idx].read_only_blockedKV(start_index, end_index, cache_kwargs) + def read_only_pagedAttention(self, block_index, updated, layer_idx, cache_kwargs): + # def read_only_pagedAttention(self, start_index, end_index, layer_idx, cache_kwargs): + """ + Reads the `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + start_index (`int`): + Start index of the K/V block to read + end_index (`int`): + End index of the K/V block to read + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + return self.layers[layer_idx].read_only_pagedAttention(block_index, updated, cache_kwargs) + # return self.layers[layer_idx].read_only_pagedAttention(start_index, end_index, cache_kwargs) + + def write_only_pagedAttention(self, key_states, value_states, layer_idx, cache_kwargs): + """ + Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + """ + self.append_new_layers(layer_idx) + return self.layers[layer_idx].write_only_pagedAttention(key_states, value_states, cache_kwargs) + + def get_seq_lengthPagedAttention(self, layer_idx: int = 0, cache_position=None) -> int: + """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position""" + if layer_idx >= len(self.layers): + return 0 + # Hack since QuantizedCache messes with keys shape as it becomes the residual cache + # if self.cache_processor is not None and isinstance(self.cache_processor, QuantizedCacheProcessor): + # return self.cache_processor.erased_length + self.layers[layer_idx].get_seq_length(cache_position) + return self.layers[layer_idx].get_seq_lengthPagedAttention(cache_position) + def write_only(self, key_states, value_states, layer_idx, cache_kwargs): """ Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`. diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 65b89d274..869868546 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3271,7 +3271,8 @@ def export( bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - # increase seq_len if using a larger number of blocks + self.supports_paged_attention = False + # increase seq_len if using a larger number of blocks and set PagedAttention params if required if self.hash_params.get("blocking_kwargs", None): max_blocks = -1 for num_blocks in self.hash_params.get("blocking_kwargs").__dict__.values(): @@ -3279,8 +3280,10 @@ def export( max_blocks = max(max_blocks, num_blocks) block_size = -(-seq_len // max_blocks) seq_len = block_size * max_blocks - fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + num_kv_blocks = self.hash_params["blocking_config"].num_kv_blocks + self.supports_paged_attention = "paged" in self.hash_params["blocking_config"].mode + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS kv_cache_shape = get_padding_shape_from_config( self.model.config, fbs if self.continuous_batching else bs, seq_len ) @@ -3340,6 +3343,7 @@ def export( "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, } + if self.ccl_enabled: example_inputs["comp_ctx_lengths"] = torch.randint(0, 127, (512,), dtype=torch.int64) dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} @@ -3362,6 +3366,21 @@ def export( else: output_names.append("logits") + if self.supports_paged_attention: + batch, num_kv_heads, CL, dh = kv_cache_shape + total_num_kv_blocks = batch * num_kv_blocks + kv_block_size = (-CL) // (-num_kv_blocks) + kv_cache_shape = [total_num_kv_blocks, num_kv_heads, kv_block_size, dh] + example_inputs["block_table"] = torch.arange((bs * num_kv_blocks), dtype=torch.int64).view(bs, num_kv_blocks) + example_inputs["slot_id"] = torch.zeros(bs, dtype=torch.int64) + dynamic_axes["block_table"] = {0: "batch_size", 1: "num_kv_blocks"} + dynamic_axes["slot_id"] = {0: "batch_size"} + # Assuming 4d pkv, might have to recheck for GPTBigCode with 3d pkv + pkv_dynamic_axes = { + 0: "total_num_kv_blocks", + 2: "kv_block_size", + } + # TODO Update the get_padding_shape_from_config method to handle the case when the model config has attention_chunk_size or sliding_window and it should return a list of shapes for each layer if ( hasattr(self.model.config, "model_type") @@ -3588,6 +3607,12 @@ def build_prefill_specialization( # TODO: remove this; not required if full_batch_size: spec["full_batch_exec_size"] = exec_batch_size + if self.hash_params.get("blocking_kwargs", None): + if "paged" in self.hash_params["blocking_config"].mode + num_kv_blocks = self.hash_params["blocking_config"].num_kv_blocks + spec["num_kv_blocks"] = num_kv_blocks + spec["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks + spec["kv_block_size"] = (-ctx_len) // (-num_kv_blocks) result = {k: v for k, v in spec.items() if v is not None} result["_graph_name"] = "Decode" if prefill_seq_len == 1 and kwargs.get("prefill_only") is False else "Prefill" return result @@ -3651,6 +3676,12 @@ def build_decode_specialization( spec["full_batch_size"] = kv_cache_batch_size else: spec["batch_size"] = kv_cache_batch_size + if self.hash_params.get("blocking_kwargs", None): + if "paged" in self.hash_params["blocking_config"].mode + num_kv_blocks = self.hash_params["blocking_config"].num_kv_blocks + spec["num_kv_blocks"] = num_kv_blocks + spec["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks + spec["kv_block_size"] = (-ctx_len) // (-num_kv_blocks) result = {k: v for k, v in spec.items() if v is not None} result["_graph_name"] = "Decode" return result From 99e556682a9c195f4bdd0e53b71db5ff17200d65 Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Fri, 8 May 2026 19:52:36 -0500 Subject: [PATCH 04/14] Added block_table and slot_id inputs + minor modelling_auto.py changes Signed-off-by: Vaibhav Verma --- QEfficient/blocking/attention_blocking.py | 13 +++++++- .../models/deepseek_v3/modeling_deepseek.py | 30 +++++++++++++++++++ .../models/gemma/modeling_gemma.py | 18 +++++++++++ .../models/gemma2/modeling_gemma2.py | 18 +++++++++++ .../models/gpt_oss/modeling_gpt_oss.py | 26 ++++++++++++++++ .../models/granite/modeling_granite.py | 16 ++++++++++ .../models/granitemoe/modeling_granitemoe.py | 18 +++++++++++ .../models/llama/modeling_llama.py | 18 +++++++++++ .../models/mistral/modeling_mistral.py | 18 +++++++++++ .../models/mixtral_moe/modeling_mixtral.py | 16 ++++++++++ .../transformers/models/modeling_auto.py | 12 ++++---- .../transformers/models/mpt/modeling_mpt.py | 18 +++++++++++ .../models/qwen2/modeling_qwen2.py | 18 +++++++++++ .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 22 ++++++++++++++ .../models/qwen3/modeling_qwen3.py | 18 +++++++++++ .../models/qwen3_moe/modeling_qwen3_moe.py | 18 +++++++++++ .../models/qwen3_vl/modeling_qwen3_vl.py | 26 ++++++++++++++++ .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 26 ++++++++++++++++ .../models/starcoder2/modeling_starcoder2.py | 18 +++++++++++ 19 files changed, 360 insertions(+), 7 deletions(-) diff --git a/QEfficient/blocking/attention_blocking.py b/QEfficient/blocking/attention_blocking.py index 34d027eb4..d4e941bd4 100644 --- a/QEfficient/blocking/attention_blocking.py +++ b/QEfficient/blocking/attention_blocking.py @@ -97,10 +97,17 @@ def past_key_value_update( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, sliding_window: Optional[int] = None, ): if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "batch_index": batch_index, + "position_ids": position_ids, + "block_table": block_table, + "slot_id": slot_id, + } if sliding_window is not None: cache_kwargs.update( { @@ -128,6 +135,8 @@ def generic_blocked_attention_interface( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_seen_tokens: Optional[int] = None, non_blocked_forward: Callable = None, score_mod: Optional[Callable] = None, @@ -184,6 +193,8 @@ def generic_blocked_attention_interface( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, sliding_window=sliding_window, ) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 3dad27103..c603baf42 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -265,6 +265,8 @@ def fused_forward_h_blocking( position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, compressed_kvs: Optional[torch.Tensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -332,6 +334,8 @@ def fused_forward_kv_blocking( position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, compressed_kvs: Optional[torch.Tensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -414,6 +418,8 @@ def fused_forward_orig( position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, compressed_kvs: Optional[torch.Tensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -533,6 +539,8 @@ def fused_forward( position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, compressed_kvs: Optional[torch.Tensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -595,6 +603,8 @@ def forward_full_kv( position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, @@ -660,6 +670,8 @@ def forward_full_kv_h_blocking( position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, @@ -716,6 +728,8 @@ def forward_full_kv_h_blocking( blocking_config=blocking_config, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, ) attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) @@ -728,6 +742,8 @@ def forward( position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, @@ -915,6 +931,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, compressed_kvs: Optional[torch.Tensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -936,6 +954,8 @@ def forward( hidden_states=orig_hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, position_embeddings=position_embeddings, past_key_value=past_key_value, compressed_kvs=compressed_kvs, @@ -951,6 +971,8 @@ def forward( hidden_states=orig_hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, position_embeddings=position_embeddings, past_key_value=past_key_value, batch_index=batch_index, @@ -1008,6 +1030,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, compressed_kvs: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, batch_index: Optional[torch.LongTensor] = None, @@ -1072,6 +1096,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, compressed_kvs=compressed_kvs, past_key_value=past_key_values, batch_index=batch_index, @@ -1123,6 +1149,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, compressed_kvs: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, batch_index: Optional[torch.LongTensor] = None, @@ -1147,6 +1175,8 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, compressed_kvs=compressed_kvs, past_key_values=past_key_values, batch_index=batch_index, diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 9ee513a25..867915cec 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -116,6 +116,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -150,6 +152,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, ) else: @@ -162,6 +166,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, ) attn_output, attn_weights = eager_attention_forward( self, @@ -190,6 +196,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -221,6 +229,8 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -258,6 +268,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -311,6 +323,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_value=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -358,6 +372,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -375,6 +391,8 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 80d6205ef..3e0103ded 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -123,6 +123,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -157,6 +159,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, ) else: @@ -169,6 +173,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, ) attn_output, attn_weights = eager_attention_forward( self, @@ -196,6 +202,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -231,6 +239,8 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -279,6 +289,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -348,6 +360,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_value=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -403,6 +417,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -427,6 +443,8 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index b7f42c0c5..0925165aa 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -773,6 +773,8 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -855,6 +857,8 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -933,6 +937,8 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -974,6 +980,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, sliding_window=self.sliding_window, sinks=self.sinks, @@ -988,6 +996,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, sliding_window=self.sliding_window, ) attn_output, attn_weights = eager_attention_forward( @@ -1014,6 +1024,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -1033,6 +1045,8 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -1074,6 +1088,8 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, @@ -1134,6 +1150,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_value=past_key_values, batch_index=batch_index, use_cache=use_cache, @@ -1174,6 +1192,8 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -1236,6 +1256,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_value=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -1282,6 +1304,8 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -1334,6 +1358,8 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 7082385c4..6ea84221a 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -143,6 +143,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, ) else: @@ -155,6 +157,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, ) attn_output, attn_weights = eager_attention_forward( self, @@ -181,6 +185,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, batch_index: Optional[torch.LongTensor] = None, @@ -221,6 +227,8 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_value, output_attentions=output_attentions, batch_index=batch_index, @@ -258,6 +266,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -316,6 +326,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_value=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -369,6 +381,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -422,6 +436,8 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index a10486606..c9da1bd17 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -132,6 +132,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, ) else: @@ -144,6 +146,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, ) attn_output, attn_weights = eager_attention_forward( self, @@ -197,6 +201,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -241,6 +247,8 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -290,6 +298,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -350,6 +360,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_value=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -364,6 +376,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_value=past_key_values, comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, @@ -574,6 +588,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -622,6 +638,8 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 811b3f84d..2b870bd80 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -112,6 +112,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -153,6 +155,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, ) else: @@ -165,6 +169,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, ) attn_output, attn_weights = eager_attention_forward( self, @@ -193,6 +199,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -211,6 +219,8 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -246,6 +256,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -297,6 +309,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_value=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -342,6 +356,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -359,6 +375,8 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index e73bf3153..238af1a83 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -119,6 +119,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -161,6 +163,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, ) else: @@ -173,6 +177,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, ) attn_output, attn_weights = eager_attention_forward( self, @@ -201,6 +207,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -234,6 +242,8 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -272,6 +282,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -335,6 +347,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_value=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -385,6 +399,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -409,6 +425,8 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index ed73cb388..626130a98 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -117,6 +117,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -158,6 +160,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, ) else: @@ -170,6 +174,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, ) attn_output, attn_weights = eager_attention_forward( self, @@ -280,6 +286,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -320,6 +328,8 @@ def forward( position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -367,6 +377,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -480,6 +492,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -504,6 +518,8 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 869868546..1eb51fc86 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3280,8 +3280,8 @@ def export( max_blocks = max(max_blocks, num_blocks) block_size = -(-seq_len // max_blocks) seq_len = block_size * max_blocks - num_kv_blocks = self.hash_params["blocking_config"].num_kv_blocks - self.supports_paged_attention = "paged" in self.hash_params["blocking_config"].mode + num_kv_blocks = self.hash_params["blocking_kwargs"].num_kv_blocks + self.supports_paged_attention = "paged" in self.hash_params["blocking_kwargs"].mode fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS kv_cache_shape = get_padding_shape_from_config( @@ -3608,8 +3608,8 @@ def build_prefill_specialization( if full_batch_size: spec["full_batch_exec_size"] = exec_batch_size if self.hash_params.get("blocking_kwargs", None): - if "paged" in self.hash_params["blocking_config"].mode - num_kv_blocks = self.hash_params["blocking_config"].num_kv_blocks + if "paged" in self.hash_params["blocking_kwargs"].mode: + num_kv_blocks = self.hash_params["blocking_kwargs"].num_kv_blocks spec["num_kv_blocks"] = num_kv_blocks spec["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks spec["kv_block_size"] = (-ctx_len) // (-num_kv_blocks) @@ -3677,8 +3677,8 @@ def build_decode_specialization( else: spec["batch_size"] = kv_cache_batch_size if self.hash_params.get("blocking_kwargs", None): - if "paged" in self.hash_params["blocking_config"].mode - num_kv_blocks = self.hash_params["blocking_config"].num_kv_blocks + if "paged" in self.hash_params["blocking_kwargs"].mode: + num_kv_blocks = self.hash_params["blocking_kwargs"].num_kv_blocks spec["num_kv_blocks"] = num_kv_blocks spec["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks spec["kv_block_size"] = (-ctx_len) // (-num_kv_blocks) diff --git a/QEfficient/transformers/models/mpt/modeling_mpt.py b/QEfficient/transformers/models/mpt/modeling_mpt.py index c92467384..48a6e16d4 100644 --- a/QEfficient/transformers/models/mpt/modeling_mpt.py +++ b/QEfficient/transformers/models/mpt/modeling_mpt.py @@ -43,6 +43,8 @@ def forward( hidden_states: torch.Tensor, position_bias: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, @@ -86,6 +88,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, position_bias=position_bias, ) @@ -102,6 +106,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, ) attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale @@ -149,6 +155,8 @@ def forward( position_bias: torch.Tensor, attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, @@ -166,6 +174,8 @@ def forward( layernorm_output, position_bias=position_bias, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, batch_index=batch_index, attention_mask=attention_mask, past_key_value=layer_past, @@ -199,6 +209,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -261,6 +273,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, batch_index=batch_index, use_cache=use_cache, output_attentions=output_attentions, @@ -316,6 +330,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, @@ -338,6 +354,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index d6f8d4bb7..52af064bc 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -127,6 +127,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -162,6 +164,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, ) else: @@ -174,6 +178,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, ) attn_output, attn_weights = eager_attention_forward( self, @@ -203,6 +209,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -237,6 +245,8 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -275,6 +285,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -330,6 +342,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_value=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -377,6 +391,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -394,6 +410,8 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 357c4af16..39ee13276 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -438,6 +438,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -481,6 +483,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids[0], + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, ) else: @@ -493,6 +497,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids[0], + block_table=block_table, + slot_id=slot_id, ) attn_output, attn_weights = eager_attention_forward( self, @@ -520,6 +526,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -562,6 +570,8 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -602,6 +612,8 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -654,6 +666,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_value=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -695,6 +709,8 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -718,6 +734,8 @@ def forward( outputs = self.language_model( input_ids=None, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, attention_mask=attention_mask, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, @@ -787,6 +805,8 @@ def forward( position_ids, image_idx, past_key_values, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, comp_ctx_lengths: Optional[List[int]] = None, ): @@ -802,6 +822,8 @@ def forward( outputs = self.model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index 9844c9101..535120825 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -129,6 +129,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -164,6 +166,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, ) else: @@ -176,6 +180,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, ) attn_output, attn_weights = eager_attention_forward( self, @@ -205,6 +211,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -239,6 +247,8 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -277,6 +287,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -332,6 +344,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_value=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -379,6 +393,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -398,6 +414,8 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 7794e752e..9c073ae1c 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -306,6 +306,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[torch.Tensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -342,6 +344,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, ) else: @@ -354,6 +358,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, ) attn_output, attn_weights = eager_attention_forward( self, @@ -375,6 +381,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -409,6 +417,8 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -447,6 +457,8 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, @@ -505,6 +517,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_value=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -546,6 +560,8 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, @@ -564,6 +580,8 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 0f6ab210d..548263da5 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -378,6 +378,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -415,6 +417,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids[0], + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, ) else: @@ -427,6 +431,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids[0], + block_table=block_table, + slot_id=slot_id, ) attn_output, attn_weights = eager_attention_forward( self, @@ -453,6 +459,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[torch.Tensor]] = None, batch_index: Optional[torch.LongTensor] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, @@ -495,6 +503,8 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, batch_index=batch_index, comp_ctx_lengths=comp_ctx_lengths, @@ -534,6 +544,8 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -590,6 +602,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -697,6 +711,8 @@ def forward( position_ids, image_idx, past_key_values, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, comp_ctx_lengths: Optional[List[int]] = None, ): @@ -726,6 +742,8 @@ def forward( outputs = self.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -748,6 +766,8 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -771,6 +791,8 @@ def forward( outputs = self.language_model( input_ids=None, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, attention_mask=attention_mask, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, @@ -806,6 +828,8 @@ def forward( input_ids, position_ids, past_key_values, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_idx: Optional[torch.LongTensor] = None, comp_ctx_lengths: Optional[List[int]] = None, @@ -827,6 +851,8 @@ def forward( outputs = self.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index cc6899fc1..dbdfa3606 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -368,6 +368,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -404,6 +406,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids[0], + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, ) else: @@ -416,6 +420,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids[0], + block_table=block_table, + slot_id=slot_id, ) attn_output, attn_weights = eager_attention_forward( self, @@ -442,6 +448,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[torch.Tensor]] = None, batch_index: Optional[torch.LongTensor] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, @@ -483,6 +491,8 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, batch_index=batch_index, comp_ctx_lengths=comp_ctx_lengths, @@ -528,6 +538,8 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -592,6 +604,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -698,6 +712,8 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -721,6 +737,8 @@ def forward( outputs = self.language_model( input_ids=None, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, attention_mask=attention_mask, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, @@ -810,6 +828,8 @@ def forward( position_ids=None, image_idx=None, past_key_values=None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, comp_ctx_lengths: Optional[List[int]] = None, ): @@ -846,6 +866,8 @@ def forward( outputs = self.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -865,6 +887,8 @@ def forward( outputs = self.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -881,6 +905,8 @@ def forward( outputs = self.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, diff --git a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py index 8ebe32faf..1640029fc 100644 --- a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py +++ b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py @@ -74,6 +74,8 @@ def forward( position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -107,6 +109,8 @@ def forward( comp_ctx_length=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_seen_tokens=past_seen_tokens, ) else: @@ -119,6 +123,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, ) attn_output, attn_weights = eager_attention_forward( self, @@ -148,6 +154,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -184,6 +192,8 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -216,6 +226,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -270,6 +282,8 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_value=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, @@ -316,6 +330,8 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -333,6 +349,8 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + block_table=block_table, + slot_id=slot_id, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, From e580faf03e8d87b18b51e10198a280da695da179 Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Sat, 9 May 2026 05:12:33 -0500 Subject: [PATCH 05/14] Working version with PagedAttention Signed-off-by: Vaibhav Verma --- .../blocking/blocked_attention_forwards.py | 16 ++++---- .../generation/text_generation_inference.py | 37 +++++++++++++++++-- QEfficient/transformers/cache_utils.py | 4 ++ .../transformers/models/modeling_auto.py | 24 ++++++------ 4 files changed, 56 insertions(+), 25 deletions(-) diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index 80f40200f..a8f147e54 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -273,8 +273,8 @@ def blocked_kv_paged_attention_forward( if skip_future.item(): break - block_index = block_table[:, i] - updated = (position_ids.max(1, keepdim=True).values // kv_block_size) == i + block_index = block_table[:, j] + updated = (position_ids.max(1, keepdim=True).values // kv_block_size) == j k_block, v_block = past_key_value.read_only_pagedAttention(block_index, updated, layer_idx, cache_kwargs) k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) @@ -544,8 +544,8 @@ def blocked_qkv_paged_attention_forward( if skip_future.item(): break - block_index = block_table[:, i] - updated = (position_ids.max(1, keepdim=True).values // kv_block_size) == i + block_index = block_table[:, j] + updated = (position_ids.max(1, keepdim=True).values // kv_block_size) == j k_block, v_block = past_key_value.read_only_pagedAttention(block_index, updated, layer_idx, cache_kwargs) k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) @@ -855,8 +855,8 @@ def blocked_hqkv_paged_attention_forward( if skip_future.item(): break - block_index = block_table[:, i] - updated = (position_ids.max(1, keepdim=True).values // kv_block_size) == i + block_index = block_table[:, j] + updated = (position_ids.max(1, keepdim=True).values // kv_block_size) == j k_block, v_block = past_key_value.read_only_pagedAttention(block_index, updated, layer_idx, cache_kwargs) k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) @@ -1218,8 +1218,8 @@ def blocked_bhqkv_paged_attention_forward( if skip_future.item(): break - block_index = block_table[:, i] - updated = (position_ids.max(1, keepdim=True).values // kv_block_size) == i + block_index = block_table[:, j] + updated = (position_ids.max(1, keepdim=True).values // kv_block_size) == j k_block, v_block = past_key_value.read_only_pagedAttention(block_index, updated, layer_idx, cache_kwargs) k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index fcb969886..ebb69edd2 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -195,7 +195,9 @@ def get_compilation_dims(qpc_path: str) -> Tuple[int, int, Optional[int]]: compilation_ctx_len = int(spec["ctx_len"]) if compilation_fbs := spec.get("full_batch_size", None): compilation_fbs = int(compilation_fbs) - return compilation_batch_size, compilation_ctx_len, compilation_fbs + if compilation_num_kv_blocks := spec.get("num_kv_blocks", None): + compilation_num_kv_blocks = int(compilation_num_kv_blocks) + return compilation_batch_size, compilation_ctx_len, compilation_fbs, compilation_num_kv_blocks def get_input_prompts(prompt: str, prompts_txt_file_path: str) -> List[str]: @@ -380,7 +382,7 @@ def cloud_ai_100_exec_kv( exec_info = QEfficient.cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc_path=qpc_path, prompt="Hi there!!", device_id=[0]) """ - batch_size, ctx_len, full_batch_size = get_compilation_dims(qpc_path) + batch_size, ctx_len, full_batch_size, num_kv_blocks = get_compilation_dims(qpc_path) prompt: List[str] = get_input_prompts(prompt, prompts_txt_file_path) prompt = fix_prompts(prompt, batch_size, full_batch_size) if prompt_to_lora_id_mapping is not None: @@ -397,6 +399,7 @@ def cloud_ai_100_exec_kv( enable_debug_logs=enable_debug_logs, write_io_dir=write_io_dir, full_batch_size=full_batch_size, + num_kv_blocks=num_kv_blocks, is_tlm=is_tlm, include_sampler=include_sampler, return_pdfs=return_pdfs, @@ -440,6 +443,7 @@ def __init__( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], qpc_path: str, full_batch_size: Optional[int] = None, + num_kv_blocks: Optional[int] = None, ctx_len: Optional[int] = None, comp_ctx_lengths_prefill: Optional[List[int]] = None, comp_ctx_lengths_decode: Optional[List[int]] = None, @@ -482,6 +486,8 @@ def __init__( self.full_batch_size = ( full_batch_size if full_batch_size else self._fetch_full_batch_size() ) # Check and fetch full batch size if CB is enabled + self.num_kv_blocks = num_kv_blocks if num_kv_blocks else 1 + self.kv_block_size = -(-self._ctx_len // self.num_kv_blocks) if num_kv_blocks else 1 # Initialize the storage variables. self.batch_index = None @@ -630,6 +636,11 @@ def prepare_decode_inputs(self): """ batch_size = self.full_batch_size if self.full_batch_size is not None else self.batch_size decode_inputs = {} + if self.num_kv_blocks: + decode_inputs["block_table"] = np.arange(batch_size * self.num_kv_blocks, dtype=np.int64).reshape( + batch_size, self.num_kv_blocks + ) + decode_inputs["slot_id"] = (self.decode_pos_ids % self.kv_block_size).reshape(batch_size) if self.is_tlm: position_ids = np.full((batch_size, self._decode_seq_len), -1, dtype=np.int64) position_ids[:, -1] = self.decode_pos_ids.flatten() @@ -744,10 +755,12 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len): """ for decode_batch_id in range(self.full_batch_size): next_prompt = prompt_queue.popleft() + block_table = self.block_table[decode_batch_id].reshape(1, -1) if self.num_kv_blocks else None # run prefill for num_chunks outputs, position_ids, generation_len = self.run_prefill( - next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1) + next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), + block_table=block_table, ) _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) @@ -769,7 +782,7 @@ def _set_output_buffers(self, batch_size: int = 1, sequence_length: int = 1): logits_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) self._session.set_buffers({"logits": logits_out_placeholder}) - def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): + def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None, block_table=None): """ Runs prefill for a given prompt and generation length. @@ -805,6 +818,10 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) inputs.pop("token_type_ids", None) + if block_table is not None: + inputs["block_table"] = block_table + inputs["slot_id"] = np.zeros((prefill_logit_bs), dtype=np.int64) + if decode_batch_id is not None: inputs["batch_index"] = decode_batch_id @@ -967,6 +984,9 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): # If the generated sequence is valid and within generation len prepare for next decode decode_inputs["input_ids"][decode_batch_id, -1] = next_token_id[decode_batch_id, -1] decode_inputs["position_ids"][decode_batch_id][..., -1] += 1 + if self.num_kv_blocks: + decode_inputs["slot_id"][decode_batch_id] += 1 + decode_inputs["slot_id"][decode_batch_id] %= self.kv_block_size self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = ( next_token_id[decode_batch_id, -1] ) @@ -1031,6 +1051,9 @@ def run_decode( # Prepare inputs for next iteration decode_inputs["input_ids"] = self._fetch_next_token_id(outputs) decode_inputs["position_ids"][:, -1] += 1 + if self.num_kv_blocks: + decode_inputs["slot_id"][:] += 1 + decode_inputs["slot_id"][:] %= self.kv_block_size cache_index += 1 self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1] finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id @@ -1080,6 +1103,7 @@ def __init__( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], qpc_path: str, full_batch_size: Optional[int] = None, + num_kv_blocks: Optional[int] = None, ctx_len: Optional[int] = None, comp_ctx_lengths_prefill: Optional[List[int]] = None, comp_ctx_lengths_decode: Optional[List[int]] = None, @@ -1096,6 +1120,7 @@ def __init__( tokenizer=tokenizer, qpc_path=qpc_path, full_batch_size=full_batch_size, + num_kv_blocks=num_kv_blocks, ctx_len=ctx_len, comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, comp_ctx_lengths_decode=comp_ctx_lengths_decode, @@ -1109,6 +1134,7 @@ def __init__( sampling_params=sampling_params, ) self._full_batch_size = self._qaic_model.full_batch_size + self._num_kv_blocks = self._qaic_model.num_kv_blocks self._tokenizer = self._qaic_model.tokenizer self._ctx_len = ctx_len self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill @@ -1138,6 +1164,9 @@ def _setup_model_execution_inputs( self._full_batch_size if self._full_batch_size is not None else self._qaic_model.batch_size ) max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) + self._qaic_model.block_table = np.arange( + execution_batch_size * self._num_kv_blocks, dtype=np.int64 + ).reshape(execution_batch_size, self._num_kv_blocks) if self._num_kv_blocks else None # Create a prompt queue. self._prompt_queue = deque(prompt) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 12754aab2..f09a6df48 100755 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -267,6 +267,8 @@ def read_only_pagedAttention(self, block_index, updated, cache_kwargs): """ # Gather k_out, v_out = self.keys, self.values + if k_out is not None: + self._mark_initialized(k_out) position_ids = cache_kwargs.get("position_ids") batch, seq_len = position_ids.shape slot_id = cache_kwargs.get("slot_id", None) @@ -306,7 +308,9 @@ def write_only_pagedAttention(self, key_states, value_states, cache_kwargs): if self.keys is None: self.keys = key_states self.values = value_states + self._mark_initialized(self.keys) else: + self._mark_initialized(self.keys) position_ids = cache_kwargs.get("position_ids") block_table = cache_kwargs.get("block_table") # [BS, num_kv_blocks/BS] -> each entry is block_id value slot_id = cache_kwargs.get("slot_id") diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 1eb51fc86..1270d45c5 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3282,6 +3282,7 @@ def export( seq_len = block_size * max_blocks num_kv_blocks = self.hash_params["blocking_kwargs"].num_kv_blocks self.supports_paged_attention = "paged" in self.hash_params["blocking_kwargs"].mode + seq_len = kv_block_size = -(-seq_len // num_kv_blocks) if self.supports_paged_attention else seq_len fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS kv_cache_shape = get_padding_shape_from_config( @@ -3369,7 +3370,6 @@ def export( if self.supports_paged_attention: batch, num_kv_heads, CL, dh = kv_cache_shape total_num_kv_blocks = batch * num_kv_blocks - kv_block_size = (-CL) // (-num_kv_blocks) kv_cache_shape = [total_num_kv_blocks, num_kv_heads, kv_block_size, dh] example_inputs["block_table"] = torch.arange((bs * num_kv_blocks), dtype=torch.int64).view(bs, num_kv_blocks) example_inputs["slot_id"] = torch.zeros(bs, dtype=torch.int64) @@ -3607,12 +3607,11 @@ def build_prefill_specialization( # TODO: remove this; not required if full_batch_size: spec["full_batch_exec_size"] = exec_batch_size - if self.hash_params.get("blocking_kwargs", None): - if "paged" in self.hash_params["blocking_kwargs"].mode: - num_kv_blocks = self.hash_params["blocking_kwargs"].num_kv_blocks - spec["num_kv_blocks"] = num_kv_blocks - spec["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks - spec["kv_block_size"] = (-ctx_len) // (-num_kv_blocks) + if "paged" in self.model.qaic_config["blocking_mode"]: + num_kv_blocks = self.model.qaic_config["num_kv_blocks"] + spec["num_kv_blocks"] = num_kv_blocks + spec["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks + spec["kv_block_size"] = -(-ctx_len // num_kv_blocks) result = {k: v for k, v in spec.items() if v is not None} result["_graph_name"] = "Decode" if prefill_seq_len == 1 and kwargs.get("prefill_only") is False else "Prefill" return result @@ -3676,12 +3675,11 @@ def build_decode_specialization( spec["full_batch_size"] = kv_cache_batch_size else: spec["batch_size"] = kv_cache_batch_size - if self.hash_params.get("blocking_kwargs", None): - if "paged" in self.hash_params["blocking_kwargs"].mode: - num_kv_blocks = self.hash_params["blocking_kwargs"].num_kv_blocks - spec["num_kv_blocks"] = num_kv_blocks - spec["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks - spec["kv_block_size"] = (-ctx_len) // (-num_kv_blocks) + if "paged" in self.model.qaic_config["blocking_mode"]: + num_kv_blocks = self.model.qaic_config["num_kv_blocks"] + spec["num_kv_blocks"] = num_kv_blocks + spec["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks + spec["kv_block_size"] = -(-ctx_len // num_kv_blocks) result = {k: v for k, v in spec.items() if v is not None} result["_graph_name"] = "Decode" return result From 7ddff49a67b775c23f9a5a02af7303850af66c65 Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Sat, 9 May 2026 05:18:52 -0500 Subject: [PATCH 06/14] Minor fixes to specialization builder Signed-off-by: Vaibhav Verma --- QEfficient/transformers/models/modeling_auto.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 1270d45c5..23fc06eee 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3607,7 +3607,7 @@ def build_prefill_specialization( # TODO: remove this; not required if full_batch_size: spec["full_batch_exec_size"] = exec_batch_size - if "paged" in self.model.qaic_config["blocking_mode"]: + if self.model.qaic_config is not None and "paged" in self.model.qaic_config.get("blocking_mode", None): num_kv_blocks = self.model.qaic_config["num_kv_blocks"] spec["num_kv_blocks"] = num_kv_blocks spec["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks @@ -3675,7 +3675,7 @@ def build_decode_specialization( spec["full_batch_size"] = kv_cache_batch_size else: spec["batch_size"] = kv_cache_batch_size - if "paged" in self.model.qaic_config["blocking_mode"]: + if self.model.qaic_config is not None and "paged" in self.model.qaic_config.get("blocking_mode", None): num_kv_blocks = self.model.qaic_config["num_kv_blocks"] spec["num_kv_blocks"] = num_kv_blocks spec["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks From f1f0bec53853d5cda05bdf9a05ac4644a26fef9b Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Tue, 12 May 2026 17:35:16 -0500 Subject: [PATCH 07/14] Added support for Qwen2.5_VL PagedAttention Signed-off-by: Vaibhav Verma --- .../transformers/models/modeling_auto.py | 17 ++++- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 72 ++++++++++++++----- 2 files changed, 72 insertions(+), 17 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 23fc06eee..909f30eb2 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1088,6 +1088,11 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): super().__init__(model, **kwargs) self.model = model.get_qeff_language_decoder() self.model.qaic_config = qaic_config + # Below is to pass qaic_config downstream + if hasattr(self.model, "model"): + self.model.model.qaic_config = qaic_config + if hasattr(self.model.model, "model"): + self.model.model.model.qaic_config = qaic_config self.hash_params["qeff_auto_class"] = self.__class__.__name__ self.continuous_batching = False @@ -1892,7 +1897,7 @@ def kv_offload_generate( if self.vision_model.qpc_path: vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids) - batch_size, ctx_len, fbs = get_compilation_dims(self.lang_model.qpc_path) + batch_size, ctx_len, fbs, num_kv_blocks = get_compilation_dims(self.lang_model.qpc_path) pad_token_id = 1 @@ -2003,6 +2008,13 @@ def kv_offload_generate( prefill_ccl_id = 0 lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + if num_kv_blocks: + kv_block_size = -(-ctx_len // num_kv_blocks) + lang_inputs["block_table"] = np.arange(batch_size * num_kv_blocks, dtype=np.int64).reshape( + batch_size, num_kv_blocks + ) + lang_inputs["slot_id"] = np.zeros(batch_size, dtype=np.int64) + lang_start = perf_counter() # Run prefill chunk_inputs = lang_inputs.copy() @@ -2097,6 +2109,9 @@ def kv_offload_generate( lang_inputs["mm_token_type_ids"] = np.zeros_like( lang_inputs["input_ids"], dtype=lang_inputs["mm_token_type_ids"].dtype ) + if num_kv_blocks: + lang_inputs["slot_id"] += 1 + lang_inputs["slot_id"] %= kv_block_size generated_ids[:, num_token] = lang_inputs["input_ids"].squeeze(1) if streamer: streamer.put(lang_inputs["input_ids"][0]) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 39ee13276..6c1aade90 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -764,6 +764,7 @@ def __init__(self, model): super().__init__() self.model = model.model self.model.vision_model = self.model.visual + self.config = model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ @@ -788,6 +789,7 @@ def __init__(self, model): super().__init__() self.model = model self.language_model = self.model.model.language_model + self.config = model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ @@ -880,29 +882,42 @@ def get_dummy_inputs( lang_inputs = {} vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=self.config.torch_dtype) vision_inputs["image_grid_thw"] = torch.zeros((inputs_shapes["image_grid_thw"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + + # Add data for KV + kv_cache_shape = get_padding_shape_from_config( + config=self.model.config.text_config, + batch_size=fbs if continuous_batching else bs, + seq_len=prefill_seq_len, + ) + lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) + qaic_config = getattr(self.model, "qaic_config", None) + blocking_mode = qaic_config.get("blocking_mode") if qaic_config is not None else None + if blocking_mode is not None and "paged" in blocking_mode: + num_kv_blocks = qaic_config["num_kv_blocks"] + batch, num_kv_heads, CL, dh = kv_cache_shape + prefill_seq_len = kv_block_size = -(-CL // num_kv_blocks) + total_num_kv_blocks = batch * num_kv_blocks + kv_cache_shape = [total_num_kv_blocks, num_kv_heads, kv_block_size, dh] + lang_inputs["block_table"] = torch.arange((bs * num_kv_blocks), dtype=torch.int64).view(bs, num_kv_blocks) + lang_inputs["slot_id"] = torch.zeros(bs, dtype=torch.int64) + lang_inputs["input_ids"] = torch.zeros((bs, prefill_seq_len), dtype=torch.int64) + lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=self.config.torch_dtype) lang_inputs["position_ids"] = ( ( torch.arange(prefill_seq_len, dtype=torch.int64) .view(1, prefill_seq_len) - .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) + .repeat(bs, 1) ) .unsqueeze(0) .repeat(4, 1, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS - - # Add data for KV - kv_cache_shape = get_padding_shape_from_config( - config=self.model.config.text_config, - batch_size=fbs if continuous_batching else bs, - seq_len=prefill_seq_len, - ) - lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)] for i in range(self.model.config.text_config.num_hidden_layers): for kv in ["key", "value"]: @@ -1076,6 +1091,14 @@ def smart_resize( if full_batch_size: lang_prefill["full_batch_exec_size"] = full_batch_size + qaic_config = getattr(self.model, "qaic_config", None) + blocking_mode = qaic_config.get("blocking_mode") if qaic_config is not None else None + if blocking_mode is not None and "paged" in blocking_mode: + num_kv_blocks = self.model.qaic_config["num_kv_blocks"] + lang_prefill["num_kv_blocks"] = num_kv_blocks + lang_prefill["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks + lang_prefill["kv_block_size"] = -(-ctx_len // num_kv_blocks) + lang_decode = { "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": 1, @@ -1089,6 +1112,14 @@ def smart_resize( else: lang_decode["batch_size"] = kv_cache_batch_size + qaic_config = getattr(self.model, "qaic_config", None) + blocking_mode = qaic_config.get("blocking_mode") if qaic_config is not None else None + if blocking_mode is not None and "paged" in blocking_mode: + num_kv_blocks = self.model.qaic_config["num_kv_blocks"] + lang_decode["num_kv_blocks"] = num_kv_blocks + lang_decode["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks + lang_decode["kv_block_size"] = -(-ctx_len // num_kv_blocks) + lang = [lang_prefill, lang_decode] specializations = {} @@ -1119,16 +1150,25 @@ def get_onnx_dynamic_axes( "vision_embeds": {0: "vision_batch_size", 1: "vision_size"}, } - for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = { + pkv_dynamic_axes = { 0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len", } - lang_dynamic_axes[f"past_value.{i}"] = { - 0: "full_batch_size" if continuous_batching else "batch_size", - 2: "ctx_len", + + qaic_config = getattr(self.model, "qaic_config", None) + blocking_mode = qaic_config.get("blocking_mode") if qaic_config is not None else None + if blocking_mode is not None and "paged" in blocking_mode: + lang_dynamic_axes["block_table"] = {0: "batch_size", 1: "num_kv_blocks"} + lang_dynamic_axes["slot_id"] = {0: "batch_size"} + pkv_dynamic_axes = { + 0: "total_num_kv_blocks", + 2: "kv_block_size", } + for i in range(num_layers): + lang_dynamic_axes[f"past_key.{i}"] = pkv_dynamic_axes + lang_dynamic_axes[f"past_value.{i}"] = pkv_dynamic_axes + if continuous_batching: lang_dynamic_axes["batch_index"] = {0: "batch_size"} From 4bc8b91cf1e9e5d379967fe1c4cdf964cc4cf0a4 Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Tue, 12 May 2026 20:36:49 -0500 Subject: [PATCH 08/14] slot_id fix for Qwen2.5_VL PagedAttention decode Signed-off-by: Vaibhav Verma --- QEfficient/transformers/models/modeling_auto.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 909f30eb2..9b8a06333 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2067,6 +2067,8 @@ def kv_offload_generate( lang_inputs["mm_token_type_ids"] = np.zeros_like( lang_inputs["input_ids"], dtype=lang_inputs["mm_token_type_ids"].dtype ) + if num_kv_blocks: + lang_inputs["slot_id"] = (np.max(lang_inputs["position_ids"]) % kv_block_size).reshape(batch_size) if "cross_attention_mask" in lang_inputs: bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape lang_inputs["cross_attention_mask"] = torch.ones((bs, 1, num_images, img_tiles), dtype=torch.int64).numpy() From ad02fbb3651d6af332bf08ad33c29446db9d695a Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Tue, 12 May 2026 20:38:12 -0500 Subject: [PATCH 09/14] Added support for Qwen3_VL PagedAttention Signed-off-by: Vaibhav Verma --- .../models/qwen3_vl/modeling_qwen3_vl.py | 73 +++++++++++++++---- 1 file changed, 57 insertions(+), 16 deletions(-) diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 548263da5..829b176fc 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -666,6 +666,7 @@ def __init__(self, model): super().__init__() self.model = model.model self.model.vision_model = self.model.visual + self.config = model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ @@ -693,6 +694,7 @@ def __init__(self, model): super().__init__() self.model = model self.language_model = self.model.model.language_model + self.config = model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ @@ -909,15 +911,38 @@ def get_dummy_inputs( (inputs_shapes["pixel_values"]), dtype=self.model.config.torch_dtype ) vision_inputs["image_grid_thw"] = torch.zeros((inputs_shapes["image_grid_thw"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + + # Add data for KV + kv_cache_shape = get_padding_shape_from_config( + config=self.model.config.text_config, + batch_size=fbs if continuous_batching else bs, + seq_len=prefill_seq_len, + ) + lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) lang_inputs["vision_embeds"] = torch.zeros( (inputs_shapes["vision_embeds"]), dtype=self.model.config.torch_dtype ) + qaic_config = getattr(self.model, "qaic_config", None) + blocking_mode = qaic_config.get("blocking_mode") if qaic_config is not None else None + if blocking_mode is not None and "paged" in blocking_mode: + num_kv_blocks = qaic_config["num_kv_blocks"] + batch, num_kv_heads, CL, dh = kv_cache_shape + prefill_seq_len = kv_block_size = -(-CL // num_kv_blocks) + total_num_kv_blocks = batch * num_kv_blocks + kv_cache_shape = [total_num_kv_blocks, num_kv_heads, kv_block_size, dh] + lang_inputs["block_table"] = torch.arange((bs * num_kv_blocks), dtype=torch.int64).view(bs, num_kv_blocks) + lang_inputs["slot_id"] = torch.zeros(bs, dtype=torch.int64) + lang_inputs["input_ids"] = torch.zeros((bs, prefill_seq_len), dtype=torch.int64) + lang_inputs["position_ids"] = ( ( torch.arange(prefill_seq_len, dtype=torch.int64) .view(1, prefill_seq_len) - .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) + .repeat(bs, 1) ) .unsqueeze(0) .repeat(4, 1, 1) @@ -926,16 +951,7 @@ def get_dummy_inputs( lang_inputs["deepstack_features"] = torch.zeros( (inputs_shapes["deepstack_features"]), dtype=self.model.config.torch_dtype ) - # Add data for KV - - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS - - kv_cache_shape = get_padding_shape_from_config( - config=self.model.config.text_config, - batch_size=fbs if continuous_batching else bs, - seq_len=prefill_seq_len, - ) + lang_inputs["deepstack_features"] = torch.zeros((inputs_shapes["deepstack_features"]), dtype=torch.float32) lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)] for i in range(self.model.config.text_config.num_hidden_layers): @@ -1116,6 +1132,14 @@ def smart_resize( if full_batch_size: lang_prefill["full_batch_exec_size"] = full_batch_size + qaic_config = getattr(self.model, "qaic_config", None) + blocking_mode = qaic_config.get("blocking_mode") if qaic_config is not None else None + if blocking_mode is not None and "paged" in blocking_mode: + num_kv_blocks = self.model.qaic_config["num_kv_blocks"] + lang_prefill["num_kv_blocks"] = num_kv_blocks + lang_prefill["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks + lang_prefill["kv_block_size"] = -(-ctx_len // num_kv_blocks) + lang_decode = { "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": 1, @@ -1130,6 +1154,14 @@ def smart_resize( else: lang_decode["batch_size"] = kv_cache_batch_size + qaic_config = getattr(self.model, "qaic_config", None) + blocking_mode = qaic_config.get("blocking_mode") if qaic_config is not None else None + if blocking_mode is not None and "paged" in blocking_mode: + num_kv_blocks = self.model.qaic_config["num_kv_blocks"] + lang_decode["num_kv_blocks"] = num_kv_blocks + lang_decode["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks + lang_decode["kv_block_size"] = -(-ctx_len // num_kv_blocks) + lang = [lang_prefill, lang_decode] specializations = {} @@ -1161,16 +1193,25 @@ def get_onnx_dynamic_axes( "deepstack_features": {0: "num_feature_layers", 1: "vision_batch_size", 2: "vision_size"}, } - for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = { + pkv_dynamic_axes = { 0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len", } - lang_dynamic_axes[f"past_value.{i}"] = { - 0: "full_batch_size" if continuous_batching else "batch_size", - 2: "ctx_len", + + qaic_config = getattr(self.model, "qaic_config", None) + blocking_mode = qaic_config.get("blocking_mode") if qaic_config is not None else None + if blocking_mode is not None and "paged" in blocking_mode: + lang_dynamic_axes["block_table"] = {0: "batch_size", 1: "num_kv_blocks"} + lang_dynamic_axes["slot_id"] = {0: "batch_size"} + pkv_dynamic_axes = { + 0: "total_num_kv_blocks", + 2: "kv_block_size", } + for i in range(num_layers): + lang_dynamic_axes[f"past_key.{i}"] = pkv_dynamic_axes + lang_dynamic_axes[f"past_value.{i}"] = pkv_dynamic_axes + if continuous_batching: lang_dynamic_axes["batch_index"] = {0: "batch_size"} From bd428adad1bee7dbd20cb53a3ff97e9ef619fda3 Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Wed, 13 May 2026 02:37:30 -0500 Subject: [PATCH 10/14] Removed commented code corrected in rebase Signed-off-by: Vaibhav Verma --- QEfficient/blocking/blocking_configurator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/QEfficient/blocking/blocking_configurator.py b/QEfficient/blocking/blocking_configurator.py index 05b0d21b2..bfc08977c 100644 --- a/QEfficient/blocking/blocking_configurator.py +++ b/QEfficient/blocking/blocking_configurator.py @@ -69,8 +69,6 @@ def _normalize_attention_mode(raw_mode: str) -> str: def _resolve_effective_blocking_mode(attention_cfg: Dict[str, Any], requested_mode: str) -> str: mode = _normalize_attention_mode(requested_mode) - #if mode == "": #this should not be here since it will prevent 'h' and 'b' modes - #return "" num_q_blocks = attention_cfg.get("num_q_blocks") or 1 num_kv_blocks = attention_cfg.get("num_kv_blocks") or 1 head_block_size = (attention_cfg.get("head_block_size") or 1) if attention_cfg.get("head_blocking_enabled") else 1 From a32fd471baf95b64c8458b0d182882594a2ef553 Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Wed, 13 May 2026 13:15:04 -0500 Subject: [PATCH 11/14] Minor fix for enum bug Signed-off-by: Vaibhav Verma --- QEfficient/blocking/attention_blocking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/blocking/attention_blocking.py b/QEfficient/blocking/attention_blocking.py index d4e941bd4..6a6de70d5 100644 --- a/QEfficient/blocking/attention_blocking.py +++ b/QEfficient/blocking/attention_blocking.py @@ -40,7 +40,7 @@ class BlockingMode(str, Enum): QKV_PAGED = "qkv_paged" HQ = "hq" HKV = "hkv" - HKV = "hkv_paged" + HKV_PAGED = "hkv_paged" HQKV = "hqkv" HQKV_PAGED = "hqkv_paged" BHQKV = "bhqkv" From 51c04f61e799be96a480b8832fbff68785576e54 Mon Sep 17 00:00:00 2001 From: Anuj Gupta Date: Wed, 13 May 2026 21:21:11 +0530 Subject: [PATCH 12/14] Optimize attention blocking nested loops (#957) changed the code from doing the exact same math repeatedly. Signed-off-by: Anuj Gupta Signed-off-by: Vaibhav Verma --- QEfficient/blocking/blocking_configurator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/QEfficient/blocking/blocking_configurator.py b/QEfficient/blocking/blocking_configurator.py index bfc08977c..bd36b30cc 100644 --- a/QEfficient/blocking/blocking_configurator.py +++ b/QEfficient/blocking/blocking_configurator.py @@ -248,6 +248,7 @@ def update_best_config(num_q_blocks: int, num_kv_blocks: int, q_kv_ratio: float, qkv_size_per_nsp = q_size_per_nsp # For KV Blocking for loop, q input and qkv output should be persistent in VTCM vtcm_footprint = q_size_per_nsp + kv_size_per_nsp + qk_size_per_nsp + qkv_size_per_nsp + q_kv_ratio = max(q_size_per_nsp / kv_size_per_nsp, kv_size_per_nsp / q_size_per_nsp) num_total_blocks = num_q_blocks * num_kv_blocks From 17fdadef46dcfe68b1fe8263b7b2d1d0a881832e Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Wed, 13 May 2026 14:41:51 -0500 Subject: [PATCH 13/14] Adding PagedAttetion support for Qwen3_VL_MOE Signed-off-by: Vaibhav Verma --- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 70 ++++++++++++++----- 1 file changed, 52 insertions(+), 18 deletions(-) diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index dbdfa3606..6dad7e1fc 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -767,6 +767,7 @@ def __init__(self, model): super().__init__() self.model = model.model self.model.vision_model = self.model.visual + self.config = model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ @@ -809,6 +810,7 @@ def __init__(self, model): super().__init__() self.model = model self.language_model = self.model.model.language_model + self.config = model.config def get_submodules_for_export(self) -> Type[nn.Module]: """ @@ -1003,30 +1005,37 @@ def get_dummy_inputs( lang_inputs["vision_embeds"] = torch.zeros( (inputs_shapes["vision_embeds"]), dtype=self.model.config.torch_dtype ) - lang_inputs["position_ids"] = ( - ( - torch.arange(prefill_seq_len, dtype=torch.int64) - .view(1, prefill_seq_len) - .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) - ) - .unsqueeze(0) - .repeat(4, 1, 1) - ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) lang_inputs["deepstack_features"] = torch.zeros( (inputs_shapes["deepstack_features"]), dtype=self.model.config.torch_dtype ) - # Add data for KV bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV kv_cache_shape = get_padding_shape_from_config( config=self.model.config.text_config, batch_size=fbs if continuous_batching else bs, seq_len=prefill_seq_len, ) + qaic_config = getattr(self.model, "qaic_config", None) + blocking_mode = qaic_config.get("blocking_mode") if qaic_config is not None else None + if blocking_mode is not None and "paged" in blocking_mode: + num_kv_blocks = qaic_config["num_kv_blocks"] + batch, num_kv_heads, CL, dh = kv_cache_shape + prefill_seq_len = kv_block_size = -(-CL // num_kv_blocks) + total_num_kv_blocks = batch * num_kv_blocks + kv_cache_shape = [total_num_kv_blocks, num_kv_heads, kv_block_size, dh] + lang_inputs["block_table"] = torch.arange((bs * num_kv_blocks), dtype=torch.int64).view(bs, num_kv_blocks) + lang_inputs["slot_id"] = torch.zeros(bs, dtype=torch.int64) + lang_inputs["input_ids"] = torch.zeros((bs, prefill_seq_len), dtype=torch.int64) + + lang_inputs["position_ids"] = ( + (torch.arange(prefill_seq_len, dtype=torch.int64).view(1, prefill_seq_len).repeat(bs, 1)).unsqueeze(0).repeat(4, 1, 1) + ) + lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)] for i in range(self.model.config.text_config.num_hidden_layers): for kv in ["key", "value"]: @@ -1206,6 +1215,14 @@ def smart_resize( if full_batch_size: lang_prefill["full_batch_exec_size"] = full_batch_size + qaic_config = getattr(self.model, "qaic_config", None) + blocking_mode = qaic_config.get("blocking_mode") if qaic_config is not None else None + if blocking_mode is not None and "paged" in blocking_mode: + num_kv_blocks = self.model.qaic_config["num_kv_blocks"] + lang_prefill["num_kv_blocks"] = num_kv_blocks + lang_prefill["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks + lang_prefill["kv_block_size"] = -(-ctx_len // num_kv_blocks) + lang_decode = { "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": 1, @@ -1220,6 +1237,14 @@ def smart_resize( else: lang_decode["batch_size"] = kv_cache_batch_size + qaic_config = getattr(self.model, "qaic_config", None) + blocking_mode = qaic_config.get("blocking_mode") if qaic_config is not None else None + if blocking_mode is not None and "paged" in blocking_mode: + num_kv_blocks = self.model.qaic_config["num_kv_blocks"] + lang_decode["num_kv_blocks"] = num_kv_blocks + lang_decode["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks + lang_decode["kv_block_size"] = -(-ctx_len // num_kv_blocks) + lang = [lang_prefill, lang_decode] specializations = {} @@ -1251,16 +1276,25 @@ def get_onnx_dynamic_axes( "deepstack_features": {0: "num_feature_layers", 1: "vision_batch_size", 2: "vision_size"}, } - for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = { - 0: "full_batch_size" if continuous_batching else "batch_size", - 2: "ctx_len", - } - lang_dynamic_axes[f"past_value.{i}"] = { - 0: "full_batch_size" if continuous_batching else "batch_size", - 2: "ctx_len", + pkv_dynamic_axes = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + + qaic_config = getattr(self.model, "qaic_config", None) + blocking_mode = qaic_config.get("blocking_mode") if qaic_config is not None else None + if blocking_mode is not None and "paged" in blocking_mode: + lang_dynamic_axes["block_table"] = {0: "batch_size", 1: "num_kv_blocks"} + lang_dynamic_axes["slot_id"] = {0: "batch_size"} + pkv_dynamic_axes = { + 0: "total_num_kv_blocks", + 2: "kv_block_size", } + for i in range(num_layers): + lang_dynamic_axes[f"past_key.{i}"] = pkv_dynamic_axes + lang_dynamic_axes[f"past_value.{i}"] = pkv_dynamic_axes + if continuous_batching: lang_dynamic_axes["batch_index"] = {0: "batch_size"} From ee2dadb3722ffaf6f685dbc9fcdad6a7866e52df Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Thu, 4 Jun 2026 08:55:00 -0500 Subject: [PATCH 14/14] Adding PagedAttention specific test, unit tests and examples + lint/format cleanup Signed-off-by: Vaibhav Verma --- QEfficient/blocking/attention_blocking.py | 6 +- .../blocking/blocked_attention_forwards.py | 11 +- QEfficient/blocking/blocking_configurator.py | 14 +- .../generation/text_generation_inference.py | 16 +- QEfficient/transformers/cache_utils.py | 13 +- .../models/granite/modeling_granite.py | 2 + .../models/granitemoe/modeling_granitemoe.py | 2 + .../transformers/models/modeling_auto.py | 14 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 6 +- .../models/qwen3_vl/modeling_qwen3_vl.py | 6 +- QEfficient/utils/generate_inputs.py | 73 ++++++- QEfficient/utils/run_utils.py | 4 +- .../qwen3_vl_blocked_paged_attention.py | 158 ++++++++++++++ .../causallm/example_pytorch_transforms.py | 12 +- .../causalLM_paged_attention_example.py | 108 ++++++++++ .../causal_lm_models/check_causal_models.py | 1 + .../test_causal_lm_blocking_hqkv.py | 195 ++++++++++++++++++ .../unit_test/base/test_modeling_qeff_base.py | 4 +- .../unit_test/models/test_model_quickcheck.py | 6 +- tests/unit_test/utils/test_generation.py | 14 +- 20 files changed, 617 insertions(+), 48 deletions(-) create mode 100644 examples/image_text_to_text/models/qwen3vl/qwen3_vl_blocked_paged_attention.py create mode 100644 examples/text_generation/causalLM_paged_attention_example.py diff --git a/QEfficient/blocking/attention_blocking.py b/QEfficient/blocking/attention_blocking.py index 6a6de70d5..63a32e8f8 100644 --- a/QEfficient/blocking/attention_blocking.py +++ b/QEfficient/blocking/attention_blocking.py @@ -22,8 +22,8 @@ blocked_hqkv_attention_forward, blocked_hqkv_paged_attention_forward, blocked_kv_attention_forward, - blocked_kv_paged_attention_forward, blocked_kv_mla_attention_forward, + blocked_kv_paged_attention_forward, blocked_q_attention_forward, blocked_qkv_attention_forward, blocked_qkv_paged_attention_forward, @@ -150,7 +150,9 @@ def generic_blocked_attention_interface( ) use_paged_kv_blocked = ( - blocking_config is not None and "paged" in blocking_config.mode and supports_paged_attention_blocked_kv(past_key_value) + blocking_config is not None + and "paged" in blocking_config.mode + and supports_paged_attention_blocked_kv(past_key_value) ) if past_key_value is not None: diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index a8f147e54..2402eacd1 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -255,7 +255,7 @@ def blocked_kv_paged_attention_forward( current_position = position_ids.max(dim=-1).values # needed for GPT-OSS if sinks is not None: - sinks = sinks.reshape(1, -1, 1, 1).expand(batch_size, -1, seq_len, -1) + sinks = sinks.reshape(1, -1, 1, 1).expand(batch_size, -1, seq_len, -1) for j in range(num_kv_blocks): start_index = j * kv_block_size @@ -857,7 +857,9 @@ def blocked_hqkv_paged_attention_forward( block_index = block_table[:, j] updated = (position_ids.max(1, keepdim=True).values // kv_block_size) == j - k_block, v_block = past_key_value.read_only_pagedAttention(block_index, updated, layer_idx, cache_kwargs) + k_block, v_block = past_key_value.read_only_pagedAttention( + block_index, updated, layer_idx, cache_kwargs + ) k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) k_g = k_block_states[:, h_start:h_end, :, :] @@ -1140,7 +1142,6 @@ def blocked_bhqkv_paged_attention_forward( kv_block_size = past_key_value.get_seq_length() if past_key_value is not None else 0 past_seen_tokens = kv_block_size * num_kv_blocks - h_output_blocks = [] h_attn_blocks = [] @@ -1220,7 +1221,9 @@ def blocked_bhqkv_paged_attention_forward( block_index = block_table[:, j] updated = (position_ids.max(1, keepdim=True).values // kv_block_size) == j - k_block, v_block = past_key_value.read_only_pagedAttention(block_index, updated, layer_idx, cache_kwargs) + k_block, v_block = past_key_value.read_only_pagedAttention( + block_index, updated, layer_idx, cache_kwargs + ) k_block_states, v_block_states = _get_kv_states(module, k_block, v_block) k_g = k_block_states[batch_start : batch_start + batch_len, h_start:h_end, :, :] diff --git a/QEfficient/blocking/blocking_configurator.py b/QEfficient/blocking/blocking_configurator.py index bd36b30cc..c5ed9209a 100644 --- a/QEfficient/blocking/blocking_configurator.py +++ b/QEfficient/blocking/blocking_configurator.py @@ -355,10 +355,20 @@ def build_transformer_blocking_config_for_transform( else: blocking_config = AttentionBlockingConfig() mode_from_config = "" - if qaic_config.get("num_kv_blocks", False) and enable_blocking and "kv" in blocking_mode and "paged" not in blocking_mode: + if ( + qaic_config.get("num_kv_blocks", False) + and enable_blocking + and "kv" in blocking_mode + and "paged" not in blocking_mode + ): mode_from_config = "kv" + mode_from_config blocking_config.num_kv_blocks = _get_valid_num_blocks(qaic_config, "num_kv_blocks") - if qaic_config.get("num_kv_blocks", False) and enable_blocking and "kv" in blocking_mode and "paged" in blocking_mode: + if ( + qaic_config.get("num_kv_blocks", False) + and enable_blocking + and "kv" in blocking_mode + and "paged" in blocking_mode + ): mode_from_config = "kv_paged" + mode_from_config blocking_config.num_kv_blocks = _get_valid_num_blocks(qaic_config, "num_kv_blocks") if qaic_config.get("num_q_blocks", False) and enable_blocking and "q" in blocking_mode: diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index ebb69edd2..a7c61b784 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -486,7 +486,7 @@ def __init__( self.full_batch_size = ( full_batch_size if full_batch_size else self._fetch_full_batch_size() ) # Check and fetch full batch size if CB is enabled - self.num_kv_blocks = num_kv_blocks if num_kv_blocks else 1 + self.num_kv_blocks = num_kv_blocks if num_kv_blocks else None self.kv_block_size = -(-self._ctx_len // self.num_kv_blocks) if num_kv_blocks else 1 # Initialize the storage variables. @@ -759,7 +759,9 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len): # run prefill for num_chunks outputs, position_ids, generation_len = self.run_prefill( - next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), + next_prompt, + generation_len, + decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), block_table=block_table, ) @@ -1164,9 +1166,13 @@ def _setup_model_execution_inputs( self._full_batch_size if self._full_batch_size is not None else self._qaic_model.batch_size ) max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) - self._qaic_model.block_table = np.arange( - execution_batch_size * self._num_kv_blocks, dtype=np.int64 - ).reshape(execution_batch_size, self._num_kv_blocks) if self._num_kv_blocks else None + self._qaic_model.block_table = ( + np.arange(execution_batch_size * self._num_kv_blocks, dtype=np.int64).reshape( + execution_batch_size, self._num_kv_blocks + ) + if self._num_kv_blocks + else None + ) # Create a prompt queue. self._prompt_queue = deque(prompt) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index f09a6df48..c92b4cb23 100755 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -260,7 +260,7 @@ def read_only_pagedAttention(self, block_index, updated, cache_kwargs): Block index of the K/V block to read from the block table updated: - Was the current block updated during the current write? If yes, then read slot_id + seq_len worth of entires from current block + Was the current block updated during the current write? If yes, then read position_ids.max % block_size worth of entires from current block Return: A tuple containing the updated key and value states. @@ -271,13 +271,14 @@ def read_only_pagedAttention(self, block_index, updated, cache_kwargs): self._mark_initialized(k_out) position_ids = cache_kwargs.get("position_ids") batch, seq_len = position_ids.shape - slot_id = cache_kwargs.get("slot_id", None) num_kv_blocks, num_kv_heads, block_size, dh = k_out.shape ctx_indices = torch.arange(block_size)[None, ...] - gather_limit = torch.where(updated, slot_id.unsqueeze(-1) + seq_len, block_size) + block_fill_len = position_ids.max(1, keepdim=True).values % block_size + gather_limit = torch.where(updated, block_fill_len, block_size) + block_indices = block_index.unsqueeze(-1) - invalid_mask = torch.ones_like(position_ids, dtype=torch.bool) - invalid_mask = torch.where(block_indices < 0, invalid_mask, ctx_indices >= gather_limit) + gather_limit = torch.where(block_indices < 0, 0, gather_limit) + invalid_mask = ctx_indices > gather_limit if torch.onnx.is_in_onnx_export(): invalid_idx_value = torch.iinfo(torch.int32).max @@ -312,7 +313,7 @@ def write_only_pagedAttention(self, key_states, value_states, cache_kwargs): else: self._mark_initialized(self.keys) position_ids = cache_kwargs.get("position_ids") - block_table = cache_kwargs.get("block_table") # [BS, num_kv_blocks/BS] -> each entry is block_id value + block_table = cache_kwargs.get("block_table") # [BS, num_kv_blocks] -> each entry is block_id value slot_id = cache_kwargs.get("slot_id") batch, num_kv_heads, seq_len, dh = key_states.shape num_kv_blocks, num_kv_heads, block_size, dh = self.keys.shape diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 6ea84221a..079fc416b 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -108,6 +108,8 @@ def forward( self, hidden_states: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index c9da1bd17..0c8c2c545 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -96,6 +96,8 @@ def forward( self, hidden_states: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + slot_id: Optional[torch.LongTensor] = None, position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 9b8a06333..7ceb87efb 100755 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2009,7 +2009,7 @@ def kv_offload_generate( lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] if num_kv_blocks: - kv_block_size = -(-ctx_len // num_kv_blocks) + kv_block_size = -(-ctx_len // num_kv_blocks) lang_inputs["block_table"] = np.arange(batch_size * num_kv_blocks, dtype=np.int64).reshape( batch_size, num_kv_blocks ) @@ -3297,8 +3297,8 @@ def export( max_blocks = max(max_blocks, num_blocks) block_size = -(-seq_len // max_blocks) seq_len = block_size * max_blocks - num_kv_blocks = self.hash_params["blocking_kwargs"].num_kv_blocks - self.supports_paged_attention = "paged" in self.hash_params["blocking_kwargs"].mode + num_kv_blocks = self.hash_params["blocking_kwargs"].num_kv_blocks + self.supports_paged_attention = "paged" in self.hash_params["blocking_kwargs"].mode seq_len = kv_block_size = -(-seq_len // num_kv_blocks) if self.supports_paged_attention else seq_len fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS @@ -3388,7 +3388,9 @@ def export( batch, num_kv_heads, CL, dh = kv_cache_shape total_num_kv_blocks = batch * num_kv_blocks kv_cache_shape = [total_num_kv_blocks, num_kv_heads, kv_block_size, dh] - example_inputs["block_table"] = torch.arange((bs * num_kv_blocks), dtype=torch.int64).view(bs, num_kv_blocks) + example_inputs["block_table"] = torch.arange((bs * num_kv_blocks), dtype=torch.int64).view( + bs, num_kv_blocks + ) example_inputs["slot_id"] = torch.zeros(bs, dtype=torch.int64) dynamic_axes["block_table"] = {0: "batch_size", 1: "num_kv_blocks"} dynamic_axes["slot_id"] = {0: "batch_size"} @@ -3624,7 +3626,7 @@ def build_prefill_specialization( # TODO: remove this; not required if full_batch_size: spec["full_batch_exec_size"] = exec_batch_size - if self.model.qaic_config is not None and "paged" in self.model.qaic_config.get("blocking_mode", None): + if self.model.qaic_config is not None and "paged" in self.model.qaic_config.get("blocking_mode", ""): num_kv_blocks = self.model.qaic_config["num_kv_blocks"] spec["num_kv_blocks"] = num_kv_blocks spec["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks @@ -3692,7 +3694,7 @@ def build_decode_specialization( spec["full_batch_size"] = kv_cache_batch_size else: spec["batch_size"] = kv_cache_batch_size - if self.model.qaic_config is not None and "paged" in self.model.qaic_config.get("blocking_mode", None): + if self.model.qaic_config is not None and "paged" in self.model.qaic_config.get("blocking_mode", ""): num_kv_blocks = self.model.qaic_config["num_kv_blocks"] spec["num_kv_blocks"] = num_kv_blocks spec["total_num_kv_blocks"] = kv_cache_batch_size * num_kv_blocks diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 6c1aade90..b4c8975a9 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1151,9 +1151,9 @@ def get_onnx_dynamic_axes( } pkv_dynamic_axes = { - 0: "full_batch_size" if continuous_batching else "batch_size", - 2: "ctx_len", - } + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } qaic_config = getattr(self.model, "qaic_config", None) blocking_mode = qaic_config.get("blocking_mode") if qaic_config is not None else None diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 829b176fc..755425dd1 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -1194,9 +1194,9 @@ def get_onnx_dynamic_axes( } pkv_dynamic_axes = { - 0: "full_batch_size" if continuous_batching else "batch_size", - 2: "ctx_len", - } + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } qaic_config = getattr(self.model, "qaic_config", None) blocking_mode = qaic_config.get("blocking_mode") if qaic_config is not None else None diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index 4a89597fe..b94761340 100755 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -20,7 +20,7 @@ class InputHandler: def __init__( - self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size, dtype=torch.float32 + self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size, dtype=torch.float32, qaic_config=None ): """ Initialization @@ -43,6 +43,7 @@ def __init__( self.full_batch_size = full_batch_size self.config = config self.dtype = dtype + self.qaic_config = qaic_config self.n_layer = get_num_layers_from_config(config) self.padding_shape = get_padding_shape_from_config( config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len @@ -90,6 +91,23 @@ def _get_layer_cache_shape(self, layer_idx): batch = self.full_batch_size if self.full_batch_size else self.padding_shape[0] return [batch, n_heads, ctx_len, d_head] + def _is_paged_attention(self) -> bool: + if self.qaic_config is None: + return False + blocking_mode = self.qaic_config.get("blocking_mode") + return blocking_mode is not None and "paged" in blocking_mode + + def _get_num_kv_blocks(self): + if not self._is_paged_attention(): + return None + return self.qaic_config.get("num_kv_blocks") + + def _get_kv_block_size(self): + num_kv_blocks = self._get_num_kv_blocks() + if num_kv_blocks is None: + return None + return -(-self.ctx_len // num_kv_blocks) + def prepare_pytorch_inputs(self): """ Function responsible for creating Prefill stage tensor inputs for PyTorch model. @@ -131,8 +149,21 @@ def prepare_pytorch_inputs(self): inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1) past_key_values = [] + num_kv_blocks = self._get_num_kv_blocks() + kv_block_size = self._get_kv_block_size() + + if self._is_paged_attention and num_kv_blocks and kv_block_size: + inputs["block_table"] = torch.arange(batch_size * num_kv_blocks, dtype=torch.int64).reshape( + batch_size, num_kv_blocks + ) + inputs["slot_id"] = torch.zeros(batch_size, dtype=torch.int64) + pad_shape_paged = [batch_size * num_kv_blocks, self.padding_shape[1], kv_block_size, self.padding_shape[3]] + for i in range(self.n_layer): - pad_shape = self._get_layer_cache_shape(i) + if self._is_paged_attention and num_kv_blocks and kv_block_size: + pad_shape = pad_shape_paged + else: + pad_shape = self._get_layer_cache_shape(i) past_key = torch.zeros((pad_shape), dtype=self.dtype) past_value = torch.zeros((pad_shape), dtype=self.dtype) pkv = (past_key, past_value) @@ -178,6 +209,16 @@ def update_pytorch_inputs(self, inputs, pt_outputs): else: updated_inputs["past_key_values"] = pkv + num_kv_blocks = self._get_num_kv_blocks() + kv_block_size = self._get_kv_block_size() + if self._is_paged_attention and num_kv_blocks and kv_block_size and "slot_id" in inputs: + updated_inputs["slot_id"] = ((updated_inputs["position_ids"]) % kv_block_size).view( + updated_inputs["position_ids"].shape[0] + ) + updated_inputs["block_table"] = torch.arange( + updated_inputs["input_ids"].shape[0] * inputs["block_table"].shape[1], dtype=torch.int64 + ).reshape(updated_inputs["input_ids"].shape[0], inputs["block_table"].shape[1]) + return updated_inputs def prepare_ort_inputs(self): @@ -207,12 +248,26 @@ def prepare_ort_inputs(self): axis=1, ).astype(np.int64) + num_kv_blocks = self._get_num_kv_blocks() + kv_block_size = self._get_kv_block_size() + + if self._is_paged_attention and num_kv_blocks and kv_block_size: + inputs["block_table"] = np.arange(batch_size * num_kv_blocks, dtype=np.int64).reshape( + batch_size, num_kv_blocks + ) + inputs["slot_id"] = np.zeros(batch_size, dtype=np.int64) + pad_shape_paged = [batch_size * num_kv_blocks, self.padding_shape[1], kv_block_size, self.padding_shape[3]] + for i in range(self.n_layer): - pad_shape = self._get_layer_cache_shape(i) + if self._is_paged_attention and num_kv_blocks and kv_block_size: + pad_shape = pad_shape_paged + else: + pad_shape = self._get_layer_cache_shape(i) inputs["past_key." + str(i)] = np.zeros((pad_shape), dtype=np.float32) inputs["past_value." + str(i)] = np.zeros((pad_shape), dtype=np.float32) if self.full_batch_size: inputs["batch_index"] = np.arange(self.full_batch_size).reshape(-1, 1) + return inputs def update_ort_inputs(self, inputs, ort_outputs): @@ -235,6 +290,17 @@ def update_ort_inputs(self, inputs, ort_outputs): updated_inputs["past_value." + str(i)] = ort_outputs["past_key_values"][i * 2 + 1] if self.full_batch_size: updated_inputs["batch_index"] = inputs["batch_index"] + + num_kv_blocks = self._get_num_kv_blocks() + kv_block_size = self._get_kv_block_size() + if self._is_paged_attention and num_kv_blocks and kv_block_size and "slot_id" in inputs: + updated_inputs["slot_id"] = ((updated_inputs["position_ids"]) % kv_block_size).reshape( + updated_inputs["position_ids"].shape[0] + ) + updated_inputs["block_table"] = np.arange( + updated_inputs["input_ids"].shape[0] * inputs["block_table"].shape[1], dtype=np.int64 + ).reshape(updated_inputs["input_ids"].shape[0], inputs["block_table"].shape[1]) + return updated_inputs def update_ort_outputs(self, ort_outputs): @@ -440,6 +506,7 @@ def update_vlm_ort_inputs(self, inputs, ort_outputs): updated_inputs["mm_token_type_ids"] = np.zeros_like( updated_inputs["input_ids"], dtype=inputs["mm_token_type_ids"].dtype ) + if "cross_attention_mask" in inputs.keys(): bs, _, num_images, img_tiles = inputs["cross_attention_mask"].shape updated_inputs["cross_attention_mask"] = torch.ones( diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index ec6b085d0..2fc2eecc1 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -17,6 +17,7 @@ from QEfficient.generation.text_generation_inference import TextGeneration from QEfficient.transformers.cache_utils import QEffDynamicCache + from QEfficient.utils.generate_inputs import InputHandler, InputHandlerInternVL, InputHandlerVLM @@ -33,7 +34,7 @@ class ApiRunner: """ def __init__( - self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size=None, dtype=torch.float32 + self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size=None, dtype=torch.float32, qaic_config=None) ): """ Initialization @@ -55,6 +56,7 @@ def __init__( ctx_len=ctx_len, full_batch_size=full_batch_size, dtype=dtype, + qaic_config=qaic_config, ) self.gen_len = self.input_handler.ctx_len - self.input_handler.prompt_len diff --git a/examples/image_text_to_text/models/qwen3vl/qwen3_vl_blocked_paged_attention.py b/examples/image_text_to_text/models/qwen3vl/qwen3_vl_blocked_paged_attention.py new file mode 100644 index 000000000..d96abade5 --- /dev/null +++ b/examples/image_text_to_text/models/qwen3vl/qwen3_vl_blocked_paged_attention.py @@ -0,0 +1,158 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +import transformers +from PIL import Image +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "Qwen/Qwen3-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) + +# config.vision_config.depth = 9 +# config.text_config.num_hidden_layers = 1 +# config.vision_config.deepstack_visual_indexes = [8] + +HEAD_BLOCK_SIZE = 8 +NUM_KV_BLOCKS = 2 +NUM_Q_BLOCKS = 2 +NUM_BATCH_BLOCKS = 2 +# head qkv blocking with PagedAttention +qaic_config = dict( + enable_blocking=True, + blocking_mode="kv_paged", + num_kv_blocks=NUM_KV_BLOCKS, +) + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + qaic_config=qaic_config, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) +### use skip_vision=Ture, if want to run only text, else false ### +skip_vision = True + +if skip_vision: + ## Only Text ## + + ## Set Batch_Size ## + batch_size = 1 + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=32768, + num_cores=16, + num_devices=8, + height=354, + width=536, + mxfp6_matmul=True, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + use_onnx_subfunctions=False, + qaic_config=qaic_config, + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] + + messages = [messages] * batch_size + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(processor.tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + batch_size = 1 + ## Vision + Text ## + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=32768, + num_cores=16, + num_devices=8, + height=354, + width=536, + # height=1024, + # width=1024, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + use_onnx_subfunctions=False, + qaic_config=qaic_config, + ) + + ### IMAGE + TEXT ### + image_url = "https://picsum.photos/id/237/536/354" + # image_url = ( + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" + # ) + + image = Image.open(requests.get(image_url, stream=True).raw) + + messages_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Descibe the image in details."}, + ], + }, + ] + + # messages_2 = [ + # { + # "role": "user", + # "content": [ + # {"type": "image", "image": image}, + # {"type": "text", "text": "Describe about the color of the dog."}, + # ], + # }, + # ] + + messages = [messages_1] * batch_size + + texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] + + image_inputs, video_inputs = process_vision_info(messages) + inputs = processor( + text=texts, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=30000) + print(output.generated_ids) + print(processor.tokenizer.batch_decode(output.generated_ids)) + print(output) diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index ff62588f9..503efc12d 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,12 +27,6 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from torch import nn # Example imports for three representative models @@ -62,6 +56,12 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer, diff --git a/examples/text_generation/causalLM_paged_attention_example.py b/examples/text_generation/causalLM_paged_attention_example.py new file mode 100644 index 000000000..fc1f4bfe6 --- /dev/null +++ b/examples/text_generation/causalLM_paged_attention_example.py @@ -0,0 +1,108 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + + +def main(): + parser = argparse.ArgumentParser(description="Basic text generation inference with PagedAttention") + parser.add_argument("--model-name", type=str, default="meta-llama/Llama-3.2-1B", help="HuggingFace model ID") + parser.add_argument("--prompt", type=str, default="Hello", help="Input prompt") + parser.add_argument("--prefill-seq-len", type=int, default=4, help="Prefill sequence length") + parser.add_argument("--ctx-len", type=int, default=32, help="Context length") + parser.add_argument("--generation-len", type=int, default=25, help="Number of tokens to generate") + parser.add_argument("--num-cores", type=int, default=16, help="Number of cores") + parser.add_argument( + "--device-group", + type=lambda device_ids: [int(x) for x in device_ids.strip("[]").split(",")], + default=[0, 1, 2, 3], + help="Device IDs (comma-separated) e.g. [0,1]", + ) + parser.add_argument( + "--blocking-mode", + type=str, + default="kv_paged", + help="PagedAttention blocking mode, valid options: kv_paged, qkv_paged, hqkv_paged", + ) + parser.add_argument( + "--num-kv-blocks", + type=int, + default="8", + help="Number of KV blocks required for 1 batch element in PagedAttention", + ) + parser.add_argument( + "--compare-non-blocking", + action="store_true", + help="Compile and print results for non-blocked version of model as well", + ) + args = parser.parse_args() + + # Load tokenizer and model + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + + if args.compare_non_blocking: + model = QEFFAutoModelForCausalLM.from_pretrained(args.model_name, num_hidden_layers=2) + + # Compile the model + qpc_path = model.compile( + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + num_cores=args.num_cores, + num_devices=4, + ) + print(f"Model compiled to: {qpc_path}") + + # Generate text + exec_info = model.generate( + tokenizer=tokenizer, + prompts=[args.prompt], + generation_len=args.generation_len, + ) + + print(f"\nPrompt: {args.prompt}") + print(f"\nGenerated non-blocked: {exec_info.generated_texts[0]}") + + # setup qaic config to enable PagedAttention blocking + qaic_config = {"enable_blocking": True, "blocking_mode": args.blocking_mode, "num_kv_blocks": args.num_kv_blocks} + model_blocked = QEFFAutoModelForCausalLM.from_pretrained( + args.model_name, num_hidden_layers=2, qaic_config=qaic_config + ) + + # Compile the model + qpc_path_blocked = model_blocked.compile( + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + num_cores=args.num_cores, + num_devices=4, + qaic_config=qaic_config, + ) + print(f"Model compiled to: {qpc_path_blocked}") + + # Generate text + exec_info_blocked = model_blocked.generate( + tokenizer=tokenizer, + prompts=[args.prompt], + generation_len=args.generation_len, + ) + + print(f"\nPrompt: {args.prompt}") + print(f"\nGenerated blocked PagedAttention: {exec_info_blocked.generated_texts[0]}") + + if args.compare_non_blocking: + print("\nPerformance non-blocked:") + print(exec_info) + + print("\nPerformance blocked PagedAttention:") + print(exec_info_blocked) + + +if __name__ == "__main__": + main() diff --git a/tests/transformers/models/causal_lm_models/check_causal_models.py b/tests/transformers/models/causal_lm_models/check_causal_models.py index f878acbe7..360f13a4c 100644 --- a/tests/transformers/models/causal_lm_models/check_causal_models.py +++ b/tests/transformers/models/causal_lm_models/check_causal_models.py @@ -79,6 +79,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( Constants.PROMPT_LEN, Constants.CTX_LEN, full_batch_size if continuous_batching else None, + qaic_config=qaic_config, ) qeff_model = QEFFAutoModelForCausalLM( copy.deepcopy(model_hf), diff --git a/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py b/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py index 0568939cd..1c8436d2a 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py +++ b/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py @@ -47,6 +47,12 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manu model_name=model_name, qaic_config=qaic_config, manual_cleanup=manual_cleanup, num_devices=4 ) + # kv_paged_attention blocking only + qaic_config = dict(enable_blocking=True, blocking_mode="kv_paged", num_kv_blocks=NUM_KV_BLOCKS) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, qaic_config=qaic_config, manual_cleanup=manual_cleanup, num_devices=4 + ) + # q block only qaic_config = dict(enable_blocking=True, num_q_blocks=NUM_Q_BLOCKS) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @@ -59,6 +65,14 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manu model_name=model_name, qaic_config=qaic_config, manual_cleanup=manual_cleanup, num_devices=4 ) + # qkv_paged_attention blocking + qaic_config = dict( + enable_blocking=True, blocking_mode="qkv_paged", num_kv_blocks=NUM_KV_BLOCKS, num_q_blocks=NUM_Q_BLOCKS + ) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, qaic_config=qaic_config, manual_cleanup=manual_cleanup, num_devices=4 + ) + # head qkv blocking qaic_config = dict( enable_blocking=True, @@ -70,6 +84,18 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manu model_name=model_name, qaic_config=qaic_config, manual_cleanup=manual_cleanup, num_devices=4 ) + # head qkv_paged_attention blocking + qaic_config = dict( + enable_blocking=True, + blocking_mode="hqkv_paged", + head_block_size=HEAD_BLOCK_SIZE, + num_kv_blocks=NUM_KV_BLOCKS, + num_q_blocks=NUM_Q_BLOCKS, + ) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, qaic_config=qaic_config, manual_cleanup=manual_cleanup, num_devices=4 + ) + @pytest.mark.few_layers @pytest.mark.llm_model @@ -92,6 +118,12 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manua model_name=model_name, qaic_config=qaic_config, n_layer=n_layer, manual_cleanup=manual_cleanup ) + # kv_paged_attention blocking only + qaic_config = dict(enable_blocking=True, blocking_mode="kv_paged", num_kv_blocks=NUM_KV_BLOCKS) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, qaic_config=qaic_config, n_layer=n_layer, manual_cleanup=manual_cleanup + ) + # q block only qaic_config = dict(enable_blocking=True, num_q_blocks=NUM_Q_BLOCKS) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @@ -104,6 +136,14 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manua model_name=model_name, qaic_config=qaic_config, n_layer=n_layer, manual_cleanup=manual_cleanup ) + # qkv_paged_attention blocking + qaic_config = dict( + enable_blocking=True, blocking_mode="qkv_paged", num_kv_blocks=NUM_KV_BLOCKS, num_q_blocks=NUM_Q_BLOCKS + ) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, qaic_config=qaic_config, n_layer=n_layer, manual_cleanup=manual_cleanup + ) + # head qkv blocking qaic_config = dict( enable_blocking=True, @@ -115,6 +155,18 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manua model_name=model_name, qaic_config=qaic_config, n_layer=n_layer, manual_cleanup=manual_cleanup ) + # head qkv_paged_attention blocking + qaic_config = dict( + enable_blocking=True, + blocking_mode="hqkv_paged", + head_block_size=HEAD_BLOCK_SIZE, + num_kv_blocks=NUM_KV_BLOCKS, + num_q_blocks=NUM_Q_BLOCKS, + ) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, qaic_config=qaic_config, n_layer=n_layer, manual_cleanup=manual_cleanup + ) + @pytest.mark.dummy_layers @pytest.mark.llm_model @@ -146,6 +198,12 @@ def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, man model_name=model_name, qaic_config=qaic_config, n_layer=n_layer, config=hf_config, manual_cleanup=manual_cleanup ) + # kv_paged_attention blocking only + qaic_config = dict(enable_blocking=True, blocking_mode="kv_paged", num_kv_blocks=NUM_KV_BLOCKS) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, qaic_config=qaic_config, n_layer=n_layer, config=hf_config, manual_cleanup=manual_cleanup + ) + # q block only qaic_config = dict(enable_blocking=True, num_q_blocks=NUM_Q_BLOCKS) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @@ -158,6 +216,14 @@ def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, man model_name=model_name, qaic_config=qaic_config, n_layer=n_layer, config=hf_config, manual_cleanup=manual_cleanup ) + # qkv_paged_attention blocking + qaic_config = dict( + enable_blocking=True, blocking_mode="qkv_paged", num_kv_blocks=NUM_KV_BLOCKS, num_q_blocks=NUM_Q_BLOCKS + ) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, qaic_config=qaic_config, n_layer=n_layer, config=hf_config, manual_cleanup=manual_cleanup + ) + # head qkv blocking qaic_config = dict( enable_blocking=True, @@ -169,6 +235,18 @@ def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, man model_name=model_name, qaic_config=qaic_config, n_layer=n_layer, config=hf_config, manual_cleanup=manual_cleanup ) + # head qkv_paged_attention blocking + qaic_config = dict( + enable_blocking=True, + blocking_mode="hqkv_paged", + head_block_size=HEAD_BLOCK_SIZE, + num_kv_blocks=NUM_KV_BLOCKS, + num_q_blocks=NUM_Q_BLOCKS, + ) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, qaic_config=qaic_config, n_layer=n_layer, config=hf_config, manual_cleanup=manual_cleanup + ) + @pytest.mark.full_layers @pytest.mark.llm_model @@ -199,6 +277,16 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, m num_devices=4, ) + # kv_paged_attention blocking only + qaic_config = dict(enable_blocking=True, blocking_mode="kv_paged", num_kv_blocks=NUM_KV_BLOCKS) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + qaic_config=qaic_config, + manual_cleanup=manual_cleanup, + continuous_batching=True, + num_devices=4, + ) + # q block only qaic_config = dict(enable_blocking=True, num_q_blocks=NUM_Q_BLOCKS) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @@ -219,6 +307,18 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, m num_devices=4, ) + # qkv_paged_attention blocking + qaic_config = dict( + enable_blocking=True, blocking_mode="qkv_paged", num_kv_blocks=NUM_KV_BLOCKS, num_q_blocks=NUM_Q_BLOCKS + ) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + qaic_config=qaic_config, + manual_cleanup=manual_cleanup, + continuous_batching=True, + num_devices=4, + ) + # head qkv blocking qaic_config = dict( enable_blocking=True, @@ -234,6 +334,22 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, m num_devices=4, ) + # head qkv_paged_attention blocking + qaic_config = dict( + enable_blocking=True, + blocking_mode="hqkv_paged", + head_block_size=HEAD_BLOCK_SIZE, + num_kv_blocks=NUM_KV_BLOCKS, + num_q_blocks=NUM_Q_BLOCKS, + ) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + qaic_config=qaic_config, + manual_cleanup=manual_cleanup, + continuous_batching=True, + num_devices=4, + ) + @pytest.mark.few_layers @pytest.mark.llm_model @@ -264,6 +380,16 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, ma continuous_batching=True, ) + # kv_paged_attention blocking only + qaic_config = dict(enable_blocking=True, blocking_mode="kv_paged", num_kv_blocks=NUM_KV_BLOCKS) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + qaic_config=qaic_config, + n_layer=n_layer, + manual_cleanup=manual_cleanup, + continuous_batching=True, + ) + # q block only qaic_config = dict(enable_blocking=True, num_q_blocks=NUM_Q_BLOCKS) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @@ -284,6 +410,18 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, ma continuous_batching=True, ) + # qkv_paged_attention blocking + qaic_config = dict( + enable_blocking=True, blocking_mode="qkv_paged", num_kv_blocks=NUM_KV_BLOCKS, num_q_blocks=NUM_Q_BLOCKS + ) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + qaic_config=qaic_config, + n_layer=n_layer, + manual_cleanup=manual_cleanup, + continuous_batching=True, + ) + # head qkv blocking qaic_config = dict( enable_blocking=True, @@ -299,6 +437,22 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, ma continuous_batching=True, ) + # head qkv_paged_attention blocking + qaic_config = dict( + enable_blocking=True, + blocking_mode="hqkv_paged", + head_block_size=HEAD_BLOCK_SIZE, + num_kv_blocks=NUM_KV_BLOCKS, + num_q_blocks=NUM_Q_BLOCKS, + ) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + qaic_config=qaic_config, + n_layer=n_layer, + manual_cleanup=manual_cleanup, + continuous_batching=True, + ) + @pytest.mark.dummy_layers @pytest.mark.llm_model @@ -340,6 +494,17 @@ def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, continuous_batching=True, ) + # kv_paged_attention blocking only + qaic_config = dict(enable_blocking=True, blocking_mode="kv_paged", num_kv_blocks=NUM_KV_BLOCKS) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + qaic_config=qaic_config, + n_layer=n_layer, + config=hf_config, + manual_cleanup=manual_cleanup, + continuous_batching=True, + ) + # q block only qaic_config = dict(enable_blocking=True, num_q_blocks=NUM_Q_BLOCKS) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @@ -362,6 +527,19 @@ def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, continuous_batching=True, ) + # qkv_paged_attention blocking + qaic_config = dict( + enable_blocking=True, blocking_mode="qkv_paged", num_kv_blocks=NUM_KV_BLOCKS, num_q_blocks=NUM_Q_BLOCKS + ) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + qaic_config=qaic_config, + n_layer=n_layer, + config=hf_config, + manual_cleanup=manual_cleanup, + continuous_batching=True, + ) + # head qkv blocking qaic_config = dict( enable_blocking=True, @@ -377,3 +555,20 @@ def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup=manual_cleanup, continuous_batching=True, ) + + # head qkv_paged_attention blocking + qaic_config = dict( + enable_blocking=True, + blocking_mode="hqkv_paged", + head_block_size=HEAD_BLOCK_SIZE, + num_kv_blocks=NUM_KV_BLOCKS, + num_q_blocks=NUM_Q_BLOCKS, + ) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + qaic_config=qaic_config, + n_layer=n_layer, + config=hf_config, + manual_cleanup=manual_cleanup, + continuous_batching=True, + ) diff --git a/tests/unit_test/base/test_modeling_qeff_base.py b/tests/unit_test/base/test_modeling_qeff_base.py index 2b68a1606..644040e35 100644 --- a/tests/unit_test/base/test_modeling_qeff_base.py +++ b/tests/unit_test/base/test_modeling_qeff_base.py @@ -256,7 +256,9 @@ def test_hash_params_contains_pretrained_model_name(self): class TestQEFFBaseModelTransformBlocking: """Tests for QEFFBaseModel.transform() attention blocking behavior.""" - @pytest.mark.parametrize("blocking_mode", ["kv", "q", "qkv", "hq", "hkv", "hqkv"]) + @pytest.mark.parametrize( + "blocking_mode", ["kv", "kv_paged", "q", "qkv", "qkv_paged", "hq", "hkv", "hkv_paged", "hqkv", "hqkv_paged"] + ) def test_transform_enable_blocking_runs_auto_configurator(self, blocking_mode): # Use a slightly larger head count here to make it possible for "h" mode to result in head blocking # when num_devices > 1. diff --git a/tests/unit_test/models/test_model_quickcheck.py b/tests/unit_test/models/test_model_quickcheck.py index 4b7ed6f17..b20dfa712 100644 --- a/tests/unit_test/models/test_model_quickcheck.py +++ b/tests/unit_test/models/test_model_quickcheck.py @@ -1162,7 +1162,7 @@ def test_new_named_format(self, tmp_path): ] }, ) - bs, ctx, fbs = get_compilation_dims(qpc_path) + bs, ctx, fbs, num_kv_blocks = get_compilation_dims(qpc_path) assert bs == 1 assert ctx == 4096 assert fbs is None @@ -1185,7 +1185,7 @@ def test_new_named_format_with_full_batch_size(self, tmp_path): ] }, ) - bs, ctx, fbs = get_compilation_dims(qpc_path) + bs, ctx, fbs, num_kv_blocks = get_compilation_dims(qpc_path) assert bs == 1 assert ctx == 4096 assert fbs == 16 @@ -1202,7 +1202,7 @@ def test_legacy_flat_format_still_works(self, tmp_path): ] }, ) - bs, ctx, fbs = get_compilation_dims(qpc_path) + bs, ctx, fbs, num_kv_blocks = get_compilation_dims(qpc_path) assert bs == 1 assert ctx == 4096 assert fbs is None diff --git a/tests/unit_test/utils/test_generation.py b/tests/unit_test/utils/test_generation.py index b85c3c4b8..cd9f95e08 100644 --- a/tests/unit_test/utils/test_generation.py +++ b/tests/unit_test/utils/test_generation.py @@ -441,16 +441,24 @@ def _write_spec(self, tmp_path, spec): def test_basic(self, tmp_path): path = self._write_spec(tmp_path, {"specializations": [{"batch_size": "4", "ctx_len": "128"}]}) - bs, cl, fbs = get_compilation_dims(path) + bs, cl, fbs, num_kv_blocks = get_compilation_dims(path) assert bs == 4 and cl == 128 and fbs is None def test_with_full_batch_size(self, tmp_path): path = self._write_spec( tmp_path, {"specializations": [{"batch_size": "4", "ctx_len": "128", "full_batch_size": "16"}]} ) - bs, cl, fbs = get_compilation_dims(path) + bs, cl, fbs, num_kv_blocks = get_compilation_dims(path) assert fbs == 16 + def test_with_pagedAttention(self, tmp_path): + path = self._write_spec( + tmp_path, + {"specializations": [{"batch_size": "4", "ctx_len": "128", "full_batch_size": "16", "num_kv_blocks": "8"}]}, + ) + bs, cl, fbs, num_kv_blocks = get_compilation_dims(path) + assert num_kv_blocks == 8 + def test_missing_file_raises(self, tmp_path): qpc_dir = tmp_path / "qpc" qpc_dir.mkdir() @@ -459,7 +467,7 @@ def test_missing_file_raises(self, tmp_path): def test_returns_ints(self, tmp_path): path = self._write_spec(tmp_path, {"specializations": [{"batch_size": "2", "ctx_len": "64"}]}) - bs, cl, fbs = get_compilation_dims(path) + bs, cl, fbs, num_kv_blocks = get_compilation_dims(path) assert isinstance(bs, int) and isinstance(cl, int)