Skip to content

Commit cdb7a75

Browse files
avion23Ralf Waldukat
andauthored
fix: clear prompt for recurrent / hybrid models when only a partial prefix matches (abetlen#2108)
Co-authored-by: Ralf Waldukat <ralf.waldukat@gmail.com>
1 parent 73ee7cd commit cdb7a75

4 files changed

Lines changed: 143 additions & 6 deletions

File tree

.github/workflows/test.yaml

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ on:
1010
env:
1111
REPO_ID: lmstudio-community/Qwen3.5-0.8B-GGUF
1212
MODEL_FILE: Qwen3.5-0.8B-Q8_0.gguf
13+
RECURRENT_REPO_ID: QuantFactory/mamba-130m-hf-GGUF
14+
RECURRENT_MODEL_FILE: mamba-130m-hf.Q2_K.gguf
15+
HYBRID_REPO_ID: tiiuae/Falcon-H1-Tiny-90M-Instruct-GGUF
16+
HYBRID_MODEL_FILE: Falcon-H1-Tiny-90M-Instruct-Q2_K.gguf
17+
MODEL_CACHE_KEY: qwen35-q8-mamba130m-q2-falconh1tiny-q2
1318

1419
jobs:
1520
download-model:
@@ -22,12 +27,15 @@ jobs:
2227
- name: Install huggingface-hub
2328
run: pip install huggingface-hub
2429
- name: Download model
25-
run: hf download ${{ env.REPO_ID }} ${{ env.MODEL_FILE }}
30+
run: |
31+
hf download ${{ env.REPO_ID }} ${{ env.MODEL_FILE }}
32+
hf download ${{ env.RECURRENT_REPO_ID }} ${{ env.RECURRENT_MODEL_FILE }}
33+
hf download ${{ env.HYBRID_REPO_ID }} ${{ env.HYBRID_MODEL_FILE }}
2634
- name: Cache model
2735
uses: actions/cache@v4
2836
with:
2937
path: ~/.cache/huggingface/hub
30-
key: ${{ runner.os }}-model-${{ env.REPO_ID }}-${{ env.MODEL_FILE }}
38+
key: ${{ runner.os }}-model-${{ env.MODEL_CACHE_KEY }}
3139

3240
build-linux:
3341
needs: download-model
@@ -49,7 +57,7 @@ jobs:
4957
uses: actions/cache@v4
5058
with:
5159
path: ~/.cache/huggingface/hub
52-
key: ${{ runner.os }}-model-${{ env.REPO_ID }}-${{ env.MODEL_FILE }}
60+
key: ${{ runner.os }}-model-${{ env.MODEL_CACHE_KEY }}
5361
- name: Install dependencies (Linux/MacOS)
5462
run: |
5563
python -m pip install --upgrade pip
@@ -81,7 +89,7 @@ jobs:
8189
uses: actions/cache@v4
8290
with:
8391
path: ~/.cache/huggingface/hub
84-
key: ${{ runner.os }}-model-${{ env.REPO_ID }}-${{ env.MODEL_FILE }}
92+
key: ${{ runner.os }}-model-${{ env.MODEL_CACHE_KEY }}
8593

8694
- name: Install dependencies (Windows)
8795
run: |
@@ -121,7 +129,7 @@ jobs:
121129
uses: actions/cache@v4
122130
with:
123131
path: ~/.cache/huggingface/hub
124-
key: ${{ runner.os }}-model-${{ env.REPO_ID }}-${{ env.MODEL_FILE }}
132+
key: ${{ runner.os }}-model-${{ env.MODEL_CACHE_KEY }}
125133

126134
- name: Install dependencies (Linux/MacOS)
127135
run: |
@@ -157,7 +165,7 @@ jobs:
157165
uses: actions/cache@v4
158166
with:
159167
path: ~/.cache/huggingface/hub
160-
key: ${{ runner.os }}-model-${{ env.REPO_ID }}-${{ env.MODEL_FILE }}
168+
key: ${{ runner.os }}-model-${{ env.MODEL_CACHE_KEY }}
161169

162170
- name: Install dependencies
163171
run: |

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
- fix: clear prompt for recurrent / hybrid models when only a partial prefix matches by @avion23 in #2108
1011
- fix: match Transformers `tojson` in chat template rendering by @CISC in #1486
1112
- fix: use env var configured multimodal library override paths when loading shared libraries by @navratil-matej in #1782
1213
- feat: add Jinja2 loop controls to chat templates by @handshape in #2018

llama_cpp/llama.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,10 @@ def free_lora_adapter():
559559

560560
self._sampler = None
561561

562+
# Cache recurrent/hybrid model detection to avoid repeated FFI calls
563+
self._is_recurrent = llama_cpp.llama_model_is_recurrent(self._model.model)
564+
self._is_hybrid = llama_cpp.llama_model_is_hybrid(self._model.model)
565+
562566
@property
563567
def ctx(self) -> llama_cpp.llama_context_p:
564568
return self._ctx.ctx
@@ -644,6 +648,11 @@ def reset(self):
644648
"""Reset the model state."""
645649
self.n_tokens = 0
646650

651+
if self._is_recurrent or self._is_hybrid:
652+
mem = llama_cpp.llama_get_memory(self._ctx.ctx)
653+
if mem is not None:
654+
llama_cpp.llama_memory_clear(mem, True)
655+
647656
def eval(self, tokens: Sequence[int]):
648657
"""Evaluate a list of tokens.
649658
@@ -899,6 +908,19 @@ def generate(
899908
longest_prefix += 1
900909
else:
901910
break
911+
912+
# Recurrent and hybrid models cannot rewind state; reset if needed
913+
if (
914+
self._is_recurrent or self._is_hybrid
915+
) and longest_prefix < self.n_tokens:
916+
longest_prefix = 0
917+
reset = True
918+
if self.verbose:
919+
print(
920+
"Llama.generate: recurrent/hybrid model requires full state reset",
921+
file=sys.stderr,
922+
)
923+
902924
if longest_prefix > 0:
903925
if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1):
904926
reset = False

tests/test_llama.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,22 @@ def llama_cpp_embedding_model_path():
7272
return model_path
7373

7474

75+
@pytest.fixture
76+
def llama_cpp_recurrent_model_path():
77+
repo_id = "QuantFactory/mamba-130m-hf-GGUF"
78+
filename = "mamba-130m-hf.Q2_K.gguf"
79+
model_path = hf_hub_download(repo_id, filename)
80+
return model_path
81+
82+
83+
@pytest.fixture
84+
def llama_cpp_hybrid_model_path():
85+
repo_id = "tiiuae/Falcon-H1-Tiny-90M-Instruct-GGUF"
86+
filename = "Falcon-H1-Tiny-90M-Instruct-Q2_K.gguf"
87+
model_path = hf_hub_download(repo_id, filename)
88+
return model_path
89+
90+
7591
def test_real_model(llama_cpp_model_path):
7692
import os
7793

@@ -233,6 +249,96 @@ def logit_processor_func(input_ids, logits):
233249
assert number_1 == number_3
234250

235251

252+
def test_real_llama_repeated_prompt_cache(llama_cpp_model_path):
253+
model = llama_cpp.Llama(
254+
llama_cpp_model_path,
255+
n_ctx=32,
256+
n_batch=32,
257+
n_ubatch=32,
258+
n_threads=multiprocessing.cpu_count(),
259+
n_threads_batch=multiprocessing.cpu_count(),
260+
logits_all=False,
261+
flash_attn=True,
262+
verbose=False,
263+
)
264+
prompt = "The quick brown fox jumps over the lazy dog. The quick brown fox"
265+
266+
output_1 = model.create_completion(
267+
prompt,
268+
max_tokens=6,
269+
temperature=0.0,
270+
seed=1337,
271+
)
272+
output_2 = model.create_completion(
273+
prompt,
274+
max_tokens=6,
275+
temperature=0.0,
276+
seed=1337,
277+
)
278+
279+
assert output_1["choices"][0]["text"] == " jumps over the lazy dog."
280+
assert output_2["choices"][0]["text"] == output_1["choices"][0]["text"]
281+
282+
283+
def _assert_prompt_cache_reset_handles_history_edit(
284+
model_path,
285+
*,
286+
is_recurrent: bool,
287+
is_hybrid: bool,
288+
):
289+
model = llama_cpp.Llama(
290+
model_path,
291+
n_ctx=32,
292+
n_batch=32,
293+
n_ubatch=32,
294+
n_threads=multiprocessing.cpu_count(),
295+
n_threads_batch=multiprocessing.cpu_count(),
296+
logits_all=False,
297+
verbose=False,
298+
)
299+
300+
assert model._is_recurrent is is_recurrent
301+
assert model._is_hybrid is is_hybrid
302+
303+
first_prompt = "The quick brown fox"
304+
second_prompt = "The slow brown fox"
305+
first_tokens = model.tokenize(first_prompt.encode(), add_bos=True, special=True)
306+
second_tokens = model.tokenize(second_prompt.encode(), add_bos=True, special=True)
307+
308+
assert first_tokens != second_tokens
309+
assert first_tokens[0] == second_tokens[0]
310+
311+
first_output = model.create_completion(
312+
first_prompt,
313+
max_tokens=1,
314+
temperature=0.0,
315+
)
316+
assert isinstance(first_output["choices"][0]["text"], str)
317+
318+
second_output = model.create_completion(
319+
second_prompt,
320+
max_tokens=1,
321+
temperature=0.0,
322+
)
323+
assert isinstance(second_output["choices"][0]["text"], str)
324+
325+
326+
def test_recurrent_model_prompt_cache_reset(llama_cpp_recurrent_model_path):
327+
_assert_prompt_cache_reset_handles_history_edit(
328+
llama_cpp_recurrent_model_path,
329+
is_recurrent=True,
330+
is_hybrid=False,
331+
)
332+
333+
334+
def test_hybrid_model_prompt_cache_reset(llama_cpp_hybrid_model_path):
335+
_assert_prompt_cache_reset_handles_history_edit(
336+
llama_cpp_hybrid_model_path,
337+
is_recurrent=False,
338+
is_hybrid=True,
339+
)
340+
341+
236342
def test_real_llama_embeddings(llama_cpp_embedding_model_path):
237343
model = llama_cpp.Llama(
238344
llama_cpp_embedding_model_path,

0 commit comments

Comments
 (0)