Skip to content

Adding PagedAttention support for CausalLM models#982

Open
vaibverm wants to merge 14 commits into
quic:release/v1.22.0_tmpfrom
vaibverm:PR_branch
Open

Adding PagedAttention support for CausalLM models#982
vaibverm wants to merge 14 commits into
quic:release/v1.22.0_tmpfrom
vaibverm:PR_branch

Conversation

@vaibverm
Copy link
Copy Markdown
Contributor

This PR adds the PagedAttention (https://arxiv.org/pdf/2309.06180) support for all CausalLM models in QEfficient.
The major change is that KV cache is not treated as a contiguous memory under this implementation but rather a collection of blocks which can reside in a non-contiguous fashion inside the memory. This forces cache scatter and gather operations to happen per KV block.

Summary of changes compared to BlockedKV:

  1. The cache shape changes from [BS, num_kv_heads, CL, dh] to [total_num_kv_blocks, num_kv_heads, kv_block_size, dh].
  2. num_kv_blocks = -(-ctx_len // kv_block_size) = physical blocks required for 1 batch element in K cache.
  3. Total_num_kv_blocks = BS (kv_batch_size) * num_kv_blocks = total physical blocks available for K cache.
  4. 2 new inputs block_table [BS, num_kv_blocks] and slot_id [BS] are passed as inputs to the ONNX.
    4) a) block_id is each entry in the block_table and points to the physical K/V block that needs to be read/written corresponding to (position_id // kv_block_size)th entry in block_table. ‘-1’ signifies invalid/unallocated block.
    4) b) slot_id tells how many entries are already filled in currently active block => read up to / write after (slot_id – 1)
  5. Limitation - Cache writes to only 1 block at a time per batch element => CPL = kv_block_size. Hence, cache writes should not cross the block boundary.
  6. vLLM provides KV Cache Manager implementation which maintains the KV cache block_table with logical to physical block mapping and slot_id for location mapping within the active block.

@anujgupt-github anujgupt-github added the enhancement New feature or request label May 20, 2026
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out

def read_only_pagedAttention(self, block_index, updated, cache_kwargs):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we rename this to read_only_paged_attention()?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was staying consistent with the actual name of the technique from the paper: https://arxiv.org/pdf/2309.06180
The attention mechanism is called PagedAttention rather than paged attention, hence I was keeping pagedAttention in our naming. I can change to paged_attention for all the methods if that would look better with snake case convention.

v_out = torch.where((invalid_mask.unsqueeze(1)).unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out

def write_only_pagedAttention(self, key_states, value_states, cache_kwargs):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

write_only_paged_attention()?

"""
return self.layers[layer_idx].read_only_blockedKV(start_index, end_index, cache_kwargs)

def read_only_pagedAttention(self, block_index, updated, layer_idx, cache_kwargs):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: same as above


_STRATEGIES: Dict[BlockingMode, Callable] = {
BlockingMode.KV: blocked_kv_attention_forward,
BlockingMode.KV_PAGED: blocked_kv_paged_attention_forward,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: @vaibverm can we add unit tests to all methods mentioned here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


_STRATEGIES: Dict[BlockingMode, Callable] = {
BlockingMode.KV: blocked_kv_attention_forward,
BlockingMode.KV_PAGED: blocked_kv_paged_attention_forward,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Added examples for causalLM and Qwen3 models.

@vbaddi
Copy link
Copy Markdown
Contributor

vbaddi commented May 27, 2026

nit: lint and format are missing, pls check.

@vbaddi vbaddi changed the base branch from main to release/v1.22.0_tmp June 5, 2026 18:53
tv-karthikeya and others added 14 commits June 5, 2026 19:11
Issue: Generated text is not accurate (unable to identify object in
given image)
5/25: Found a workaround, seems like compiler issue - debugging further

---------

Signed-off-by: vtirumal <vtirumal@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Reverts quic#1010

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
changed the code from doing the exact same math repeatedly.

Signed-off-by: Anuj Gupta <anujgupt@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
…ormat cleanup

Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants