From 85ab161169504428716da0e663d1b1af45c90a07 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 13 Apr 2026 17:02:47 -0700 Subject: [PATCH] Integrate attention sink into ET LLM export and runner Summary: Add custom op support, export pipeline integration, and C++ runner fixes for the attention sink ring buffer implementation. - CustomKVCacheWithAttentionSink: custom op variant using update_cache_with_indices for scatter-write performance. Replaces KVCacheWithAttentionSink during export. - CustomRingKVCache replacement: handle RingKVCache -> CustomRingKVCache in the replacement pass, and set SDPACustom.use_attention_mask=True for ring buffer models. - Export transform ordering: replace SDPA before KV cache so that _replace_kv_cache_with_custom_kv_cache can set use_attention_mask=True on the already-existing SDPACustom (previously the ordering was reversed, causing the mask flag to be overwritten by a new SDPACustom). - C++ runner: add max_seq_len prefill check; make context length check conditional for sliding window models (max_seq_len < max_context_len) since they handle position wrapping internally via ring buffer. Differential Revision: D100216686 --- examples/models/llama/export_llama_lib.py | 9 +- .../source_transformation/custom_kv_cache.py | 110 +++++++++++++++++- .../test_attention_sink.py | 45 +++++++ extension/llm/runner/text_llm_runner.cpp | 32 +++-- 4 files changed, 183 insertions(+), 13 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index b0acca44304..dbcd385c35e 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1781,16 +1781,17 @@ def _get_source_transforms( # noqa use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask if use_sdpa_with_kv_cache: - transforms.append(replace_kv_cache_with_custom_kv_cache) - # todo: do this optionally - # if use attention mask instead of causal attention - # then create partial function that sets use_attention_mask=True + # Replace SDPA first, then KV cache. Order matters: the KV cache + # replacement sets SDPACustom.use_attention_mask=True for ring buffer + # models (attention sink, sliding window). If SDPA is replaced after, + # a new SDPACustom(use_attention_mask=False) would overwrite it. if use_attention_mask_for_custom_sdpa: transforms.append( partial(replace_sdpa_with_custom_op, use_attention_mask=True) ) else: transforms.append(replace_sdpa_with_custom_op) + transforms.append(replace_kv_cache_with_custom_kv_cache) if quantize_kv_cache: assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 8d4d37e0e93..907904c8ba2 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -371,8 +371,41 @@ def replace_kv_cache_with_custom_kv_cache(module): def _replace_kv_cache_with_custom_kv_cache(module): + # Import here to avoid circular imports + from executorch.examples.models.llama.source_transformation.attention_sink import ( + KVCacheWithAttentionSink, + ) + for name, child in module.named_children(): - if isinstance(child, KVCache): + if isinstance(child, KVCacheWithAttentionSink): + # Replace with custom op variant for performance + setattr( + module, + name, + CustomKVCacheWithAttentionSink.from_kv_cache_with_attention_sink(child), + ) + # If parent has SDPACustom, enable explicit mask mode + sdpa = getattr(module, "SDPA", None) + if sdpa is not None and hasattr(sdpa, "use_attention_mask"): + sdpa.use_attention_mask = True + elif isinstance(child, RingKVCache): + # RingKVCache (e.g., from attention sink with sink_size=0) needs + # CustomRingKVCache, not plain CustomKVCache + setattr( + module, + name, + CustomRingKVCache( + child.max_batch_size, + child.window_size, + child.n_heads, + child.head_dim, + dtype=child.k_cache.dtype, + ), + ) + sdpa = getattr(module, "SDPA", None) + if sdpa is not None and hasattr(sdpa, "use_attention_mask"): + sdpa.use_attention_mask = True + elif isinstance(child, KVCache): cache_shape = child.k_cache.shape cache_dtype = child.k_cache.dtype max_batch_size, n_heads, max_context_length, head_dim = cache_shape @@ -466,6 +499,81 @@ def from_quantized_kv_cache( ) +class CustomKVCacheWithAttentionSink(CustomKVCache): + """ + CustomKVCache variant for attention sink models. + + Uses the custom update_cache_with_indices op for performance while + supporting sink tokens (fixed slots) + ring buffer (sliding window). + Modeled after CustomRingKVCache but with CachePositionsManagerWithSink. + """ + + def __init__( + self, + max_batch_size, + n_heads, + head_dim, + window_size, + sink_size, + dtype=torch.float32, + ): + # Total cache size: sink slots + ring buffer (2x window for wrap safety) + total_cache_size = sink_size + window_size * 2 + super().__init__( + max_batch_size, total_cache_size, n_heads, head_dim, dtype + ) + from executorch.examples.models.llama.source_transformation.attention_sink import ( + CachePositionsManagerWithSink, + _create_causal_mask_for_attention_sink, + ) + + self.cache_positions_manager = CachePositionsManagerWithSink( + total_cache_size, sink_size + ) + self.is_ring_buffer = True + self.window_size = window_size + self.sink_size = sink_size + self._create_causal_mask_for_attention_sink = ( + _create_causal_mask_for_attention_sink + ) + + def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): + cache_positions = self.cache_positions_manager.cache_positions + if self.sink_size > 0: + return self._create_causal_mask_for_attention_sink( + cache_positions, self.window_size, self.sink_size, start_pos, seq_len + ) + else: + return _create_causal_mask_for_ring_buffer( + cache_positions, self.window_size, start_pos, seq_len + ) + + def update(self, input_pos, k_val, v_val): + seq_len = k_val.transpose(1, 2).size(1) + assert seq_len <= self.k_cache.size( + 1 + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(1)})" + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len + ) + indices = indices.unsqueeze(0) + + return super().update(input_pos, k_val, v_val, indices) + + @classmethod + def from_kv_cache_with_attention_sink(cls, kv_cache): + """Create from an existing KVCacheWithAttentionSink.""" + max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape + return cls( + max_batch_size, + n_heads, + head_dim, + kv_cache.window_size, + kv_cache.sink_size, + dtype=kv_cache.k_cache.dtype, + ) + + class CustomRingKVCache(CustomKVCache): def __init__( self, diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py index 51474d75969..54cf1e57ac5 100644 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ b/examples/models/llama/source_transformation/test_attention_sink.py @@ -397,3 +397,48 @@ def test_beyond_context_window_basic(self): self.assertTrue( torch.isfinite(out).all(), "Output contains non-finite values" ) + + def test_beyond_context_window_custom_sdpa(self): + """Generate tokens beyond context window with custom SDPA + custom KV cache.""" + sink_size = 4 + window_size = 16 + args = self._make_args(max_context_len=128) + model = self._build_model(args, sink_size, window_size, use_custom_sdpa=True) + + # Verify KV caches were replaced with CustomKVCacheWithAttentionSink + from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( + CustomKVCacheWithAttentionSink, + ) + + found_custom_cache = False + for m in model.modules(): + if isinstance(m, CustomKVCacheWithAttentionSink): + found_custom_cache = True + break + self.assertTrue( + found_custom_cache, "Expected CustomKVCacheWithAttentionSink in model" + ) + + # Generate 80 tokens — well beyond KV cache size of 36 + outputs = self._run_generation(model, args, num_tokens=80) + + self.assertEqual(len(outputs), 77) + for out in outputs: + self.assertTrue( + torch.isfinite(out).all(), "Output contains non-finite values" + ) + + def test_sink_zero_custom_sdpa(self): + """Degenerate case: sink_size=0 with custom SDPA (pure ring buffer).""" + sink_size = 0 + window_size = 16 + args = self._make_args(max_context_len=128) + model = self._build_model(args, sink_size, window_size, use_custom_sdpa=True) + + outputs = self._run_generation(model, args, num_tokens=60) + + self.assertEqual(len(outputs), 57) + for out in outputs: + self.assertTrue( + torch.isfinite(out).all(), "Output contains non-finite values" + ) diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index e8b37ba8863..dc3b0f9a7a7 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -110,6 +110,8 @@ Error TextLLMRunner::generate( stats_->inference_start_ms = time_in_ms(); shouldStop_ = false; + // Get max_seq_len for single prefill chunk limit + int64_t max_seq_len = metadata_.at(kMaxSeqLen); int64_t max_context_len = metadata_.at(kMaxContextLen); uint64_t cur_token = 0; @@ -138,13 +140,26 @@ Error TextLLMRunner::generate( InvalidArgument, "Expected at least 1 prompt token"); ET_CHECK_OR_RETURN_ERROR( - pos_ + num_prompt_tokens < max_context_len, + num_prompt_tokens <= max_seq_len, InvalidArgument, - "pos_ %" PRId64 " + num_prompt_tokens %d >= max_context_len %" PRId64 - ", Max seq length exceeded - please increase max seq len value in your export script", - pos_, + "num_prompt_tokens %d > max_seq_len %" PRId64 + ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", num_prompt_tokens, - max_context_len); + max_seq_len); + // For non-sliding-window models, also check that we won't exceed + // KV cache capacity. Sliding window models (where max_seq_len < + // max_context_len) handle position wrapping internally. + if (max_seq_len >= max_context_len) { + ET_CHECK_OR_RETURN_ERROR( + pos_ + num_prompt_tokens < max_context_len, + InvalidArgument, + "pos_ %" PRId64 " + num_prompt_tokens %d >= max_context_len %" PRId64 + ", Max seq length exceeded - please increase max seq len value in " + "your export script", + pos_, + num_prompt_tokens, + max_context_len); + } // print prompts if (config.echo) { @@ -168,9 +183,10 @@ Error TextLLMRunner::generate( prefill_next_token_.reset(); } - // Resolve max_new_tokens. pos_ now reflects all occupied positions - // (including prompt tokens just prefilled). - int max_new_tokens = config.resolve_max_new_tokens(max_context_len, pos_); + // For sliding window models, the ring buffer recycles space — pos_ doesn't + // represent consumed capacity, so pass 0 to get the full budget. + int64_t effective_pos = (max_seq_len < max_context_len) ? 0 : pos_; + int max_new_tokens = config.resolve_max_new_tokens(max_context_len, effective_pos); ET_LOG( Info,