Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,16 +1781,17 @@ def _get_source_transforms( # noqa
use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask

if use_sdpa_with_kv_cache:
transforms.append(replace_kv_cache_with_custom_kv_cache)
# todo: do this optionally
# if use attention mask instead of causal attention
# then create partial function that sets use_attention_mask=True
# Replace SDPA first, then KV cache. Order matters: the KV cache
# replacement sets SDPACustom.use_attention_mask=True for ring buffer
# models (attention sink, sliding window). If SDPA is replaced after,
# a new SDPACustom(use_attention_mask=False) would overwrite it.
if use_attention_mask_for_custom_sdpa:
transforms.append(
partial(replace_sdpa_with_custom_op, use_attention_mask=True)
)
else:
transforms.append(replace_sdpa_with_custom_op)
transforms.append(replace_kv_cache_with_custom_kv_cache)

if quantize_kv_cache:
assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
Expand Down
110 changes: 109 additions & 1 deletion examples/models/llama/source_transformation/custom_kv_cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -371,8 +371,41 @@


def _replace_kv_cache_with_custom_kv_cache(module):
# Import here to avoid circular imports
from executorch.examples.models.llama.source_transformation.attention_sink import (
KVCacheWithAttentionSink,
)

for name, child in module.named_children():
if isinstance(child, KVCache):
if isinstance(child, KVCacheWithAttentionSink):
# Replace with custom op variant for performance
setattr(
module,
name,
CustomKVCacheWithAttentionSink.from_kv_cache_with_attention_sink(child),
)
# If parent has SDPACustom, enable explicit mask mode
sdpa = getattr(module, "SDPA", None)
if sdpa is not None and hasattr(sdpa, "use_attention_mask"):
sdpa.use_attention_mask = True
elif isinstance(child, RingKVCache):
# RingKVCache (e.g., from attention sink with sink_size=0) needs
# CustomRingKVCache, not plain CustomKVCache
setattr(
module,
name,
CustomRingKVCache(
child.max_batch_size,
child.window_size,
child.n_heads,
child.head_dim,
dtype=child.k_cache.dtype,
),
)
sdpa = getattr(module, "SDPA", None)
if sdpa is not None and hasattr(sdpa, "use_attention_mask"):
sdpa.use_attention_mask = True
elif isinstance(child, KVCache):
cache_shape = child.k_cache.shape
cache_dtype = child.k_cache.dtype
max_batch_size, n_heads, max_context_length, head_dim = cache_shape
Expand Down Expand Up @@ -466,6 +499,81 @@
)


class CustomKVCacheWithAttentionSink(CustomKVCache):
"""
CustomKVCache variant for attention sink models.

Uses the custom update_cache_with_indices op for performance while
supporting sink tokens (fixed slots) + ring buffer (sliding window).
Modeled after CustomRingKVCache but with CachePositionsManagerWithSink.
"""

def __init__(
self,
max_batch_size,
n_heads,
head_dim,
window_size,
sink_size,
dtype=torch.float32,
):
# Total cache size: sink slots + ring buffer (2x window for wrap safety)
total_cache_size = sink_size + window_size * 2
super().__init__(
max_batch_size, total_cache_size, n_heads, head_dim, dtype
)
from executorch.examples.models.llama.source_transformation.attention_sink import (
CachePositionsManagerWithSink,
_create_causal_mask_for_attention_sink,
)

self.cache_positions_manager = CachePositionsManagerWithSink(
total_cache_size, sink_size
)
self.is_ring_buffer = True
self.window_size = window_size
self.sink_size = sink_size
self._create_causal_mask_for_attention_sink = (
_create_causal_mask_for_attention_sink
)

def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
cache_positions = self.cache_positions_manager.cache_positions
if self.sink_size > 0:
return self._create_causal_mask_for_attention_sink(
cache_positions, self.window_size, self.sink_size, start_pos, seq_len
)
else:
return _create_causal_mask_for_ring_buffer(
cache_positions, self.window_size, start_pos, seq_len
)

def update(self, input_pos, k_val, v_val):
seq_len = k_val.transpose(1, 2).size(1)
assert seq_len <= self.k_cache.size(
1
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(1)})"
indices = self.cache_positions_manager.calculate_positions_and_update_indices(
input_pos, seq_len
)
indices = indices.unsqueeze(0)

return super().update(input_pos, k_val, v_val, indices)
Comment on lines +551 to +561
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CustomKVCacheWithAttentionSink.update() computes ring-buffer write indices but does not validate that the number of window tokens written in a single update fits within the ring portion (i.e., avoids duplicate indices due to modulo wrap). KVCacheWithAttentionSink.update() has an explicit num_window_tokens <= ring_size guard to prevent non-deterministic index_copy_ behavior; the same kind of guard is needed here as well (especially since update_cache_with_indices is also a scatter-style update).

Copilot uses AI. Check for mistakes.

@classmethod
def from_kv_cache_with_attention_sink(cls, kv_cache):
"""Create from an existing KVCacheWithAttentionSink."""
max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape
return cls(
max_batch_size,
n_heads,
head_dim,
kv_cache.window_size,
kv_cache.sink_size,
dtype=kv_cache.k_cache.dtype,
)


class CustomRingKVCache(CustomKVCache):
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,48 @@ def test_beyond_context_window_basic(self):
self.assertTrue(
torch.isfinite(out).all(), "Output contains non-finite values"
)

def test_beyond_context_window_custom_sdpa(self):
"""Generate tokens beyond context window with custom SDPA + custom KV cache."""
sink_size = 4
window_size = 16
args = self._make_args(max_context_len=128)
model = self._build_model(args, sink_size, window_size, use_custom_sdpa=True)

# Verify KV caches were replaced with CustomKVCacheWithAttentionSink
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
CustomKVCacheWithAttentionSink,
)

found_custom_cache = False
for m in model.modules():
if isinstance(m, CustomKVCacheWithAttentionSink):
found_custom_cache = True
break
self.assertTrue(
found_custom_cache, "Expected CustomKVCacheWithAttentionSink in model"
)

# Generate 80 tokens — well beyond KV cache size of 36
outputs = self._run_generation(model, args, num_tokens=80)

self.assertEqual(len(outputs), 77)
for out in outputs:
self.assertTrue(
torch.isfinite(out).all(), "Output contains non-finite values"
)

def test_sink_zero_custom_sdpa(self):
"""Degenerate case: sink_size=0 with custom SDPA (pure ring buffer)."""
sink_size = 0
window_size = 16
args = self._make_args(max_context_len=128)
model = self._build_model(args, sink_size, window_size, use_custom_sdpa=True)

outputs = self._run_generation(model, args, num_tokens=60)

self.assertEqual(len(outputs), 57)
for out in outputs:
self.assertTrue(
torch.isfinite(out).all(), "Output contains non-finite values"
)
32 changes: 24 additions & 8 deletions extension/llm/runner/text_llm_runner.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -110,6 +110,8 @@
stats_->inference_start_ms = time_in_ms();
shouldStop_ = false;

// Get max_seq_len for single prefill chunk limit
int64_t max_seq_len = metadata_.at(kMaxSeqLen);
int64_t max_context_len = metadata_.at(kMaxContextLen);

uint64_t cur_token = 0;
Expand Down Expand Up @@ -138,13 +140,26 @@
InvalidArgument,
"Expected at least 1 prompt token");
ET_CHECK_OR_RETURN_ERROR(
pos_ + num_prompt_tokens < max_context_len,
num_prompt_tokens <= max_seq_len,
InvalidArgument,
"pos_ %" PRId64 " + num_prompt_tokens %d >= max_context_len %" PRId64
", Max seq length exceeded - please increase max seq len value in your export script",
pos_,
"num_prompt_tokens %d > max_seq_len %" PRId64
", Single prefill chunk too large - please reduce prompt size or increase max_seq_len",
num_prompt_tokens,
max_context_len);
max_seq_len);
// For non-sliding-window models, also check that we won't exceed
// KV cache capacity. Sliding window models (where max_seq_len <
// max_context_len) handle position wrapping internally.
if (max_seq_len >= max_context_len) {
ET_CHECK_OR_RETURN_ERROR(
pos_ + num_prompt_tokens < max_context_len,
InvalidArgument,
"pos_ %" PRId64 " + num_prompt_tokens %d >= max_context_len %" PRId64
", Max seq length exceeded - please increase max seq len value in "
"your export script",
pos_,
num_prompt_tokens,
max_context_len);
}

// print prompts
if (config.echo) {
Expand All @@ -168,9 +183,10 @@
prefill_next_token_.reset();
}

// Resolve max_new_tokens. pos_ now reflects all occupied positions
// (including prompt tokens just prefilled).
int max_new_tokens = config.resolve_max_new_tokens(max_context_len, pos_);
// For sliding window models, the ring buffer recycles space — pos_ doesn't
// represent consumed capacity, so pass 0 to get the full budget.
int64_t effective_pos = (max_seq_len < max_context_len) ? 0 : pos_;
int max_new_tokens = config.resolve_max_new_tokens(max_context_len, effective_pos);
Comment on lines +186 to +189
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

effective_pos is set to 0 for sliding-window models, but resolve_max_new_tokens(max_context_len, effective_pos) is intended to ensure start_pos + max_new_tokens <= max_context_len. With effective_pos=0, a nonzero pos_ can lead to generating past max_context_len (e.g., RoPE tables typically guard input_pos + seq_len <= max_context_len), causing runtime failures. Consider keeping pos_ (or a wrapped position if the model truly wraps positions) when resolving max_new_tokens, and only relax KV-cache-capacity checks separately from the RoPE/context-length limit.

Suggested change
// For sliding window models, the ring buffer recycles space — pos_ doesn't
// represent consumed capacity, so pass 0 to get the full budget.
int64_t effective_pos = (max_seq_len < max_context_len) ? 0 : pos_;
int max_new_tokens = config.resolve_max_new_tokens(max_context_len, effective_pos);
// Resolve generation budget from the actual starting position so
// start_pos + max_new_tokens stays within max_context_len. Sliding-window
// KV-cache reuse must not bypass positional/context-length limits.
int64_t effective_pos = pos_;
int max_new_tokens =
config.resolve_max_new_tokens(max_context_len, effective_pos);

Copilot uses AI. Check for mistakes.

ET_LOG(
Info,
Expand Down
Loading