From 653d5dddbe42584d2ff4813a49f158bb7157bc8a Mon Sep 17 00:00:00 2001 From: Ankur Kaul Date: Mon, 15 Jun 2026 03:15:36 +0530 Subject: [PATCH] fix: preserve recurrent/hybrid model state when the full prompt is already cached --- .gitignore | Bin 3295 -> 3486 bytes CHANGELOG.md | 2 ++ llama_cpp/llama.py | 36 +++++++++++++++++++++------ tests/test_llama.py | 59 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 90 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index ff773c668415e603581c74c0cf539095f737d5ef..0712d885784cb8919657810f02470044f07c9c2f 100644 GIT binary patch literal 3486 zcmb_fO>f*b5cPS1{SO3Opsis^K+_aya!PEcO^~FEI_ae-78FHk>6J*9q;{?R`o1A4 zb)6>dr3YL5m@f|Byy0Gyy%D@$uhy&5PGM(sagfd{qvO)S3)9Go##_IyN7XlRQc8Vr z0y>Ry7}EX#!QRK+bM|5ivJhsls+SF`n6W_(>SG*Vzj?!2tb5qp+DUM+liMv6FA%}+ z@b?FXU)iY?J@*h=ug=*OllMVtiVG`!&j-F0)>ap#HxK8{)T{d8(jWt!@KmemKUs)M zz!>fZ2%mhw)Szi@&@ z{Qu#5d@Xl@gnU*!aEIA(CsP{zjyqZ9J6TpQ$X20rYam-H?Qt}uxD+iSjGOD#+vo3I zy}Y@7pHjD(US%mq=VKrI2$hM2o%AwY1cN7bbSfaltezNW$u0B3DH~Yp%z(up2ix-? zqnTQ^23jZ9SrIE7x5^e6mEj(OGy&!$pcZz@7lcI-jGNLn>3`5FAZ{OUmBIkoypzmHYmoZ& zY8QqctCY!L?&+h`XBNfEJS;D0LRej|wDEB2gD>v~;Ajzvvx}rc=djad+=J$BvUGB< z=*FqipuUafXzs?gU1m`Qs}%ox$F)b!6Yh~E6YUzDp0=;=xb6(|0s2pA2e|>jsG_$B zF#xw7&H&1~^ndm$$VaE+;?9gtr|LgvFHvs$#a4ut7F7G?QP!wL8GMlwK)!E<<91r) zOJD`_6xXfYDSgixBgR4M5aXxH5Wd>Z)k8Km(+i>$31lM7xFe|u0(YIHVj`i}Bc;o+ zcP7Z%m_1~bopim$U5LFkKoRbeBwuZEg<5BbVYZEI`Ph>oXb`z=h0(3*VopfIPLaJ( z7VM>Qz&MK2aW^qcQt1Q7lNs5dx?NyLt^p@tw$lXu(-QIu$%@z~#$*a1kyQjIbpzgm zmK!zZY(PZH_9zm4P|j$|kblNvVH}l*!+0>VH9#Ksdd>(jpKvpRS9Xd1xe&%RxNmY~ zC(Em}d;a*xCzq%(1O(iJk>M?egMj-Hp>C9=WKkp7Q+%2KQ7<|;)1x}+AU~TU=+42G zW`ct`ObZEcYHeJ+U+lDO#}5<((FEvWD5#QEs|0H#Vr`@rN_tA}={z~es@gE&jQ7QC zgZa@?y zz+A|SXUG|Kr{btk3yB`aglK9|s5~eJGxV3}Zz!YBRT4l3(Ns-p!P^%KX-O~awgalA z#5(I6(= zT-%#ty#5mzm@M*PD~!|YYlcyLfKhI)&`<%d*iz8i;;fcD6dtV9j6H(;d$F{tNfF)x z4sUtb0hf+q%nv;Da%?Pgdvn5gIAmn6@(@4x2b!iaJ~!`v!+(s+n(y;+9#V|Ac0$jX zZ;s8K4ixPp zPsikXGf~Hw;6KuP?s<6Pe8@Ul&iWobv3|M^%F+CuRh};8YcFF6LP?t+_TUVSr^p^{W%

pwPk2^cIZxE9a(81jzI#ddbrpn0}7ndH)R+?g?Fg8(Tg39Cd1y==Si~qR*LQ zZmvrXsDTGOITe%A~dhw_q@B#`q{?z zZuXsW>V;e712-=G1)MqZZy<4j0}|rUzkr|pcFTr~Ju{vU&&+S;$GgApeb%K?mrN2u zP2?NA z7JNjLaC`0UHKF~K4q1G>;xjI_>G9DBzpaGxIQ5m1npF_AV%lDCk`o*szMm7rtjD`+ zEoRK-phc8QB(Q-IKMes%*e|&O$ba;9eYte6x z7aBJFvvw3dSnofCTFsOga_2!h!^^1oN-vXo#-PNnX82rrWCHHqATfh0t+hv>$UPP?HE4 zSQg0j^Mm8Z{_T>vCeWP(BU8@dBO3>YDP^CVSn>kc)ocN{@A>KhjjYF#1L`JC&`)S-ZtU>Ru#VAx(I)5 z?T6dj#T%q!(!+mxRUMvhGnkmFPSWslo55k0WkO|(HId zFn{r3Ay4d*4u6w7STj}|z7)YxJ&Bn4AB=Jb^u|Q$1`Vb0T+RutB1bPBHtq@u-F{cSd<{l|q1XX-69VP94bAVMxxGnk#fc5g3YeF8Ki9aW?qi*-5>iu?;DQV*K| diff --git a/CHANGELOG.md b/CHANGELOG.md index 56c5ffb557..44d7347e17 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- fix: preserve recurrent/hybrid model state when the full prompt is already cached + ## [0.3.29] - feat(example): use MTMD batch encoding by @abetlen in #2301 diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 4a09b55ee5..882c005f4d 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -913,13 +913,35 @@ def generate( if ( self._is_recurrent or self._is_hybrid ) and longest_prefix < self.n_tokens: - longest_prefix = 0 - reset = True - if self.verbose: - print( - "Llama.generate: recurrent/hybrid model requires full state reset", - file=sys.stderr, - ) + # Full prompt already cached -> preserve state (no reset). + # zip(self._input_ids, tokens[:-1]) compares N versus N-1 + # tokens so longest_prefix is always < self.n_tokens when + # the cache holds the full prompt. Without this guard the + # recurrent/hybrid branch always fires and wipes the state, + # including multimodal embeddings injected by handlers + # such as MTMDChatHandler. + if ( + len(tokens) > 0 + and longest_prefix >= len(tokens) - 1 + and self.n_tokens == len(tokens) + and self._input_ids[-1] == tokens[-1] + ): + reset = False + tokens = [] + longest_prefix = 0 + if self.verbose: + print( + "Llama.generate: full prompt already cached for hybrid model, skipping reset", + file=sys.stderr, + ) + else: + longest_prefix = 0 + reset = True + if self.verbose: + print( + "Llama.generate: recurrent/hybrid model requires full state reset", + file=sys.stderr, + ) if longest_prefix > 0: if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1): diff --git a/tests/test_llama.py b/tests/test_llama.py index 336d6a6122..0f80c11537 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -339,6 +339,65 @@ def test_hybrid_model_prompt_cache_reset(llama_cpp_hybrid_model_path): ) +def _assert_repeated_full_prompt_preserves_state( + model_path, + *, + is_recurrent: bool, + is_hybrid: bool, +): + model = llama_cpp.Llama( + model_path, + n_ctx=32, + n_batch=32, + n_ubatch=32, + n_threads=multiprocessing.cpu_count(), + n_threads_batch=multiprocessing.cpu_count(), + logits_all=False, + verbose=False, + ) + + assert model._is_recurrent is is_recurrent + assert model._is_hybrid is is_hybrid + + prompt = "The quick brown fox" + max_tokens = 3 + temperature = 0.0 + + output_1 = model.create_completion( + prompt, max_tokens=max_tokens, temperature=temperature + ) + n_tokens_1 = model.n_tokens + + output_2 = model.create_completion( + prompt, max_tokens=max_tokens, temperature=temperature + ) + n_tokens_2 = model.n_tokens + + # Without the fix, generate() resets the state and re-evaluates from + # scratch, keeping n_tokens flat. With the fix, the state is preserved + # and n_tokens keeps growing across repeated calls. + assert n_tokens_2 > n_tokens_1 + assert output_2["choices"][0]["text"] + + +def test_recurrent_model_repeated_full_prompt_preserves_state( + llama_cpp_recurrent_model_path, +): + _assert_repeated_full_prompt_preserves_state( + llama_cpp_recurrent_model_path, + is_recurrent=True, + is_hybrid=False, + ) + + +def test_hybrid_model_repeated_full_prompt_preserves_state(llama_cpp_hybrid_model_path): + _assert_repeated_full_prompt_preserves_state( + llama_cpp_hybrid_model_path, + is_recurrent=False, + is_hybrid=True, + ) + + def test_real_llama_embeddings(llama_cpp_embedding_model_path): model = llama_cpp.Llama( llama_cpp_embedding_model_path,