Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions QEfficient/blocking/attention_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,35 @@

from QEfficient.blocking.blocked_attention_forwards import (
blocked_bhqkv_attention_forward,
blocked_bhqkv_paged_attention_forward,
blocked_h_attention_forward,
blocked_h_mla_attention_forward,
blocked_hqkv_attention_forward,
blocked_hqkv_paged_attention_forward,
blocked_kv_attention_forward,
blocked_kv_mla_attention_forward,
blocked_kv_paged_attention_forward,
blocked_q_attention_forward,
blocked_qkv_attention_forward,
blocked_qkv_paged_attention_forward,
)


class BlockingMode(str, Enum):
NONE = ""
KV = "kv"
KV_PAGED = "kv_paged"
Q = "q"
H = "h"
QKV = "qkv"
QKV_PAGED = "qkv_paged"
HQ = "hq"
HKV = "hkv"
HKV_PAGED = "hkv_paged"
HQKV = "hqkv"
HQKV_PAGED = "hqkv_paged"
BHQKV = "bhqkv"
BHQKV_PAGED = "bhqkv_paged"


@dataclass
Expand All @@ -52,15 +61,24 @@ def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool:
return past_key_value is not None and hasattr(past_key_value, "read_only_blockedKV")


def supports_paged_attention_blocked_kv(past_key_value: Optional[Cache]) -> bool:
return past_key_value is not None and hasattr(past_key_value, "read_only_pagedAttention")


_STRATEGIES: Dict[BlockingMode, Callable] = {
BlockingMode.KV: blocked_kv_attention_forward,
BlockingMode.KV_PAGED: blocked_kv_paged_attention_forward,
Comment thread
vbaddi marked this conversation as resolved.
Comment thread
vbaddi marked this conversation as resolved.
BlockingMode.Q: blocked_q_attention_forward,
BlockingMode.H: blocked_h_attention_forward,
BlockingMode.QKV: blocked_qkv_attention_forward,
BlockingMode.QKV_PAGED: blocked_qkv_paged_attention_forward,
BlockingMode.HQ: blocked_hqkv_attention_forward,
BlockingMode.HKV: blocked_hqkv_attention_forward,
BlockingMode.HKV_PAGED: blocked_hqkv_paged_attention_forward,
BlockingMode.HQKV: blocked_hqkv_attention_forward,
BlockingMode.HQKV_PAGED: blocked_hqkv_paged_attention_forward,
BlockingMode.BHQKV: blocked_bhqkv_attention_forward,
BlockingMode.BHQKV_PAGED: blocked_bhqkv_paged_attention_forward,
}

_STRATEGIES_MLA: Dict[BlockingMode, Callable] = {
Expand All @@ -79,10 +97,17 @@ def past_key_value_update(
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
block_table: Optional[torch.LongTensor] = None,
slot_id: Optional[torch.LongTensor] = None,
sliding_window: Optional[int] = None,
):
if past_key_value is not None:
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
cache_kwargs = {
"batch_index": batch_index,
"position_ids": position_ids,
"block_table": block_table,
"slot_id": slot_id,
}
if sliding_window is not None:
cache_kwargs.update(
{
Expand Down Expand Up @@ -110,6 +135,8 @@ def generic_blocked_attention_interface(
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
block_table: Optional[torch.LongTensor] = None,
slot_id: Optional[torch.LongTensor] = None,
past_seen_tokens: Optional[int] = None,
non_blocked_forward: Callable = None,
score_mod: Optional[Callable] = None,
Expand All @@ -122,8 +149,29 @@ def generic_blocked_attention_interface(
blocking_config is not None and "kv" in blocking_config.mode and supports_blocked_kv(past_key_value)
)

use_paged_kv_blocked = (
blocking_config is not None
and "paged" in blocking_config.mode
and supports_paged_attention_blocked_kv(past_key_value)
)

if past_key_value is not None:
if use_kv_blocked and sliding_window is None:
if use_paged_kv_blocked and sliding_window is None:
cache_kwargs = {
"batch_index": batch_index,
"position_ids": position_ids,
"block_table": block_table,
"slot_id": slot_id,
}
if sliding_window is not None:
cache_kwargs.update(
{
"is_sliding": sliding_window is not None,
"sliding_window": past_key_value.sliding_window_len,
}
)
past_key_value.write_only_pagedAttention(key, value, module.layer_idx, cache_kwargs)
elif use_kv_blocked and sliding_window is None:
cache_kwargs = {
"batch_index": batch_index,
"position_ids": position_ids,
Expand All @@ -147,6 +195,8 @@ def generic_blocked_attention_interface(
comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
position_ids=position_ids,
block_table=block_table,
slot_id=slot_id,
sliding_window=sliding_window,
)

Expand Down
Loading
Loading