diff --git a/backend_service/agent.py b/backend_service/agent.py
index 277380e..8050384 100644
--- a/backend_service/agent.py
+++ b/backend_service/agent.py
@@ -485,7 +485,11 @@ def run_agent_loop_streaming(
# consumed so the assistant bubble doesn't show raw call
# JSON next to the rendered ToolCallCard (FU-040).
text = _strip_tool_call_xml(result.text)
- chunk_size = 4
+ # The final answer is already fully computed (tool-calling turns
+ # are non-streaming), so the old 4-char dribble just added fake
+ # latency + yields. Emit in larger chunks; the SSE layer coalesces
+ # these further and the user sees the answer near-instantly.
+ chunk_size = 48
for i in range(0, len(text), chunk_size):
yield {"token": text[i:i + chunk_size]}
diff --git a/backend_service/inference/llama_cpp_engine.py b/backend_service/inference/llama_cpp_engine.py
index d62af35..3d9e884 100644
--- a/backend_service/inference/llama_cpp_engine.py
+++ b/backend_service/inference/llama_cpp_engine.py
@@ -92,6 +92,17 @@
"frequency_penalty",
"presence_penalty",
"stop",
+ # Modern anti-repetition / quality samplers llama-server supports
+ # natively. Forward-only: builds that don't recognise them ignore the
+ # field, so old binaries are unaffected. DRY beats plain repeat_penalty
+ # at killing verbatim loops; XTC adds creative variety; top-n-sigma is
+ # a temperature-stable truncator.
+ "dry_multiplier",
+ "dry_base",
+ "dry_allowed_length",
+ "xtc_probability",
+ "xtc_threshold",
+ "top_n_sigma",
# Phase 3.3: per-token confidence info. llama-server returns
# top-k alternatives with their logprobs in each delta when
# `logprobs: true` + `top_logprobs: N` are set.
@@ -421,6 +432,7 @@ def _build_command(
fit_enabled: bool,
is_fallback: bool,
speculative_decoding: bool = False,
+ fused_attention: bool = False,
canonical_repo: str | None = None,
model_ref: str = "",
) -> tuple[list[str], str | None, bool, str | None]:
@@ -449,6 +461,19 @@ def _build_command(
str(max(256, context_tokens)),
"--jinja",
]
+ # Reuse the single slot's KV cache across chat turns: a growing
+ # conversation re-prefills only the new suffix instead of the whole
+ # history (turn-2+ TTFT drops sharply on long chats). Forward-gated
+ # on binary support so older llama-server builds are unaffected.
+ if _llama_server_supports(binary, "--cache-reuse"):
+ command.extend(["--cache-reuse", "256"])
+ # Honour the user's fused-attention toggle. It was plumbed into
+ # load_model + stored on LoadedModelInfo but never emitted as a
+ # flag. Flash attention is a large decode + KV-memory win on Metal
+ # and is required by the quantized KV cache types. Opt-in via the
+ # existing flag so a model/quant combo that dislikes it can disable.
+ if fused_attention and _llama_server_supports(binary, "--flash-attn"):
+ command.extend(["--flash-attn", "on"])
if _llama_server_supports(binary, "--reasoning-format"):
command.extend(["--reasoning-format", "deepseek"])
if _llama_server_supports(binary, "--reasoning"):
@@ -660,6 +685,7 @@ def load_model(
fit_enabled=fit_enabled,
is_fallback=is_fallback,
speculative_decoding=speculative_decoding,
+ fused_attention=fused_attention,
canonical_repo=canonical_repo,
model_ref=model_ref,
)
@@ -791,6 +817,9 @@ def generate(
"temperature": temperature,
"max_tokens": max_tokens,
"stream": False,
+ # Reuse the slot's cached prompt prefix across turns (pairs with
+ # the server's --cache-reuse) so unchanged history isn't reprocessed.
+ "cache_prompt": True,
}
if tools:
payload["tools"] = tools
@@ -884,6 +913,9 @@ def stream_generate(
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True,
+ # Reuse the slot's cached prompt prefix across turns (pairs with
+ # the server's --cache-reuse) so unchanged history isn't reprocessed.
+ "cache_prompt": True,
}
if tools:
payload["tools"] = tools
diff --git a/backend_service/mlx_worker.py b/backend_service/mlx_worker.py
index c7a0e52..f3acfc6 100644
--- a/backend_service/mlx_worker.py
+++ b/backend_service/mlx_worker.py
@@ -59,6 +59,7 @@
from backend_service import mlx_worker_lifecycle as _lifecycle
from backend_service import mlx_worker_speculative as _speculative
from backend_service import mlx_worker_generate as _generate
+from backend_service import mlx_worker_prompt_cache as _prompt_cache
# Phase 1f-4: model + runtime introspection helpers now live in
# ``backend_service.mlx_worker_diagnostics``. Re-export so existing imports
@@ -127,6 +128,13 @@ def __init__(self) -> None:
# delimiters via ``reasoning_delimiters_for``. Default
# (``...``) still applies when ``None``.
self._loaded_model_ref: str | None = None
+ # Tier 4: persistent single-slot prompt cache for native-strategy chat
+ # so follow-up turns prefill only the new suffix. Managed by
+ # backend_service.mlx_worker_prompt_cache; invalidated on any model
+ # load / unload / profile change.
+ self._persist_cache: Any | None = None
+ self._persist_tokens: list[int] = []
+ self._persist_cache_model_ref: str | None = None
def handle(self, request: dict[str, Any]) -> dict[str, Any] | None:
op = request.get("op")
@@ -148,12 +156,15 @@ def handle(self, request: dict[str, Any]) -> dict[str, Any] | None:
raise ValueError(f"Unsupported worker operation: {op}")
def load_model(self, request: dict[str, Any]) -> dict[str, Any]:
+ _prompt_cache.invalidate(self)
return _lifecycle.load_model(self, request)
def unload_model(self) -> dict[str, Any]:
+ _prompt_cache.invalidate(self)
return _lifecycle.unload_model(self)
def update_profile(self, request: dict[str, Any]) -> dict[str, Any]:
+ _prompt_cache.invalidate(self)
return _lifecycle.update_profile(self, request)
def _apply_cache_profile(
diff --git a/backend_service/mlx_worker_generate.py b/backend_service/mlx_worker_generate.py
index 7157631..2d7a65d 100644
--- a/backend_service/mlx_worker_generate.py
+++ b/backend_service/mlx_worker_generate.py
@@ -34,6 +34,7 @@
)
from backend_service.mlx_worker_request import (
_apply_mlx_seed,
+ _build_mlx_logits_processors,
_build_mlx_sampler,
_extract_top_logprobs,
_format_tools_for_prompt,
@@ -46,6 +47,7 @@
strip_harmony_boilerplate,
)
from backend_service.runaway_guard import RunawayGuard
+from backend_service import mlx_worker_prompt_cache as _prompt_cache
if TYPE_CHECKING:
@@ -109,24 +111,32 @@ def generate_standard(state: WorkerState, request: dict[str, Any]) -> dict[str,
system_prompt=system_prompt,
)
sampler = _build_mlx_sampler(request)
- prompt_cache, runtime_note = state._make_cache()
- runtime_note = _merge_runtime_notes(runtime_note, prompt_note)
- runtime_fields = state._runtime_fields(prompt_cache=prompt_cache)
+ acq = _prompt_cache.acquire(state, prompt_text)
+ prompt_cache = acq.cache
+ prompt_feed = acq.prompt_feed
+ managed = acq.managed
+ runtime_note = _merge_runtime_notes(acq.note, prompt_note)
+ runtime_fields = state._runtime_fields(prompt_cache=acq.fields_cache)
transcript_fallback = _plain_chat_fallback_active(prompt_note)
runaway_guard = RunawayGuard()
runaway_stopped = False
+ generated_ids: list[int] = []
try:
text_parts: list[str] = []
last_response = None
for response in stream_generate(
state.model,
state.tokenizer,
- prompt_text,
+ prompt_feed,
max_tokens=int(request.get("maxTokens") or 256),
sampler=sampler,
+ logits_processors=_build_mlx_logits_processors(request),
prompt_cache=prompt_cache,
):
+ _tok = getattr(response, "token", None)
+ if isinstance(_tok, int):
+ generated_ids.append(_tok)
if response.text:
text_parts.append(response.text)
try:
@@ -135,8 +145,20 @@ def generate_standard(state: WorkerState, request: dict[str, Any]) -> dict[str,
runaway_stopped = True
break
last_response = response
+ if managed:
+ _prompt_cache.commit(
+ state,
+ cache=prompt_cache,
+ commit_tokens=acq.commit_tokens,
+ generated_ids=generated_ids,
+ model_ref=state._loaded_model_ref,
+ )
except (ValueError, RuntimeError, TypeError, AttributeError) as exc:
- _should_retry = (
+ was_managed = managed
+ if managed:
+ _prompt_cache.invalidate(state)
+ managed = False
+ _should_retry = was_managed or (
prompt_cache is not None
and _should_retry_cache_failure(exc)
)
@@ -319,10 +341,13 @@ def stream_generate(state: WorkerState, request: dict[str, Any]) -> None:
system_prompt=system_prompt,
)
sampler = _build_mlx_sampler(request)
- prompt_cache, runtime_note = state._make_cache()
- runtime_note = _merge_runtime_notes(runtime_note, prompt_note)
+ acq = _prompt_cache.acquire(state, prompt_text)
+ prompt_cache = acq.cache
+ prompt_feed = acq.prompt_feed
+ managed = acq.managed
+ runtime_note = _merge_runtime_notes(acq.note, prompt_note)
runtime_note = _merge_runtime_notes(runtime_note, speculative_stream_fallback_note)
- runtime_fields = state._runtime_fields(prompt_cache=prompt_cache)
+ runtime_fields = state._runtime_fields(prompt_cache=acq.fields_cache)
transcript_fallback = _plain_chat_fallback_active(prompt_note)
thinking_mode = request.get("thinkingMode") or "off"
@@ -336,6 +361,7 @@ def stream_generate(state: WorkerState, request: dict[str, Any]) -> None:
transcript_trimmed = False
runaway_guard = RunawayGuard()
runaway_stopped = False
+ generated_ids: list[int] = []
# Phase 3.3 follow-up: when the request opted into logprobs,
# extract top-k per token via the helper and forward inline
# with each text chunk.
@@ -346,11 +372,15 @@ def stream_generate(state: WorkerState, request: dict[str, Any]) -> None:
for response in mlx_stream_generate(
state.model,
state.tokenizer,
- prompt_text,
+ prompt_feed,
max_tokens=int(request.get("maxTokens") or 256),
sampler=sampler,
+ logits_processors=_build_mlx_logits_processors(request),
prompt_cache=prompt_cache,
):
+ _tok = getattr(response, "token", None)
+ if isinstance(_tok, int):
+ generated_ids.append(_tok)
if response.text:
# Check for runaway loops before emitting
try:
@@ -392,8 +422,20 @@ def stream_generate(state: WorkerState, request: dict[str, Any]) -> None:
transcript_trimmed = transcript_trimmed or transcript_filter.stopped
if visible_text:
_emit({"ok": True, "chunk": {"text": visible_text}})
+ if managed:
+ _prompt_cache.commit(
+ state,
+ cache=prompt_cache,
+ commit_tokens=acq.commit_tokens,
+ generated_ids=generated_ids,
+ model_ref=state._loaded_model_ref,
+ )
except (ValueError, RuntimeError, TypeError, AttributeError) as exc:
- _should_retry = (
+ was_managed = managed
+ if managed:
+ _prompt_cache.invalidate(state)
+ managed = False
+ _should_retry = was_managed or (
prompt_cache is not None
and _should_retry_cache_failure(exc)
)
diff --git a/backend_service/mlx_worker_prompt_cache.py b/backend_service/mlx_worker_prompt_cache.py
new file mode 100644
index 0000000..4ccfbea
--- /dev/null
+++ b/backend_service/mlx_worker_prompt_cache.py
@@ -0,0 +1,122 @@
+"""Per-session MLX prompt-cache reuse (tier 4 of the chat-LLM review).
+
+Native-strategy chat turns re-prefill the *entire* conversation every time
+(`prompt_cache=None` → mlx-lm builds a fresh cache + processes the whole
+prompt). This module keeps one persistent mlx-lm prompt cache on the
+worker and reuses the longest matching token prefix across turns: trim the
+divergent tail off the cache, prefill only the new suffix, then re-commit
+the cache keyed by ``prompt_tokens + generated_tokens``. A single-slot port
+of mlx-lm's server reuse logic (``LRUPromptCache.fetch_nearest_cache``).
+
+Correctness invariant: the persisted token list ALWAYS equals the cache's
+positional contents (prompt + generated), so the next turn's common-prefix
+trim is exact. Any uncertainty — compression strategy active, model
+changed, cache not trimmable (SSM/Mamba/rotating-full, mlx-lm #980),
+tokenisation failure, no common prefix, partial trim — falls back to a
+fresh full prefill, i.e. identical output to the pre-cache path, just
+without the speedup. Gated to the ``native`` strategy; compression caches
+(turboquant / triattention) keep their existing per-call path untouched.
+"""
+
+from __future__ import annotations
+
+from collections import namedtuple
+from typing import Any
+
+# cache: object passed to stream_generate as prompt_cache
+# prompt_feed: what to pass as the `prompt` arg (suffix token list on a
+# reuse hit, full token list on a fresh native cache, or the
+# original prompt_text string for the compression / fallback path)
+# note: runtime note from _make_cache (compression fallback msgs)
+# commit_tokens: full prompt token list to re-key after generation (None when
+# not managing a native cache)
+# fields_cache: value to feed _runtime_fields (None for native, the
+# compression cache otherwise) so the strategy badge stays right
+# managed: True only when we own a native persistent cache to commit
+Acquired = namedtuple(
+ "Acquired", "cache prompt_feed note commit_tokens fields_cache managed"
+)
+
+
+def _common_prefix_len(a: list[int], b: list[int]) -> int:
+ n = 0
+ for x, y in zip(a, b):
+ if x != y:
+ break
+ n += 1
+ return n
+
+
+def _native_result(cache: Any | None, full_tokens: list[int], prompt_text: str, note: str | None) -> Acquired:
+ """Wrap a fresh-native-cache outcome (or a give-up fallback)."""
+ if cache is not None:
+ return Acquired(cache, full_tokens, note, full_tokens, None, True)
+ # Couldn't build a managed cache → behave exactly like before.
+ return Acquired(None, prompt_text, note, None, None, False)
+
+
+def acquire(state: Any, prompt_text: str) -> Acquired:
+ base_cache, note = state._make_cache()
+ if base_cache is not None:
+ # Compression strategy: unchanged behaviour, no persistence.
+ return Acquired(base_cache, prompt_text, note, None, base_cache, False)
+
+ # Native strategy — manage a persistent single-slot cache.
+ try:
+ from mlx_lm.models.cache import ( # noqa: PLC0415
+ can_trim_prompt_cache,
+ make_prompt_cache,
+ trim_prompt_cache,
+ )
+
+ full_tokens = list(state.tokenizer.encode(prompt_text))
+ except Exception: # noqa: BLE001 — any failure → safe full-reprocess fallback
+ return Acquired(None, prompt_text, note, None, None, False)
+
+ def _fresh() -> Any | None:
+ try:
+ return make_prompt_cache(state.model)
+ except Exception: # noqa: BLE001
+ return None
+
+ model_ref = getattr(state, "_loaded_model_ref", None)
+ persist = getattr(state, "_persist_cache", None)
+ persist_tokens = getattr(state, "_persist_tokens", None) or []
+ persist_ref = getattr(state, "_persist_cache_model_ref", None)
+
+ # Reset conditions: nothing cached, different model, empty history.
+ if persist is None or persist_ref != model_ref or not persist_tokens:
+ return _native_result(_fresh(), full_tokens, prompt_text, note)
+
+ try:
+ if not can_trim_prompt_cache(persist):
+ return _native_result(_fresh(), full_tokens, prompt_text, note)
+ # Always leave >=1 token to process live (mlx-lm does the same).
+ common = min(_common_prefix_len(persist_tokens, full_tokens), len(full_tokens) - 1)
+ if common <= 0:
+ return _native_result(_fresh(), full_tokens, prompt_text, note)
+ num_to_trim = len(persist_tokens) - common
+ if num_to_trim > 0:
+ trimmed = trim_prompt_cache(persist, num_to_trim)
+ if trimmed != num_to_trim:
+ # Couldn't roll back cleanly — don't risk a spliced mismatch.
+ return _native_result(_fresh(), full_tokens, prompt_text, note)
+ # Reuse hit: cache now holds exactly the common prefix; prefill suffix.
+ return Acquired(persist, full_tokens[common:], note, full_tokens, None, True)
+ except Exception: # noqa: BLE001
+ return _native_result(_fresh(), full_tokens, prompt_text, note)
+
+
+def commit(state: Any, *, cache: Any, commit_tokens: list[int] | None, generated_ids: list[int], model_ref: str | None) -> None:
+ """Persist the cache keyed by prompt + generated tokens (positional truth)."""
+ if cache is None or commit_tokens is None:
+ return
+ state._persist_cache = cache
+ state._persist_tokens = list(commit_tokens) + [t for t in generated_ids if isinstance(t, int)]
+ state._persist_cache_model_ref = model_ref
+
+
+def invalidate(state: Any) -> None:
+ state._persist_cache = None
+ state._persist_tokens = []
+ state._persist_cache_model_ref = None
diff --git a/backend_service/mlx_worker_request.py b/backend_service/mlx_worker_request.py
index 6bb1ab7..5c2112e 100644
--- a/backend_service/mlx_worker_request.py
+++ b/backend_service/mlx_worker_request.py
@@ -133,7 +133,10 @@ def _build_mlx_sampler(request: dict[str, Any]) -> Any:
kwargs: dict[str, Any] = {"temp": float(request.get("temperature") or 0.0)}
samplers = request.get("samplers") or {}
if isinstance(samplers, dict):
- for src in ("top_p", "top_k", "min_p"):
+ # XTC (xtc_probability/xtc_threshold) is supported by current
+ # make_sampler and adds creative variety; it survives the signature
+ # filter below on builds that have it and is dropped on older ones.
+ for src in ("top_p", "top_k", "min_p", "xtc_probability", "xtc_threshold"):
value = samplers.get(src)
if value is not None:
kwargs[src] = value
@@ -147,6 +150,47 @@ def _build_mlx_sampler(request: dict[str, Any]) -> Any:
return make_sampler(**filtered)
+def _build_mlx_logits_processors(request: dict[str, Any]) -> Any:
+ """Build mlx-lm logits processors (repetition penalty) from the request.
+
+ mlx-lm applies repetition penalty via ``logits_processors``, NOT through
+ ``make_sampler`` — so the UI's ``repeat_penalty`` was silently dropped
+ when only the sampler was wired. Returns None when no (or a no-op 1.0)
+ penalty is requested, so callers can pass ``logits_processors=None`` (the
+ mlx-lm default). Signature-filtered like the sampler for cross-version
+ robustness.
+ """
+ import inspect
+
+ samplers = request.get("samplers") or {}
+ if not isinstance(samplers, dict):
+ return None
+ raw = samplers.get("repeat_penalty", samplers.get("repetition_penalty"))
+ try:
+ penalty = float(raw) if raw is not None else None
+ except (TypeError, ValueError):
+ penalty = None
+ if penalty is None or abs(penalty - 1.0) < 1e-6:
+ return None
+
+ try:
+ from mlx_lm.sample_utils import make_logits_processors
+
+ kwargs: dict[str, Any] = {"repetition_penalty": penalty}
+ ctx = samplers.get("repeat_penalty_context") or samplers.get("repetition_context_size")
+ if ctx is not None:
+ try:
+ kwargs["repetition_context_size"] = int(ctx)
+ except (TypeError, ValueError):
+ pass
+ sig = inspect.signature(make_logits_processors)
+ allowed = set(sig.parameters.keys())
+ filtered = {k: v for k, v in kwargs.items() if k in allowed}
+ return make_logits_processors(**filtered)
+ except Exception:
+ return None
+
+
def _sampler_seed(request: dict[str, Any]) -> int | None:
samplers = request.get("samplers") or {}
if not isinstance(samplers, dict):
diff --git a/backend_service/models/__init__.py b/backend_service/models/__init__.py
index 4c43b62..e2f9414 100644
--- a/backend_service/models/__init__.py
+++ b/backend_service/models/__init__.py
@@ -151,6 +151,14 @@ class GenerateRequest(BaseModel):
mirostatMode: Literal[0, 1, 2] | None = None
mirostatTau: float | None = Field(default=None, ge=0.0, le=10.0)
mirostatEta: float | None = Field(default=None, ge=0.0, le=1.0)
+ # Modern samplers (tier 2). XTC drops top tokens for variety; DRY
+ # penalises repeated multi-token sequences. llama-server applies all;
+ # mlx-lm applies XTC via make_sampler and ignores DRY (llama-only).
+ xtcProbability: float | None = Field(default=None, ge=0.0, le=1.0)
+ xtcThreshold: float | None = Field(default=None, ge=0.0, le=1.0)
+ dryMultiplier: float | None = Field(default=None, ge=0.0, le=4.0)
+ dryBase: float | None = Field(default=None, ge=0.0, le=8.0)
+ dryAllowedLength: int | None = Field(default=None, ge=0, le=64)
seed: int | None = Field(default=None, ge=0, le=2**31 - 1)
# Constrained decoding: when set, llama-server enforces a JSON schema
# via its `response_format: {type: "json_schema", json_schema: {...}}`
@@ -268,6 +276,15 @@ class OpenAIChatCompletionRequest(BaseModel):
presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0)
seed: int | None = Field(default=None, ge=0, le=2**31 - 1)
stop: list[str] | str | None = None
+ # Non-standard but widely-accepted local-server sampler fields. Mapped
+ # into the runtime sampler dict in state/openai_compat.py for parity with
+ # the native chat route (llama-server takes these natively; the MLX worker
+ # consumes min_p + repeat_penalty).
+ min_p: float | None = Field(default=None, ge=0.0, le=1.0)
+ repeat_penalty: float | None = Field(default=None, ge=0.0, le=2.0)
+ mirostat: int | None = Field(default=None, ge=0, le=2)
+ mirostat_tau: float | None = Field(default=None, ge=0.0)
+ mirostat_eta: float | None = Field(default=None, ge=0.0)
response_format: dict[str, Any] | None = None
diff --git a/backend_service/state/__init__.py b/backend_service/state/__init__.py
index 57b3931..248bbcc 100644
--- a/backend_service/state/__init__.py
+++ b/backend_service/state/__init__.py
@@ -35,6 +35,7 @@
_build_sampler_overrides,
_clean_prompt_for_title,
_compose_chat_system_prompt,
+ _history_token_budget,
_legacy_title_from_prompt,
_normalize_remote_provider_api_base,
_read_text_tail,
diff --git a/backend_service/state/_helpers.py b/backend_service/state/_helpers.py
index fee56df..b38e597 100644
--- a/backend_service/state/_helpers.py
+++ b/backend_service/state/_helpers.py
@@ -57,6 +57,14 @@ def _put(dst: str, value: Any) -> None:
overrides["mirostat"] = mirostat_mode
_put("mirostat_tau", getattr(request, "mirostatTau", None))
_put("mirostat_eta", getattr(request, "mirostatEta", None))
+ # Modern samplers (tier 2): XTC (both engines) + DRY (llama only).
+ # Engine-side key names; llama-server forwards them via
+ # _LLAMA_SAMPLER_KEYS, mlx-lm reads xtc_* in _build_mlx_sampler.
+ _put("xtc_probability", getattr(request, "xtcProbability", None))
+ _put("xtc_threshold", getattr(request, "xtcThreshold", None))
+ _put("dry_multiplier", getattr(request, "dryMultiplier", None))
+ _put("dry_base", getattr(request, "dryBase", None))
+ _put("dry_allowed_length", getattr(request, "dryAllowedLength", None))
# Phase 3.3: when the user enables logprobs on a request the
# frontend sends a top-k count; map it onto llama-server's
# `logprobs` + `top_logprobs` parameters so the response delta
@@ -68,10 +76,43 @@ def _put(dst: str, value: Any) -> None:
return overrides
+def _estimate_tokens(text: str) -> int:
+ """Cheap, deliberately CONSERVATIVE token estimate (no tokenizer here).
+
+ Assumes ~3 chars/token vs the ~4 typical for English so the history
+ window UNDER-fills the context rather than risking an overflow the MLX
+ path can't recover from. Code and CJK are denser than English, so
+ erring small protects them too. Off by a constant factor — fine for a
+ safety budget, not for billing.
+ """
+ return (len(text) // 3) + 1
+
+
+def _history_token_budget(
+ *,
+ context_tokens: int,
+ max_tokens: int,
+ system_prompt: str | None,
+ prompt: str | None,
+) -> int:
+ """Token budget left for *prior* history after reserving room for the
+ system prompt, the current user prompt, the generation, and chat-template
+ overhead. Floors at 512 so a single recent turn is always kept.
+ """
+ reserved = (
+ _estimate_tokens(system_prompt or "")
+ + _estimate_tokens(prompt or "")
+ + int(max_tokens or 0)
+ + 512 # chat-template + role-tag + tool-schema overhead headroom
+ )
+ return max(512, int(context_tokens or 0) - reserved)
+
+
def _build_history_with_reasoning(
messages: list[dict[str, Any]],
*,
preserve_reasoning: bool,
+ token_budget: int | None = None,
) -> list[dict[str, Any]]:
"""Project a session's stored messages into the history list passed to the
inference layer.
@@ -79,10 +120,17 @@ def _build_history_with_reasoning(
When `preserve_reasoning` is true and an assistant message has a
`reasoning` field captured by ThinkingTokenFilter on a previous turn,
the reasoning is re-emitted inside `...` tags ahead of
- the visible answer. Reasoning-capable models (Qwen3, DeepSeek R1, etc.)
- consume this naturally on follow-up turns; non-reasoning models will
- treat it as inline text. Falsy / missing reasoning is skipped, so this
- is safe to call unconditionally.
+ the visible answer. (Upstream chat templates for Qwen3 / DeepSeek-R1
+ actually strip prior reasoning, so the live chat path now passes
+ `preserve_reasoning=False`; the option is kept for callers that want it.)
+ Falsy / missing reasoning is skipped, so this is safe to call
+ unconditionally.
+
+ When `token_budget` is set, a sliding window keeps every system message
+ plus the NEWEST conversation turns that fit the budget (estimated, no
+ tokenizer), dropping the oldest. This bounds prompt growth across a long
+ chat — preventing silent truncation on llama.cpp and out-of-context
+ errors on MLX. ``None`` disables windowing (unchanged behaviour).
"""
history: list[dict[str, Any]] = []
for message in messages:
@@ -97,7 +145,26 @@ def _build_history_with_reasoning(
if reasoning_str:
text = f"\n{reasoning_str}\n\n\n{text}"
history.append({"role": role, "text": text})
- return history
+
+ if token_budget is None or token_budget <= 0:
+ return history
+
+ # System messages are always kept; window the conversation tail.
+ system_msgs = [m for m in history if m["role"] == "system"]
+ convo = [m for m in history if m["role"] != "system"]
+ used = sum(_estimate_tokens(m["text"]) for m in system_msgs)
+ kept_tail: list[dict[str, Any]] = []
+ for message in reversed(convo):
+ cost = _estimate_tokens(message["text"])
+ # Always keep the most recent turn even if it alone blows the budget;
+ # dropping the latest context is worse than a small overflow the
+ # engine can still truncate.
+ if kept_tail and used + cost > token_budget:
+ break
+ used += cost
+ kept_tail.append(message)
+ kept_tail.reverse()
+ return system_msgs + kept_tail
_TITLE_LEADING_PATTERNS = [
diff --git a/backend_service/state/generation.py b/backend_service/state/generation.py
index 15098f4..1ace636 100644
--- a/backend_service/state/generation.py
+++ b/backend_service/state/generation.py
@@ -35,6 +35,7 @@
_build_history_with_reasoning,
_build_sampler_overrides,
_compose_chat_system_prompt,
+ _history_token_budget,
)
@@ -144,7 +145,17 @@ def generate(state: ChaosEngineState, request: GenerateRequest) -> dict[str, Any
history = _build_history_with_reasoning(
session["messages"],
- preserve_reasoning=(effective_thinking_mode == "auto"),
+ # Don't replay prior reasoning — upstream chat templates
+ # (Qwen3 / DeepSeek-R1) strip it, and re-feeding it bloats the
+ # prompt every turn. token_budget windows the oldest turns out so
+ # a long chat can't silently overflow the context.
+ preserve_reasoning=False,
+ token_budget=_history_token_budget(
+ context_tokens=desired_context_tokens,
+ max_tokens=request.maxTokens,
+ system_prompt=request.systemPrompt,
+ prompt=request.prompt,
+ ),
)
session["messages"].append({"role": "user", "text": request.prompt, "metrics": None})
session["updatedAt"] = state._time_label()
@@ -393,7 +404,17 @@ def generate_stream(state: ChaosEngineState, request: GenerateRequest):
history = _build_history_with_reasoning(
session["messages"],
- preserve_reasoning=(effective_thinking_mode == "auto"),
+ # Don't replay prior reasoning — upstream chat templates
+ # (Qwen3 / DeepSeek-R1) strip it, and re-feeding it bloats the
+ # prompt every turn. token_budget windows the oldest turns out so
+ # a long chat can't silently overflow the context.
+ preserve_reasoning=False,
+ token_budget=_history_token_budget(
+ context_tokens=desired_context_tokens,
+ max_tokens=request.maxTokens,
+ system_prompt=request.systemPrompt,
+ prompt=request.prompt,
+ ),
)
session["messages"].append({"role": "user", "text": request.prompt, "metrics": None})
session["updatedAt"] = state._time_label()
@@ -599,6 +620,24 @@ def _maybe_emit_generating_phase() -> str:
ttft_seconds = round(time.perf_counter() - gen_start, 3)
return f"data: {json.dumps({'phase': 'generating', 'ttftSeconds': ttft_seconds})}\n\n"
+ # Token coalescing: batch visible token frames so a fast decoder
+ # doesn't pay a json.dumps + SSE frame per token. Flush on size, a
+ # short time window, any non-token event, or stream end. Disabled
+ # when per-token logprobs are requested (they must stay 1:1 aligned).
+ _COALESCE_CHARS = 24
+ _COALESCE_SECS = 0.05
+ _coalesce_tokens = not (request.logprobs and int(request.logprobs) > 0)
+ _tok: dict[str, Any] = {"buf": [], "chars": 0, "started": 0.0}
+
+ def _flush_tokens() -> str:
+ if not _tok["buf"]:
+ return ""
+ merged = "".join(_tok["buf"])
+ _tok["buf"] = []
+ _tok["chars"] = 0
+ _tok["started"] = 0.0
+ return f"data: {json.dumps({'token': merged})}\n\n"
+
try:
if enable_tools:
from backend_service.agent import run_agent_loop_streaming
@@ -619,7 +658,20 @@ def _maybe_emit_generating_phase() -> str:
if phase_event:
yield phase_event
full_text += event["token"]
- yield f"data: {json.dumps({'token': event['token']})}\n\n"
+ if _coalesce_tokens:
+ if not _tok["buf"]:
+ _tok["started"] = time.perf_counter()
+ _tok["buf"].append(event["token"])
+ _tok["chars"] += len(event["token"])
+ if (
+ _tok["chars"] >= _COALESCE_CHARS
+ or time.perf_counter() - _tok["started"] >= _COALESCE_SECS
+ ):
+ _f = _flush_tokens()
+ if _f:
+ yield _f
+ else:
+ yield f"data: {json.dumps({'token': event['token']})}\n\n"
if len(full_text) > runaway_char_budget:
runaway_triggered = True
cancelled = True
@@ -628,8 +680,14 @@ def _maybe_emit_generating_phase() -> str:
phase_event = _maybe_emit_generating_phase()
if phase_event:
yield phase_event
+ _f = _flush_tokens()
+ if _f:
+ yield _f
yield f"data: {json.dumps({'toolCallStart': event['tool_call_start']})}\n\n"
elif "tool_call_result" in event:
+ _f = _flush_tokens()
+ if _f:
+ yield _f
agent_tool_calls.append(event["tool_call_result"])
yield f"data: {json.dumps({'toolCallResult': event['tool_call_result']})}\n\n"
elif event.get("done"):
@@ -653,16 +711,35 @@ def _maybe_emit_generating_phase() -> str:
phase_event = _maybe_emit_generating_phase()
if phase_event:
yield phase_event
+ _f = _flush_tokens()
+ if _f:
+ yield _f
full_reasoning += chunk.reasoning
yield f"data: {json.dumps({'reasoning': chunk.reasoning})}\n\n"
if chunk.reasoning_done:
+ _f = _flush_tokens()
+ if _f:
+ yield _f
yield f"data: {json.dumps({'reasoningDone': True})}\n\n"
if chunk.text:
phase_event = _maybe_emit_generating_phase()
if phase_event:
yield phase_event
full_text += chunk.text
- yield f"data: {json.dumps({'token': chunk.text})}\n\n"
+ if _coalesce_tokens:
+ if not _tok["buf"]:
+ _tok["started"] = time.perf_counter()
+ _tok["buf"].append(chunk.text)
+ _tok["chars"] += len(chunk.text)
+ if (
+ _tok["chars"] >= _COALESCE_CHARS
+ or time.perf_counter() - _tok["started"] >= _COALESCE_SECS
+ ):
+ _f = _flush_tokens()
+ if _f:
+ yield _f
+ else:
+ yield f"data: {json.dumps({'token': chunk.text})}\n\n"
# Phase 3.3: forward per-token logprobs when
# the inference layer captured them.
if chunk.token_logprobs:
@@ -730,6 +807,9 @@ def _maybe_emit_generating_phase() -> str:
f"{p_avail:.1f} GB, "
f"pressure={p_pressure:.0f}%.",
)
+ _f = _flush_tokens()
+ if _f:
+ yield _f
yield (
"data: "
+ json.dumps({
@@ -762,6 +842,9 @@ def _maybe_emit_generating_phase() -> str:
"chat", "warning",
f"[{model_tag}] Thermal warning: critical.",
)
+ _f = _flush_tokens()
+ if _f:
+ yield _f
yield (
"data: "
+ json.dumps({
@@ -794,11 +877,20 @@ def _maybe_emit_generating_phase() -> str:
chaosengine.active_requests = max(0, chaosengine.active_requests - 1)
chaosengine.add_log("chat", "error", f"[{model_tag}] Streaming failed: {exc}")
chaosengine.clear_chat_cancel(session_id_for_cancel)
+ _f = _flush_tokens()
+ if _f:
+ yield _f
yield f"data: {json.dumps({'error': str(exc)})}\n\n"
return
finally:
chaosengine.clear_chat_cancel(session_id_for_cancel)
+ # Flush any tokens still buffered by the coalescer before the
+ # terminal done / cancelled events (covers normal end + all breaks).
+ _f = _flush_tokens()
+ if _f:
+ yield _f
+
if cancelled:
yield f"data: {json.dumps({'cancelled': True})}\n\n"
if runaway_loop_reason is not None:
diff --git a/backend_service/state/openai_compat.py b/backend_service/state/openai_compat.py
index b25dedd..a5a3cb0 100644
--- a/backend_service/state/openai_compat.py
+++ b/backend_service/state/openai_compat.py
@@ -236,6 +236,19 @@ def openai_chat_completion(
oai_samplers["seed"] = request.seed
if request.stop is not None:
oai_samplers["stop"] = request.stop if isinstance(request.stop, list) else [request.stop]
+ # Parity with the native chat route's sampler set: min_p, repeat_penalty,
+ # and mirostat were silently dropped on the /v1 path. llama-server takes
+ # these key names natively; the MLX worker consumes min_p + repeat_penalty.
+ if request.min_p is not None:
+ oai_samplers["min_p"] = request.min_p
+ if request.repeat_penalty is not None:
+ oai_samplers["repeat_penalty"] = request.repeat_penalty
+ if request.mirostat is not None:
+ oai_samplers["mirostat"] = request.mirostat
+ if request.mirostat_tau is not None:
+ oai_samplers["mirostat_tau"] = request.mirostat_tau
+ if request.mirostat_eta is not None:
+ oai_samplers["mirostat_eta"] = request.mirostat_eta
# Phase 2.13: pull a JSON schema out of OpenAI's response_format
# envelope so the constrained-decode path lights up. Anything
diff --git a/src/components/SamplerPanel.tsx b/src/components/SamplerPanel.tsx
index 9df721e..7f58c33 100644
--- a/src/components/SamplerPanel.tsx
+++ b/src/components/SamplerPanel.tsx
@@ -194,6 +194,39 @@ export function SamplerPanel({ overrides, onChange, disabled }: SamplerPanelProp
disabled={disabled}
onChange={(v) => patch("repeatPenalty", v)}
/>
+ patch("xtcProbability", v)}
+ />
+ patch("xtcThreshold", v)}
+ />
+ patch("dryMultiplier", v)}
+ />
{
expect(samplerPayload({ topP: 0.9, topK: null, seed: null })).toEqual({ topP: 0.9 });
});
+ it("projects modern samplers (xtc + dry)", () => {
+ expect(
+ samplerPayload({ xtcProbability: 0.5, xtcThreshold: 0.1, dryMultiplier: 0.8 }),
+ ).toEqual({ xtcProbability: 0.5, xtcThreshold: 0.1, dryMultiplier: 0.8 });
+ });
+
+ it("round-trips modern samplers through storage", () => {
+ writeSamplerOverrides("sx", { xtcProbability: 0.5, dryMultiplier: 0.8 });
+ expect(readSamplerOverrides("sx")).toEqual({ xtcProbability: 0.5, dryMultiplier: 0.8 });
+ });
+
it("parses jsonSchemaText into jsonSchema when valid", () => {
const schemaText = '{"type":"object","properties":{"answer":{"type":"string"}}}';
expect(samplerPayload({ jsonSchemaText: schemaText })).toEqual({
diff --git a/src/features/chat/samplerOverrides.ts b/src/features/chat/samplerOverrides.ts
index 4bcf226..07007e1 100644
--- a/src/features/chat/samplerOverrides.ts
+++ b/src/features/chat/samplerOverrides.ts
@@ -21,6 +21,9 @@ const NUMERIC_KEYS = [
"seed",
"mirostatTau",
"mirostatEta",
+ "xtcProbability",
+ "xtcThreshold",
+ "dryMultiplier",
] as const;
function storageKey(sessionId: string): string {
@@ -95,6 +98,9 @@ export function samplerPayload(overrides: SamplerOverrides): Record;
/**
* Phase 3.3: when set, asks llama-server to return top-k logprobs
@@ -255,6 +260,9 @@ export interface SamplerOverrides {
mirostatMode?: 0 | 1 | 2 | null;
mirostatTau?: number | null;
mirostatEta?: number | null;
+ xtcProbability?: number | null;
+ xtcThreshold?: number | null;
+ dryMultiplier?: number | null;
/**
* Phase 2.2: opt-in constrained decoding. Raw JSON-schema text the
* user typed in the SamplerPanel. Parsed at send-time and forwarded
diff --git a/tests/test_backend_service.py b/tests/test_backend_service.py
index ff1c6af..c5b04a1 100644
--- a/tests/test_backend_service.py
+++ b/tests/test_backend_service.py
@@ -1350,6 +1350,30 @@ def test_openai_completion_forwards_sampler_fields(self):
self.assertEqual(runtime_kwargs["samplers"]["stop"], ["END"])
self.assertIn("properties", runtime_kwargs["json_schema"])
+ def test_openai_completion_forwards_extended_samplers(self):
+ # Parity fix: min_p / repeat_penalty / mirostat were dropped on the
+ # /v1 path. They must now reach the runtime sampler dict.
+ response = self.client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "google/gemma-4-E4B-it",
+ "messages": [{"role": "user", "content": "test"}],
+ "max_tokens": 16,
+ "min_p": 0.05,
+ "repeat_penalty": 1.15,
+ "mirostat": 2,
+ "mirostat_tau": 5.0,
+ "mirostat_eta": 0.1,
+ },
+ )
+ self.assertEqual(response.status_code, 200)
+ samplers = self.client.app.state.chaosengine.runtime.last_generate_kwargs["samplers"]
+ self.assertEqual(samplers["min_p"], 0.05)
+ self.assertEqual(samplers["repeat_penalty"], 1.15)
+ self.assertEqual(samplers["mirostat"], 2)
+ self.assertEqual(samplers["mirostat_tau"], 5.0)
+ self.assertEqual(samplers["mirostat_eta"], 0.1)
+
def test_openai_completion_omits_sampler_dict_when_none_set(self):
response = self.client.post(
"/v1/chat/completions",
diff --git a/tests/test_history_with_reasoning.py b/tests/test_history_with_reasoning.py
index 74f8da4..0f45e97 100644
--- a/tests/test_history_with_reasoning.py
+++ b/tests/test_history_with_reasoning.py
@@ -9,6 +9,7 @@
import unittest
from backend_service.state import _build_history_with_reasoning
+from backend_service.state._helpers import _estimate_tokens, _history_token_budget
class BuildHistoryWithReasoningTests(unittest.TestCase):
@@ -65,5 +66,61 @@ def test_preserves_message_order(self):
self.assertIn("R2", history[3]["text"])
+class HistoryTokenWindowTests(unittest.TestCase):
+ def test_token_budget_none_keeps_all(self):
+ messages = [{"role": "user", "text": "x" * 300} for _ in range(6)]
+ history = _build_history_with_reasoning(messages, preserve_reasoning=False, token_budget=None)
+ self.assertEqual(len(history), 6)
+
+ def test_windows_oldest_turns_out(self):
+ # Each 30-char text ~= 11 estimated tokens; budget 25 keeps 2 newest.
+ messages = [
+ {"role": "user", "text": "a" * 30},
+ {"role": "assistant", "text": "b" * 30},
+ {"role": "user", "text": "c" * 30},
+ {"role": "assistant", "text": "d" * 30},
+ ]
+ history = _build_history_with_reasoning(messages, preserve_reasoning=False, token_budget=25)
+ self.assertEqual([h["text"] for h in history], ["c" * 30, "d" * 30])
+
+ def test_always_keeps_latest_turn_even_if_over_budget(self):
+ messages = [{"role": "user", "text": "z" * 300}]
+ history = _build_history_with_reasoning(messages, preserve_reasoning=False, token_budget=10)
+ self.assertEqual(len(history), 1)
+ self.assertEqual(history[0]["text"], "z" * 300)
+
+ def test_system_messages_always_kept(self):
+ messages = [
+ {"role": "system", "text": "s" * 30},
+ {"role": "user", "text": "u" * 300},
+ {"role": "assistant", "text": "a" * 300},
+ {"role": "user", "text": "n" * 9},
+ ]
+ history = _build_history_with_reasoning(messages, preserve_reasoning=False, token_budget=20)
+ roles = [h["role"] for h in history]
+ self.assertIn("system", roles)
+ self.assertEqual(history[-1]["text"], "n" * 9)
+ self.assertNotIn("u" * 300, [h["text"] for h in history])
+
+ def test_estimate_tokens_is_conservative(self):
+ # ~3 chars/token (over-estimates English so the window stays safe).
+ self.assertEqual(_estimate_tokens(""), 1)
+ self.assertEqual(_estimate_tokens("abc"), 2)
+ self.assertEqual(_estimate_tokens("a" * 30), 11)
+
+ def test_history_token_budget_reserves_and_floors(self):
+ budget = _history_token_budget(
+ context_tokens=2000, max_tokens=256, system_prompt="x" * 30, prompt="y" * 30
+ )
+ # 2000 - (11 + 11 + 256 + 512) = 1210
+ self.assertEqual(budget, 1210)
+
+ def test_history_token_budget_floor_512(self):
+ budget = _history_token_budget(
+ context_tokens=100, max_tokens=256, system_prompt=None, prompt=None
+ )
+ self.assertEqual(budget, 512)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_mlx_prompt_cache.py b/tests/test_mlx_prompt_cache.py
new file mode 100644
index 0000000..593e419
--- /dev/null
+++ b/tests/test_mlx_prompt_cache.py
@@ -0,0 +1,180 @@
+"""Tests for the MLX per-session prompt-cache reuse logic (tier 4).
+
+Exercises backend_service/mlx_worker_prompt_cache.py with a fake worker
+state and patched mlx-lm cache primitives — no real model load. The
+correctness contract under test: the persisted token list always equals
+the cache's positional contents, and any uncertainty falls back to a fresh
+full prefill.
+"""
+
+import unittest
+from unittest import mock
+
+from backend_service import mlx_worker_prompt_cache as pc
+
+CACHE_MOD = "mlx_lm.models.cache"
+
+
+class FakeCache:
+ """Sentinel standing in for an mlx-lm prompt cache."""
+
+ def __init__(self, label):
+ self.label = label
+
+
+class FakeState:
+ def __init__(self, *, base_cache=None, base_note=None, tokens=None, model_ref="m"):
+ self._base = (base_cache, base_note)
+ self._tokens = list(tokens or [])
+ self.model = object()
+ self._loaded_model_ref = model_ref
+ self.tokenizer = self
+ self._persist_cache = None
+ self._persist_tokens = []
+ self._persist_cache_model_ref = None
+
+ def _make_cache(self):
+ return self._base
+
+ def encode(self, _text): # stands in for tokenizer.encode
+ return list(self._tokens)
+
+
+class CommonPrefixTests(unittest.TestCase):
+ def test_common_prefix_len(self):
+ self.assertEqual(pc._common_prefix_len([1, 2, 3], [1, 2, 9]), 2)
+ self.assertEqual(pc._common_prefix_len([1, 2], [9]), 0)
+ self.assertEqual(pc._common_prefix_len([1, 2, 3], [1, 2, 3, 4]), 3)
+
+
+class AcquireCompressionTests(unittest.TestCase):
+ def test_compression_strategy_passthrough(self):
+ comp = FakeCache("compression")
+ state = FakeState(base_cache=comp, base_note="cn")
+ acq = pc.acquire(state, "p-text")
+ self.assertIs(acq.cache, comp)
+ self.assertEqual(acq.prompt_feed, "p-text") # string, unchanged
+ self.assertFalse(acq.managed)
+ self.assertIs(acq.fields_cache, comp)
+ self.assertIsNone(acq.commit_tokens)
+
+
+class AcquireNativeTests(unittest.TestCase):
+ def _patches(self, *, can_trim=True, trim=lambda c, n: n, fresh_label="fresh"):
+ return (
+ mock.patch(f"{CACHE_MOD}.make_prompt_cache", return_value=FakeCache(fresh_label)),
+ mock.patch(f"{CACHE_MOD}.can_trim_prompt_cache", return_value=can_trim),
+ mock.patch(f"{CACHE_MOD}.trim_prompt_cache", side_effect=trim),
+ )
+
+ def test_fresh_native_cache_full_prefill(self):
+ state = FakeState(base_cache=None, tokens=[1, 2, 3])
+ with self._patches()[0], self._patches()[1], self._patches()[2]:
+ acq = pc.acquire(state, "ignored")
+ self.assertTrue(acq.managed)
+ self.assertIsInstance(acq.cache, FakeCache)
+ self.assertEqual(acq.prompt_feed, [1, 2, 3]) # full token list
+ self.assertEqual(acq.commit_tokens, [1, 2, 3])
+ self.assertIsNone(acq.fields_cache)
+
+ def test_reuse_hit_feeds_only_suffix_no_trim(self):
+ persist = FakeCache("persist")
+ state = FakeState(base_cache=None, tokens=[1, 2, 3, 4, 5], model_ref="m")
+ state._persist_cache = persist
+ state._persist_tokens = [1, 2, 3]
+ state._persist_cache_model_ref = "m"
+ m1, m2, m3 = self._patches()
+ with m1, m2, m3 as trim:
+ acq = pc.acquire(state, "ignored")
+ self.assertIs(acq.cache, persist) # reused, not fresh
+ self.assertEqual(acq.prompt_feed, [4, 5]) # suffix only
+ self.assertEqual(acq.commit_tokens, [1, 2, 3, 4, 5])
+ trim.assert_not_called() # num_to_trim == 0
+
+ def test_reuse_with_divergence_trims_tail(self):
+ persist = FakeCache("persist")
+ state = FakeState(base_cache=None, tokens=[1, 2, 3, 4], model_ref="m")
+ state._persist_cache = persist
+ state._persist_tokens = [1, 2, 3, 9, 9] # diverges after index 3
+ state._persist_cache_model_ref = "m"
+ m1, m2, m3 = self._patches()
+ with m1, m2, m3 as trim:
+ acq = pc.acquire(state, "ignored")
+ self.assertIs(acq.cache, persist)
+ trim.assert_called_once_with(persist, 2) # 5 cached - 3 common
+ self.assertEqual(acq.prompt_feed, [4]) # full[3:]
+
+ def test_reset_on_model_change(self):
+ state = FakeState(base_cache=None, tokens=[1, 2, 3], model_ref="new")
+ state._persist_cache = FakeCache("stale")
+ state._persist_tokens = [1, 2, 3]
+ state._persist_cache_model_ref = "old"
+ m1, m2, m3 = self._patches()
+ with m1, m2, m3:
+ acq = pc.acquire(state, "ignored")
+ self.assertEqual(acq.prompt_feed, [1, 2, 3]) # fresh → full prefill
+ self.assertEqual(acq.cache.label, "fresh")
+
+ def test_reset_when_cache_not_trimmable(self):
+ state = FakeState(base_cache=None, tokens=[1, 2, 3, 4], model_ref="m")
+ state._persist_cache = FakeCache("persist")
+ state._persist_tokens = [1, 2, 3]
+ state._persist_cache_model_ref = "m"
+ m1, m2, m3 = self._patches(can_trim=False)
+ with m1, m2, m3:
+ acq = pc.acquire(state, "ignored")
+ self.assertEqual(acq.cache.label, "fresh")
+ self.assertEqual(acq.prompt_feed, [1, 2, 3, 4])
+
+ def test_reset_when_no_common_prefix(self):
+ state = FakeState(base_cache=None, tokens=[7, 8, 9], model_ref="m")
+ state._persist_cache = FakeCache("persist")
+ state._persist_tokens = [1, 2, 3]
+ state._persist_cache_model_ref = "m"
+ m1, m2, m3 = self._patches()
+ with m1, m2, m3:
+ acq = pc.acquire(state, "ignored")
+ self.assertEqual(acq.cache.label, "fresh")
+ self.assertEqual(acq.prompt_feed, [7, 8, 9])
+
+ def test_partial_trim_falls_back_to_fresh(self):
+ state = FakeState(base_cache=None, tokens=[1, 2, 3, 4], model_ref="m")
+ state._persist_cache = FakeCache("persist")
+ state._persist_tokens = [1, 2, 3, 9, 9]
+ state._persist_cache_model_ref = "m"
+ # trim returns fewer than requested → unsafe → fresh
+ m1, m2, m3 = self._patches(trim=lambda c, n: n - 1)
+ with m1, m2, m3:
+ acq = pc.acquire(state, "ignored")
+ self.assertEqual(acq.cache.label, "fresh")
+ self.assertEqual(acq.prompt_feed, [1, 2, 3, 4])
+
+
+class CommitInvalidateTests(unittest.TestCase):
+ def test_commit_accounting_is_prompt_plus_generated(self):
+ state = FakeState()
+ cache = FakeCache("c")
+ pc.commit(state, cache=cache, commit_tokens=[1, 2, 3], generated_ids=[4, 5], model_ref="m")
+ self.assertIs(state._persist_cache, cache)
+ self.assertEqual(state._persist_tokens, [1, 2, 3, 4, 5])
+ self.assertEqual(state._persist_cache_model_ref, "m")
+
+ def test_commit_noop_when_not_managed(self):
+ state = FakeState()
+ pc.commit(state, cache=None, commit_tokens=None, generated_ids=[4], model_ref="m")
+ self.assertIsNone(state._persist_cache)
+ self.assertEqual(state._persist_tokens, [])
+
+ def test_invalidate_clears(self):
+ state = FakeState()
+ state._persist_cache = FakeCache("c")
+ state._persist_tokens = [1, 2]
+ state._persist_cache_model_ref = "m"
+ pc.invalidate(state)
+ self.assertIsNone(state._persist_cache)
+ self.assertEqual(state._persist_tokens, [])
+ self.assertIsNone(state._persist_cache_model_ref)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_mlx_worker.py b/tests/test_mlx_worker.py
index 7212104..d1f79d0 100644
--- a/tests/test_mlx_worker.py
+++ b/tests/test_mlx_worker.py
@@ -875,5 +875,41 @@ def test_unload_clears_multimodal_state(self):
self.assertFalse(worker.is_multimodal)
+class MlxLogitsProcessorTests(unittest.TestCase):
+ """_build_mlx_logits_processors wires repeat_penalty (mlx-lm applies it
+ via logits_processors, not the sampler — it was being dropped)."""
+
+ def setUp(self):
+ from backend_service.mlx_worker_request import _build_mlx_logits_processors
+
+ self._build = _build_mlx_logits_processors
+
+ def test_none_when_no_samplers(self):
+ self.assertIsNone(self._build({}))
+ self.assertIsNone(self._build({"samplers": None}))
+
+ def test_none_when_penalty_absent_or_neutral(self):
+ self.assertIsNone(self._build({"samplers": {"top_p": 0.9}}))
+ self.assertIsNone(self._build({"samplers": {"repeat_penalty": 1.0}}))
+
+ def test_none_when_penalty_non_numeric(self):
+ self.assertIsNone(self._build({"samplers": {"repeat_penalty": "oops"}}))
+
+ @unittest.skipUnless(
+ __import__("importlib").util.find_spec("mlx_lm") is not None,
+ "mlx-lm not installed",
+ )
+ def test_builds_processors_for_real_penalty(self):
+ result = self._build({"samplers": {"repeat_penalty": 1.3}})
+ self.assertIsNotNone(result)
+ self.assertTrue(len(result) >= 1)
+
+ def test_accepts_repetition_penalty_alias_without_raising(self):
+ try:
+ self._build({"samplers": {"repetition_penalty": 1.2}})
+ except Exception as exc: # noqa: BLE001
+ self.fail(f"alias parse raised: {exc}")
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_sampler_payload.py b/tests/test_sampler_payload.py
index 4f63b15..a79f3bd 100644
--- a/tests/test_sampler_payload.py
+++ b/tests/test_sampler_payload.py
@@ -55,6 +55,30 @@ def test_merges_all_supported_sampler_keys(self):
self.assertEqual(payload["mirostat_tau"], 5.0)
self.assertEqual(payload["mirostat_eta"], 0.1)
+ def test_merges_modern_quality_samplers(self):
+ # DRY / XTC / top-n-sigma were added to _LLAMA_SAMPLER_KEYS; they
+ # must now flow through to the llama-server payload.
+ payload: dict = {}
+ _apply_sampler_kwargs(
+ payload,
+ samplers={
+ "dry_multiplier": 0.8,
+ "dry_base": 1.75,
+ "dry_allowed_length": 2,
+ "xtc_probability": 0.5,
+ "xtc_threshold": 0.1,
+ "top_n_sigma": 1.0,
+ },
+ reasoning_effort=None,
+ json_schema=None,
+ )
+ self.assertEqual(payload["dry_multiplier"], 0.8)
+ self.assertEqual(payload["dry_base"], 1.75)
+ self.assertEqual(payload["dry_allowed_length"], 2)
+ self.assertEqual(payload["xtc_probability"], 0.5)
+ self.assertEqual(payload["xtc_threshold"], 0.1)
+ self.assertEqual(payload["top_n_sigma"], 1.0)
+
def test_none_values_in_samplers_skip_merge(self):
# The frontend may send the union of fields with most set to null —
# explicit nulls must not override server defaults.
@@ -131,6 +155,19 @@ def test_emits_llama_field_names(self):
self.assertEqual(overrides["mirostat_tau"], 5.0)
self.assertEqual(overrides["mirostat_eta"], 0.1)
+ def test_emits_modern_sampler_field_names(self):
+ # XTC + DRY map to llama/mlx engine-side snake_case keys.
+ request = SimpleNamespace(
+ xtcProbability=0.5, xtcThreshold=0.1,
+ dryMultiplier=0.8, dryBase=1.75, dryAllowedLength=2,
+ )
+ overrides = _build_sampler_overrides(request)
+ self.assertEqual(overrides["xtc_probability"], 0.5)
+ self.assertEqual(overrides["xtc_threshold"], 0.1)
+ self.assertEqual(overrides["dry_multiplier"], 0.8)
+ self.assertEqual(overrides["dry_base"], 1.75)
+ self.assertEqual(overrides["dry_allowed_length"], 2)
+
def test_partial_override_keeps_only_set_fields(self):
request = SimpleNamespace(
topP=0.9, topK=None, minP=None, repeatPenalty=None,