Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion backend_service/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,11 @@ def run_agent_loop_streaming(
# consumed so the assistant bubble doesn't show raw call
# JSON next to the rendered ToolCallCard (FU-040).
text = _strip_tool_call_xml(result.text)
chunk_size = 4
# The final answer is already fully computed (tool-calling turns
# are non-streaming), so the old 4-char dribble just added fake
# latency + yields. Emit in larger chunks; the SSE layer coalesces
# these further and the user sees the answer near-instantly.
chunk_size = 48
for i in range(0, len(text), chunk_size):
yield {"token": text[i:i + chunk_size]}

Expand Down
32 changes: 32 additions & 0 deletions backend_service/inference/llama_cpp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,17 @@
"frequency_penalty",
"presence_penalty",
"stop",
# Modern anti-repetition / quality samplers llama-server supports
# natively. Forward-only: builds that don't recognise them ignore the
# field, so old binaries are unaffected. DRY beats plain repeat_penalty
# at killing verbatim loops; XTC adds creative variety; top-n-sigma is
# a temperature-stable truncator.
"dry_multiplier",
"dry_base",
"dry_allowed_length",
"xtc_probability",
"xtc_threshold",
"top_n_sigma",
# Phase 3.3: per-token confidence info. llama-server returns
# top-k alternatives with their logprobs in each delta when
# `logprobs: true` + `top_logprobs: N` are set.
Expand Down Expand Up @@ -421,6 +432,7 @@ def _build_command(
fit_enabled: bool,
is_fallback: bool,
speculative_decoding: bool = False,
fused_attention: bool = False,
canonical_repo: str | None = None,
model_ref: str = "",
) -> tuple[list[str], str | None, bool, str | None]:
Expand Down Expand Up @@ -449,6 +461,19 @@ def _build_command(
str(max(256, context_tokens)),
"--jinja",
]
# Reuse the single slot's KV cache across chat turns: a growing
# conversation re-prefills only the new suffix instead of the whole
# history (turn-2+ TTFT drops sharply on long chats). Forward-gated
# on binary support so older llama-server builds are unaffected.
if _llama_server_supports(binary, "--cache-reuse"):
command.extend(["--cache-reuse", "256"])
# Honour the user's fused-attention toggle. It was plumbed into
# load_model + stored on LoadedModelInfo but never emitted as a
# flag. Flash attention is a large decode + KV-memory win on Metal
# and is required by the quantized KV cache types. Opt-in via the
# existing flag so a model/quant combo that dislikes it can disable.
if fused_attention and _llama_server_supports(binary, "--flash-attn"):
command.extend(["--flash-attn", "on"])
if _llama_server_supports(binary, "--reasoning-format"):
command.extend(["--reasoning-format", "deepseek"])
if _llama_server_supports(binary, "--reasoning"):
Expand Down Expand Up @@ -660,6 +685,7 @@ def load_model(
fit_enabled=fit_enabled,
is_fallback=is_fallback,
speculative_decoding=speculative_decoding,
fused_attention=fused_attention,
canonical_repo=canonical_repo,
model_ref=model_ref,
)
Expand Down Expand Up @@ -791,6 +817,9 @@ def generate(
"temperature": temperature,
"max_tokens": max_tokens,
"stream": False,
# Reuse the slot's cached prompt prefix across turns (pairs with
# the server's --cache-reuse) so unchanged history isn't reprocessed.
"cache_prompt": True,
}
if tools:
payload["tools"] = tools
Expand Down Expand Up @@ -884,6 +913,9 @@ def stream_generate(
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True,
# Reuse the slot's cached prompt prefix across turns (pairs with
# the server's --cache-reuse) so unchanged history isn't reprocessed.
"cache_prompt": True,
}
if tools:
payload["tools"] = tools
Expand Down
11 changes: 11 additions & 0 deletions backend_service/mlx_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from backend_service import mlx_worker_lifecycle as _lifecycle
from backend_service import mlx_worker_speculative as _speculative
from backend_service import mlx_worker_generate as _generate
from backend_service import mlx_worker_prompt_cache as _prompt_cache

# Phase 1f-4: model + runtime introspection helpers now live in
# ``backend_service.mlx_worker_diagnostics``. Re-export so existing imports
Expand Down Expand Up @@ -127,6 +128,13 @@ def __init__(self) -> None:
# delimiters via ``reasoning_delimiters_for``. Default
# (``<think>...</think>``) still applies when ``None``.
self._loaded_model_ref: str | None = None
# Tier 4: persistent single-slot prompt cache for native-strategy chat
# so follow-up turns prefill only the new suffix. Managed by
# backend_service.mlx_worker_prompt_cache; invalidated on any model
# load / unload / profile change.
self._persist_cache: Any | None = None
self._persist_tokens: list[int] = []
self._persist_cache_model_ref: str | None = None

def handle(self, request: dict[str, Any]) -> dict[str, Any] | None:
op = request.get("op")
Expand All @@ -148,12 +156,15 @@ def handle(self, request: dict[str, Any]) -> dict[str, Any] | None:
raise ValueError(f"Unsupported worker operation: {op}")

def load_model(self, request: dict[str, Any]) -> dict[str, Any]:
_prompt_cache.invalidate(self)
return _lifecycle.load_model(self, request)

def unload_model(self) -> dict[str, Any]:
_prompt_cache.invalidate(self)
return _lifecycle.unload_model(self)

def update_profile(self, request: dict[str, Any]) -> dict[str, Any]:
_prompt_cache.invalidate(self)
return _lifecycle.update_profile(self, request)

def _apply_cache_profile(
Expand Down
62 changes: 52 additions & 10 deletions backend_service/mlx_worker_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from backend_service.mlx_worker_request import (
_apply_mlx_seed,
_build_mlx_logits_processors,
_build_mlx_sampler,
_extract_top_logprobs,
_format_tools_for_prompt,
Expand All @@ -46,6 +47,7 @@
strip_harmony_boilerplate,
)
from backend_service.runaway_guard import RunawayGuard
from backend_service import mlx_worker_prompt_cache as _prompt_cache


if TYPE_CHECKING:
Expand Down Expand Up @@ -109,24 +111,32 @@ def generate_standard(state: WorkerState, request: dict[str, Any]) -> dict[str,
system_prompt=system_prompt,
)
sampler = _build_mlx_sampler(request)
prompt_cache, runtime_note = state._make_cache()
runtime_note = _merge_runtime_notes(runtime_note, prompt_note)
runtime_fields = state._runtime_fields(prompt_cache=prompt_cache)
acq = _prompt_cache.acquire(state, prompt_text)
prompt_cache = acq.cache
prompt_feed = acq.prompt_feed
managed = acq.managed
runtime_note = _merge_runtime_notes(acq.note, prompt_note)
runtime_fields = state._runtime_fields(prompt_cache=acq.fields_cache)
transcript_fallback = _plain_chat_fallback_active(prompt_note)

runaway_guard = RunawayGuard()
runaway_stopped = False
generated_ids: list[int] = []
try:
text_parts: list[str] = []
last_response = None
for response in stream_generate(
state.model,
state.tokenizer,
prompt_text,
prompt_feed,
max_tokens=int(request.get("maxTokens") or 256),
sampler=sampler,
logits_processors=_build_mlx_logits_processors(request),
prompt_cache=prompt_cache,
):
_tok = getattr(response, "token", None)
if isinstance(_tok, int):
generated_ids.append(_tok)
if response.text:
text_parts.append(response.text)
try:
Expand All @@ -135,8 +145,20 @@ def generate_standard(state: WorkerState, request: dict[str, Any]) -> dict[str,
runaway_stopped = True
break
last_response = response
if managed:
_prompt_cache.commit(
state,
cache=prompt_cache,
commit_tokens=acq.commit_tokens,
generated_ids=generated_ids,
model_ref=state._loaded_model_ref,
)
except (ValueError, RuntimeError, TypeError, AttributeError) as exc:
_should_retry = (
was_managed = managed
if managed:
_prompt_cache.invalidate(state)
managed = False
_should_retry = was_managed or (
prompt_cache is not None
and _should_retry_cache_failure(exc)
)
Expand Down Expand Up @@ -319,10 +341,13 @@ def stream_generate(state: WorkerState, request: dict[str, Any]) -> None:
system_prompt=system_prompt,
)
sampler = _build_mlx_sampler(request)
prompt_cache, runtime_note = state._make_cache()
runtime_note = _merge_runtime_notes(runtime_note, prompt_note)
acq = _prompt_cache.acquire(state, prompt_text)
prompt_cache = acq.cache
prompt_feed = acq.prompt_feed
managed = acq.managed
runtime_note = _merge_runtime_notes(acq.note, prompt_note)
runtime_note = _merge_runtime_notes(runtime_note, speculative_stream_fallback_note)
runtime_fields = state._runtime_fields(prompt_cache=prompt_cache)
runtime_fields = state._runtime_fields(prompt_cache=acq.fields_cache)
transcript_fallback = _plain_chat_fallback_active(prompt_note)

thinking_mode = request.get("thinkingMode") or "off"
Expand All @@ -336,6 +361,7 @@ def stream_generate(state: WorkerState, request: dict[str, Any]) -> None:
transcript_trimmed = False
runaway_guard = RunawayGuard()
runaway_stopped = False
generated_ids: list[int] = []
# Phase 3.3 follow-up: when the request opted into logprobs,
# extract top-k per token via the helper and forward inline
# with each text chunk.
Expand All @@ -346,11 +372,15 @@ def stream_generate(state: WorkerState, request: dict[str, Any]) -> None:
for response in mlx_stream_generate(
state.model,
state.tokenizer,
prompt_text,
prompt_feed,
max_tokens=int(request.get("maxTokens") or 256),
sampler=sampler,
logits_processors=_build_mlx_logits_processors(request),
prompt_cache=prompt_cache,
):
_tok = getattr(response, "token", None)
if isinstance(_tok, int):
generated_ids.append(_tok)
if response.text:
# Check for runaway loops before emitting
try:
Expand Down Expand Up @@ -392,8 +422,20 @@ def stream_generate(state: WorkerState, request: dict[str, Any]) -> None:
transcript_trimmed = transcript_trimmed or transcript_filter.stopped
if visible_text:
_emit({"ok": True, "chunk": {"text": visible_text}})
if managed:
_prompt_cache.commit(
state,
cache=prompt_cache,
commit_tokens=acq.commit_tokens,
generated_ids=generated_ids,
model_ref=state._loaded_model_ref,
)
except (ValueError, RuntimeError, TypeError, AttributeError) as exc:
_should_retry = (
was_managed = managed
if managed:
_prompt_cache.invalidate(state)
managed = False
_should_retry = was_managed or (
prompt_cache is not None
and _should_retry_cache_failure(exc)
)
Expand Down
122 changes: 122 additions & 0 deletions backend_service/mlx_worker_prompt_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Per-session MLX prompt-cache reuse (tier 4 of the chat-LLM review).

Native-strategy chat turns re-prefill the *entire* conversation every time
(`prompt_cache=None` → mlx-lm builds a fresh cache + processes the whole
prompt). This module keeps one persistent mlx-lm prompt cache on the
worker and reuses the longest matching token prefix across turns: trim the
divergent tail off the cache, prefill only the new suffix, then re-commit
the cache keyed by ``prompt_tokens + generated_tokens``. A single-slot port
of mlx-lm's server reuse logic (``LRUPromptCache.fetch_nearest_cache``).

Correctness invariant: the persisted token list ALWAYS equals the cache's
positional contents (prompt + generated), so the next turn's common-prefix
trim is exact. Any uncertainty — compression strategy active, model
changed, cache not trimmable (SSM/Mamba/rotating-full, mlx-lm #980),
tokenisation failure, no common prefix, partial trim — falls back to a
fresh full prefill, i.e. identical output to the pre-cache path, just
without the speedup. Gated to the ``native`` strategy; compression caches
(turboquant / triattention) keep their existing per-call path untouched.
"""

from __future__ import annotations

from collections import namedtuple
from typing import Any

# cache: object passed to stream_generate as prompt_cache
# prompt_feed: what to pass as the `prompt` arg (suffix token list on a
# reuse hit, full token list on a fresh native cache, or the
# original prompt_text string for the compression / fallback path)
# note: runtime note from _make_cache (compression fallback msgs)
# commit_tokens: full prompt token list to re-key after generation (None when
# not managing a native cache)
# fields_cache: value to feed _runtime_fields (None for native, the
# compression cache otherwise) so the strategy badge stays right
# managed: True only when we own a native persistent cache to commit
Acquired = namedtuple(
"Acquired", "cache prompt_feed note commit_tokens fields_cache managed"
)


def _common_prefix_len(a: list[int], b: list[int]) -> int:
n = 0
for x, y in zip(a, b):
if x != y:
break
n += 1
return n


def _native_result(cache: Any | None, full_tokens: list[int], prompt_text: str, note: str | None) -> Acquired:
"""Wrap a fresh-native-cache outcome (or a give-up fallback)."""
if cache is not None:
return Acquired(cache, full_tokens, note, full_tokens, None, True)
# Couldn't build a managed cache → behave exactly like before.
return Acquired(None, prompt_text, note, None, None, False)


def acquire(state: Any, prompt_text: str) -> Acquired:
base_cache, note = state._make_cache()
if base_cache is not None:
# Compression strategy: unchanged behaviour, no persistence.
return Acquired(base_cache, prompt_text, note, None, base_cache, False)

# Native strategy — manage a persistent single-slot cache.
try:
from mlx_lm.models.cache import ( # noqa: PLC0415
can_trim_prompt_cache,
make_prompt_cache,
trim_prompt_cache,
)

full_tokens = list(state.tokenizer.encode(prompt_text))
except Exception: # noqa: BLE001 — any failure → safe full-reprocess fallback
return Acquired(None, prompt_text, note, None, None, False)

def _fresh() -> Any | None:
try:
return make_prompt_cache(state.model)
except Exception: # noqa: BLE001
return None

model_ref = getattr(state, "_loaded_model_ref", None)
persist = getattr(state, "_persist_cache", None)
persist_tokens = getattr(state, "_persist_tokens", None) or []
persist_ref = getattr(state, "_persist_cache_model_ref", None)

# Reset conditions: nothing cached, different model, empty history.
if persist is None or persist_ref != model_ref or not persist_tokens:
return _native_result(_fresh(), full_tokens, prompt_text, note)

try:
if not can_trim_prompt_cache(persist):
return _native_result(_fresh(), full_tokens, prompt_text, note)
# Always leave >=1 token to process live (mlx-lm does the same).
common = min(_common_prefix_len(persist_tokens, full_tokens), len(full_tokens) - 1)
if common <= 0:
return _native_result(_fresh(), full_tokens, prompt_text, note)
num_to_trim = len(persist_tokens) - common
if num_to_trim > 0:
trimmed = trim_prompt_cache(persist, num_to_trim)
if trimmed != num_to_trim:
# Couldn't roll back cleanly — don't risk a spliced mismatch.
return _native_result(_fresh(), full_tokens, prompt_text, note)
# Reuse hit: cache now holds exactly the common prefix; prefill suffix.
return Acquired(persist, full_tokens[common:], note, full_tokens, None, True)
except Exception: # noqa: BLE001
return _native_result(_fresh(), full_tokens, prompt_text, note)


def commit(state: Any, *, cache: Any, commit_tokens: list[int] | None, generated_ids: list[int], model_ref: str | None) -> None:
"""Persist the cache keyed by prompt + generated tokens (positional truth)."""
if cache is None or commit_tokens is None:
return
state._persist_cache = cache
state._persist_tokens = list(commit_tokens) + [t for t in generated_ids if isinstance(t, int)]
state._persist_cache_model_ref = model_ref


def invalidate(state: Any) -> None:
state._persist_cache = None
state._persist_tokens = []
state._persist_cache_model_ref = None
Loading