Skip to content

Commit 653d5dd

Browse files
author
Ankur Kaul
committed
fix: preserve recurrent/hybrid model state when the full prompt is already cached
1 parent 3850aff commit 653d5dd

4 files changed

Lines changed: 90 additions & 7 deletions

File tree

.gitignore

191 Bytes
Binary file not shown.

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
- fix: preserve recurrent/hybrid model state when the full prompt is already cached
11+
1012
## [0.3.29]
1113

1214
- feat(example): use MTMD batch encoding by @abetlen in #2301

llama_cpp/llama.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -913,13 +913,35 @@ def generate(
913913
if (
914914
self._is_recurrent or self._is_hybrid
915915
) and longest_prefix < self.n_tokens:
916-
longest_prefix = 0
917-
reset = True
918-
if self.verbose:
919-
print(
920-
"Llama.generate: recurrent/hybrid model requires full state reset",
921-
file=sys.stderr,
922-
)
916+
# Full prompt already cached -> preserve state (no reset).
917+
# zip(self._input_ids, tokens[:-1]) compares N versus N-1
918+
# tokens so longest_prefix is always < self.n_tokens when
919+
# the cache holds the full prompt. Without this guard the
920+
# recurrent/hybrid branch always fires and wipes the state,
921+
# including multimodal embeddings injected by handlers
922+
# such as MTMDChatHandler.
923+
if (
924+
len(tokens) > 0
925+
and longest_prefix >= len(tokens) - 1
926+
and self.n_tokens == len(tokens)
927+
and self._input_ids[-1] == tokens[-1]
928+
):
929+
reset = False
930+
tokens = []
931+
longest_prefix = 0
932+
if self.verbose:
933+
print(
934+
"Llama.generate: full prompt already cached for hybrid model, skipping reset",
935+
file=sys.stderr,
936+
)
937+
else:
938+
longest_prefix = 0
939+
reset = True
940+
if self.verbose:
941+
print(
942+
"Llama.generate: recurrent/hybrid model requires full state reset",
943+
file=sys.stderr,
944+
)
923945

924946
if longest_prefix > 0:
925947
if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1):

tests/test_llama.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,65 @@ def test_hybrid_model_prompt_cache_reset(llama_cpp_hybrid_model_path):
339339
)
340340

341341

342+
def _assert_repeated_full_prompt_preserves_state(
343+
model_path,
344+
*,
345+
is_recurrent: bool,
346+
is_hybrid: bool,
347+
):
348+
model = llama_cpp.Llama(
349+
model_path,
350+
n_ctx=32,
351+
n_batch=32,
352+
n_ubatch=32,
353+
n_threads=multiprocessing.cpu_count(),
354+
n_threads_batch=multiprocessing.cpu_count(),
355+
logits_all=False,
356+
verbose=False,
357+
)
358+
359+
assert model._is_recurrent is is_recurrent
360+
assert model._is_hybrid is is_hybrid
361+
362+
prompt = "The quick brown fox"
363+
max_tokens = 3
364+
temperature = 0.0
365+
366+
output_1 = model.create_completion(
367+
prompt, max_tokens=max_tokens, temperature=temperature
368+
)
369+
n_tokens_1 = model.n_tokens
370+
371+
output_2 = model.create_completion(
372+
prompt, max_tokens=max_tokens, temperature=temperature
373+
)
374+
n_tokens_2 = model.n_tokens
375+
376+
# Without the fix, generate() resets the state and re-evaluates from
377+
# scratch, keeping n_tokens flat. With the fix, the state is preserved
378+
# and n_tokens keeps growing across repeated calls.
379+
assert n_tokens_2 > n_tokens_1
380+
assert output_2["choices"][0]["text"]
381+
382+
383+
def test_recurrent_model_repeated_full_prompt_preserves_state(
384+
llama_cpp_recurrent_model_path,
385+
):
386+
_assert_repeated_full_prompt_preserves_state(
387+
llama_cpp_recurrent_model_path,
388+
is_recurrent=True,
389+
is_hybrid=False,
390+
)
391+
392+
393+
def test_hybrid_model_repeated_full_prompt_preserves_state(llama_cpp_hybrid_model_path):
394+
_assert_repeated_full_prompt_preserves_state(
395+
llama_cpp_hybrid_model_path,
396+
is_recurrent=False,
397+
is_hybrid=True,
398+
)
399+
400+
342401
def test_real_llama_embeddings(llama_cpp_embedding_model_path):
343402
model = llama_cpp.Llama(
344403
llama_cpp_embedding_model_path,

0 commit comments

Comments
 (0)