-
Notifications
You must be signed in to change notification settings - Fork 933
Integrate attention sink into ET LLM export and runner #18860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,3 +1,3 @@ | ||||||||||||||||||||||
| /* | ||||||||||||||||||||||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||||||||||||||||||
| * All rights reserved. | ||||||||||||||||||||||
|
|
@@ -110,6 +110,8 @@ | |||||||||||||||||||||
| stats_->inference_start_ms = time_in_ms(); | ||||||||||||||||||||||
| shouldStop_ = false; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Get max_seq_len for single prefill chunk limit | ||||||||||||||||||||||
| int64_t max_seq_len = metadata_.at(kMaxSeqLen); | ||||||||||||||||||||||
| int64_t max_context_len = metadata_.at(kMaxContextLen); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| uint64_t cur_token = 0; | ||||||||||||||||||||||
|
|
@@ -138,13 +140,26 @@ | |||||||||||||||||||||
| InvalidArgument, | ||||||||||||||||||||||
| "Expected at least 1 prompt token"); | ||||||||||||||||||||||
| ET_CHECK_OR_RETURN_ERROR( | ||||||||||||||||||||||
| pos_ + num_prompt_tokens < max_context_len, | ||||||||||||||||||||||
| num_prompt_tokens <= max_seq_len, | ||||||||||||||||||||||
| InvalidArgument, | ||||||||||||||||||||||
| "pos_ %" PRId64 " + num_prompt_tokens %d >= max_context_len %" PRId64 | ||||||||||||||||||||||
| ", Max seq length exceeded - please increase max seq len value in your export script", | ||||||||||||||||||||||
| pos_, | ||||||||||||||||||||||
| "num_prompt_tokens %d > max_seq_len %" PRId64 | ||||||||||||||||||||||
| ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", | ||||||||||||||||||||||
| num_prompt_tokens, | ||||||||||||||||||||||
| max_context_len); | ||||||||||||||||||||||
| max_seq_len); | ||||||||||||||||||||||
| // For non-sliding-window models, also check that we won't exceed | ||||||||||||||||||||||
| // KV cache capacity. Sliding window models (where max_seq_len < | ||||||||||||||||||||||
| // max_context_len) handle position wrapping internally. | ||||||||||||||||||||||
| if (max_seq_len >= max_context_len) { | ||||||||||||||||||||||
| ET_CHECK_OR_RETURN_ERROR( | ||||||||||||||||||||||
| pos_ + num_prompt_tokens < max_context_len, | ||||||||||||||||||||||
| InvalidArgument, | ||||||||||||||||||||||
| "pos_ %" PRId64 " + num_prompt_tokens %d >= max_context_len %" PRId64 | ||||||||||||||||||||||
| ", Max seq length exceeded - please increase max seq len value in " | ||||||||||||||||||||||
| "your export script", | ||||||||||||||||||||||
| pos_, | ||||||||||||||||||||||
| num_prompt_tokens, | ||||||||||||||||||||||
| max_context_len); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // print prompts | ||||||||||||||||||||||
| if (config.echo) { | ||||||||||||||||||||||
|
|
@@ -168,9 +183,10 @@ | |||||||||||||||||||||
| prefill_next_token_.reset(); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Resolve max_new_tokens. pos_ now reflects all occupied positions | ||||||||||||||||||||||
| // (including prompt tokens just prefilled). | ||||||||||||||||||||||
| int max_new_tokens = config.resolve_max_new_tokens(max_context_len, pos_); | ||||||||||||||||||||||
| // For sliding window models, the ring buffer recycles space — pos_ doesn't | ||||||||||||||||||||||
| // represent consumed capacity, so pass 0 to get the full budget. | ||||||||||||||||||||||
| int64_t effective_pos = (max_seq_len < max_context_len) ? 0 : pos_; | ||||||||||||||||||||||
| int max_new_tokens = config.resolve_max_new_tokens(max_context_len, effective_pos); | ||||||||||||||||||||||
|
Comment on lines
+186
to
+189
|
||||||||||||||||||||||
| // For sliding window models, the ring buffer recycles space — pos_ doesn't | |
| // represent consumed capacity, so pass 0 to get the full budget. | |
| int64_t effective_pos = (max_seq_len < max_context_len) ? 0 : pos_; | |
| int max_new_tokens = config.resolve_max_new_tokens(max_context_len, effective_pos); | |
| // Resolve generation budget from the actual starting position so | |
| // start_pos + max_new_tokens stays within max_context_len. Sliding-window | |
| // KV-cache reuse must not bypass positional/context-length limits. | |
| int64_t effective_pos = pos_; | |
| int max_new_tokens = | |
| config.resolve_max_new_tokens(max_context_len, effective_pos); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CustomKVCacheWithAttentionSink.update()computes ring-buffer writeindicesbut does not validate that the number of window tokens written in a single update fits within the ring portion (i.e., avoids duplicate indices due to modulo wrap).KVCacheWithAttentionSink.update()has an explicitnum_window_tokens <= ring_sizeguard to prevent non-deterministicindex_copy_behavior; the same kind of guard is needed here as well (especially sinceupdate_cache_with_indicesis also a scatter-style update).