diff --git a/backend_service/agent.py b/backend_service/agent.py index 277380e..8050384 100644 --- a/backend_service/agent.py +++ b/backend_service/agent.py @@ -485,7 +485,11 @@ def run_agent_loop_streaming( # consumed so the assistant bubble doesn't show raw call # JSON next to the rendered ToolCallCard (FU-040). text = _strip_tool_call_xml(result.text) - chunk_size = 4 + # The final answer is already fully computed (tool-calling turns + # are non-streaming), so the old 4-char dribble just added fake + # latency + yields. Emit in larger chunks; the SSE layer coalesces + # these further and the user sees the answer near-instantly. + chunk_size = 48 for i in range(0, len(text), chunk_size): yield {"token": text[i:i + chunk_size]} diff --git a/backend_service/inference/llama_cpp_engine.py b/backend_service/inference/llama_cpp_engine.py index d62af35..3d9e884 100644 --- a/backend_service/inference/llama_cpp_engine.py +++ b/backend_service/inference/llama_cpp_engine.py @@ -92,6 +92,17 @@ "frequency_penalty", "presence_penalty", "stop", + # Modern anti-repetition / quality samplers llama-server supports + # natively. Forward-only: builds that don't recognise them ignore the + # field, so old binaries are unaffected. DRY beats plain repeat_penalty + # at killing verbatim loops; XTC adds creative variety; top-n-sigma is + # a temperature-stable truncator. + "dry_multiplier", + "dry_base", + "dry_allowed_length", + "xtc_probability", + "xtc_threshold", + "top_n_sigma", # Phase 3.3: per-token confidence info. llama-server returns # top-k alternatives with their logprobs in each delta when # `logprobs: true` + `top_logprobs: N` are set. @@ -421,6 +432,7 @@ def _build_command( fit_enabled: bool, is_fallback: bool, speculative_decoding: bool = False, + fused_attention: bool = False, canonical_repo: str | None = None, model_ref: str = "", ) -> tuple[list[str], str | None, bool, str | None]: @@ -449,6 +461,19 @@ def _build_command( str(max(256, context_tokens)), "--jinja", ] + # Reuse the single slot's KV cache across chat turns: a growing + # conversation re-prefills only the new suffix instead of the whole + # history (turn-2+ TTFT drops sharply on long chats). Forward-gated + # on binary support so older llama-server builds are unaffected. + if _llama_server_supports(binary, "--cache-reuse"): + command.extend(["--cache-reuse", "256"]) + # Honour the user's fused-attention toggle. It was plumbed into + # load_model + stored on LoadedModelInfo but never emitted as a + # flag. Flash attention is a large decode + KV-memory win on Metal + # and is required by the quantized KV cache types. Opt-in via the + # existing flag so a model/quant combo that dislikes it can disable. + if fused_attention and _llama_server_supports(binary, "--flash-attn"): + command.extend(["--flash-attn", "on"]) if _llama_server_supports(binary, "--reasoning-format"): command.extend(["--reasoning-format", "deepseek"]) if _llama_server_supports(binary, "--reasoning"): @@ -660,6 +685,7 @@ def load_model( fit_enabled=fit_enabled, is_fallback=is_fallback, speculative_decoding=speculative_decoding, + fused_attention=fused_attention, canonical_repo=canonical_repo, model_ref=model_ref, ) @@ -791,6 +817,9 @@ def generate( "temperature": temperature, "max_tokens": max_tokens, "stream": False, + # Reuse the slot's cached prompt prefix across turns (pairs with + # the server's --cache-reuse) so unchanged history isn't reprocessed. + "cache_prompt": True, } if tools: payload["tools"] = tools @@ -884,6 +913,9 @@ def stream_generate( "temperature": temperature, "max_tokens": max_tokens, "stream": True, + # Reuse the slot's cached prompt prefix across turns (pairs with + # the server's --cache-reuse) so unchanged history isn't reprocessed. + "cache_prompt": True, } if tools: payload["tools"] = tools diff --git a/backend_service/mlx_worker.py b/backend_service/mlx_worker.py index c7a0e52..f3acfc6 100644 --- a/backend_service/mlx_worker.py +++ b/backend_service/mlx_worker.py @@ -59,6 +59,7 @@ from backend_service import mlx_worker_lifecycle as _lifecycle from backend_service import mlx_worker_speculative as _speculative from backend_service import mlx_worker_generate as _generate +from backend_service import mlx_worker_prompt_cache as _prompt_cache # Phase 1f-4: model + runtime introspection helpers now live in # ``backend_service.mlx_worker_diagnostics``. Re-export so existing imports @@ -127,6 +128,13 @@ def __init__(self) -> None: # delimiters via ``reasoning_delimiters_for``. Default # (``...``) still applies when ``None``. self._loaded_model_ref: str | None = None + # Tier 4: persistent single-slot prompt cache for native-strategy chat + # so follow-up turns prefill only the new suffix. Managed by + # backend_service.mlx_worker_prompt_cache; invalidated on any model + # load / unload / profile change. + self._persist_cache: Any | None = None + self._persist_tokens: list[int] = [] + self._persist_cache_model_ref: str | None = None def handle(self, request: dict[str, Any]) -> dict[str, Any] | None: op = request.get("op") @@ -148,12 +156,15 @@ def handle(self, request: dict[str, Any]) -> dict[str, Any] | None: raise ValueError(f"Unsupported worker operation: {op}") def load_model(self, request: dict[str, Any]) -> dict[str, Any]: + _prompt_cache.invalidate(self) return _lifecycle.load_model(self, request) def unload_model(self) -> dict[str, Any]: + _prompt_cache.invalidate(self) return _lifecycle.unload_model(self) def update_profile(self, request: dict[str, Any]) -> dict[str, Any]: + _prompt_cache.invalidate(self) return _lifecycle.update_profile(self, request) def _apply_cache_profile( diff --git a/backend_service/mlx_worker_generate.py b/backend_service/mlx_worker_generate.py index 7157631..2d7a65d 100644 --- a/backend_service/mlx_worker_generate.py +++ b/backend_service/mlx_worker_generate.py @@ -34,6 +34,7 @@ ) from backend_service.mlx_worker_request import ( _apply_mlx_seed, + _build_mlx_logits_processors, _build_mlx_sampler, _extract_top_logprobs, _format_tools_for_prompt, @@ -46,6 +47,7 @@ strip_harmony_boilerplate, ) from backend_service.runaway_guard import RunawayGuard +from backend_service import mlx_worker_prompt_cache as _prompt_cache if TYPE_CHECKING: @@ -109,24 +111,32 @@ def generate_standard(state: WorkerState, request: dict[str, Any]) -> dict[str, system_prompt=system_prompt, ) sampler = _build_mlx_sampler(request) - prompt_cache, runtime_note = state._make_cache() - runtime_note = _merge_runtime_notes(runtime_note, prompt_note) - runtime_fields = state._runtime_fields(prompt_cache=prompt_cache) + acq = _prompt_cache.acquire(state, prompt_text) + prompt_cache = acq.cache + prompt_feed = acq.prompt_feed + managed = acq.managed + runtime_note = _merge_runtime_notes(acq.note, prompt_note) + runtime_fields = state._runtime_fields(prompt_cache=acq.fields_cache) transcript_fallback = _plain_chat_fallback_active(prompt_note) runaway_guard = RunawayGuard() runaway_stopped = False + generated_ids: list[int] = [] try: text_parts: list[str] = [] last_response = None for response in stream_generate( state.model, state.tokenizer, - prompt_text, + prompt_feed, max_tokens=int(request.get("maxTokens") or 256), sampler=sampler, + logits_processors=_build_mlx_logits_processors(request), prompt_cache=prompt_cache, ): + _tok = getattr(response, "token", None) + if isinstance(_tok, int): + generated_ids.append(_tok) if response.text: text_parts.append(response.text) try: @@ -135,8 +145,20 @@ def generate_standard(state: WorkerState, request: dict[str, Any]) -> dict[str, runaway_stopped = True break last_response = response + if managed: + _prompt_cache.commit( + state, + cache=prompt_cache, + commit_tokens=acq.commit_tokens, + generated_ids=generated_ids, + model_ref=state._loaded_model_ref, + ) except (ValueError, RuntimeError, TypeError, AttributeError) as exc: - _should_retry = ( + was_managed = managed + if managed: + _prompt_cache.invalidate(state) + managed = False + _should_retry = was_managed or ( prompt_cache is not None and _should_retry_cache_failure(exc) ) @@ -319,10 +341,13 @@ def stream_generate(state: WorkerState, request: dict[str, Any]) -> None: system_prompt=system_prompt, ) sampler = _build_mlx_sampler(request) - prompt_cache, runtime_note = state._make_cache() - runtime_note = _merge_runtime_notes(runtime_note, prompt_note) + acq = _prompt_cache.acquire(state, prompt_text) + prompt_cache = acq.cache + prompt_feed = acq.prompt_feed + managed = acq.managed + runtime_note = _merge_runtime_notes(acq.note, prompt_note) runtime_note = _merge_runtime_notes(runtime_note, speculative_stream_fallback_note) - runtime_fields = state._runtime_fields(prompt_cache=prompt_cache) + runtime_fields = state._runtime_fields(prompt_cache=acq.fields_cache) transcript_fallback = _plain_chat_fallback_active(prompt_note) thinking_mode = request.get("thinkingMode") or "off" @@ -336,6 +361,7 @@ def stream_generate(state: WorkerState, request: dict[str, Any]) -> None: transcript_trimmed = False runaway_guard = RunawayGuard() runaway_stopped = False + generated_ids: list[int] = [] # Phase 3.3 follow-up: when the request opted into logprobs, # extract top-k per token via the helper and forward inline # with each text chunk. @@ -346,11 +372,15 @@ def stream_generate(state: WorkerState, request: dict[str, Any]) -> None: for response in mlx_stream_generate( state.model, state.tokenizer, - prompt_text, + prompt_feed, max_tokens=int(request.get("maxTokens") or 256), sampler=sampler, + logits_processors=_build_mlx_logits_processors(request), prompt_cache=prompt_cache, ): + _tok = getattr(response, "token", None) + if isinstance(_tok, int): + generated_ids.append(_tok) if response.text: # Check for runaway loops before emitting try: @@ -392,8 +422,20 @@ def stream_generate(state: WorkerState, request: dict[str, Any]) -> None: transcript_trimmed = transcript_trimmed or transcript_filter.stopped if visible_text: _emit({"ok": True, "chunk": {"text": visible_text}}) + if managed: + _prompt_cache.commit( + state, + cache=prompt_cache, + commit_tokens=acq.commit_tokens, + generated_ids=generated_ids, + model_ref=state._loaded_model_ref, + ) except (ValueError, RuntimeError, TypeError, AttributeError) as exc: - _should_retry = ( + was_managed = managed + if managed: + _prompt_cache.invalidate(state) + managed = False + _should_retry = was_managed or ( prompt_cache is not None and _should_retry_cache_failure(exc) ) diff --git a/backend_service/mlx_worker_prompt_cache.py b/backend_service/mlx_worker_prompt_cache.py new file mode 100644 index 0000000..4ccfbea --- /dev/null +++ b/backend_service/mlx_worker_prompt_cache.py @@ -0,0 +1,122 @@ +"""Per-session MLX prompt-cache reuse (tier 4 of the chat-LLM review). + +Native-strategy chat turns re-prefill the *entire* conversation every time +(`prompt_cache=None` → mlx-lm builds a fresh cache + processes the whole +prompt). This module keeps one persistent mlx-lm prompt cache on the +worker and reuses the longest matching token prefix across turns: trim the +divergent tail off the cache, prefill only the new suffix, then re-commit +the cache keyed by ``prompt_tokens + generated_tokens``. A single-slot port +of mlx-lm's server reuse logic (``LRUPromptCache.fetch_nearest_cache``). + +Correctness invariant: the persisted token list ALWAYS equals the cache's +positional contents (prompt + generated), so the next turn's common-prefix +trim is exact. Any uncertainty — compression strategy active, model +changed, cache not trimmable (SSM/Mamba/rotating-full, mlx-lm #980), +tokenisation failure, no common prefix, partial trim — falls back to a +fresh full prefill, i.e. identical output to the pre-cache path, just +without the speedup. Gated to the ``native`` strategy; compression caches +(turboquant / triattention) keep their existing per-call path untouched. +""" + +from __future__ import annotations + +from collections import namedtuple +from typing import Any + +# cache: object passed to stream_generate as prompt_cache +# prompt_feed: what to pass as the `prompt` arg (suffix token list on a +# reuse hit, full token list on a fresh native cache, or the +# original prompt_text string for the compression / fallback path) +# note: runtime note from _make_cache (compression fallback msgs) +# commit_tokens: full prompt token list to re-key after generation (None when +# not managing a native cache) +# fields_cache: value to feed _runtime_fields (None for native, the +# compression cache otherwise) so the strategy badge stays right +# managed: True only when we own a native persistent cache to commit +Acquired = namedtuple( + "Acquired", "cache prompt_feed note commit_tokens fields_cache managed" +) + + +def _common_prefix_len(a: list[int], b: list[int]) -> int: + n = 0 + for x, y in zip(a, b): + if x != y: + break + n += 1 + return n + + +def _native_result(cache: Any | None, full_tokens: list[int], prompt_text: str, note: str | None) -> Acquired: + """Wrap a fresh-native-cache outcome (or a give-up fallback).""" + if cache is not None: + return Acquired(cache, full_tokens, note, full_tokens, None, True) + # Couldn't build a managed cache → behave exactly like before. + return Acquired(None, prompt_text, note, None, None, False) + + +def acquire(state: Any, prompt_text: str) -> Acquired: + base_cache, note = state._make_cache() + if base_cache is not None: + # Compression strategy: unchanged behaviour, no persistence. + return Acquired(base_cache, prompt_text, note, None, base_cache, False) + + # Native strategy — manage a persistent single-slot cache. + try: + from mlx_lm.models.cache import ( # noqa: PLC0415 + can_trim_prompt_cache, + make_prompt_cache, + trim_prompt_cache, + ) + + full_tokens = list(state.tokenizer.encode(prompt_text)) + except Exception: # noqa: BLE001 — any failure → safe full-reprocess fallback + return Acquired(None, prompt_text, note, None, None, False) + + def _fresh() -> Any | None: + try: + return make_prompt_cache(state.model) + except Exception: # noqa: BLE001 + return None + + model_ref = getattr(state, "_loaded_model_ref", None) + persist = getattr(state, "_persist_cache", None) + persist_tokens = getattr(state, "_persist_tokens", None) or [] + persist_ref = getattr(state, "_persist_cache_model_ref", None) + + # Reset conditions: nothing cached, different model, empty history. + if persist is None or persist_ref != model_ref or not persist_tokens: + return _native_result(_fresh(), full_tokens, prompt_text, note) + + try: + if not can_trim_prompt_cache(persist): + return _native_result(_fresh(), full_tokens, prompt_text, note) + # Always leave >=1 token to process live (mlx-lm does the same). + common = min(_common_prefix_len(persist_tokens, full_tokens), len(full_tokens) - 1) + if common <= 0: + return _native_result(_fresh(), full_tokens, prompt_text, note) + num_to_trim = len(persist_tokens) - common + if num_to_trim > 0: + trimmed = trim_prompt_cache(persist, num_to_trim) + if trimmed != num_to_trim: + # Couldn't roll back cleanly — don't risk a spliced mismatch. + return _native_result(_fresh(), full_tokens, prompt_text, note) + # Reuse hit: cache now holds exactly the common prefix; prefill suffix. + return Acquired(persist, full_tokens[common:], note, full_tokens, None, True) + except Exception: # noqa: BLE001 + return _native_result(_fresh(), full_tokens, prompt_text, note) + + +def commit(state: Any, *, cache: Any, commit_tokens: list[int] | None, generated_ids: list[int], model_ref: str | None) -> None: + """Persist the cache keyed by prompt + generated tokens (positional truth).""" + if cache is None or commit_tokens is None: + return + state._persist_cache = cache + state._persist_tokens = list(commit_tokens) + [t for t in generated_ids if isinstance(t, int)] + state._persist_cache_model_ref = model_ref + + +def invalidate(state: Any) -> None: + state._persist_cache = None + state._persist_tokens = [] + state._persist_cache_model_ref = None diff --git a/backend_service/mlx_worker_request.py b/backend_service/mlx_worker_request.py index 6bb1ab7..5c2112e 100644 --- a/backend_service/mlx_worker_request.py +++ b/backend_service/mlx_worker_request.py @@ -133,7 +133,10 @@ def _build_mlx_sampler(request: dict[str, Any]) -> Any: kwargs: dict[str, Any] = {"temp": float(request.get("temperature") or 0.0)} samplers = request.get("samplers") or {} if isinstance(samplers, dict): - for src in ("top_p", "top_k", "min_p"): + # XTC (xtc_probability/xtc_threshold) is supported by current + # make_sampler and adds creative variety; it survives the signature + # filter below on builds that have it and is dropped on older ones. + for src in ("top_p", "top_k", "min_p", "xtc_probability", "xtc_threshold"): value = samplers.get(src) if value is not None: kwargs[src] = value @@ -147,6 +150,47 @@ def _build_mlx_sampler(request: dict[str, Any]) -> Any: return make_sampler(**filtered) +def _build_mlx_logits_processors(request: dict[str, Any]) -> Any: + """Build mlx-lm logits processors (repetition penalty) from the request. + + mlx-lm applies repetition penalty via ``logits_processors``, NOT through + ``make_sampler`` — so the UI's ``repeat_penalty`` was silently dropped + when only the sampler was wired. Returns None when no (or a no-op 1.0) + penalty is requested, so callers can pass ``logits_processors=None`` (the + mlx-lm default). Signature-filtered like the sampler for cross-version + robustness. + """ + import inspect + + samplers = request.get("samplers") or {} + if not isinstance(samplers, dict): + return None + raw = samplers.get("repeat_penalty", samplers.get("repetition_penalty")) + try: + penalty = float(raw) if raw is not None else None + except (TypeError, ValueError): + penalty = None + if penalty is None or abs(penalty - 1.0) < 1e-6: + return None + + try: + from mlx_lm.sample_utils import make_logits_processors + + kwargs: dict[str, Any] = {"repetition_penalty": penalty} + ctx = samplers.get("repeat_penalty_context") or samplers.get("repetition_context_size") + if ctx is not None: + try: + kwargs["repetition_context_size"] = int(ctx) + except (TypeError, ValueError): + pass + sig = inspect.signature(make_logits_processors) + allowed = set(sig.parameters.keys()) + filtered = {k: v for k, v in kwargs.items() if k in allowed} + return make_logits_processors(**filtered) + except Exception: + return None + + def _sampler_seed(request: dict[str, Any]) -> int | None: samplers = request.get("samplers") or {} if not isinstance(samplers, dict): diff --git a/backend_service/models/__init__.py b/backend_service/models/__init__.py index 4c43b62..e2f9414 100644 --- a/backend_service/models/__init__.py +++ b/backend_service/models/__init__.py @@ -151,6 +151,14 @@ class GenerateRequest(BaseModel): mirostatMode: Literal[0, 1, 2] | None = None mirostatTau: float | None = Field(default=None, ge=0.0, le=10.0) mirostatEta: float | None = Field(default=None, ge=0.0, le=1.0) + # Modern samplers (tier 2). XTC drops top tokens for variety; DRY + # penalises repeated multi-token sequences. llama-server applies all; + # mlx-lm applies XTC via make_sampler and ignores DRY (llama-only). + xtcProbability: float | None = Field(default=None, ge=0.0, le=1.0) + xtcThreshold: float | None = Field(default=None, ge=0.0, le=1.0) + dryMultiplier: float | None = Field(default=None, ge=0.0, le=4.0) + dryBase: float | None = Field(default=None, ge=0.0, le=8.0) + dryAllowedLength: int | None = Field(default=None, ge=0, le=64) seed: int | None = Field(default=None, ge=0, le=2**31 - 1) # Constrained decoding: when set, llama-server enforces a JSON schema # via its `response_format: {type: "json_schema", json_schema: {...}}` @@ -268,6 +276,15 @@ class OpenAIChatCompletionRequest(BaseModel): presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0) seed: int | None = Field(default=None, ge=0, le=2**31 - 1) stop: list[str] | str | None = None + # Non-standard but widely-accepted local-server sampler fields. Mapped + # into the runtime sampler dict in state/openai_compat.py for parity with + # the native chat route (llama-server takes these natively; the MLX worker + # consumes min_p + repeat_penalty). + min_p: float | None = Field(default=None, ge=0.0, le=1.0) + repeat_penalty: float | None = Field(default=None, ge=0.0, le=2.0) + mirostat: int | None = Field(default=None, ge=0, le=2) + mirostat_tau: float | None = Field(default=None, ge=0.0) + mirostat_eta: float | None = Field(default=None, ge=0.0) response_format: dict[str, Any] | None = None diff --git a/backend_service/state/__init__.py b/backend_service/state/__init__.py index 57b3931..248bbcc 100644 --- a/backend_service/state/__init__.py +++ b/backend_service/state/__init__.py @@ -35,6 +35,7 @@ _build_sampler_overrides, _clean_prompt_for_title, _compose_chat_system_prompt, + _history_token_budget, _legacy_title_from_prompt, _normalize_remote_provider_api_base, _read_text_tail, diff --git a/backend_service/state/_helpers.py b/backend_service/state/_helpers.py index fee56df..b38e597 100644 --- a/backend_service/state/_helpers.py +++ b/backend_service/state/_helpers.py @@ -57,6 +57,14 @@ def _put(dst: str, value: Any) -> None: overrides["mirostat"] = mirostat_mode _put("mirostat_tau", getattr(request, "mirostatTau", None)) _put("mirostat_eta", getattr(request, "mirostatEta", None)) + # Modern samplers (tier 2): XTC (both engines) + DRY (llama only). + # Engine-side key names; llama-server forwards them via + # _LLAMA_SAMPLER_KEYS, mlx-lm reads xtc_* in _build_mlx_sampler. + _put("xtc_probability", getattr(request, "xtcProbability", None)) + _put("xtc_threshold", getattr(request, "xtcThreshold", None)) + _put("dry_multiplier", getattr(request, "dryMultiplier", None)) + _put("dry_base", getattr(request, "dryBase", None)) + _put("dry_allowed_length", getattr(request, "dryAllowedLength", None)) # Phase 3.3: when the user enables logprobs on a request the # frontend sends a top-k count; map it onto llama-server's # `logprobs` + `top_logprobs` parameters so the response delta @@ -68,10 +76,43 @@ def _put(dst: str, value: Any) -> None: return overrides +def _estimate_tokens(text: str) -> int: + """Cheap, deliberately CONSERVATIVE token estimate (no tokenizer here). + + Assumes ~3 chars/token vs the ~4 typical for English so the history + window UNDER-fills the context rather than risking an overflow the MLX + path can't recover from. Code and CJK are denser than English, so + erring small protects them too. Off by a constant factor — fine for a + safety budget, not for billing. + """ + return (len(text) // 3) + 1 + + +def _history_token_budget( + *, + context_tokens: int, + max_tokens: int, + system_prompt: str | None, + prompt: str | None, +) -> int: + """Token budget left for *prior* history after reserving room for the + system prompt, the current user prompt, the generation, and chat-template + overhead. Floors at 512 so a single recent turn is always kept. + """ + reserved = ( + _estimate_tokens(system_prompt or "") + + _estimate_tokens(prompt or "") + + int(max_tokens or 0) + + 512 # chat-template + role-tag + tool-schema overhead headroom + ) + return max(512, int(context_tokens or 0) - reserved) + + def _build_history_with_reasoning( messages: list[dict[str, Any]], *, preserve_reasoning: bool, + token_budget: int | None = None, ) -> list[dict[str, Any]]: """Project a session's stored messages into the history list passed to the inference layer. @@ -79,10 +120,17 @@ def _build_history_with_reasoning( When `preserve_reasoning` is true and an assistant message has a `reasoning` field captured by ThinkingTokenFilter on a previous turn, the reasoning is re-emitted inside `...` tags ahead of - the visible answer. Reasoning-capable models (Qwen3, DeepSeek R1, etc.) - consume this naturally on follow-up turns; non-reasoning models will - treat it as inline text. Falsy / missing reasoning is skipped, so this - is safe to call unconditionally. + the visible answer. (Upstream chat templates for Qwen3 / DeepSeek-R1 + actually strip prior reasoning, so the live chat path now passes + `preserve_reasoning=False`; the option is kept for callers that want it.) + Falsy / missing reasoning is skipped, so this is safe to call + unconditionally. + + When `token_budget` is set, a sliding window keeps every system message + plus the NEWEST conversation turns that fit the budget (estimated, no + tokenizer), dropping the oldest. This bounds prompt growth across a long + chat — preventing silent truncation on llama.cpp and out-of-context + errors on MLX. ``None`` disables windowing (unchanged behaviour). """ history: list[dict[str, Any]] = [] for message in messages: @@ -97,7 +145,26 @@ def _build_history_with_reasoning( if reasoning_str: text = f"\n{reasoning_str}\n\n\n{text}" history.append({"role": role, "text": text}) - return history + + if token_budget is None or token_budget <= 0: + return history + + # System messages are always kept; window the conversation tail. + system_msgs = [m for m in history if m["role"] == "system"] + convo = [m for m in history if m["role"] != "system"] + used = sum(_estimate_tokens(m["text"]) for m in system_msgs) + kept_tail: list[dict[str, Any]] = [] + for message in reversed(convo): + cost = _estimate_tokens(message["text"]) + # Always keep the most recent turn even if it alone blows the budget; + # dropping the latest context is worse than a small overflow the + # engine can still truncate. + if kept_tail and used + cost > token_budget: + break + used += cost + kept_tail.append(message) + kept_tail.reverse() + return system_msgs + kept_tail _TITLE_LEADING_PATTERNS = [ diff --git a/backend_service/state/generation.py b/backend_service/state/generation.py index 15098f4..1ace636 100644 --- a/backend_service/state/generation.py +++ b/backend_service/state/generation.py @@ -35,6 +35,7 @@ _build_history_with_reasoning, _build_sampler_overrides, _compose_chat_system_prompt, + _history_token_budget, ) @@ -144,7 +145,17 @@ def generate(state: ChaosEngineState, request: GenerateRequest) -> dict[str, Any history = _build_history_with_reasoning( session["messages"], - preserve_reasoning=(effective_thinking_mode == "auto"), + # Don't replay prior reasoning — upstream chat templates + # (Qwen3 / DeepSeek-R1) strip it, and re-feeding it bloats the + # prompt every turn. token_budget windows the oldest turns out so + # a long chat can't silently overflow the context. + preserve_reasoning=False, + token_budget=_history_token_budget( + context_tokens=desired_context_tokens, + max_tokens=request.maxTokens, + system_prompt=request.systemPrompt, + prompt=request.prompt, + ), ) session["messages"].append({"role": "user", "text": request.prompt, "metrics": None}) session["updatedAt"] = state._time_label() @@ -393,7 +404,17 @@ def generate_stream(state: ChaosEngineState, request: GenerateRequest): history = _build_history_with_reasoning( session["messages"], - preserve_reasoning=(effective_thinking_mode == "auto"), + # Don't replay prior reasoning — upstream chat templates + # (Qwen3 / DeepSeek-R1) strip it, and re-feeding it bloats the + # prompt every turn. token_budget windows the oldest turns out so + # a long chat can't silently overflow the context. + preserve_reasoning=False, + token_budget=_history_token_budget( + context_tokens=desired_context_tokens, + max_tokens=request.maxTokens, + system_prompt=request.systemPrompt, + prompt=request.prompt, + ), ) session["messages"].append({"role": "user", "text": request.prompt, "metrics": None}) session["updatedAt"] = state._time_label() @@ -599,6 +620,24 @@ def _maybe_emit_generating_phase() -> str: ttft_seconds = round(time.perf_counter() - gen_start, 3) return f"data: {json.dumps({'phase': 'generating', 'ttftSeconds': ttft_seconds})}\n\n" + # Token coalescing: batch visible token frames so a fast decoder + # doesn't pay a json.dumps + SSE frame per token. Flush on size, a + # short time window, any non-token event, or stream end. Disabled + # when per-token logprobs are requested (they must stay 1:1 aligned). + _COALESCE_CHARS = 24 + _COALESCE_SECS = 0.05 + _coalesce_tokens = not (request.logprobs and int(request.logprobs) > 0) + _tok: dict[str, Any] = {"buf": [], "chars": 0, "started": 0.0} + + def _flush_tokens() -> str: + if not _tok["buf"]: + return "" + merged = "".join(_tok["buf"]) + _tok["buf"] = [] + _tok["chars"] = 0 + _tok["started"] = 0.0 + return f"data: {json.dumps({'token': merged})}\n\n" + try: if enable_tools: from backend_service.agent import run_agent_loop_streaming @@ -619,7 +658,20 @@ def _maybe_emit_generating_phase() -> str: if phase_event: yield phase_event full_text += event["token"] - yield f"data: {json.dumps({'token': event['token']})}\n\n" + if _coalesce_tokens: + if not _tok["buf"]: + _tok["started"] = time.perf_counter() + _tok["buf"].append(event["token"]) + _tok["chars"] += len(event["token"]) + if ( + _tok["chars"] >= _COALESCE_CHARS + or time.perf_counter() - _tok["started"] >= _COALESCE_SECS + ): + _f = _flush_tokens() + if _f: + yield _f + else: + yield f"data: {json.dumps({'token': event['token']})}\n\n" if len(full_text) > runaway_char_budget: runaway_triggered = True cancelled = True @@ -628,8 +680,14 @@ def _maybe_emit_generating_phase() -> str: phase_event = _maybe_emit_generating_phase() if phase_event: yield phase_event + _f = _flush_tokens() + if _f: + yield _f yield f"data: {json.dumps({'toolCallStart': event['tool_call_start']})}\n\n" elif "tool_call_result" in event: + _f = _flush_tokens() + if _f: + yield _f agent_tool_calls.append(event["tool_call_result"]) yield f"data: {json.dumps({'toolCallResult': event['tool_call_result']})}\n\n" elif event.get("done"): @@ -653,16 +711,35 @@ def _maybe_emit_generating_phase() -> str: phase_event = _maybe_emit_generating_phase() if phase_event: yield phase_event + _f = _flush_tokens() + if _f: + yield _f full_reasoning += chunk.reasoning yield f"data: {json.dumps({'reasoning': chunk.reasoning})}\n\n" if chunk.reasoning_done: + _f = _flush_tokens() + if _f: + yield _f yield f"data: {json.dumps({'reasoningDone': True})}\n\n" if chunk.text: phase_event = _maybe_emit_generating_phase() if phase_event: yield phase_event full_text += chunk.text - yield f"data: {json.dumps({'token': chunk.text})}\n\n" + if _coalesce_tokens: + if not _tok["buf"]: + _tok["started"] = time.perf_counter() + _tok["buf"].append(chunk.text) + _tok["chars"] += len(chunk.text) + if ( + _tok["chars"] >= _COALESCE_CHARS + or time.perf_counter() - _tok["started"] >= _COALESCE_SECS + ): + _f = _flush_tokens() + if _f: + yield _f + else: + yield f"data: {json.dumps({'token': chunk.text})}\n\n" # Phase 3.3: forward per-token logprobs when # the inference layer captured them. if chunk.token_logprobs: @@ -730,6 +807,9 @@ def _maybe_emit_generating_phase() -> str: f"{p_avail:.1f} GB, " f"pressure={p_pressure:.0f}%.", ) + _f = _flush_tokens() + if _f: + yield _f yield ( "data: " + json.dumps({ @@ -762,6 +842,9 @@ def _maybe_emit_generating_phase() -> str: "chat", "warning", f"[{model_tag}] Thermal warning: critical.", ) + _f = _flush_tokens() + if _f: + yield _f yield ( "data: " + json.dumps({ @@ -794,11 +877,20 @@ def _maybe_emit_generating_phase() -> str: chaosengine.active_requests = max(0, chaosengine.active_requests - 1) chaosengine.add_log("chat", "error", f"[{model_tag}] Streaming failed: {exc}") chaosengine.clear_chat_cancel(session_id_for_cancel) + _f = _flush_tokens() + if _f: + yield _f yield f"data: {json.dumps({'error': str(exc)})}\n\n" return finally: chaosengine.clear_chat_cancel(session_id_for_cancel) + # Flush any tokens still buffered by the coalescer before the + # terminal done / cancelled events (covers normal end + all breaks). + _f = _flush_tokens() + if _f: + yield _f + if cancelled: yield f"data: {json.dumps({'cancelled': True})}\n\n" if runaway_loop_reason is not None: diff --git a/backend_service/state/openai_compat.py b/backend_service/state/openai_compat.py index b25dedd..a5a3cb0 100644 --- a/backend_service/state/openai_compat.py +++ b/backend_service/state/openai_compat.py @@ -236,6 +236,19 @@ def openai_chat_completion( oai_samplers["seed"] = request.seed if request.stop is not None: oai_samplers["stop"] = request.stop if isinstance(request.stop, list) else [request.stop] + # Parity with the native chat route's sampler set: min_p, repeat_penalty, + # and mirostat were silently dropped on the /v1 path. llama-server takes + # these key names natively; the MLX worker consumes min_p + repeat_penalty. + if request.min_p is not None: + oai_samplers["min_p"] = request.min_p + if request.repeat_penalty is not None: + oai_samplers["repeat_penalty"] = request.repeat_penalty + if request.mirostat is not None: + oai_samplers["mirostat"] = request.mirostat + if request.mirostat_tau is not None: + oai_samplers["mirostat_tau"] = request.mirostat_tau + if request.mirostat_eta is not None: + oai_samplers["mirostat_eta"] = request.mirostat_eta # Phase 2.13: pull a JSON schema out of OpenAI's response_format # envelope so the constrained-decode path lights up. Anything diff --git a/src/components/SamplerPanel.tsx b/src/components/SamplerPanel.tsx index 9df721e..7f58c33 100644 --- a/src/components/SamplerPanel.tsx +++ b/src/components/SamplerPanel.tsx @@ -194,6 +194,39 @@ export function SamplerPanel({ overrides, onChange, disabled }: SamplerPanelProp disabled={disabled} onChange={(v) => patch("repeatPenalty", v)} /> + patch("xtcProbability", v)} + /> + patch("xtcThreshold", v)} + /> + patch("dryMultiplier", v)} + /> { expect(samplerPayload({ topP: 0.9, topK: null, seed: null })).toEqual({ topP: 0.9 }); }); + it("projects modern samplers (xtc + dry)", () => { + expect( + samplerPayload({ xtcProbability: 0.5, xtcThreshold: 0.1, dryMultiplier: 0.8 }), + ).toEqual({ xtcProbability: 0.5, xtcThreshold: 0.1, dryMultiplier: 0.8 }); + }); + + it("round-trips modern samplers through storage", () => { + writeSamplerOverrides("sx", { xtcProbability: 0.5, dryMultiplier: 0.8 }); + expect(readSamplerOverrides("sx")).toEqual({ xtcProbability: 0.5, dryMultiplier: 0.8 }); + }); + it("parses jsonSchemaText into jsonSchema when valid", () => { const schemaText = '{"type":"object","properties":{"answer":{"type":"string"}}}'; expect(samplerPayload({ jsonSchemaText: schemaText })).toEqual({ diff --git a/src/features/chat/samplerOverrides.ts b/src/features/chat/samplerOverrides.ts index 4bcf226..07007e1 100644 --- a/src/features/chat/samplerOverrides.ts +++ b/src/features/chat/samplerOverrides.ts @@ -21,6 +21,9 @@ const NUMERIC_KEYS = [ "seed", "mirostatTau", "mirostatEta", + "xtcProbability", + "xtcThreshold", + "dryMultiplier", ] as const; function storageKey(sessionId: string): string { @@ -95,6 +98,9 @@ export function samplerPayload(overrides: SamplerOverrides): Record; /** * Phase 3.3: when set, asks llama-server to return top-k logprobs @@ -255,6 +260,9 @@ export interface SamplerOverrides { mirostatMode?: 0 | 1 | 2 | null; mirostatTau?: number | null; mirostatEta?: number | null; + xtcProbability?: number | null; + xtcThreshold?: number | null; + dryMultiplier?: number | null; /** * Phase 2.2: opt-in constrained decoding. Raw JSON-schema text the * user typed in the SamplerPanel. Parsed at send-time and forwarded diff --git a/tests/test_backend_service.py b/tests/test_backend_service.py index ff1c6af..c5b04a1 100644 --- a/tests/test_backend_service.py +++ b/tests/test_backend_service.py @@ -1350,6 +1350,30 @@ def test_openai_completion_forwards_sampler_fields(self): self.assertEqual(runtime_kwargs["samplers"]["stop"], ["END"]) self.assertIn("properties", runtime_kwargs["json_schema"]) + def test_openai_completion_forwards_extended_samplers(self): + # Parity fix: min_p / repeat_penalty / mirostat were dropped on the + # /v1 path. They must now reach the runtime sampler dict. + response = self.client.post( + "/v1/chat/completions", + json={ + "model": "google/gemma-4-E4B-it", + "messages": [{"role": "user", "content": "test"}], + "max_tokens": 16, + "min_p": 0.05, + "repeat_penalty": 1.15, + "mirostat": 2, + "mirostat_tau": 5.0, + "mirostat_eta": 0.1, + }, + ) + self.assertEqual(response.status_code, 200) + samplers = self.client.app.state.chaosengine.runtime.last_generate_kwargs["samplers"] + self.assertEqual(samplers["min_p"], 0.05) + self.assertEqual(samplers["repeat_penalty"], 1.15) + self.assertEqual(samplers["mirostat"], 2) + self.assertEqual(samplers["mirostat_tau"], 5.0) + self.assertEqual(samplers["mirostat_eta"], 0.1) + def test_openai_completion_omits_sampler_dict_when_none_set(self): response = self.client.post( "/v1/chat/completions", diff --git a/tests/test_history_with_reasoning.py b/tests/test_history_with_reasoning.py index 74f8da4..0f45e97 100644 --- a/tests/test_history_with_reasoning.py +++ b/tests/test_history_with_reasoning.py @@ -9,6 +9,7 @@ import unittest from backend_service.state import _build_history_with_reasoning +from backend_service.state._helpers import _estimate_tokens, _history_token_budget class BuildHistoryWithReasoningTests(unittest.TestCase): @@ -65,5 +66,61 @@ def test_preserves_message_order(self): self.assertIn("R2", history[3]["text"]) +class HistoryTokenWindowTests(unittest.TestCase): + def test_token_budget_none_keeps_all(self): + messages = [{"role": "user", "text": "x" * 300} for _ in range(6)] + history = _build_history_with_reasoning(messages, preserve_reasoning=False, token_budget=None) + self.assertEqual(len(history), 6) + + def test_windows_oldest_turns_out(self): + # Each 30-char text ~= 11 estimated tokens; budget 25 keeps 2 newest. + messages = [ + {"role": "user", "text": "a" * 30}, + {"role": "assistant", "text": "b" * 30}, + {"role": "user", "text": "c" * 30}, + {"role": "assistant", "text": "d" * 30}, + ] + history = _build_history_with_reasoning(messages, preserve_reasoning=False, token_budget=25) + self.assertEqual([h["text"] for h in history], ["c" * 30, "d" * 30]) + + def test_always_keeps_latest_turn_even_if_over_budget(self): + messages = [{"role": "user", "text": "z" * 300}] + history = _build_history_with_reasoning(messages, preserve_reasoning=False, token_budget=10) + self.assertEqual(len(history), 1) + self.assertEqual(history[0]["text"], "z" * 300) + + def test_system_messages_always_kept(self): + messages = [ + {"role": "system", "text": "s" * 30}, + {"role": "user", "text": "u" * 300}, + {"role": "assistant", "text": "a" * 300}, + {"role": "user", "text": "n" * 9}, + ] + history = _build_history_with_reasoning(messages, preserve_reasoning=False, token_budget=20) + roles = [h["role"] for h in history] + self.assertIn("system", roles) + self.assertEqual(history[-1]["text"], "n" * 9) + self.assertNotIn("u" * 300, [h["text"] for h in history]) + + def test_estimate_tokens_is_conservative(self): + # ~3 chars/token (over-estimates English so the window stays safe). + self.assertEqual(_estimate_tokens(""), 1) + self.assertEqual(_estimate_tokens("abc"), 2) + self.assertEqual(_estimate_tokens("a" * 30), 11) + + def test_history_token_budget_reserves_and_floors(self): + budget = _history_token_budget( + context_tokens=2000, max_tokens=256, system_prompt="x" * 30, prompt="y" * 30 + ) + # 2000 - (11 + 11 + 256 + 512) = 1210 + self.assertEqual(budget, 1210) + + def test_history_token_budget_floor_512(self): + budget = _history_token_budget( + context_tokens=100, max_tokens=256, system_prompt=None, prompt=None + ) + self.assertEqual(budget, 512) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_mlx_prompt_cache.py b/tests/test_mlx_prompt_cache.py new file mode 100644 index 0000000..593e419 --- /dev/null +++ b/tests/test_mlx_prompt_cache.py @@ -0,0 +1,180 @@ +"""Tests for the MLX per-session prompt-cache reuse logic (tier 4). + +Exercises backend_service/mlx_worker_prompt_cache.py with a fake worker +state and patched mlx-lm cache primitives — no real model load. The +correctness contract under test: the persisted token list always equals +the cache's positional contents, and any uncertainty falls back to a fresh +full prefill. +""" + +import unittest +from unittest import mock + +from backend_service import mlx_worker_prompt_cache as pc + +CACHE_MOD = "mlx_lm.models.cache" + + +class FakeCache: + """Sentinel standing in for an mlx-lm prompt cache.""" + + def __init__(self, label): + self.label = label + + +class FakeState: + def __init__(self, *, base_cache=None, base_note=None, tokens=None, model_ref="m"): + self._base = (base_cache, base_note) + self._tokens = list(tokens or []) + self.model = object() + self._loaded_model_ref = model_ref + self.tokenizer = self + self._persist_cache = None + self._persist_tokens = [] + self._persist_cache_model_ref = None + + def _make_cache(self): + return self._base + + def encode(self, _text): # stands in for tokenizer.encode + return list(self._tokens) + + +class CommonPrefixTests(unittest.TestCase): + def test_common_prefix_len(self): + self.assertEqual(pc._common_prefix_len([1, 2, 3], [1, 2, 9]), 2) + self.assertEqual(pc._common_prefix_len([1, 2], [9]), 0) + self.assertEqual(pc._common_prefix_len([1, 2, 3], [1, 2, 3, 4]), 3) + + +class AcquireCompressionTests(unittest.TestCase): + def test_compression_strategy_passthrough(self): + comp = FakeCache("compression") + state = FakeState(base_cache=comp, base_note="cn") + acq = pc.acquire(state, "p-text") + self.assertIs(acq.cache, comp) + self.assertEqual(acq.prompt_feed, "p-text") # string, unchanged + self.assertFalse(acq.managed) + self.assertIs(acq.fields_cache, comp) + self.assertIsNone(acq.commit_tokens) + + +class AcquireNativeTests(unittest.TestCase): + def _patches(self, *, can_trim=True, trim=lambda c, n: n, fresh_label="fresh"): + return ( + mock.patch(f"{CACHE_MOD}.make_prompt_cache", return_value=FakeCache(fresh_label)), + mock.patch(f"{CACHE_MOD}.can_trim_prompt_cache", return_value=can_trim), + mock.patch(f"{CACHE_MOD}.trim_prompt_cache", side_effect=trim), + ) + + def test_fresh_native_cache_full_prefill(self): + state = FakeState(base_cache=None, tokens=[1, 2, 3]) + with self._patches()[0], self._patches()[1], self._patches()[2]: + acq = pc.acquire(state, "ignored") + self.assertTrue(acq.managed) + self.assertIsInstance(acq.cache, FakeCache) + self.assertEqual(acq.prompt_feed, [1, 2, 3]) # full token list + self.assertEqual(acq.commit_tokens, [1, 2, 3]) + self.assertIsNone(acq.fields_cache) + + def test_reuse_hit_feeds_only_suffix_no_trim(self): + persist = FakeCache("persist") + state = FakeState(base_cache=None, tokens=[1, 2, 3, 4, 5], model_ref="m") + state._persist_cache = persist + state._persist_tokens = [1, 2, 3] + state._persist_cache_model_ref = "m" + m1, m2, m3 = self._patches() + with m1, m2, m3 as trim: + acq = pc.acquire(state, "ignored") + self.assertIs(acq.cache, persist) # reused, not fresh + self.assertEqual(acq.prompt_feed, [4, 5]) # suffix only + self.assertEqual(acq.commit_tokens, [1, 2, 3, 4, 5]) + trim.assert_not_called() # num_to_trim == 0 + + def test_reuse_with_divergence_trims_tail(self): + persist = FakeCache("persist") + state = FakeState(base_cache=None, tokens=[1, 2, 3, 4], model_ref="m") + state._persist_cache = persist + state._persist_tokens = [1, 2, 3, 9, 9] # diverges after index 3 + state._persist_cache_model_ref = "m" + m1, m2, m3 = self._patches() + with m1, m2, m3 as trim: + acq = pc.acquire(state, "ignored") + self.assertIs(acq.cache, persist) + trim.assert_called_once_with(persist, 2) # 5 cached - 3 common + self.assertEqual(acq.prompt_feed, [4]) # full[3:] + + def test_reset_on_model_change(self): + state = FakeState(base_cache=None, tokens=[1, 2, 3], model_ref="new") + state._persist_cache = FakeCache("stale") + state._persist_tokens = [1, 2, 3] + state._persist_cache_model_ref = "old" + m1, m2, m3 = self._patches() + with m1, m2, m3: + acq = pc.acquire(state, "ignored") + self.assertEqual(acq.prompt_feed, [1, 2, 3]) # fresh → full prefill + self.assertEqual(acq.cache.label, "fresh") + + def test_reset_when_cache_not_trimmable(self): + state = FakeState(base_cache=None, tokens=[1, 2, 3, 4], model_ref="m") + state._persist_cache = FakeCache("persist") + state._persist_tokens = [1, 2, 3] + state._persist_cache_model_ref = "m" + m1, m2, m3 = self._patches(can_trim=False) + with m1, m2, m3: + acq = pc.acquire(state, "ignored") + self.assertEqual(acq.cache.label, "fresh") + self.assertEqual(acq.prompt_feed, [1, 2, 3, 4]) + + def test_reset_when_no_common_prefix(self): + state = FakeState(base_cache=None, tokens=[7, 8, 9], model_ref="m") + state._persist_cache = FakeCache("persist") + state._persist_tokens = [1, 2, 3] + state._persist_cache_model_ref = "m" + m1, m2, m3 = self._patches() + with m1, m2, m3: + acq = pc.acquire(state, "ignored") + self.assertEqual(acq.cache.label, "fresh") + self.assertEqual(acq.prompt_feed, [7, 8, 9]) + + def test_partial_trim_falls_back_to_fresh(self): + state = FakeState(base_cache=None, tokens=[1, 2, 3, 4], model_ref="m") + state._persist_cache = FakeCache("persist") + state._persist_tokens = [1, 2, 3, 9, 9] + state._persist_cache_model_ref = "m" + # trim returns fewer than requested → unsafe → fresh + m1, m2, m3 = self._patches(trim=lambda c, n: n - 1) + with m1, m2, m3: + acq = pc.acquire(state, "ignored") + self.assertEqual(acq.cache.label, "fresh") + self.assertEqual(acq.prompt_feed, [1, 2, 3, 4]) + + +class CommitInvalidateTests(unittest.TestCase): + def test_commit_accounting_is_prompt_plus_generated(self): + state = FakeState() + cache = FakeCache("c") + pc.commit(state, cache=cache, commit_tokens=[1, 2, 3], generated_ids=[4, 5], model_ref="m") + self.assertIs(state._persist_cache, cache) + self.assertEqual(state._persist_tokens, [1, 2, 3, 4, 5]) + self.assertEqual(state._persist_cache_model_ref, "m") + + def test_commit_noop_when_not_managed(self): + state = FakeState() + pc.commit(state, cache=None, commit_tokens=None, generated_ids=[4], model_ref="m") + self.assertIsNone(state._persist_cache) + self.assertEqual(state._persist_tokens, []) + + def test_invalidate_clears(self): + state = FakeState() + state._persist_cache = FakeCache("c") + state._persist_tokens = [1, 2] + state._persist_cache_model_ref = "m" + pc.invalidate(state) + self.assertIsNone(state._persist_cache) + self.assertEqual(state._persist_tokens, []) + self.assertIsNone(state._persist_cache_model_ref) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_mlx_worker.py b/tests/test_mlx_worker.py index 7212104..d1f79d0 100644 --- a/tests/test_mlx_worker.py +++ b/tests/test_mlx_worker.py @@ -875,5 +875,41 @@ def test_unload_clears_multimodal_state(self): self.assertFalse(worker.is_multimodal) +class MlxLogitsProcessorTests(unittest.TestCase): + """_build_mlx_logits_processors wires repeat_penalty (mlx-lm applies it + via logits_processors, not the sampler — it was being dropped).""" + + def setUp(self): + from backend_service.mlx_worker_request import _build_mlx_logits_processors + + self._build = _build_mlx_logits_processors + + def test_none_when_no_samplers(self): + self.assertIsNone(self._build({})) + self.assertIsNone(self._build({"samplers": None})) + + def test_none_when_penalty_absent_or_neutral(self): + self.assertIsNone(self._build({"samplers": {"top_p": 0.9}})) + self.assertIsNone(self._build({"samplers": {"repeat_penalty": 1.0}})) + + def test_none_when_penalty_non_numeric(self): + self.assertIsNone(self._build({"samplers": {"repeat_penalty": "oops"}})) + + @unittest.skipUnless( + __import__("importlib").util.find_spec("mlx_lm") is not None, + "mlx-lm not installed", + ) + def test_builds_processors_for_real_penalty(self): + result = self._build({"samplers": {"repeat_penalty": 1.3}}) + self.assertIsNotNone(result) + self.assertTrue(len(result) >= 1) + + def test_accepts_repetition_penalty_alias_without_raising(self): + try: + self._build({"samplers": {"repetition_penalty": 1.2}}) + except Exception as exc: # noqa: BLE001 + self.fail(f"alias parse raised: {exc}") + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_sampler_payload.py b/tests/test_sampler_payload.py index 4f63b15..a79f3bd 100644 --- a/tests/test_sampler_payload.py +++ b/tests/test_sampler_payload.py @@ -55,6 +55,30 @@ def test_merges_all_supported_sampler_keys(self): self.assertEqual(payload["mirostat_tau"], 5.0) self.assertEqual(payload["mirostat_eta"], 0.1) + def test_merges_modern_quality_samplers(self): + # DRY / XTC / top-n-sigma were added to _LLAMA_SAMPLER_KEYS; they + # must now flow through to the llama-server payload. + payload: dict = {} + _apply_sampler_kwargs( + payload, + samplers={ + "dry_multiplier": 0.8, + "dry_base": 1.75, + "dry_allowed_length": 2, + "xtc_probability": 0.5, + "xtc_threshold": 0.1, + "top_n_sigma": 1.0, + }, + reasoning_effort=None, + json_schema=None, + ) + self.assertEqual(payload["dry_multiplier"], 0.8) + self.assertEqual(payload["dry_base"], 1.75) + self.assertEqual(payload["dry_allowed_length"], 2) + self.assertEqual(payload["xtc_probability"], 0.5) + self.assertEqual(payload["xtc_threshold"], 0.1) + self.assertEqual(payload["top_n_sigma"], 1.0) + def test_none_values_in_samplers_skip_merge(self): # The frontend may send the union of fields with most set to null — # explicit nulls must not override server defaults. @@ -131,6 +155,19 @@ def test_emits_llama_field_names(self): self.assertEqual(overrides["mirostat_tau"], 5.0) self.assertEqual(overrides["mirostat_eta"], 0.1) + def test_emits_modern_sampler_field_names(self): + # XTC + DRY map to llama/mlx engine-side snake_case keys. + request = SimpleNamespace( + xtcProbability=0.5, xtcThreshold=0.1, + dryMultiplier=0.8, dryBase=1.75, dryAllowedLength=2, + ) + overrides = _build_sampler_overrides(request) + self.assertEqual(overrides["xtc_probability"], 0.5) + self.assertEqual(overrides["xtc_threshold"], 0.1) + self.assertEqual(overrides["dry_multiplier"], 0.8) + self.assertEqual(overrides["dry_base"], 1.75) + self.assertEqual(overrides["dry_allowed_length"], 2) + def test_partial_override_keeps_only_set_fields(self): request = SimpleNamespace( topP=0.9, topK=None, minP=None, repeatPenalty=None,