Adding PagedAttention support for CausalLM models#982
Conversation
| v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) | ||
| return k_out, v_out | ||
|
|
||
| def read_only_pagedAttention(self, block_index, updated, cache_kwargs): |
There was a problem hiding this comment.
nit: can we rename this to read_only_paged_attention()?
There was a problem hiding this comment.
I was staying consistent with the actual name of the technique from the paper: https://arxiv.org/pdf/2309.06180
The attention mechanism is called PagedAttention rather than paged attention, hence I was keeping pagedAttention in our naming. I can change to paged_attention for all the methods if that would look better with snake case convention.
| v_out = torch.where((invalid_mask.unsqueeze(1)).unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) | ||
| return k_out, v_out | ||
|
|
||
| def write_only_pagedAttention(self, key_states, value_states, cache_kwargs): |
There was a problem hiding this comment.
write_only_paged_attention()?
| """ | ||
| return self.layers[layer_idx].read_only_blockedKV(start_index, end_index, cache_kwargs) | ||
|
|
||
| def read_only_pagedAttention(self, block_index, updated, layer_idx, cache_kwargs): |
|
|
||
| _STRATEGIES: Dict[BlockingMode, Callable] = { | ||
| BlockingMode.KV: blocked_kv_attention_forward, | ||
| BlockingMode.KV_PAGED: blocked_kv_paged_attention_forward, |
There was a problem hiding this comment.
nit: @vaibverm can we add unit tests to all methods mentioned here?
|
|
||
| _STRATEGIES: Dict[BlockingMode, Callable] = { | ||
| BlockingMode.KV: blocked_kv_attention_forward, | ||
| BlockingMode.KV_PAGED: blocked_kv_paged_attention_forward, |
There was a problem hiding this comment.
nit: also, can you add an example to enable this here, similar to this: https://github.com/quic/efficient-transformers/blob/main/examples/text_generation/blocked_attention_inference.py
There was a problem hiding this comment.
Done. Added examples for causalLM and Qwen3 models.
|
nit: lint and format are missing, pls check. |
Issue: Generated text is not accurate (unable to identify object in given image) 5/25: Found a workaround, seems like compiler issue - debugging further --------- Signed-off-by: vtirumal <vtirumal@qti.qualcomm.com> Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Reverts quic#1010 Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com> Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
changed the code from doing the exact same math repeatedly. Signed-off-by: Anuj Gupta <anujgupt@qti.qualcomm.com> Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
…ormat cleanup Signed-off-by: Vaibhav Verma <vaibverm@qti.qualcomm.com>
This PR adds the PagedAttention (https://arxiv.org/pdf/2309.06180) support for all CausalLM models in QEfficient.
The major change is that KV cache is not treated as a contiguous memory under this implementation but rather a collection of blocks which can reside in a non-contiguous fashion inside the memory. This forces cache scatter and gather operations to happen per KV block.
Summary of changes compared to BlockedKV:
4) a) block_id is each entry in the block_table and points to the physical K/V block that needs to be read/written corresponding to (position_id // kv_block_size)th entry in block_table. ‘-1’ signifies invalid/unallocated block.
4) b) slot_id tells how many entries are already filled in currently active block => read up to / write after (slot_id – 1)