diff --git a/QEfficient/blocking/attention_blocking.py b/QEfficient/blocking/attention_blocking.py index a64c649cee..720ca9d81d 100644 --- a/QEfficient/blocking/attention_blocking.py +++ b/QEfficient/blocking/attention_blocking.py @@ -46,6 +46,11 @@ class AttentionBlockingConfig: head_block_size: Optional[int] = None skip_kv: Optional[bool] = True num_batch_blocks: Optional[int] = None + skip_softmax: bool = False + skip_softmax_scale: Optional[float] = None + skip_softmax_prefill_scale: float = 1.0 + skip_softmax_decode_scale: float = 1.0 + skip_softmax_min_keep_blocks: int = 1 def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool: @@ -164,10 +169,16 @@ def generic_blocked_attention_interface( num_kv_blocks=blocking_config.num_kv_blocks, num_q_blocks=blocking_config.num_q_blocks, head_block_size=blocking_config.head_block_size, + skip_kv=blocking_config.skip_kv, num_batch_blocks=blocking_config.num_batch_blocks, score_mod=score_mod, position_bias=position_bias, sinks=sinks, + skip_softmax=blocking_config.skip_softmax, + skip_softmax_scale=blocking_config.skip_softmax_scale, + skip_softmax_prefill_scale=blocking_config.skip_softmax_prefill_scale, + skip_softmax_decode_scale=blocking_config.skip_softmax_decode_scale, + skip_softmax_min_keep_blocks=blocking_config.skip_softmax_min_keep_blocks, ) return attn_output, attn_weights @@ -227,10 +238,16 @@ def generic_blocked_mla_attention_interface( num_kv_blocks=blocking_config.num_kv_blocks, num_q_blocks=blocking_config.num_q_blocks, head_block_size=blocking_config.head_block_size, + skip_kv=blocking_config.skip_kv, num_batch_blocks=blocking_config.num_batch_blocks, score_mod=score_mod, position_bias=position_bias, sinks=sinks, + skip_softmax=blocking_config.skip_softmax, + skip_softmax_scale=blocking_config.skip_softmax_scale, + skip_softmax_prefill_scale=blocking_config.skip_softmax_prefill_scale, + skip_softmax_decode_scale=blocking_config.skip_softmax_decode_scale, + skip_softmax_min_keep_blocks=blocking_config.skip_softmax_min_keep_blocks, ) return attn_output, attn_weights diff --git a/QEfficient/blocking/blocked_attention_forwards.py b/QEfficient/blocking/blocked_attention_forwards.py index 6aed6e49f9..0388fe9a73 100644 --- a/QEfficient/blocking/blocked_attention_forwards.py +++ b/QEfficient/blocking/blocked_attention_forwards.py @@ -42,6 +42,72 @@ def _normalize_int(value: Optional[torch.Tensor | int]) -> int: return int(value) if value is not None else 0 +def _get_skip_softmax_log_lambda( + *, + skip_softmax: bool, + query: torch.Tensor, + ctx_len: int, + skip_softmax_scale: Optional[float] = None, + skip_softmax_prefill_scale: float = 1.0, + skip_softmax_decode_scale: float = 1.0, +) -> Optional[torch.Tensor]: + if not skip_softmax: + return None + + if skip_softmax_scale is not None: + scale = skip_softmax_scale + else: + scale = skip_softmax_decode_scale if query.shape[2] == 1 else skip_softmax_prefill_scale + if scale <= 0: + return None + + lambda_eff = torch.tensor(float(scale) / max(1.0, float(ctx_len)), dtype=torch.float32, device=query.device) + lambda_eff = torch.clamp(lambda_eff, min=torch.finfo(torch.float32).tiny) + return torch.log(lambda_eff) + + +def _compute_skip_softmax_mask( + *, + block_max: torch.Tensor, + current_max: torch.Tensor, + current_denominator: torch.Tensor, + log_lambda: torch.Tensor, + block_idx: int, + min_keep_blocks: int, +) -> torch.Tensor: + if block_idx < max(0, int(min_keep_blocks)): + return torch.zeros_like(block_max, dtype=torch.bool) + + running_max = torch.maximum(current_max, block_max) + has_prior_contribution = current_denominator > 0 + return ((block_max.to(torch.float32) - running_max.to(torch.float32)) < log_lambda) & has_prior_contribution + + +def _build_skip_mask( + *, + attn_weights_block: torch.Tensor, + current_max: torch.Tensor, + current_denominator: torch.Tensor, + skip_softmax_log_lambda: Optional[torch.Tensor], + block_idx: int, + skip_softmax_min_keep_blocks: int, + skip_future: Optional[torch.Tensor], +) -> Optional[torch.Tensor]: + skip_mask = None + if skip_softmax_log_lambda is not None: + skip_mask = _compute_skip_softmax_mask( + block_max=attn_weights_block.max(dim=3).values, + current_max=current_max, + current_denominator=current_denominator, + log_lambda=skip_softmax_log_lambda, + block_idx=block_idx, + min_keep_blocks=skip_softmax_min_keep_blocks, + ) + if skip_future is not None: + skip_mask = skip_future if skip_mask is None else (skip_mask | skip_future) + return skip_mask + + def update_running_softmax( current_max: torch.Tensor, attn_weights_block: torch.Tensor, @@ -49,7 +115,8 @@ def update_running_softmax( output: torch.Tensor, v_block: torch.Tensor, skip_kv: bool = False, - skip_future: Optional(torch.Tensor) = None, + skip_future: Optional[torch.Tensor] = None, + skip_mask: Optional[torch.Tensor] = None, ): # Update Running row maximum prev_max = current_max @@ -78,10 +145,13 @@ def update_running_softmax( * torch.exp(delta_max.unsqueeze(-1)) ) - if skip_kv and (torch.onnx.is_in_onnx_export() or torch.jit.is_tracing()): - current_max = torch.where(skip_future, prev_max, current_max_updated) - current_denominator = torch.where(skip_future, prev_denominator, current_denominator_updated) - output = torch.where(skip_future.unsqueeze(-1), prev_output, output_updated) + if skip_mask is None and skip_kv and (torch.onnx.is_in_onnx_export() or torch.jit.is_tracing()): + skip_mask = skip_future + + if skip_mask is not None: + current_max = torch.where(skip_mask, prev_max, current_max_updated) + current_denominator = torch.where(skip_mask, prev_denominator, current_denominator_updated) + output = torch.where(skip_mask.unsqueeze(-1), prev_output, output_updated) else: # Eager mode current_max = current_max_updated @@ -107,6 +177,11 @@ def blocked_kv_attention_forward( skip_kv: bool = False, position_bias: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, + skip_softmax: bool = False, + skip_softmax_scale: Optional[float] = None, + skip_softmax_prefill_scale: float = 1.0, + skip_softmax_decode_scale: float = 1.0, + skip_softmax_min_keep_blocks: int = 1, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Initialize result tensor @@ -138,6 +213,15 @@ def blocked_kv_attention_forward( if sinks is not None: sinks = sinks.reshape(1, -1, 1, 1).expand(batch_size, -1, seq_len, -1) + skip_softmax_log_lambda = _get_skip_softmax_log_lambda( + skip_softmax=skip_softmax, + query=query, + ctx_len=past_seen_tokens, + skip_softmax_scale=skip_softmax_scale, + skip_softmax_prefill_scale=skip_softmax_prefill_scale, + skip_softmax_decode_scale=skip_softmax_decode_scale, + ) + for j in range(num_kv_blocks): start_index = j * kv_block_size if j == num_kv_blocks - 1: @@ -188,8 +272,25 @@ def blocked_kv_attention_forward( if mask_block is not None: attn_weights_block = torch.where(mask_block, masked_tensor, attn_weights_block) + skip_mask = _build_skip_mask( + attn_weights_block=attn_weights_block, + current_max=current_max, + current_denominator=current_denominator, + skip_softmax_log_lambda=skip_softmax_log_lambda, + block_idx=j, + skip_softmax_min_keep_blocks=skip_softmax_min_keep_blocks, + skip_future=skip_future, + ) + current_max, current_denominator, output = update_running_softmax( - current_max, attn_weights_block, current_denominator, output, v_block_states, skip_kv, skip_future + current_max, + attn_weights_block, + current_denominator, + output, + v_block_states, + skip_kv, + skip_future, + skip_mask, ) # If present, apply Attention Sinks, needed for GPT-OSS @@ -220,6 +321,11 @@ def blocked_qkv_attention_forward( skip_kv: bool = False, position_bias: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, + skip_softmax: bool = False, + skip_softmax_scale: Optional[float] = None, + skip_softmax_prefill_scale: float = 1.0, + skip_softmax_decode_scale: float = 1.0, + skip_softmax_min_keep_blocks: int = 1, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Initialize Running Maximum and Denominator @@ -248,6 +354,15 @@ def blocked_qkv_attention_forward( if sinks is not None: sinks = sinks.reshape(1, -1, 1, 1).expand(batch_size, -1, seq_len, -1) + skip_softmax_log_lambda = _get_skip_softmax_log_lambda( + skip_softmax=skip_softmax, + query=query, + ctx_len=past_seen_tokens, + skip_softmax_scale=skip_softmax_scale, + skip_softmax_prefill_scale=skip_softmax_prefill_scale, + skip_softmax_decode_scale=skip_softmax_decode_scale, + ) + for q_block_idx in range(num_q_blocks): q_start = q_block_positions[q_block_idx] if q_block_idx == num_q_blocks - 1: @@ -317,6 +432,16 @@ def blocked_qkv_attention_forward( attn_mask_block = mask_block[:, :, q_start : q_start + q_len_block, :] attn_weights_block = torch.where(attn_mask_block, masked_tensor, attn_weights_block) + skip_mask = _build_skip_mask( + attn_weights_block=attn_weights_block, + current_max=current_max, + current_denominator=current_denominator, + skip_softmax_log_lambda=skip_softmax_log_lambda, + block_idx=j, + skip_softmax_min_keep_blocks=skip_softmax_min_keep_blocks, + skip_future=skip_future, + ) + current_max, current_denominator, output_blocks = update_running_softmax( current_max, attn_weights_block, @@ -325,6 +450,7 @@ def blocked_qkv_attention_forward( v_block_states, skip_kv, skip_future, + skip_mask, ) # If present, apply Attention Sinks, needed for GPT-OSS @@ -358,6 +484,11 @@ def blocked_hqkv_attention_forward( skip_kv: bool = False, position_bias: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, + skip_softmax: bool = False, + skip_softmax_scale: Optional[float] = None, + skip_softmax_prefill_scale: float = 1.0, + skip_softmax_decode_scale: float = 1.0, + skip_softmax_min_keep_blocks: int = 1, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Initialize Running Maximum and Denominator @@ -388,6 +519,15 @@ def blocked_hqkv_attention_forward( if sinks is not None: sinks = sinks.reshape(1, -1, 1, 1).expand(batch_size, -1, seq_len, -1) + skip_softmax_log_lambda = _get_skip_softmax_log_lambda( + skip_softmax=skip_softmax, + query=query, + ctx_len=past_seen_tokens, + skip_softmax_scale=skip_softmax_scale, + skip_softmax_prefill_scale=skip_softmax_prefill_scale, + skip_softmax_decode_scale=skip_softmax_decode_scale, + ) + # Process each head block independently for head_block_idx in range(num_head_blocks): h_start = head_block_idx * head_block_size @@ -473,8 +613,25 @@ def blocked_hqkv_attention_forward( mask_block_g = mask_block[:, :, q_start : q_start + q_len_block, :] attn_weights_block = torch.where(mask_block_g, masked_tensor, attn_weights_block) + skip_mask = _build_skip_mask( + attn_weights_block=attn_weights_block, + current_max=current_max, + current_denominator=current_denominator, + skip_softmax_log_lambda=skip_softmax_log_lambda, + block_idx=j, + skip_softmax_min_keep_blocks=skip_softmax_min_keep_blocks, + skip_future=skip_future, + ) + current_max, current_denominator, output_blocks = update_running_softmax( - current_max, attn_weights_block, current_denominator, output_blocks, v_g, skip_kv, skip_future + current_max, + attn_weights_block, + current_denominator, + output_blocks, + v_g, + skip_kv, + skip_future, + skip_mask, ) # If present, apply Attention Sinks, needed for GPT-OSS if sinks is not None: @@ -516,6 +673,11 @@ def blocked_bhqkv_attention_forward( skip_kv: bool = False, position_bias: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, + skip_softmax: bool = False, + skip_softmax_scale: Optional[float] = None, + skip_softmax_prefill_scale: float = 1.0, + skip_softmax_decode_scale: float = 1.0, + skip_softmax_min_keep_blocks: int = 1, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Initialize Running Maximum and Denominator @@ -553,6 +715,15 @@ def blocked_bhqkv_attention_forward( if sinks is not None: sinks = sinks.reshape(1, -1, 1, 1).expand(batch_size, -1, seq_len, -1) + skip_softmax_log_lambda = _get_skip_softmax_log_lambda( + skip_softmax=skip_softmax, + query=query, + ctx_len=past_seen_tokens, + skip_softmax_scale=skip_softmax_scale, + skip_softmax_prefill_scale=skip_softmax_prefill_scale, + skip_softmax_decode_scale=skip_softmax_decode_scale, + ) + # Process each head block independently for head_block_idx in range(num_head_blocks): h_start = head_block_idx * head_block_size @@ -654,8 +825,25 @@ def blocked_bhqkv_attention_forward( ] attn_weights_block = torch.where(mask_block_g, masked_tensor, attn_weights_block) + skip_mask = _build_skip_mask( + attn_weights_block=attn_weights_block, + current_max=current_max, + current_denominator=current_denominator, + skip_softmax_log_lambda=skip_softmax_log_lambda, + block_idx=j, + skip_softmax_min_keep_blocks=skip_softmax_min_keep_blocks, + skip_future=skip_future, + ) + current_max, current_denominator, output_blocks = update_running_softmax( - current_max, attn_weights_block, current_denominator, output_blocks, v_g, skip_kv, skip_future + current_max, + attn_weights_block, + current_denominator, + output_blocks, + v_g, + skip_kv, + skip_future, + skip_mask, ) batch_output_blocks.append(output_blocks) batch_attn_blocks.append(attn_weights_block) @@ -837,6 +1025,11 @@ def blocked_kv_mla_attention_forward( skip_kv: bool = False, position_bias: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, + skip_softmax: bool = False, + skip_softmax_scale: Optional[float] = None, + skip_softmax_prefill_scale: float = 1.0, + skip_softmax_decode_scale: float = 1.0, + skip_softmax_min_keep_blocks: int = 1, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Initialize result tensor @@ -867,6 +1060,15 @@ def blocked_kv_mla_attention_forward( position_ids = cache_kwargs.get("position_ids") current_position = position_ids.max(dim=-1).values + skip_softmax_log_lambda = _get_skip_softmax_log_lambda( + skip_softmax=skip_softmax, + query=query, + ctx_len=ctx_len, + skip_softmax_scale=skip_softmax_scale, + skip_softmax_prefill_scale=skip_softmax_prefill_scale, + skip_softmax_decode_scale=skip_softmax_decode_scale, + ) + for j in range(num_kv_blocks): start_index = j * kv_block_size if j == num_kv_blocks - 1: @@ -920,6 +1122,15 @@ def blocked_kv_mla_attention_forward( attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling # [1, 64, q_len, 576] X [1, 1, 576, kv_block_size] -> [1, 64, q_len, kv_block_size] attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) + skip_mask = _build_skip_mask( + attn_weights_block=attn_weights_block, + current_max=current_max, + current_denominator=current_denominator, + skip_softmax_log_lambda=skip_softmax_log_lambda, + block_idx=j, + skip_softmax_min_keep_blocks=skip_softmax_min_keep_blocks, + skip_future=skip_future, + ) current_max, current_denominator, output = update_running_softmax( current_max, attn_weights_block, @@ -928,6 +1139,7 @@ def blocked_kv_mla_attention_forward( compressed_kv_block, skip_kv, skip_future, + skip_mask, ) # [1, 64, q_len, kv_block_size] X [1, 1, kv_block_size, 512] -> [1, 64, q_len, 512] else: knope = torch.matmul(compressed_kv_block, per_head_k_up_normal) @@ -940,6 +1152,15 @@ def blocked_kv_mla_attention_forward( krope_nope = torch.cat((knope, k_pe_block), dim=-1) attn_weights_block = torch.matmul(query, krope_nope.transpose(2, 3)) * scaling attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) + skip_mask = _build_skip_mask( + attn_weights_block=attn_weights_block, + current_max=current_max, + current_denominator=current_denominator, + skip_softmax_log_lambda=skip_softmax_log_lambda, + block_idx=j, + skip_softmax_min_keep_blocks=skip_softmax_min_keep_blocks, + skip_future=skip_future, + ) current_max, current_denominator, output = update_running_softmax( current_max, attn_weights_block, @@ -948,6 +1169,7 @@ def blocked_kv_mla_attention_forward( compressed_kv_block, skip_kv, skip_future, + skip_mask, ) attn_output = torch.matmul(output, per_head_v_up) diff --git a/QEfficient/blocking/blocking_configurator.py b/QEfficient/blocking/blocking_configurator.py index eaa256611a..2403cb0032 100644 --- a/QEfficient/blocking/blocking_configurator.py +++ b/QEfficient/blocking/blocking_configurator.py @@ -376,4 +376,18 @@ def build_transformer_blocking_config_for_transform( if qaic_config.get("skip_kv", False) and enable_blocking: blocking_config.skip_kv = qaic_config.get("skip_kv") + if qaic_config.get("skip_softmax", False) and enable_blocking: + blocking_config.skip_softmax = bool(qaic_config.get("skip_softmax")) + if qaic_config.get("skip_softmax_scale", None) is not None: + blocking_config.skip_softmax_scale = float(qaic_config.get("skip_softmax_scale")) + blocking_config.skip_softmax_prefill_scale = float( + qaic_config.get("skip_softmax_prefill_scale", blocking_config.skip_softmax_prefill_scale) + ) + blocking_config.skip_softmax_decode_scale = float( + qaic_config.get("skip_softmax_decode_scale", blocking_config.skip_softmax_decode_scale) + ) + blocking_config.skip_softmax_min_keep_blocks = int( + qaic_config.get("skip_softmax_min_keep_blocks", blocking_config.skip_softmax_min_keep_blocks) + ) + return blocking_config diff --git a/examples/text_generation/skip_softmax_kv_blocking_inference.py b/examples/text_generation/skip_softmax_kv_blocking_inference.py new file mode 100644 index 0000000000..8b3fb31357 --- /dev/null +++ b/examples/text_generation/skip_softmax_kv_blocking_inference.py @@ -0,0 +1,127 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse +from pprint import pprint + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + + +def _parse_device_group(device_ids): + if device_ids is None: + return None + return [int(x) for x in device_ids.strip("[]").split(",") if x] + + +def main(): + parser = argparse.ArgumentParser(description="KV-blocked text generation with skip-softmax attention") + parser.add_argument("--model-name", type=str, default="meta-llama/Llama-3.2-1B", help="HuggingFace model ID") + parser.add_argument("--prompt", type=str, default="Hello", help="Input prompt") + parser.add_argument("--prefill-seq-len", type=int, default=1, help="Prefill sequence length") + parser.add_argument("--ctx-len", type=int, default=32768, help="Context length used for compilation") + parser.add_argument("--generation-len", type=int, default=100, help="Number of tokens to generate") + parser.add_argument("--num-cores", type=int, default=16, help="Number of cores") + parser.add_argument("--aic-hw-version", type=str, default="ai100", help="AIC hardware version") + parser.add_argument( + "--device-group", + type=_parse_device_group, + default=None, + help="Device IDs, e.g. [0] or [0,1,2,3]", + ) + parser.add_argument("--num-kv-blocks", type=int, default=8, help="Number of KV blocks for blocked attention") + parser.add_argument( + "--kv-blocking", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable KV-blocked attention. Use --no-kv-blocking for plain non-blocked baseline.", + ) + parser.add_argument( + "--skip-softmax", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable skip-softmax on top of KV blocking. Use --no-skip-softmax for KV-blocking baseline.", + ) + parser.add_argument( + "--skip-softmax-scale", + type=float, + default=None, + help="Explicit BLASST scale. Effective lambda is scale / ctx_len and overrides phase scales.", + ) + parser.add_argument( + "--skip-softmax-prefill-scale", + type=float, + default=1.0, + help="Prefill BLASST scale used when --skip-softmax-scale is not set", + ) + parser.add_argument( + "--skip-softmax-decode-scale", + type=float, + default=1.0, + help="Decode BLASST scale used when --skip-softmax-scale is not set", + ) + parser.add_argument( + "--skip-softmax-min-keep-blocks", + type=int, + default=1, + help="Minimum leading KV blocks to keep before applying skip-softmax", + ) + args = parser.parse_args() + + qaic_config = None + if args.kv_blocking: + qaic_config = { + "enable_blocking": True, + "blocking_mode": "kv", + "num_kv_blocks": args.num_kv_blocks, + "skip_kv": True, + "skip_softmax": args.skip_softmax, + "skip_softmax_prefill_scale": args.skip_softmax_prefill_scale, + "skip_softmax_decode_scale": args.skip_softmax_decode_scale, + "skip_softmax_min_keep_blocks": args.skip_softmax_min_keep_blocks, + } + if args.skip_softmax_scale is not None: + qaic_config["skip_softmax_scale"] = args.skip_softmax_scale + + print("qaic_config:") + pprint(qaic_config) + + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + model = QEFFAutoModelForCausalLM.from_pretrained(args.model_name) + + qpc_path = model.compile( + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + num_cores=args.num_cores, + aic_hw_version=args.aic_hw_version, + num_devices=(1 if args.device_group is None else len(args.device_group)), + qaic_config=qaic_config, + ) + print(f"Model compiled to: {qpc_path}") + + exec_info = model.generate( + tokenizer=tokenizer, + prompts=[args.prompt], + device_id=args.device_group, + generation_len=args.generation_len, + ) + + print(f"\nPrompt: {args.prompt}") + print(f"Generated: {exec_info.generated_texts[0]}") + if not args.kv_blocking: + perf_label = "plain non-blocked" + elif args.skip_softmax: + perf_label = "KV-blocked skip-softmax" + else: + perf_label = "KV-blocked no skip-softmax" + print(f"Performance {perf_label}:") + print(exec_info) + + +if __name__ == "__main__": + main() diff --git a/tests/test_skip_softmax_kv_blocking.py b/tests/test_skip_softmax_kv_blocking.py new file mode 100644 index 0000000000..432aa99876 --- /dev/null +++ b/tests/test_skip_softmax_kv_blocking.py @@ -0,0 +1,135 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import math + +import torch + +from QEfficient.blocking.blocked_attention_forwards import ( + _compute_skip_softmax_mask, + _get_skip_softmax_log_lambda, + update_running_softmax, +) + + +def test_skip_softmax_lambda_uses_explicit_scale_when_provided(): + query = torch.zeros(1, 2, 1, 4) + + log_lambda = _get_skip_softmax_log_lambda( + skip_softmax=True, + query=query, + ctx_len=4096, + skip_softmax_scale=16.0, + skip_softmax_prefill_scale=999.0, + skip_softmax_decode_scale=999.0, + ) + + assert torch.allclose(log_lambda.cpu(), torch.tensor(math.log(16.0 / 4096.0), dtype=torch.float32)) + + +def test_skip_softmax_lambda_uses_decode_scale_for_q_len_one(): + query = torch.zeros(1, 2, 1, 4) + + log_lambda = _get_skip_softmax_log_lambda( + skip_softmax=True, + query=query, + ctx_len=1024, + skip_softmax_prefill_scale=8.0, + skip_softmax_decode_scale=2.0, + ) + + assert torch.allclose(log_lambda.cpu(), torch.tensor(math.log(2.0 / 1024.0), dtype=torch.float32)) + + +def test_skip_softmax_lambda_uses_prefill_scale_for_q_len_greater_than_one(): + query = torch.zeros(1, 2, 16, 4) + + log_lambda = _get_skip_softmax_log_lambda( + skip_softmax=True, + query=query, + ctx_len=2048, + skip_softmax_prefill_scale=4.0, + skip_softmax_decode_scale=2.0, + ) + + assert torch.allclose(log_lambda.cpu(), torch.tensor(math.log(4.0 / 2048.0), dtype=torch.float32)) + + +def test_skip_softmax_mask_does_not_skip_before_prior_contribution(): + block_max = torch.tensor([[[0.0]]]) + current_max = torch.tensor([[[10.0]]]) + current_denominator = torch.zeros(1, 1, 1) + log_lambda = torch.tensor(math.log(1.0 / 1024.0), dtype=torch.float32) + + mask = _compute_skip_softmax_mask( + block_max=block_max, + current_max=current_max, + current_denominator=current_denominator, + log_lambda=log_lambda, + block_idx=1, + min_keep_blocks=1, + ) + + assert not mask.item() + + +def test_skip_softmax_mask_skips_after_min_keep_blocks(): + block_max = torch.tensor([[[0.0]]]) + current_max = torch.tensor([[[10.0]]]) + current_denominator = torch.ones(1, 1, 1) + log_lambda = torch.tensor(math.log(1.0 / 1024.0), dtype=torch.float32) + + mask = _compute_skip_softmax_mask( + block_max=block_max, + current_max=current_max, + current_denominator=current_denominator, + log_lambda=log_lambda, + block_idx=1, + min_keep_blocks=1, + ) + + assert mask.item() + + +def test_update_running_softmax_preserves_state_for_skip_mask(): + current_max = torch.zeros(1, 1, 1) + current_denominator = torch.ones(1, 1, 1) + output = torch.ones(1, 1, 1, 2) + + scores = torch.full((1, 1, 1, 2), -10.0) + value = torch.ones(1, 1, 2, 2) * 7.0 + skip_mask = torch.ones(1, 1, 1, dtype=torch.bool) + + next_max, next_den, next_output = update_running_softmax( + current_max=current_max, + attn_weights_block=scores, + current_denominator=current_denominator, + output=output, + v_block=value, + skip_mask=skip_mask, + ) + + assert torch.equal(next_max, current_max) + assert torch.equal(next_den, current_denominator) + assert torch.equal(next_output, output) + + +def test_update_running_softmax_updates_state_when_not_skipped(): + current_max = torch.zeros(1, 1, 1) + current_denominator = torch.ones(1, 1, 1) + output = torch.zeros(1, 1, 1, 2) + + scores = torch.zeros(1, 1, 1, 2) + value = torch.ones(1, 1, 2, 2) + skip_mask = torch.zeros(1, 1, 1, dtype=torch.bool) + + _, next_den, next_output = update_running_softmax( + current_max, scores, current_denominator, output, value, skip_mask=skip_mask + ) + + assert torch.all(next_den > current_denominator) + assert torch.all(next_output > output)