Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified .gitignore
Binary file not shown.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

- fix: preserve recurrent/hybrid model state when the full prompt is already cached

## [0.3.29]

- feat(example): use MTMD batch encoding by @abetlen in #2301
Expand Down
36 changes: 29 additions & 7 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,13 +913,35 @@ def generate(
if (
self._is_recurrent or self._is_hybrid
) and longest_prefix < self.n_tokens:
longest_prefix = 0
reset = True
if self.verbose:
print(
"Llama.generate: recurrent/hybrid model requires full state reset",
file=sys.stderr,
)
# Full prompt already cached -> preserve state (no reset).
# zip(self._input_ids, tokens[:-1]) compares N versus N-1
# tokens so longest_prefix is always < self.n_tokens when
# the cache holds the full prompt. Without this guard the
# recurrent/hybrid branch always fires and wipes the state,
# including multimodal embeddings injected by handlers
# such as MTMDChatHandler.
if (
len(tokens) > 0
and longest_prefix >= len(tokens) - 1
and self.n_tokens == len(tokens)
and self._input_ids[-1] == tokens[-1]
):
reset = False
tokens = []
longest_prefix = 0
if self.verbose:
print(
"Llama.generate: full prompt already cached for hybrid model, skipping reset",
file=sys.stderr,
)
else:
longest_prefix = 0
reset = True
if self.verbose:
print(
"Llama.generate: recurrent/hybrid model requires full state reset",
file=sys.stderr,
)

if longest_prefix > 0:
if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1):
Expand Down
59 changes: 59 additions & 0 deletions tests/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,65 @@ def test_hybrid_model_prompt_cache_reset(llama_cpp_hybrid_model_path):
)


def _assert_repeated_full_prompt_preserves_state(
model_path,
*,
is_recurrent: bool,
is_hybrid: bool,
):
model = llama_cpp.Llama(
model_path,
n_ctx=32,
n_batch=32,
n_ubatch=32,
n_threads=multiprocessing.cpu_count(),
n_threads_batch=multiprocessing.cpu_count(),
logits_all=False,
verbose=False,
)

assert model._is_recurrent is is_recurrent
assert model._is_hybrid is is_hybrid

prompt = "The quick brown fox"
max_tokens = 3
temperature = 0.0

output_1 = model.create_completion(
prompt, max_tokens=max_tokens, temperature=temperature
)
n_tokens_1 = model.n_tokens

output_2 = model.create_completion(
prompt, max_tokens=max_tokens, temperature=temperature
)
n_tokens_2 = model.n_tokens

# Without the fix, generate() resets the state and re-evaluates from
# scratch, keeping n_tokens flat. With the fix, the state is preserved
# and n_tokens keeps growing across repeated calls.
assert n_tokens_2 > n_tokens_1
assert output_2["choices"][0]["text"]


def test_recurrent_model_repeated_full_prompt_preserves_state(
llama_cpp_recurrent_model_path,
):
_assert_repeated_full_prompt_preserves_state(
llama_cpp_recurrent_model_path,
is_recurrent=True,
is_hybrid=False,
)


def test_hybrid_model_repeated_full_prompt_preserves_state(llama_cpp_hybrid_model_path):
_assert_repeated_full_prompt_preserves_state(
llama_cpp_hybrid_model_path,
is_recurrent=False,
is_hybrid=True,
)


def test_real_llama_embeddings(llama_cpp_embedding_model_path):
model = llama_cpp.Llama(
llama_cpp_embedding_model_path,
Expand Down