diff --git a/docs/configurations.md b/docs/configurations.md index 9d8bf2d2..61a1c39c 100644 --- a/docs/configurations.md +++ b/docs/configurations.md @@ -136,7 +136,7 @@ Every domain below is shown as a single table that lists **all** constants Thuki ### `[inference]` -Thuki reaches a model through a **provider**. `active_provider` names which one is used; each provider is described by a `[[inference.providers]]` block. Phase 1 ships two providers: **Ollama** (reached over HTTP at a configurable URL, local or remote) and a **Built-in (Thuki)** entry reserved for an upcoming bundled engine. A fresh install defaults to the Ollama provider. +Thuki reaches a model through a **provider**. `active_provider` names which one is used; each provider is described by a `[[inference.providers]]` block. Phase 1 ships two providers: **Ollama** (reached over HTTP at a configurable URL, local or remote) and a **Built-in (Thuki)** entry reserved for an upcoming bundled engine. A fresh install defaults to the Ollama provider. You can also add **OpenAI-compatible** providers (LM Studio, Jan, llama-server, etc.) by specifying `kind = "openai"` and a valid `base_url`. Each provider keeps its own selected `model`. Thuki discovers installed models live from Ollama's `/api/tags` endpoint and lets you pick one from the in-app model picker (or the Providers section of Settings); the choice is written to that provider's `model` field. When no model is installed and none has been chosen, Thuki refuses to dispatch a chat request and surfaces a "Pick a model" prompt. Pull a model with `ollama pull ` and select it. @@ -154,10 +154,11 @@ Each `[[inference.providers]]` block has these fields: | Field | Description | | :--------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------- | | `id` | Stable identifier referenced by `active_provider`. The `builtin` and `ollama` ids are seeded automatically. | -| `kind` | `builtin` or `ollama`. Any other kind is dropped on load. Determines how Thuki talks to the provider (the Ollama kind uses Ollama's native API). | +| `kind` | `"builtin"`, `"ollama"`, or `"openai"`. Any other kind is dropped on load. Determines how Thuki talks to the provider. | | `label` | Human-readable name shown in Settings. | -| `base_url` | For the Ollama kind: where Thuki reaches the server (defaults to `http://127.0.0.1:11434`; point it at another machine to use remote Ollama). Empty for the built-in kind. A provider of kind `ollama` with an empty `base_url` is dropped and re-seeded at the localhost default. | +| `base_url` | For the `ollama` and `openai` kinds: the server's base URL. For `ollama`, defaults to `http://127.0.0.1:11434` if empty (then re-seeded). For `openai`, must be a valid `http://` or `https://` URL; a provider with an empty or non-http(s) URL is dropped without healing. Empty for the `builtin` kind. | | `model` | The model selected for this provider, written when you pick one. Empty means "none chosen yet". | +| `vision` | For `openai`-kind providers only: set to `true` if the selected model accepts image inputs. OpenAI-compatible local servers expose no capability probe, so this is declared manually. Ignored for `builtin` and `ollama` (capabilities are resolved from the manifest or Ollama's `/api/show`). Defaults to `false`. | If the active model has been removed from Ollama between launches, Thuki silently falls back to the first installed model the next time you open the picker. If no models are installed at all, the next request surfaces a "Model not found" error with the exact `ollama pull ` command to run. @@ -180,6 +181,7 @@ The table below also lists the baked-in safety limits that govern Thuki's commun | `MAX_HF_API_BODY_BYTES` | `4 MiB` | No | Defense-in-depth bound on attacker-controlled data from a remote service, mirroring `MAX_OLLAMA_TAGS_BODY_BYTES`. | — | The largest Hugging Face API response body (repo file listings) Thuki will accept while resolving a model to download. Larger responses are rejected mid-stream and the request returns an error. | | `HF_API_TIMEOUT_SECS` | `15 s` | No | Protocol cap on a hung remote service so the download UI cannot stall on metadata resolution; 15 s is generous for a small metadata call over the internet. | — | How long Thuki waits for a Hugging Face API metadata call (repo file listing) to respond before giving up. Applies to resolving pasted repo ids and listing a repo's GGUF files, not to the model download itself. | | `HF_BASE_URL` | `https://huggingface.co` | No | Single origin for model metadata and downloads; the sha256-pinning and provenance model assume the canonical Hub. Pointing downloads at an arbitrary mirror would bypass the integrity guarantees that make the curated starter registry safe. | — | The Hugging Face origin Thuki uses for all model metadata calls and blob downloads. Every starter in the registry pins a repo at an exact revision and carries a sha256 digest verified on install; those digests are read from this origin and only meaningful against it. | +| `MAX_SSE_LINE_BYTES` | `1 MiB` | No | Defense-in-depth bound on attacker-controlled stream data. A malicious or broken chat server could otherwise grow a single stream line without limit and exhaust memory. | — | The longest single Server-Sent-Events line Thuki accepts while streaming a chat response from an OpenAI-compatible (`/v1`) server. A stream line exceeding this aborts the response with an error. | ### `[prompt]` diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 39a9aed2..8e48ff81 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -1924,6 +1924,18 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "keyring" +version = "3.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eebcc3aff044e5944a8fbaf69eb277d11986064cba30c468730e8b9909fb551c" +dependencies = [ + "log", + "security-framework 2.11.1", + "security-framework 3.7.0", + "zeroize", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -3228,7 +3240,7 @@ dependencies = [ "openssl-probe", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 3.7.0", ] [[package]] @@ -3256,7 +3268,7 @@ dependencies = [ "rustls-native-certs", "rustls-platform-verifier-android", "rustls-webpki", - "security-framework", + "security-framework 3.7.0", "security-framework-sys", "webpki-root-certs", "windows-sys 0.61.2", @@ -3367,6 +3379,19 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.11.0", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + [[package]] name = "security-framework" version = "3.7.0" @@ -4305,6 +4330,7 @@ dependencies = [ "futures-util", "html-escape", "image", + "keyring", "libc", "mockito", "objc2", diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index df20fb47..0acb8402 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -45,6 +45,7 @@ async-trait = "0.1" semver = "1" sha2 = "0.10" libc = "0.2" +keyring = { version = "3", features = ["apple-native"] } [target.'cfg(target_os = "macos")'.dependencies] tauri-nspanel = { git = "https://github.com/ahkohd/tauri-nspanel", branch = "v2.1" } diff --git a/src-tauri/src/commands.rs b/src-tauri/src/commands.rs index ec521614..552ba8bf 100644 --- a/src-tauri/src/commands.rs +++ b/src-tauri/src/commands.rs @@ -93,6 +93,9 @@ pub fn apply_capability_filter(messages: &mut [ChatMessage], caps: &Capabilities pub enum EngineErrorKind { /// Ollama process is not running (connection refused / timeout). EngineUnreachable, + /// The bundled engine's sidecar process failed to launch or crashed before + /// passing its health check. + EngineStartFailed, /// The requested model has not been pulled yet (HTTP 404). ModelNotFound, /// No active model has been selected. The user must pick a model from @@ -115,36 +118,363 @@ pub fn no_model_selected_error() -> EngineError { } } -/// Returns the error to emit when the active provider's kind has no Phase-1 -/// implementation, or `None` when the kind is the native Ollama path. Pure so -/// the routing decision is unit-tested even though `ask_model` is coverage-off. -/// In Phase 1 the only functional kind is `ollama`; the built-in engine and -/// the generic OpenAI-compatible kind arrive in Phase 2. -pub(crate) fn unsupported_provider_error(kind: &str, label: &str) -> Option { - use crate::config::defaults::PROVIDER_KIND_OLLAMA; - if kind == PROVIDER_KIND_OLLAMA { - return None; - } - let who = if label.trim().is_empty() { - "This provider" - } else { - label - }; - Some(EngineError { - kind: EngineErrorKind::EngineUnreachable, - message: format!("{who} is not available in this version of Thuki yet."), - }) -} - /// Structured error emitted over the streaming channel. /// Rust owns all user-facing copy; the frontend only uses `kind` for styling. -#[derive(Clone, Serialize, Debug)] +#[derive(Clone, Serialize, Debug, PartialEq)] pub struct EngineError { pub kind: EngineErrorKind, /// Final user-facing string. First line is the title, remainder is the subtitle. pub message: String, } +/// How a chat turn reaches its inference backend, decided once per request +/// from the active provider's kind. +#[derive(Debug, PartialEq, Eq)] +pub enum ChatRoute { + /// Native Ollama `/api/chat` streaming at the provider's base URL. + OllamaNative { + /// Full `/api/chat` endpoint. + endpoint: String, + }, + /// Generic OpenAI-compatible `/v1` streaming at the provider's base URL. + /// The API key is fetched later by provider id so the Keychain read + /// happens only on the path that needs it. + V1 { + base_url: String, + api_key_provider: Option, + }, + /// The bundled engine: resolve the installed model, ensure the sidecar + /// serves it, then stream via the `/v1` client at the engine's port. + Builtin { + /// The active provider's `model` field: the manifest id. + model_id: String, + }, +} + +/// Decides the chat route from the resolved config. Pure so the routing +/// decision is unit-tested even though `ask_model` is coverage-off. +/// +/// Errors: +/// - unknown/empty kind (defensive; the loader drops unknown kinds and +/// repairs a dangling `active_provider` pointer), +/// - `builtin` with an empty model (`NoModelSelected`, pointing the user at +/// the Settings model pick). +pub fn resolve_chat_route( + inference: &crate::config::schema::InferenceSection, +) -> Result { + use crate::config::defaults::{ + PROVIDER_KIND_BUILTIN, PROVIDER_KIND_OLLAMA, PROVIDER_KIND_OPENAI, + }; + match inference.active_provider_kind() { + PROVIDER_KIND_OLLAMA => Ok(ChatRoute::OllamaNative { + endpoint: format!( + "{}/api/chat", + inference.active_provider_base_url().trim_end_matches('/') + ), + }), + PROVIDER_KIND_OPENAI => Ok(ChatRoute::V1 { + base_url: inference + .active_provider_base_url() + .trim_end_matches('/') + .to_string(), + api_key_provider: Some(inference.active_provider.clone()), + }), + PROVIDER_KIND_BUILTIN => { + let model = inference.active_provider_model(); + if model.is_empty() { + return Err(EngineError { + kind: EngineErrorKind::NoModelSelected, + message: "No model selected\nPick or download a model in Settings.".to_string(), + }); + } + Ok(ChatRoute::Builtin { + model_id: model.to_string(), + }) + } + _ => Err(EngineError { + kind: EngineErrorKind::Other, + message: "Something went wrong\nThe active provider has an unknown kind.".to_string(), + }), + } +} + +/// Maps an installed-model manifest row onto the engine [`Target`] the +/// runner spawns: blob-store paths for the weights and optional mmproj plus +/// the configured context size. +/// +/// [`Target`]: crate::engine::state::Target +pub fn builtin_target( + conn: &rusqlite::Connection, + store: &crate::models::storage::ModelStore, + model_id: &str, + num_ctx: u32, +) -> Result { + let row = crate::models::manifest::get(conn, model_id).map_err(|e| EngineError { + kind: EngineErrorKind::Other, + message: format!("Something went wrong\nCould not read the installed-model manifest: {e}"), + })?; + let Some(model) = row else { + return Err(EngineError { + kind: EngineErrorKind::ModelNotFound, + message: "The selected model is not installed.\nPick or download a model in Settings." + .to_string(), + }); + }; + Ok(crate::engine::state::Target { + model_path: store.blob_path(&model.sha256), + mmproj_path: model + .mmproj_sha256 + .as_deref() + .map(|sha| store.blob_path(sha)), + num_ctx, + }) +} + +/// Parses llama-server's `GET /props` response and reports whether the +/// loaded model accepts image input. The flag lives at `modalities.vision`; +/// an absent field, a non-boolean value, or a malformed body all collapse to +/// `false` so the gate fails closed (images are stripped rather than letting +/// llama-server reject the whole request). +pub(crate) fn parse_props_vision(body: &[u8]) -> bool { + serde_json::from_slice::(body) + .ok() + .and_then(|v| { + v.get("modalities") + .and_then(|m| m.get("vision")) + .and_then(|b| b.as_bool()) + }) + .unwrap_or(false) +} + +/// Asks the serving llama-server whether the loaded model accepts images +/// (`GET /props`). Any transport or read failure collapses to `false`, +/// matching [`parse_props_vision`]'s fail-closed contract. +async fn fetch_builtin_vision(client: &reqwest::Client, base_url: &str) -> bool { + match client.get(format!("{base_url}/props")).send().await { + Ok(resp) => match resp.bytes().await { + Ok(bytes) => parse_props_vision(&bytes), + Err(_) => false, + }, + Err(_) => false, + } +} + +/// Runs the built-in-engine stage of a chat turn: mark activity, ensure the +/// engine serves `target`, then stream via the `/v1` client at the engine's +/// port. Pulled out of [`ask_model`] so the ensure-error mapping is covered +/// by tests: +/// - `Superseded` becomes a terminal `Cancelled` (a newer settings change +/// preempted this request; never an engine-start failure), +/// - `StartFailed` becomes a typed `EngineStartFailed` error. +/// +/// When the outgoing messages carry images, the serving llama-server is asked +/// whether the loaded model actually accepts them (`/props` runtime gate); +/// a non-vision model gets the images stripped through the same +/// [`apply_capability_filter`] path and stderr notice the cache-driven filter +/// uses, instead of letting the whole request fail. +/// +/// Returns the accumulated assistant content (empty on the error paths) so +/// the caller's persistence tail treats every route identically. +pub(crate) async fn stream_builtin_chat( + engine: &crate::engine::runner::EngineHandle, + target: crate::engine::state::Target, + model_id: String, + mut messages: Vec, + client: &reqwest::Client, + cancel_token: CancellationToken, + on_chunk: impl Fn(StreamChunk), +) -> String { + engine.touch(); + match engine.ensure_loaded(target).await { + Ok(port) => { + let base_url = format!("http://127.0.0.1:{port}"); + let carries_images = messages + .iter() + .any(|m| m.images.as_ref().is_some_and(|imgs| !imgs.is_empty())); + if carries_images && !fetch_builtin_vision(client, &base_url).await { + let stats = apply_capability_filter(&mut messages, &Capabilities::default()); + if stats.stripped_images > 0 { + eprintln!( + "thuki: [capability filter] model={} stripped_images={}", + model_id, stats.stripped_images + ); + } + } + crate::openai::stream_openai_chat( + crate::openai::OpenAiChatParams { + base_url, + model: model_id, + messages, + api_key: None, + }, + client, + cancel_token, + on_chunk, + ) + .await + } + Err(crate::engine::runner::EnsureError::Superseded) => { + on_chunk(StreamChunk::Cancelled); + String::new() + } + Err(crate::engine::runner::EnsureError::StartFailed(detail)) => { + on_chunk(StreamChunk::Error(EngineError { + kind: EngineErrorKind::EngineStartFailed, + message: format!("Thuki's engine could not start.\n{detail}"), + })); + String::new() + } + } +} + +/// Reads the API key for an `openai`-kind provider from the secret store. +/// Errors degrade to `None` with a stderr log: a missing or unreadable key +/// must not block a keyless local `/v1` server. +pub(crate) fn resolve_provider_api_key( + store: &dyn crate::keychain::SecretStore, + provider_id: Option<&str>, +) -> Option { + let id = provider_id?; + match store.get(id) { + Ok(key) => key, + Err(e) => { + eprintln!("thuki: [keychain] failed to read the api key for provider '{id}': {e}"); + None + } + } +} + +/// How LLM calls reach the active provider, decided once per pipeline turn. +/// +/// Downstream of [`ChatRoute`]: the route names the provider kind, the +/// transport is the resolved wire target. `Builtin` routes collapse into +/// `V1` here because once the engine sidecar is serving, it is just another +/// keyless OpenAI-compatible server at a loopback port. +#[derive(Clone, PartialEq)] +pub enum LlmTransport { + /// Native Ollama `/api/chat` at the provider's base URL. + OllamaNative { + /// Full `/api/chat` endpoint. + endpoint: String, + }, + /// Generic OpenAI-compatible `/v1` server: an `openai`-kind provider + /// (key already resolved) or the built-in engine (keyless, engine port). + V1 { + base_url: String, + api_key: Option, + }, +} + +impl LlmTransport { + /// Human-readable endpoint label for forensic trace records. + pub fn endpoint_label(&self) -> String { + match self { + LlmTransport::OllamaNative { endpoint } => endpoint.clone(), + LlmTransport::V1 { base_url, .. } => format!("{base_url}/v1/chat/completions"), + } + } +} + +impl std::fmt::Debug for LlmTransport { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LlmTransport::OllamaNative { endpoint } => f + .debug_struct("OllamaNative") + .field("endpoint", endpoint) + .finish(), + LlmTransport::V1 { base_url, api_key } => f + .debug_struct("V1") + .field("base_url", base_url) + .field("api_key", &api_key.as_ref().map(|_| "")) + .finish(), + } + } +} + +/// Picks the model slug for a pipeline turn. `Builtin` routes carry their +/// model in the provider config (already validated non-empty by +/// `resolve_chat_route`); every other kind keeps the picker-backed fallback +/// whose `None` means "no model selected". +/// +/// Used by both the search pipeline and title generation so the selection +/// logic stays in one place. +pub fn model_for_route(route: &ChatRoute, fallback: Option) -> Option { + match route { + ChatRoute::Builtin { model_id } => Some(model_id.clone()), + _ => fallback, + } +} + +/// Error from [`resolve_llm_transport`]. Splits the engine-ensure outcomes so +/// each caller can map them into its own vocabulary: `Superseded` is a +/// cancellation (a newer settings change preempted the request, never a +/// failure), `Engine` carries a typed user-facing error. +#[derive(Debug, PartialEq)] +pub enum TransportError { + /// A newer settings change preempted the engine ensure. + Superseded, + /// A typed engine error (start failure, missing manifest row, ...). + Engine(EngineError), +} + +/// Resolves a [`ChatRoute`] into the [`LlmTransport`] a pipeline turn streams +/// through. `OllamaNative` passes through; `V1` resolves the provider's API +/// key; `Builtin` maps the manifest row to an engine [`Target`], marks +/// activity, and ensures the sidecar serves it, yielding a keyless `V1` +/// transport at the engine's loopback port. +/// +/// `num_ctx` is consumed only by the builtin arm: the context size is a +/// launch property of the llama-server process, not a per-request knob. +/// +/// [`Target`]: crate::engine::state::Target +pub(crate) async fn resolve_llm_transport( + route: ChatRoute, + db: &crate::history::Database, + store: &crate::models::storage::ModelStore, + engine: &crate::engine::runner::EngineHandle, + secrets: &dyn crate::keychain::SecretStore, + num_ctx: u32, +) -> Result { + match route { + ChatRoute::OllamaNative { endpoint } => Ok(LlmTransport::OllamaNative { endpoint }), + ChatRoute::V1 { + base_url, + api_key_provider, + } => Ok(LlmTransport::V1 { + base_url, + api_key: resolve_provider_api_key(secrets, api_key_provider.as_deref()), + }), + ChatRoute::Builtin { model_id } => { + // Resolve the manifest row inside a scope so the connection guard + // drops before the ensure await. A poisoned lock is recovered: + // the connection itself is not invalidated by an unrelated panic. + let target = { + let conn = match db.0.lock() { + Ok(conn) => conn, + Err(poisoned) => poisoned.into_inner(), + }; + builtin_target(&conn, store, &model_id, num_ctx).map_err(TransportError::Engine)? + }; + engine.touch(); + match engine.ensure_loaded(target).await { + Ok(port) => Ok(LlmTransport::V1 { + base_url: format!("http://127.0.0.1:{port}"), + api_key: None, + }), + Err(crate::engine::runner::EnsureError::Superseded) => { + Err(TransportError::Superseded) + } + Err(crate::engine::runner::EnsureError::StartFailed(detail)) => { + Err(TransportError::Engine(EngineError { + kind: EngineErrorKind::EngineStartFailed, + message: format!("Thuki's engine could not start.\n{detail}"), + })) + } + } + } + } +} + /// Pulls the human-readable reason out of an Ollama error payload. Ollama /// returns `{"error":"..."}` on every non-2xx status from `/api/chat`; when /// the body is empty, malformed, or missing the `error` key we return @@ -586,40 +916,35 @@ pub async fn ask_model( active_model: State<'_, crate::models::ActiveModelState>, capabilities_cache: State<'_, ModelCapabilitiesCache>, trace_recorder: State<'_, std::sync::Arc>, + db: State<'_, crate::history::Database>, + model_store: State<'_, crate::models::storage::ModelStore>, + engine: State<'_, crate::engine::runner::EngineHandle>, + secrets: State<'_, crate::keychain::Secrets>, ) -> Result<(), String> { // Snapshot the config once so all downstream reads (endpoint, prompt, model) // see a consistent view even if the user edits Settings mid-stream. let config = config.read().clone(); - // Route by provider kind. Phase 1 implements only the native Ollama path; - // a non-Ollama active provider (the built-in engine) cannot serve yet, so - // bail with a typed, provider-labeled error regardless of model selection. - { - let kind = config.inference.active_provider_kind(); - let label = config - .inference - .active() - .map(|p| p.label.as_str()) - .unwrap_or(""); - if let Some(err) = unsupported_provider_error(kind, label) { + // Route by the active provider's kind: native Ollama, the built-in + // engine, or a generic OpenAI-compatible server. The decision is made + // once here; the streaming dispatch below consumes it. + let route = match resolve_chat_route(&config.inference) { + Ok(route) => route, + Err(err) => { let _ = on_event.send(StreamChunk::Error(err)); return Ok(()); } - } + }; - let endpoint = format!( - "{}/api/chat", - config - .inference - .active_provider_base_url() - .trim_end_matches('/') - ); - // Snapshot the active model slug; drop the guard before any `.await`. - let model_name = { + // Snapshot the picker-backed active model; drop the guard before any + // `.await`. It is only a fallback: `Builtin` routes carry their model in + // the provider config (kept fresh by `persist_active_provider_model`), so + // a builtin chat must not depend on this snapshot. + let snapshot = { let guard = active_model.0.lock().map_err(|e| e.to_string())?; guard.clone() }; - let Some(model_name) = model_name else { + let Some(model_name) = model_for_route(&route, snapshot) else { // Defense in depth: the onboarding gate already refuses to open the // overlay without a selected model, so this branch only fires if the // user removed their last installed model with `ollama rm` between @@ -752,27 +1077,81 @@ pub async fn ask_model( let token_count_for_pump = std::sync::Arc::clone(&token_count_atomic); let recorder_for_pump = std::sync::Arc::clone(&bound_recorder); - let accumulated = stream_ollama_chat( - OllamaChatParams { - endpoint, - model: model_name, - messages, - think, - keep_alive, - num_ctx: config.inference.num_ctx, - }, - &client, - cancel_token.clone(), - |chunk| { - // Mirror the user-visible chunk into the trace before - // forwarding it to the frontend. Token / ThinkingToken - // chunks land as discrete trace events; terminal chunks are - // summarized below by `AssistantComplete`. - record_chunk_to_trace(&chunk, &recorder_for_pump, &token_count_for_pump); - let _ = on_event.send(chunk); - }, - ) - .await; + // Mirror every user-visible chunk into the trace before forwarding it + // to the frontend. Token / ThinkingToken chunks land as discrete trace + // events; terminal chunks are summarized below by `AssistantComplete`. + // Captures by reference only, so the closure is Copy and each route arm + // can consume it. + let pump = |chunk: StreamChunk| { + record_chunk_to_trace(&chunk, &recorder_for_pump, &token_count_for_pump); + let _ = on_event.send(chunk); + }; + + // Every arm returns the accumulated assistant content (empty when the + // turn ended in a pre-stream error), so the persistence tail below is + // identical for all three routes. + let accumulated = match route { + ChatRoute::OllamaNative { endpoint } => { + stream_ollama_chat( + OllamaChatParams { + endpoint, + model: model_name, + messages, + think, + keep_alive, + num_ctx: config.inference.num_ctx, + }, + &client, + cancel_token.clone(), + pump, + ) + .await + } + ChatRoute::Builtin { model_id } => { + // Resolve the manifest row to blob-store paths inside a scope so + // the connection guard drops before any `.await`. + let target = { + let conn = db.0.lock().map_err(|e| e.to_string())?; + builtin_target(&conn, &model_store, &model_id, config.inference.num_ctx) + }; + match target { + Ok(target) => { + stream_builtin_chat( + &engine, + target, + model_id, + messages, + &client, + cancel_token.clone(), + pump, + ) + .await + } + Err(err) => { + pump(StreamChunk::Error(err)); + String::new() + } + } + } + ChatRoute::V1 { + base_url, + api_key_provider, + } => { + let api_key = resolve_provider_api_key(secrets.0.as_ref(), api_key_provider.as_deref()); + crate::openai::stream_openai_chat( + crate::openai::OpenAiChatParams { + base_url, + model: model_name, + messages, + api_key, + }, + &client, + cancel_token.clone(), + pump, + ) + .await + } + }; let stream_ended_ms = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -1847,26 +2226,6 @@ mod tests { ); } - #[test] - fn unsupported_provider_error_passes_ollama_and_flags_other_kinds() { - use crate::config::defaults::{PROVIDER_KIND_BUILTIN, PROVIDER_KIND_OLLAMA}; - // The native Ollama path proceeds (no error). - assert!(unsupported_provider_error(PROVIDER_KIND_OLLAMA, "Ollama").is_none()); - // The built-in kind has no Phase-1 implementation: typed error, labeled. - let err = unsupported_provider_error(PROVIDER_KIND_BUILTIN, "Built-in (Thuki)").unwrap(); - assert_eq!(err.kind, EngineErrorKind::EngineUnreachable); - assert!(err.message.contains("Built-in (Thuki)")); - // An empty label falls back to a generic noun. - let unlabeled = unsupported_provider_error(PROVIDER_KIND_BUILTIN, "").unwrap(); - assert!(unlabeled.message.contains("This provider")); - // Any non-Ollama kind is flagged, not just the built-in: the gate keys - // on `kind != ollama`, so a future provider kind also bails cleanly - // rather than falling through to the unreachable Ollama HTTP path. - let other = unsupported_provider_error("openai", "Cloud").unwrap(); - assert_eq!(other.kind, EngineErrorKind::EngineUnreachable); - assert!(other.message.contains("Cloud")); - } - #[test] fn engine_error_kinds_serialize_as_pascal_case() { // Wire format contract: every kind must serialize verbatim in @@ -1875,6 +2234,7 @@ mod tests { // and error routing without failing any other test. let cases = [ (EngineErrorKind::EngineUnreachable, "EngineUnreachable"), + (EngineErrorKind::EngineStartFailed, "EngineStartFailed"), (EngineErrorKind::ModelNotFound, "ModelNotFound"), (EngineErrorKind::NoModelSelected, "NoModelSelected"), (EngineErrorKind::Other, "Other"), @@ -2538,4 +2898,965 @@ mod tests { ); assert_eq!(mock.snapshot().len(), 0); } + + // ─── resolve_chat_route ───────────────────────────────────────────── + + /// Helper: an `InferenceSection` whose single provider `p1` is active. + fn inference_with_provider( + kind: &str, + base_url: &str, + model: &str, + ) -> crate::config::schema::InferenceSection { + use crate::config::schema::{InferenceSection, Provider}; + InferenceSection { + active_provider: "p1".to_string(), + providers: vec![Provider { + id: "p1".to_string(), + kind: kind.to_string(), + label: "Test".to_string(), + base_url: base_url.to_string(), + model: model.to_string(), + vision: false, + }], + ..Default::default() + } + } + + #[test] + fn resolve_chat_route_ollama() { + use crate::config::defaults::PROVIDER_KIND_OLLAMA; + let inference = + inference_with_provider(PROVIDER_KIND_OLLAMA, "http://127.0.0.1:11434/", ""); + assert_eq!( + resolve_chat_route(&inference).unwrap(), + ChatRoute::OllamaNative { + endpoint: "http://127.0.0.1:11434/api/chat".to_string(), + } + ); + } + + #[test] + fn resolve_chat_route_openai() { + use crate::config::defaults::PROVIDER_KIND_OPENAI; + let inference = + inference_with_provider(PROVIDER_KIND_OPENAI, "http://localhost:8080/", "qwen3"); + assert_eq!( + resolve_chat_route(&inference).unwrap(), + ChatRoute::V1 { + base_url: "http://localhost:8080".to_string(), + api_key_provider: Some("p1".to_string()), + } + ); + } + + #[test] + fn resolve_chat_route_builtin() { + use crate::config::defaults::PROVIDER_KIND_BUILTIN; + let inference = inference_with_provider(PROVIDER_KIND_BUILTIN, "", "org/repo:m.gguf"); + assert_eq!( + resolve_chat_route(&inference).unwrap(), + ChatRoute::Builtin { + model_id: "org/repo:m.gguf".to_string(), + } + ); + } + + #[test] + fn resolve_chat_route_no_model_selected() { + use crate::config::defaults::PROVIDER_KIND_BUILTIN; + let inference = inference_with_provider(PROVIDER_KIND_BUILTIN, "", ""); + let err = resolve_chat_route(&inference).unwrap_err(); + assert_eq!(err.kind, EngineErrorKind::NoModelSelected); + assert!(err.message.contains("Settings")); + } + + #[test] + fn resolve_chat_route_unknown_kind() { + let inference = inference_with_provider("weird", "http://x", "m"); + let err = resolve_chat_route(&inference).unwrap_err(); + assert_eq!(err.kind, EngineErrorKind::Other); + assert!(err.message.contains("unknown kind")); + } + + // ─── builtin_target ───────────────────────────────────────────────── + + /// Helper: a complete manifest row keyed by `id` with the given hashes. + fn installed_model( + id: &str, + sha256: &str, + mmproj_sha256: Option<&str>, + ) -> crate::models::manifest::InstalledModel { + crate::models::manifest::InstalledModel { + id: id.to_string(), + display_name: format!("Model {id}"), + repo: "org/repo".to_string(), + revision: "a".repeat(40), + file_name: format!("{id}.gguf"), + sha256: sha256.to_string(), + size_bytes: 1_000_000, + quant: "Q4_K_M".to_string(), + vision: mmproj_sha256.is_some(), + thinking: false, + mmproj_file: mmproj_sha256.map(|_| format!("{id}-mmproj.gguf")), + mmproj_sha256: mmproj_sha256.map(str::to_string), + } + } + + #[test] + fn builtin_target_maps_manifest_row() { + let conn = crate::database::open_in_memory().unwrap(); + let dir = tempfile::tempdir().unwrap(); + let store = crate::models::storage::ModelStore::new(dir.path().to_path_buf()).unwrap(); + crate::models::manifest::insert( + &conn, + &installed_model("org/repo:v.gguf", "sha_w", Some("sha_mm")), + ) + .unwrap(); + crate::models::manifest::insert(&conn, &installed_model("org/repo:t.gguf", "sha_t", None)) + .unwrap(); + + let vision = builtin_target(&conn, &store, "org/repo:v.gguf", 4096).unwrap(); + assert_eq!(vision.model_path, store.blob_path("sha_w")); + assert_eq!(vision.mmproj_path, Some(store.blob_path("sha_mm"))); + assert_eq!(vision.num_ctx, 4096); + + let text = builtin_target(&conn, &store, "org/repo:t.gguf", DEFAULT_NUM_CTX).unwrap(); + assert_eq!(text.model_path, store.blob_path("sha_t")); + assert_eq!(text.mmproj_path, None); + assert_eq!(text.num_ctx, DEFAULT_NUM_CTX); + } + + #[test] + fn builtin_target_missing_row_is_model_not_found() { + let conn = crate::database::open_in_memory().unwrap(); + let dir = tempfile::tempdir().unwrap(); + let store = crate::models::storage::ModelStore::new(dir.path().to_path_buf()).unwrap(); + let err = builtin_target(&conn, &store, "org/repo:gone.gguf", 4096).unwrap_err(); + assert_eq!(err.kind, EngineErrorKind::ModelNotFound); + assert!(err.message.contains("Settings")); + } + + #[test] + fn builtin_target_manifest_read_error_is_other() { + // A bare connection without the schema makes `manifest::get` fail. + let conn = rusqlite::Connection::open_in_memory().unwrap(); + let dir = tempfile::tempdir().unwrap(); + let store = crate::models::storage::ModelStore::new(dir.path().to_path_buf()).unwrap(); + let err = builtin_target(&conn, &store, "org/repo:m.gguf", 4096).unwrap_err(); + assert_eq!(err.kind, EngineErrorKind::Other); + assert!(err.message.contains("manifest")); + } + + // ─── resolve_provider_api_key ─────────────────────────────────────── + + #[test] + fn resolve_provider_api_key_reads_key_and_misses_to_none() { + use crate::keychain::SecretStore; + let store = crate::keychain::FakeSecretStore::new(); + store.set("p1", "sk-test").unwrap(); + assert_eq!( + resolve_provider_api_key(&store, Some("p1")), + Some("sk-test".to_string()) + ); + assert_eq!(resolve_provider_api_key(&store, Some("absent")), None); + assert_eq!(resolve_provider_api_key(&store, None), None); + } + + /// A secret store whose reads always fail, for the degrade-to-None path. + struct FailingSecretStore; + + impl crate::keychain::SecretStore for FailingSecretStore { + fn set(&self, _provider_id: &str, _secret: &str) -> Result<(), String> { + Err("locked".to_string()) + } + fn get(&self, _provider_id: &str) -> Result, String> { + Err("locked".to_string()) + } + fn delete(&self, _provider_id: &str) -> Result<(), String> { + Err("locked".to_string()) + } + } + + #[test] + fn resolve_provider_api_key_error_degrades_to_none() { + use crate::keychain::SecretStore; + assert_eq!( + resolve_provider_api_key(&FailingSecretStore, Some("p1")), + None + ); + // The other trait methods fail too; the chat path never calls them. + assert!(FailingSecretStore.set("p1", "sk").is_err()); + assert!(FailingSecretStore.delete("p1").is_err()); + } + + // ─── Ollama native path regression ────────────────────────────────── + + /// Locks the native `/api/chat` wire contract across the routing change: + /// the exact request body (model, messages, stream, think, options + /// {temperature, top_p, top_k, num_ctx}, keep_alive) must be identical + /// to the pre-routing Phase 1 payload. + #[tokio::test] + async fn ollama_request_body_unchanged() { + use crate::config::defaults::PROVIDER_KIND_OLLAMA; + let mut server = mockito::Server::new_async().await; + + // The endpoint comes from the route resolver, exactly as `ask_model` + // dispatches it. + let inference = + inference_with_provider(PROVIDER_KIND_OLLAMA, &format!("{}/", server.url()), ""); + let endpoint = format!("{}/api/chat", server.url()); + assert_eq!( + resolve_chat_route(&inference).unwrap(), + ChatRoute::OllamaNative { + endpoint: endpoint.clone(), + } + ); + + let expected_body = serde_json::json!({ + "model": "gemma3:12b", + "messages": [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"} + ], + "stream": true, + "think": false, + "options": { + "temperature": 1.0, + "top_p": 0.95, + "top_k": 64, + "num_ctx": DEFAULT_NUM_CTX + }, + "keep_alive": "10m" + }); + let mock = server + .mock("POST", "/api/chat") + .match_body(mockito::Matcher::Json(expected_body)) + .with_body(chat_line("", true)) + .create_async() + .await; + + let client = reqwest::Client::new(); + let (_, callback) = collect_chunks(); + stream_ollama_chat( + OllamaChatParams { + endpoint, + model: "gemma3:12b".to_string(), + messages: vec![ + ChatMessage { + role: "system".to_string(), + content: "sys".to_string(), + images: None, + }, + ChatMessage { + role: "user".to_string(), + content: "hi".to_string(), + images: None, + }, + ], + think: false, + keep_alive: Some("10m".to_string()), + num_ctx: DEFAULT_NUM_CTX, + }, + &client, + CancellationToken::new(), + callback, + ) + .await; + + mock.assert_async().await; + } + + // ─── stream_builtin_chat ──────────────────────────────────────────── + + /// Scriptable [`crate::engine::process::EngineProcess`] for the built-in + /// route tests: hands out a fixed port, optionally fails every spawn, + /// and either answers health probes with 200 or hangs them forever so a + /// test can preempt the in-flight ensure. + struct ScriptedEngineProcess { + port: u16, + spawn_error: Option, + healthy: bool, + } + + struct ScriptedChild { + exit_tx: tokio::sync::watch::Sender, + exit_rx: tokio::sync::watch::Receiver, + } + + #[async_trait::async_trait] + impl crate::engine::process::EngineChild for ScriptedChild { + async fn wait_exit(&mut self) { + let _ = self.exit_rx.wait_for(|exited| *exited).await; + } + async fn kill(&mut self) { + let _ = self.exit_tx.send(true); + } + } + + #[async_trait::async_trait] + impl crate::engine::process::EngineProcess for ScriptedEngineProcess { + async fn spawn( + &self, + _args: &crate::engine::process::SpawnArgs, + ) -> Result, String> { + if let Some(ref message) = self.spawn_error { + return Err(message.clone()); + } + let (exit_tx, exit_rx) = tokio::sync::watch::channel(false); + Ok(Box::new(ScriptedChild { exit_tx, exit_rx })) + } + fn free_port(&self) -> Result { + Ok(self.port) + } + async fn health_probe(&self, _port: u16) -> Result { + if !self.healthy { + // Hangs until the poll task is dropped by a kill; the + // answer below is only ever reached on the healthy path. + std::future::pending::<()>().await; + } + Ok(200) + } + } + + /// Helper: an [`crate::engine::state::Target`] with placeholder paths. + fn engine_target() -> crate::engine::state::Target { + crate::engine::state::Target { + model_path: std::path::PathBuf::from("/tmp/m.gguf"), + mmproj_path: None, + num_ctx: DEFAULT_NUM_CTX, + } + } + + /// Helper: an [`crate::engine::runner::EngineHandle`] over a scripted + /// process with idle unload disabled. + fn spawn_engine(process: ScriptedEngineProcess) -> crate::engine::runner::EngineHandle { + crate::engine::runner::EngineHandle::spawn( + Arc::new(process), + 0, + std::time::Duration::from_secs(3600), + ) + } + + #[tokio::test] + async fn stream_builtin_chat_streams_from_engine_port() { + let mut server = mockito::Server::new_async().await; + let port: u16 = server + .url() + .rsplit(':') + .next() + .unwrap() + .parse() + .expect("mockito url ends in a port"); + let mock = server + .mock("POST", "/v1/chat/completions") + .with_header("content-type", "text/event-stream") + .with_body("data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\ndata: [DONE]\n") + .create_async() + .await; + + let engine = spawn_engine(ScriptedEngineProcess { + port, + spawn_error: None, + healthy: true, + }); + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + let accumulated = stream_builtin_chat( + &engine, + engine_target(), + "org/repo:m.gguf".to_string(), + vec![], + &client, + CancellationToken::new(), + callback, + ) + .await; + + mock.assert_async().await; + assert_eq!(accumulated, "Hi"); + let chunks = chunks.lock().unwrap(); + assert!(matches!(&chunks[0], StreamChunk::Token(t) if t == "Hi")); + assert_eq!( + std::mem::discriminant(&chunks[1]), + std::mem::discriminant(&StreamChunk::Done) + ); + engine.shutdown().await; + } + + #[tokio::test] + async fn superseded_ensure_emits_cancelled() { + // Health probes hang, so the ensure stays in flight until the + // unload preempts it. + let engine = spawn_engine(ScriptedEngineProcess { + port: 1, + spawn_error: None, + healthy: false, + }); + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + + let task = { + let engine = engine.clone(); + tokio::spawn(async move { + stream_builtin_chat( + &engine, + engine_target(), + "org/repo:m.gguf".to_string(), + vec![], + &client, + CancellationToken::new(), + callback, + ) + .await + }) + }; + + // Wait until the spawn landed and the health poll is in flight, + // then preempt the waiting ensure. + let mut status = engine.status(); + status + .wait_for(|s| s.state == "starting") + .await + .expect("actor is running"); + engine.unload().await; + + let accumulated = task.await.unwrap(); + assert_eq!(accumulated, ""); + let chunks = chunks.lock().unwrap(); + assert_eq!(chunks.len(), 1, "exactly one terminal chunk"); + assert_eq!( + std::mem::discriminant(&chunks[0]), + std::mem::discriminant(&StreamChunk::Cancelled) + ); + engine.shutdown().await; + } + + #[tokio::test] + async fn start_failed_maps_engine_start_failed() { + let engine = spawn_engine(ScriptedEngineProcess { + port: 1, + spawn_error: Some("spawn boom".to_string()), + healthy: true, + }); + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + let accumulated = stream_builtin_chat( + &engine, + engine_target(), + "org/repo:m.gguf".to_string(), + vec![], + &client, + CancellationToken::new(), + callback, + ) + .await; + + assert_eq!(accumulated, ""); + let chunks = chunks.lock().unwrap(); + assert_eq!(chunks.len(), 1, "exactly one terminal chunk"); + assert!(matches!( + &chunks[0], + StreamChunk::Error(e) + if e.kind == EngineErrorKind::EngineStartFailed + && e.message.starts_with("Thuki's engine could not start.\n") + && e.message.contains("spawn boom") + )); + engine.shutdown().await; + } + + // ─── /props runtime vision gate ───────────────────────────────────── + + #[test] + fn parse_props_vision_true_false_absent_malformed() { + assert!(parse_props_vision(br#"{"modalities":{"vision":true}}"#)); + assert!(!parse_props_vision(br#"{"modalities":{"vision":false}}"#)); + assert!(!parse_props_vision(br#"{"modalities":{}}"#), "absent flag"); + assert!(!parse_props_vision(br#"{}"#), "absent modalities"); + assert!( + !parse_props_vision(br#"{"modalities":{"vision":"yes"}}"#), + "non-boolean flag fails closed" + ); + assert!(!parse_props_vision(b"not json"), "malformed body"); + } + + #[tokio::test] + async fn fetch_builtin_vision_transport_error_fails_closed() { + let client = reqwest::Client::new(); + assert!(!fetch_builtin_vision(&client, "http://127.0.0.1:1").await); + } + + /// A 2xx `/props` response whose body dies mid-read (connection closed + /// before the promised Content-Length) fails closed, like every other + /// gate failure mode. + #[tokio::test] + async fn fetch_builtin_vision_body_read_failure_fails_closed() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut req_buf = [0u8; 8192]; + let _ = stream.read(&mut req_buf).await; + // Promise more bytes than are sent, then shut down. + let response = + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 1000\r\n\r\n{\"modalities\""; + let _ = stream.write_all(response.as_bytes()).await; + let _ = stream.shutdown().await; + }); + + let client = reqwest::Client::new(); + assert!(!fetch_builtin_vision(&client, &format!("http://127.0.0.1:{port}")).await); + } + + /// Messages carrying one image, as the gate sees them after the + /// capability snapshot is built. + fn image_message() -> Vec { + vec![ChatMessage { + role: "user".to_string(), + content: "hi".to_string(), + images: Some(vec!["QUJD".to_string()]), + }] + } + + /// Drives `stream_builtin_chat` against a mockito server acting as the + /// engine port, with `/props` scripted to report `vision` and the chat + /// mock matching `expected_chat_body`. Returns once both mocks assert. + async fn run_props_gate_case(vision: bool, expected_chat_body: &str) { + let mut server = mockito::Server::new_async().await; + let port: u16 = server + .url() + .rsplit(':') + .next() + .unwrap() + .parse() + .expect("mockito url ends in a port"); + let props_mock = server + .mock("GET", "/props") + .with_status(200) + .with_body(format!(r#"{{"modalities":{{"vision":{vision}}}}}"#)) + .create_async() + .await; + let chat_mock = server + .mock("POST", "/v1/chat/completions") + .match_body(mockito::Matcher::PartialJsonString( + expected_chat_body.to_string(), + )) + .with_header("content-type", "text/event-stream") + .with_body("data: [DONE]\n") + .create_async() + .await; + + let engine = spawn_engine(ScriptedEngineProcess { + port, + spawn_error: None, + healthy: true, + }); + let client = reqwest::Client::new(); + let (_chunks, callback) = collect_chunks(); + stream_builtin_chat( + &engine, + engine_target(), + "org/repo:m.gguf".to_string(), + image_message(), + &client, + CancellationToken::new(), + callback, + ) + .await; + + props_mock.assert_async().await; + chat_mock.assert_async().await; + engine.shutdown().await; + } + + #[tokio::test] + async fn props_gate_strips_images_when_vision_unloaded() { + // vision:false -> the image part is stripped, so the wire message + // collapses to the plain-string content shape. + run_props_gate_case(false, r#"{"messages":[{"role":"user","content":"hi"}]}"#).await; + } + + #[tokio::test] + async fn props_gate_keeps_images_when_vision_supported() { + // vision:true -> the multipart content shape with the image part + // reaches the wire untouched. + run_props_gate_case( + true, + r#"{"messages":[{"role":"user","content":[{"type":"text","text":"hi"},{"type":"image_url","image_url":{"url":"data:image/jpeg;base64,QUJD"}}]}]}"#, + ) + .await; + } + + // ─── LlmTransport / resolve_llm_transport ─────────────────────────── + + #[test] + fn llm_transport_endpoint_label_names_the_wire_target() { + let native = LlmTransport::OllamaNative { + endpoint: "http://127.0.0.1:11434/api/chat".to_string(), + }; + assert_eq!(native.endpoint_label(), "http://127.0.0.1:11434/api/chat"); + let v1 = LlmTransport::V1 { + base_url: "http://localhost:8080".to_string(), + api_key: None, + }; + assert_eq!( + v1.endpoint_label(), + "http://localhost:8080/v1/chat/completions" + ); + } + + #[test] + fn llm_transport_debug_redacts_api_key() { + let with_key = LlmTransport::V1 { + base_url: "https://api.openai.com".to_string(), + api_key: Some("sk-supersecret".to_string()), + }; + let debug = format!("{with_key:?}"); + assert!( + !debug.contains("sk-supersecret"), + "key must not appear in Debug output" + ); + assert!( + debug.contains(""), + "redacted placeholder must be present" + ); + + let no_key = LlmTransport::V1 { + base_url: "http://127.0.0.1:8080".to_string(), + api_key: None, + }; + let debug_none = format!("{no_key:?}"); + assert!(debug_none.contains("None"), "None key must show as None"); + + // OllamaNative has no key field; just verify it formats without panic. + let native = LlmTransport::OllamaNative { + endpoint: "http://127.0.0.1:11434/api/chat".to_string(), + }; + let debug_native = format!("{native:?}"); + assert!(debug_native.contains("OllamaNative")); + } + + // ─── model_for_route ──────────────────────────────────────────────────── + + #[test] + fn model_for_route_prefers_builtin_provider_model() { + let route = ChatRoute::Builtin { + model_id: "org/repo:m.gguf".to_string(), + }; + assert_eq!( + model_for_route(&route, Some("gemma3:12b".to_string())), + Some("org/repo:m.gguf".to_string()) + ); + assert_eq!( + model_for_route(&route, None), + Some("org/repo:m.gguf".to_string()) + ); + } + + #[test] + fn model_for_route_keeps_fallback_for_non_builtin_routes() { + let ollama = ChatRoute::OllamaNative { + endpoint: "http://127.0.0.1:11434/api/chat".to_string(), + }; + assert_eq!( + model_for_route(&ollama, Some("gemma3:12b".to_string())), + Some("gemma3:12b".to_string()) + ); + assert_eq!(model_for_route(&ollama, None), None); + + let v1 = ChatRoute::V1 { + base_url: "http://localhost:8080".to_string(), + api_key_provider: None, + }; + assert_eq!( + model_for_route(&v1, Some("gpt-4o".to_string())), + Some("gpt-4o".to_string()) + ); + assert_eq!(model_for_route(&v1, None), None); + } + + /// Helper: a `Database` over a fresh in-memory schema. + fn test_db() -> crate::history::Database { + crate::history::Database(StdMutex::new(crate::database::open_in_memory().unwrap())) + } + + /// Helper: a `ModelStore` rooted in a fresh temp dir, plus the dir guard. + fn test_store() -> (tempfile::TempDir, crate::models::storage::ModelStore) { + let dir = tempfile::tempdir().unwrap(); + let store = crate::models::storage::ModelStore::new(dir.path().to_path_buf()).unwrap(); + (dir, store) + } + + #[tokio::test] + async fn resolve_llm_transport_passes_ollama_endpoint_through() { + let db = test_db(); + let (_dir, store) = test_store(); + let engine = spawn_engine(ScriptedEngineProcess { + port: 1, + spawn_error: None, + healthy: true, + }); + let secrets = crate::keychain::FakeSecretStore::new(); + let transport = resolve_llm_transport( + ChatRoute::OllamaNative { + endpoint: "http://127.0.0.1:11434/api/chat".to_string(), + }, + &db, + &store, + &engine, + &secrets, + DEFAULT_NUM_CTX, + ) + .await + .unwrap(); + assert_eq!( + transport, + LlmTransport::OllamaNative { + endpoint: "http://127.0.0.1:11434/api/chat".to_string(), + } + ); + engine.shutdown().await; + } + + #[tokio::test] + async fn resolve_llm_transport_v1_resolves_api_key() { + use crate::keychain::SecretStore; + let db = test_db(); + let (_dir, store) = test_store(); + let engine = spawn_engine(ScriptedEngineProcess { + port: 1, + spawn_error: None, + healthy: true, + }); + let secrets = crate::keychain::FakeSecretStore::new(); + secrets.set("p1", "sk-test").unwrap(); + let transport = resolve_llm_transport( + ChatRoute::V1 { + base_url: "http://localhost:8080".to_string(), + api_key_provider: Some("p1".to_string()), + }, + &db, + &store, + &engine, + &secrets, + DEFAULT_NUM_CTX, + ) + .await + .unwrap(); + assert_eq!( + transport, + LlmTransport::V1 { + base_url: "http://localhost:8080".to_string(), + api_key: Some("sk-test".to_string()), + } + ); + engine.shutdown().await; + } + + #[tokio::test] + async fn resolve_llm_transport_builtin_ensures_engine() { + let db = test_db(); + { + let conn = db.0.lock().unwrap(); + crate::models::manifest::insert( + &conn, + &installed_model("org/repo:m.gguf", "sha_w", None), + ) + .unwrap(); + } + let (_dir, store) = test_store(); + let engine = spawn_engine(ScriptedEngineProcess { + port: 4242, + spawn_error: None, + healthy: true, + }); + let secrets = crate::keychain::FakeSecretStore::new(); + let transport = resolve_llm_transport( + ChatRoute::Builtin { + model_id: "org/repo:m.gguf".to_string(), + }, + &db, + &store, + &engine, + &secrets, + DEFAULT_NUM_CTX, + ) + .await + .unwrap(); + assert_eq!( + transport, + LlmTransport::V1 { + base_url: "http://127.0.0.1:4242".to_string(), + api_key: None, + } + ); + // The ensure landed: the engine reports the loaded model. + assert_eq!(engine.status().borrow().state, "loaded"); + engine.shutdown().await; + } + + #[tokio::test] + async fn resolve_llm_transport_builtin_missing_row_is_engine_error() { + let db = test_db(); + let (_dir, store) = test_store(); + let engine = spawn_engine(ScriptedEngineProcess { + port: 1, + spawn_error: None, + healthy: true, + }); + let secrets = crate::keychain::FakeSecretStore::new(); + let err = resolve_llm_transport( + ChatRoute::Builtin { + model_id: "org/repo:gone.gguf".to_string(), + }, + &db, + &store, + &engine, + &secrets, + DEFAULT_NUM_CTX, + ) + .await + .unwrap_err(); + assert!(matches!( + err, + TransportError::Engine(e) if e.kind == EngineErrorKind::ModelNotFound + )); + engine.shutdown().await; + } + + #[tokio::test] + async fn resolve_llm_transport_recovers_poisoned_db_lock() { + let db = test_db(); + { + let conn = db.0.lock().unwrap(); + crate::models::manifest::insert( + &conn, + &installed_model("org/repo:m.gguf", "sha_w", None), + ) + .unwrap(); + } + // Poison the connection mutex with an unrelated panic; the resolver + // must recover the guard rather than fail the turn. + let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + let _guard = db.0.lock().unwrap(); + panic!("poison"); + })); + assert!(db.0.lock().is_err(), "mutex must be poisoned"); + + let (_dir, store) = test_store(); + let engine = spawn_engine(ScriptedEngineProcess { + port: 4243, + spawn_error: None, + healthy: true, + }); + let secrets = crate::keychain::FakeSecretStore::new(); + let transport = resolve_llm_transport( + ChatRoute::Builtin { + model_id: "org/repo:m.gguf".to_string(), + }, + &db, + &store, + &engine, + &secrets, + DEFAULT_NUM_CTX, + ) + .await + .unwrap(); + assert_eq!( + transport, + LlmTransport::V1 { + base_url: "http://127.0.0.1:4243".to_string(), + api_key: None, + } + ); + engine.shutdown().await; + } + + #[tokio::test] + async fn resolve_llm_transport_superseded_and_start_failed_map() { + // StartFailed: every spawn errors out. + let db = test_db(); + { + let conn = db.0.lock().unwrap(); + crate::models::manifest::insert( + &conn, + &installed_model("org/repo:m.gguf", "sha_w", None), + ) + .unwrap(); + } + let (_dir, store) = test_store(); + let secrets = crate::keychain::FakeSecretStore::new(); + let engine = spawn_engine(ScriptedEngineProcess { + port: 1, + spawn_error: Some("spawn boom".to_string()), + healthy: true, + }); + let err = resolve_llm_transport( + ChatRoute::Builtin { + model_id: "org/repo:m.gguf".to_string(), + }, + &db, + &store, + &engine, + &secrets, + DEFAULT_NUM_CTX, + ) + .await + .unwrap_err(); + assert!(matches!( + err, + TransportError::Engine(ref e) + if e.kind == EngineErrorKind::EngineStartFailed + && e.message.starts_with("Thuki's engine could not start.\n") + && e.message.contains("spawn boom") + )); + engine.shutdown().await; + + // Superseded: health probes hang, so the in-flight ensure can be + // preempted by an unload. + let engine = spawn_engine(ScriptedEngineProcess { + port: 1, + spawn_error: None, + healthy: false, + }); + let task = { + let engine = engine.clone(); + let db = test_db(); + { + let conn = db.0.lock().unwrap(); + crate::models::manifest::insert( + &conn, + &installed_model("org/repo:m.gguf", "sha_w", None), + ) + .unwrap(); + } + let (_dir2, store2) = test_store(); + tokio::spawn(async move { + let secrets = crate::keychain::FakeSecretStore::new(); + resolve_llm_transport( + ChatRoute::Builtin { + model_id: "org/repo:m.gguf".to_string(), + }, + &db, + &store2, + &engine, + &secrets, + DEFAULT_NUM_CTX, + ) + .await + }) + }; + let mut status = engine.status(); + status + .wait_for(|s| s.state == "starting") + .await + .expect("actor is running"); + engine.unload().await; + let err = task.await.unwrap().unwrap_err(); + assert_eq!(err, TransportError::Superseded); + engine.shutdown().await; + } } diff --git a/src-tauri/src/config/defaults.rs b/src-tauri/src/config/defaults.rs index 82d31687..8b0b2698 100644 --- a/src-tauri/src/config/defaults.rs +++ b/src-tauri/src/config/defaults.rs @@ -14,9 +14,15 @@ pub const PROVIDER_ID_BUILTIN: &str = "builtin"; pub const PROVIDER_ID_OLLAMA: &str = "ollama"; /// Provider kinds understood by the loader. Providers with any other kind are -/// dropped during resolution. +/// dropped during resolution. Recognized kinds: `"builtin"`, `"ollama"`, +/// `"openai"`. pub const PROVIDER_KIND_BUILTIN: &str = "builtin"; pub const PROVIDER_KIND_OLLAMA: &str = "ollama"; +/// Any OpenAI-compatible local or remote inference server (LM Studio, Jan, +/// llama-server, etc.). Requires a valid http(s) `base_url`; providers with +/// an empty or non-http(s) URL are dropped rather than healed (unlike Ollama, +/// there is no sensible localhost default for arbitrary /v1 servers). +pub const PROVIDER_KIND_OPENAI: &str = "openai"; /// Human-readable provider labels shown in Settings. pub const DEFAULT_BUILTIN_LABEL: &str = "Built-in (Thuki)"; @@ -105,6 +111,11 @@ pub const ENGINE_COMMAND_QUEUE_CAPACITY: usize = 64; /// user-tunable: pure IPC hygiene, invisible below the UI refresh rate. pub const DOWNLOAD_PROGRESS_MIN_INTERVAL_MS: u64 = 500; +/// Maximum accepted length of a single Server-Sent-Events line from a /v1 +/// streaming response. Bounds attacker-controlled data from a chat server +/// (a malicious or broken server cannot grow a single line unboundedly). +pub const MAX_SSE_LINE_BYTES: usize = 1024 * 1024; + /// Built-in secretary persona prompt. User overrides via `[prompt] system` in /// the config file. The slash-command appendix is composed on top at load time /// and is never written back to the file. diff --git a/src-tauri/src/config/loader.rs b/src-tauri/src/config/loader.rs index a636ca11..ba77fcc7 100644 --- a/src-tauri/src/config/loader.rs +++ b/src-tauri/src/config/loader.rs @@ -312,10 +312,17 @@ pub(crate) fn resolve(config: &mut AppConfig) { /// panics on user input. fn resolve_inference(inf: &mut crate::config::schema::InferenceSection) { use crate::config::defaults::{ - DEFAULT_ACTIVE_PROVIDER, PROVIDER_KIND_BUILTIN, PROVIDER_KIND_OLLAMA, + DEFAULT_ACTIVE_PROVIDER, PROVIDER_ID_BUILTIN, PROVIDER_ID_OLLAMA, PROVIDER_KIND_BUILTIN, + PROVIDER_KIND_OLLAMA, PROVIDER_KIND_OPENAI, }; use crate::config::schema::{builtin_provider, ollama_provider}; + // Snapshot the file shape before any provider synthesis or reseed: a + // pre-providers file (no [[inference.providers]] array) deserializes to + // an empty list, while fresh-seeded defaults and new-shape files always + // carry providers. Consumed by the active-pointer pin at the end. + let is_pre_providers_file = inf.providers.is_empty(); + // num_ctx + keep_warm: unchanged clamping (Ollama-path knobs). clamp_u32( &mut inf.num_ctx, @@ -338,10 +345,8 @@ fn resolve_inference(inf: &mut crate::config::schema::InferenceSection) { // Migration: a pre-providers file has `ollama_url` and no `providers`. // Carry the URL onto a synthesized Ollama provider; the active model is // attached later during startup orchestration (it lives in SQLite). The - // active pointer is left to the dangling-pointer repair below: a migrated - // config either omits `active_provider` (serde defaults it to the Phase-1 - // default of `ollama`) or names something the repair resets to that same - // default, so existing Ollama users land on the Ollama provider either way. + // active pointer is handled by the pre-providers pin at the end of this + // function. if let Some(legacy) = inf.legacy_ollama_url.take() { if inf.providers.is_empty() { let url = if legacy.trim().is_empty() { @@ -372,10 +377,27 @@ fn resolve_inference(inf: &mut crate::config::schema::InferenceSection) { } } - // Drop unknown-kind providers and non-builtin providers with no base_url. + // Drop unknown-kind providers and network providers with no valid base_url. + // builtin: always kept (URL not required). + // ollama: kept when base_url is non-empty (Ollama heal loop above already + // reset bad schemes; an empty URL is dropped and the reseed below + // restores the localhost default). + // openai: kept only when base_url is a valid http(s) URL. Unlike Ollama + // there is no sensible localhost default for arbitrary /v1 servers, + // so an empty or non-http(s) URL is dropped without healing. inf.providers.retain(|p| match p.kind.as_str() { PROVIDER_KIND_BUILTIN => true, PROVIDER_KIND_OLLAMA => !p.base_url.trim().is_empty(), + PROVIDER_KIND_OPENAI => { + let ok = is_http_url(&p.base_url); + if !ok { + eprintln!( + "thuki: [config] dropping openai provider '{}': base_url must be a non-empty http(s) URL", + p.id + ); + } + ok + } other => { eprintln!("thuki: [config] dropping provider with unknown kind '{other}'"); false @@ -405,6 +427,23 @@ fn resolve_inference(inf: &mut crate::config::schema::InferenceSection) { } inf.active_provider = DEFAULT_ACTIVE_PROVIDER.to_string(); } + + // A pre-providers file (no [[inference.providers]] array) predates the + // built-in engine: that user runs Ollama. Pin the pointer explicitly so + // the compiled default (which favors the built-in engine from Phase 2 on) + // only ever applies to fresh installs and new-shape files. Covers both + // legacy shapes: with an ollama_url key and without one. An explicit + // active_provider equal to the compiled default, or naming the built-in + // provider, is also overridden here: in a pre-providers file neither + // value can refer to a working built-in provider (none existed when the + // file was written). + if is_pre_providers_file + && (inf.active_provider.trim().is_empty() + || inf.active_provider == PROVIDER_ID_BUILTIN + || inf.active_provider == DEFAULT_ACTIVE_PROVIDER) + { + inf.active_provider = PROVIDER_ID_OLLAMA.to_string(); + } } /// True when `url` is an absolute http(s) URL. Used to keep a malformed or diff --git a/src-tauri/src/config/migrate.rs b/src-tauri/src/config/migrate.rs index 7b86856f..12384f46 100644 --- a/src-tauri/src/config/migrate.rs +++ b/src-tauri/src/config/migrate.rs @@ -2,12 +2,16 @@ //! startup orchestration (SQLite active-model fold-in). Kept pure so both //! halves are unit-tested without a Tauri app or a real SQLite connection. +use super::defaults::PROVIDER_KIND_OLLAMA; use super::schema::AppConfig; /// Attaches a legacy SQLite `active_model` onto the active provider's `model` -/// field when that provider has no model yet. Returns true if it mutated the -/// config (so startup can decide whether to persist). No-op when `legacy` is -/// empty/whitespace or the active provider already has a model. +/// field when that provider is Ollama-kind and has no model yet. The legacy +/// slug is by definition an Ollama model name, so it never attaches to a +/// provider of any other kind. Returns true if it mutated the config (so +/// startup can decide whether to persist). No-op when `legacy` is +/// empty/whitespace, the active provider is not Ollama-kind, or it already +/// has a model. pub fn attach_legacy_active_model(config: &mut AppConfig, legacy: Option<&str>) -> bool { let Some(model) = legacy.map(str::trim).filter(|m| !m.is_empty()) else { return false; @@ -19,6 +23,9 @@ pub fn attach_legacy_active_model(config: &mut AppConfig, legacy: Option<&str>) .iter_mut() .find(|p| p.id == active_id) { + if provider.kind != PROVIDER_KIND_OLLAMA { + return false; + } if provider.model.trim().is_empty() { provider.model = model.to_string(); return true; diff --git a/src-tauri/src/config/schema.rs b/src-tauri/src/config/schema.rs index 43ade2b2..1e9da3ad 100644 --- a/src-tauri/src/config/schema.rs +++ b/src-tauri/src/config/schema.rs @@ -26,7 +26,7 @@ use super::defaults::{ DEFAULT_TEXT_BASE_PX, DEFAULT_TEXT_FONT_WEIGHT, DEFAULT_TEXT_LETTER_SPACING_PX, DEFAULT_TEXT_LINE_HEIGHT, DEFAULT_TOP_K_URLS, DEFAULT_UPDATER_AUTO_CHECK, DEFAULT_UPDATER_CHECK_INTERVAL_HOURS, DEFAULT_UPDATER_MANIFEST_URL, PROVIDER_ID_BUILTIN, - PROVIDER_ID_OLLAMA, PROVIDER_KIND_BUILTIN, PROVIDER_KIND_OLLAMA, + PROVIDER_ID_OLLAMA, PROVIDER_KIND_BUILTIN, PROVIDER_KIND_OLLAMA, PROVIDER_KIND_OPENAI, }; /// A single configured inference provider. Exactly one is active at a time @@ -38,15 +38,23 @@ use super::defaults::{ pub struct Provider { /// Stable identifier referenced by `active_provider`. pub id: String, - /// Provider kind: `"builtin"` or `"ollama"`. Unknown kinds are dropped by - /// the loader. + /// Provider kind: `"builtin"`, `"ollama"`, or `"openai"`. Unknown kinds + /// are dropped by the loader. pub kind: String, /// Human-readable name shown in Settings. pub label: String, - /// Base URL for network providers (Ollama). Empty for the built-in engine. + /// Base URL for network providers (Ollama, OpenAI-compatible). Empty for + /// the built-in engine. pub base_url: String, /// The model selected for this provider. Empty means "none chosen yet". pub model: String, + /// Manual vision flag for `openai`-kind providers. OpenAI-compatible local + /// servers expose no capability probe, so the user declares whether the + /// selected model accepts image inputs. Ignored for `builtin` and `ollama`, + /// whose capabilities are resolved from the manifest or Ollama's + /// `/api/show` response. + #[serde(default)] + pub vision: bool, } /// The built-in provider record (Thuki's own engine; no URL). @@ -57,6 +65,7 @@ pub fn builtin_provider() -> Provider { label: DEFAULT_BUILTIN_LABEL.to_string(), base_url: String::new(), model: String::new(), + vision: false, } } @@ -68,6 +77,21 @@ pub fn ollama_provider(base_url: &str) -> Provider { label: DEFAULT_OLLAMA_LABEL.to_string(), base_url: base_url.to_string(), model: String::new(), + vision: false, + } +} + +/// An OpenAI-compatible provider record with the given id, label, and base URL. +/// `vision` defaults to `false`; the caller or user sets it to `true` when the +/// selected model accepts image inputs. +pub fn openai_provider(id: &str, label: &str, base_url: &str) -> Provider { + Provider { + id: id.to_string(), + kind: PROVIDER_KIND_OPENAI.to_string(), + label: label.to_string(), + base_url: base_url.to_string(), + model: String::new(), + vision: false, } } diff --git a/src-tauri/src/config/tests.rs b/src-tauri/src/config/tests.rs index c02f1902..eb1e67cc 100644 --- a/src-tauri/src/config/tests.rs +++ b/src-tauri/src/config/tests.rs @@ -24,14 +24,14 @@ use super::defaults::{ DEFAULT_TEXT_FONT_WEIGHT, DEFAULT_TEXT_LETTER_SPACING_PX, DEFAULT_TEXT_LINE_HEIGHT, DEFAULT_TOP_K_URLS, DEFAULT_UPDATER_CHECK_INTERVAL_HOURS, DEFAULT_UPDATER_MANIFEST_URL, PROVIDER_ID_BUILTIN, PROVIDER_ID_OLLAMA, PROVIDER_KIND_BUILTIN, PROVIDER_KIND_OLLAMA, - SLASH_COMMAND_PROMPT_APPENDIX, + PROVIDER_KIND_OPENAI, SLASH_COMMAND_PROMPT_APPENDIX, }; use super::error::ConfigError; -use super::loader::{compose_system_prompt, load_from_path}; +use super::loader::{compose_system_prompt, load_from_path, resolve}; use super::migrate::{attach_legacy_active_model, toml_has_providers}; use super::schema::{ - AppConfig, BehaviorSection, DebugSection, InferenceSection, PromptSection, Provider, - QuoteSection, SearchSection, UpdaterSection, WindowSection, + ollama_provider, openai_provider, AppConfig, BehaviorSection, DebugSection, InferenceSection, + PromptSection, Provider, QuoteSection, SearchSection, UpdaterSection, WindowSection, }; use super::writer::atomic_write; @@ -1480,11 +1480,13 @@ fn provider_constructors_carry_expected_fields() { assert_eq!(b.id, PROVIDER_ID_BUILTIN); assert_eq!(b.kind, PROVIDER_KIND_BUILTIN); assert!(b.base_url.is_empty()); + assert!(!b.vision); let o = super::schema::ollama_provider("http://x:1"); assert_eq!(o.id, PROVIDER_ID_OLLAMA); assert_eq!(o.kind, PROVIDER_KIND_OLLAMA); assert_eq!(o.base_url, "http://x:1"); + assert!(!o.vision); } #[test] @@ -1802,6 +1804,121 @@ fn new_shape_with_model_roundtrips_through_toml() { ); } +// ── inference providers: pre-providers active pin ─────────────────────────── + +#[test] +fn pre_providers_file_with_url_pins_active_to_ollama() { + // A pre-providers file carrying an ollama_url: the user runs Ollama, so + // the active pointer must land on the Ollama provider regardless of the + // compiled default, and the URL must survive the migration. + let dir = fresh_temp_dir(); + let path = config_path_in(&dir); + std::fs::write( + &path, + "[inference]\nollama_url = \"http://10.0.0.5:11434\"\n", + ) + .unwrap(); + let c = load_from_path(&path).unwrap(); + assert_eq!(c.inference.active_provider, PROVIDER_ID_OLLAMA); + let ollama = c + .inference + .providers + .iter() + .find(|p| p.id == PROVIDER_ID_OLLAMA) + .unwrap(); + assert_eq!(ollama.base_url, "http://10.0.0.5:11434"); +} + +#[test] +fn pre_providers_file_without_url_key_pins_active_to_ollama() { + // A pre-providers file WITHOUT an ollama_url key (the user never changed + // the URL) is still a pre-providers file: providers are reseeded and the + // active pointer must land on the Ollama provider. + let dir = fresh_temp_dir(); + let path = config_path_in(&dir); + std::fs::write(&path, "[inference]\nnum_ctx = 4096\n").unwrap(); + let c = load_from_path(&path).unwrap(); + assert_eq!(c.inference.active_provider, PROVIDER_ID_OLLAMA); + assert!(c + .inference + .providers + .iter() + .any(|p| p.kind == PROVIDER_KIND_BUILTIN)); + assert!(c + .inference + .providers + .iter() + .any(|p| p.kind == PROVIDER_KIND_OLLAMA)); + assert_eq!(c.inference.num_ctx, 4096); +} + +#[test] +fn pre_providers_explicit_custom_active_keeps_it() { + // An explicit active_provider naming the Ollama provider in a + // pre-providers file survives resolution unchanged. + let dir = fresh_temp_dir(); + let path = config_path_in(&dir); + std::fs::write( + &path, + "[inference]\nactive_provider = \"ollama\"\nollama_url = \"http://10.0.0.5:11434\"\n", + ) + .unwrap(); + let c = load_from_path(&path).unwrap(); + assert_eq!(c.inference.active_provider, PROVIDER_ID_OLLAMA); +} + +#[test] +fn pre_providers_explicit_builtin_is_pinned_to_ollama() { + // A pre-providers file cannot legitimately point at the built-in provider + // (none existed when the file was written), so an explicit "builtin" is + // overridden to the Ollama provider. + let dir = fresh_temp_dir(); + let path = config_path_in(&dir); + std::fs::write( + &path, + "[inference]\nactive_provider = \"builtin\"\nnum_ctx = 4096\n", + ) + .unwrap(); + let c = load_from_path(&path).unwrap(); + assert_eq!(c.inference.active_provider, PROVIDER_ID_OLLAMA); +} + +#[test] +fn new_shape_config_active_untouched() { + // A new-shape file (explicit [[inference.providers]]) is never pinned: + // an explicit "builtin" choice is respected. + let dir = fresh_temp_dir(); + let path = config_path_in(&dir); + std::fs::write( + &path, + r#" + [inference] + active_provider = "builtin" + [[inference.providers]] + id = "builtin" + kind = "builtin" + label = "Built-in (Thuki)" + [[inference.providers]] + id = "ollama" + kind = "ollama" + label = "Ollama" + base_url = "http://127.0.0.1:11434" + "#, + ) + .unwrap(); + let c = load_from_path(&path).unwrap(); + assert_eq!(c.inference.active_provider, PROVIDER_ID_BUILTIN); +} + +#[test] +fn fresh_seed_uses_compiled_default() { + // A fresh-seeded config (schema Default = default_providers()) is NOT a + // pre-providers file: the compiled default pointer is left alone. + let mut c = AppConfig::default(); + resolve(&mut c); + assert_eq!(c.inference.active_provider, DEFAULT_ACTIVE_PROVIDER); +} + // ── inference providers: migrate helpers ───────────────────────────────────── #[test] @@ -1817,11 +1934,15 @@ fn attach_legacy_active_model_sets_model_on_active_provider() { #[test] fn attach_legacy_active_model_targets_the_active_provider_only() { // The legacy slug must land on the *active* provider, never on some other - // provider that merely happens to have an empty model. Make the built-in - // active (empty) and give Ollama a pre-existing model: attach writes the - // built-in and leaves Ollama untouched. + // provider that merely happens to have an empty model. Add a second + // Ollama-kind provider, make it active (empty model), and give the default + // Ollama entry a pre-existing model: attach writes the active one and + // leaves the other untouched. let mut c = AppConfig::default(); - c.inference.active_provider = PROVIDER_ID_BUILTIN.to_string(); + let mut remote = ollama_provider("http://10.0.0.9:11434"); + remote.id = "ollama-remote".to_string(); + c.inference.providers.push(remote); + c.inference.active_provider = "ollama-remote".to_string(); if let Some(ollama) = c .inference .providers @@ -1831,13 +1952,13 @@ fn attach_legacy_active_model_targets_the_active_provider_only() { ollama.model = "existing:7b".to_string(); } assert!(attach_legacy_active_model(&mut c, Some("legacy:1b"))); - let builtin = c + let remote = c .inference .providers .iter() - .find(|p| p.id == PROVIDER_ID_BUILTIN) + .find(|p| p.id == "ollama-remote") .unwrap(); - assert_eq!(builtin.model, "legacy:1b"); + assert_eq!(remote.model, "legacy:1b"); let ollama = c .inference .providers @@ -1847,6 +1968,30 @@ fn attach_legacy_active_model_targets_the_active_provider_only() { assert_eq!(ollama.model, "existing:7b"); } +#[test] +fn legacy_model_attaches_only_to_ollama_kind_provider() { + // The legacy SQLite slug is by definition an Ollama model name. When the + // active provider is not Ollama-kind (post-flip: a fresh builtin default), + // the slug must NOT attach: writing an Ollama slug onto the built-in + // provider would make chat fail with ModelNotFound. + let mut c = AppConfig::default(); + c.inference.active_provider = PROVIDER_ID_BUILTIN.to_string(); + assert!(!attach_legacy_active_model(&mut c, Some("phi4:14b"))); + let builtin = c + .inference + .providers + .iter() + .find(|p| p.id == PROVIDER_ID_BUILTIN) + .unwrap(); + assert!(builtin.model.is_empty()); + + // Active = Ollama-kind with an empty model: attaches as before. + let mut c = AppConfig::default(); + c.inference.active_provider = PROVIDER_ID_OLLAMA.to_string(); + assert!(attach_legacy_active_model(&mut c, Some("phi4:14b"))); + assert_eq!(c.inference.active_provider_model(), "phi4:14b"); +} + #[test] fn attach_legacy_active_model_ignores_empty_and_missing_provider() { let mut c = AppConfig::default(); @@ -1879,4 +2024,203 @@ fn provider_struct_default_is_all_empty() { assert!(p.kind.is_empty()); assert!(p.base_url.is_empty()); assert!(p.model.is_empty()); + assert!(!p.vision); +} + +// ── inference providers: openai kind ──────────────────────────────────────── + +#[test] +fn openai_provider_constructor_shape() { + let p = openai_provider("lmstudio", "LM Studio", "http://localhost:1234"); + assert_eq!(p.id, "lmstudio"); + assert_eq!(p.kind, PROVIDER_KIND_OPENAI); + assert_eq!(p.label, "LM Studio"); + assert_eq!(p.base_url, "http://localhost:1234"); + assert!(p.model.is_empty()); + assert!(!p.vision); +} + +#[test] +fn openai_kind_with_url_is_kept() { + let dir = fresh_temp_dir(); + let path = config_path_in(&dir); + std::fs::write( + &path, + r#" + [inference] + active_provider = "lmstudio" + [[inference.providers]] + id = "builtin" + kind = "builtin" + label = "Built-in (Thuki)" + [[inference.providers]] + id = "ollama" + kind = "ollama" + label = "Ollama" + base_url = "http://127.0.0.1:11434" + [[inference.providers]] + id = "lmstudio" + kind = "openai" + label = "LM Studio" + base_url = "http://localhost:1234" + "#, + ) + .unwrap(); + let c = load_from_path(&path).unwrap(); + let p = c + .inference + .providers + .iter() + .find(|p| p.id == "lmstudio") + .expect("openai provider should be retained"); + assert_eq!(p.kind, PROVIDER_KIND_OPENAI); + assert_eq!(p.base_url, "http://localhost:1234"); +} + +#[test] +fn openai_kind_without_url_is_dropped() { + let dir = fresh_temp_dir(); + let path = config_path_in(&dir); + std::fs::write( + &path, + r#" + [inference] + active_provider = "ollama" + [[inference.providers]] + id = "builtin" + kind = "builtin" + label = "Built-in (Thuki)" + [[inference.providers]] + id = "ollama" + kind = "ollama" + label = "Ollama" + base_url = "http://127.0.0.1:11434" + [[inference.providers]] + id = "lmstudio" + kind = "openai" + label = "LM Studio" + base_url = "" + "#, + ) + .unwrap(); + let c = load_from_path(&path).unwrap(); + assert!( + !c.inference.providers.iter().any(|p| p.id == "lmstudio"), + "openai provider with empty base_url must be dropped" + ); +} + +#[test] +fn openai_kind_bad_scheme_is_dropped() { + // Both a non-http(s) scheme and a bare host without a scheme are rejected. + let dir = fresh_temp_dir(); + let path = config_path_in(&dir); + std::fs::write( + &path, + r#" + [inference] + active_provider = "ollama" + [[inference.providers]] + id = "builtin" + kind = "builtin" + label = "Built-in (Thuki)" + [[inference.providers]] + id = "ollama" + kind = "ollama" + label = "Ollama" + base_url = "http://127.0.0.1:11434" + [[inference.providers]] + id = "bad-scheme" + kind = "openai" + label = "Bad" + base_url = "file:///x" + [[inference.providers]] + id = "no-scheme" + kind = "openai" + label = "No scheme" + base_url = "localhost:1234" + "#, + ) + .unwrap(); + let c = load_from_path(&path).unwrap(); + assert!( + !c.inference.providers.iter().any(|p| p.id == "bad-scheme"), + "openai provider with file:// scheme must be dropped" + ); + assert!( + !c.inference.providers.iter().any(|p| p.id == "no-scheme"), + "openai provider with scheme-less host must be dropped" + ); +} + +#[test] +fn provider_vision_flag_roundtrips() { + // A TOML file with vision=true on an openai provider survives load unmodified. + let dir = fresh_temp_dir(); + let path = config_path_in(&dir); + std::fs::write( + &path, + r#" + [inference] + active_provider = "jan" + [[inference.providers]] + id = "builtin" + kind = "builtin" + label = "Built-in (Thuki)" + [[inference.providers]] + id = "ollama" + kind = "ollama" + label = "Ollama" + base_url = "http://127.0.0.1:11434" + [[inference.providers]] + id = "jan" + kind = "openai" + label = "Jan" + base_url = "http://localhost:1337" + vision = true + "#, + ) + .unwrap(); + let c = load_from_path(&path).unwrap(); + let jan = c + .inference + .providers + .iter() + .find(|p| p.id == "jan") + .expect("jan provider must be retained"); + assert!(jan.vision, "vision=true must round-trip through TOML load"); +} + +#[test] +fn unknown_kind_still_dropped() { + // Regression: adding openai must not affect the unknown-kind drop path. + let dir = fresh_temp_dir(); + let path = config_path_in(&dir); + std::fs::write( + &path, + r#" + [inference] + active_provider = "ollama" + [[inference.providers]] + id = "builtin" + kind = "builtin" + label = "Built-in (Thuki)" + [[inference.providers]] + id = "ollama" + kind = "ollama" + label = "Ollama" + base_url = "http://127.0.0.1:11434" + [[inference.providers]] + id = "cloud" + kind = "anthropic" + label = "Cloud" + base_url = "https://api.anthropic.com" + "#, + ) + .unwrap(); + let c = load_from_path(&path).unwrap(); + assert!( + !c.inference.providers.iter().any(|p| p.id == "cloud"), + "provider with unknown kind must still be dropped" + ); } diff --git a/src-tauri/src/history.rs b/src-tauri/src/history.rs index f631cd1f..5057420a 100644 --- a/src-tauri/src/history.rs +++ b/src-tauri/src/history.rs @@ -247,11 +247,66 @@ pub fn delete_conversation( Ok(()) } -/// Generates a short AI title for a saved conversation by asking Ollama. -/// Runs as a fire-and-forget background task - the frontend polls or -/// refreshes the list to see the updated title. +/// Runs the title-generation LLM call against the resolved transport and +/// returns the accumulated raw response. The native Ollama arm keeps the +/// exact `stream_ollama_chat` parameters the title path has always sent; +/// `/v1` transports accumulate through `openai::stream_openai_chat`, which +/// honors the same `StreamChunk` contract. +/// +/// `num_ctx` feeds only the native arm: it is NOT sent on `/v1` (a launch +/// property of the builtin engine; informational for openai-kind servers). +/// +/// Deliberately does NOT share the search pipeline's streaming branch: title +/// generation has no per-chunk side effects, trace recording, or search +/// events, so it shares only the `stream_*` primitives. +pub(crate) async fn generate_title_text( + transport: &crate::commands::LlmTransport, + model: String, + title_messages: Vec, + client: &reqwest::Client, + num_ctx: u32, +) -> String { + let cancel_token = tokio_util::sync::CancellationToken::new(); + match transport { + crate::commands::LlmTransport::OllamaNative { endpoint } => { + crate::commands::stream_ollama_chat( + crate::commands::OllamaChatParams { + endpoint: endpoint.clone(), + model, + messages: title_messages, + think: false, + keep_alive: None, + num_ctx, + }, + client, + cancel_token, + |_| {}, // No per-chunk side effects; we use the accumulated return value. + ) + .await + } + crate::commands::LlmTransport::V1 { base_url, api_key } => { + crate::openai::stream_openai_chat( + crate::openai::OpenAiChatParams { + base_url: base_url.clone(), + model, + messages: title_messages, + api_key: api_key.clone(), + }, + client, + cancel_token, + |_| {}, // No per-chunk side effects; we use the accumulated return value. + ) + .await + } + } +} + +/// Generates a short AI title for a saved conversation by asking the active +/// provider. Runs as a fire-and-forget background task - the frontend polls +/// or refreshes the list to see the updated title. #[cfg_attr(coverage_nightly, coverage(off))] #[cfg_attr(not(coverage), tauri::command)] +#[allow(clippy::too_many_arguments)] pub async fn generate_title( conversation_id: String, messages: Vec, @@ -259,6 +314,9 @@ pub async fn generate_title( db: State<'_, Database>, client: State<'_, reqwest::Client>, app_config: State<'_, parking_lot::RwLock>, + model_store: State<'_, crate::models::storage::ModelStore>, + engine: State<'_, crate::engine::runner::EngineHandle>, + secrets: State<'_, crate::keychain::Secrets>, ) -> Result<(), String> { let app_config = app_config.read().clone(); // Build a condensed context for title generation. @@ -294,27 +352,37 @@ pub async fn generate_title( }, ]; - let endpoint = format!( - "{}/api/chat", - app_config - .inference - .active_provider_base_url() - .trim_end_matches('/') - ); + // Route by the active provider's kind, mirroring `ask_model`. Title + // generation is best-effort background work, so a route or transport + // failure (no model picked, engine start failure, ensure superseded) + // skips the title silently rather than surfacing an error. + let Ok(route) = crate::commands::resolve_chat_route(&app_config.inference) else { + return Ok(()); + }; + // The builtin route carries its model in the provider config; the + // sidecar serves that manifest id, not the frontend-stamped slug. + let Some(model) = crate::commands::model_for_route(&route, Some(model)) else { + return Ok(()); + }; + let Ok(transport) = crate::commands::resolve_llm_transport( + route, + &db, + &model_store, + &engine, + secrets.0.as_ref(), + app_config.inference.num_ctx, + ) + .await + else { + return Ok(()); + }; - let cancel_token = tokio_util::sync::CancellationToken::new(); - let accumulated = crate::commands::stream_ollama_chat( - crate::commands::OllamaChatParams { - endpoint, - model, - messages: title_messages, - think: false, - keep_alive: None, - num_ctx: app_config.inference.num_ctx, - }, + let accumulated = generate_title_text( + &transport, + model, + title_messages, &client, - cancel_token, - |_| {}, // No per-chunk side effects; we use the accumulated return value. + app_config.inference.num_ctx, ) .await; @@ -500,6 +568,105 @@ mod tests { assert!(title.ends_with("...")); } + fn title_messages() -> Vec { + vec![ + ChatMessage { + role: "system".to_string(), + content: "sys".to_string(), + images: None, + }, + ChatMessage { + role: "user".to_string(), + content: "summarize".to_string(), + images: None, + }, + ] + } + + /// Locks the native title-generation wire contract across the transport + /// change: the exact `/api/chat` body (stream, think, options, no + /// keep_alive) must be identical to the pre-routing payload. + #[tokio::test] + async fn title_gen_on_ollama_unchanged() { + use crate::config::defaults::DEFAULT_NUM_CTX; + let mut server = mockito::Server::new_async().await; + let expected_body = serde_json::json!({ + "model": "gemma3:12b", + "messages": [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "summarize"} + ], + "stream": true, + "think": false, + "options": { + "temperature": 1.0, + "top_p": 0.95, + "top_k": 64, + "num_ctx": DEFAULT_NUM_CTX + } + }); + let mock = server + .mock("POST", "/api/chat") + .match_body(mockito::Matcher::Json(expected_body)) + .with_body("{\"message\":{\"content\":\"My Title\"},\"done\":true}\n") + .create_async() + .await; + + let client = reqwest::Client::new(); + let transport = crate::commands::LlmTransport::OllamaNative { + endpoint: format!("{}/api/chat", server.url()), + }; + let accumulated = generate_title_text( + &transport, + "gemma3:12b".to_string(), + title_messages(), + &client, + DEFAULT_NUM_CTX, + ) + .await; + + mock.assert_async().await; + assert_eq!(accumulated, "My Title"); + } + + /// A `/v1` transport accumulates the title through the OpenAI-compatible + /// streaming client, honoring the provider's API key. + #[tokio::test] + async fn title_gen_on_v1() { + use wiremock::matchers::{header, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let server = MockServer::start().await; + let body = "data: {\"choices\":[{\"delta\":{\"content\":\"V1\"}}]}\n\n\ + data: {\"choices\":[{\"delta\":{\"content\":\" Title\"}}]}\n\n\ + data: [DONE]\n"; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .and(header("authorization", "Bearer sk-test")) + .respond_with( + ResponseTemplate::new(200).set_body_raw(body.as_bytes(), "text/event-stream"), + ) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let transport = crate::commands::LlmTransport::V1 { + base_url: server.uri(), + api_key: Some("sk-test".to_string()), + }; + let accumulated = generate_title_text( + &transport, + "any-model".to_string(), + title_messages(), + &client, + crate::config::defaults::DEFAULT_NUM_CTX, + ) + .await; + + assert_eq!(accumulated, "V1 Title"); + } + #[test] fn title_truncation_over_100_chars() { let mut title = "a".repeat(150); diff --git a/src-tauri/src/keychain.rs b/src-tauri/src/keychain.rs new file mode 100644 index 00000000..f87afed3 --- /dev/null +++ b/src-tauri/src/keychain.rs @@ -0,0 +1,248 @@ +//! Keychain integration for per-provider API key storage. +//! +//! API keys for `openai`-kind providers are stored in the macOS Keychain +//! under [`KEYCHAIN_SERVICE`]. The Keychain is the only place keys ever +//! live; they are never written to the TOML config and are never returned +//! to the frontend (only existence is queryable via [`has_provider_api_key`]). +//! +//! ## Extension points +//! +//! The [`SecretStore`] trait decouples business logic from the Keychain so +//! command handlers and callers in later tasks can be tested with +//! [`FakeSecretStore`] without touching the real user Keychain. + +use std::sync::Arc; + +// ─── Service constant ──────────────────────────────────────────────────────── + +/// Keychain service name under which per-provider API keys are stored. +/// Account = provider id. Stable: changing it orphans existing entries. +pub const KEYCHAIN_SERVICE: &str = "com.quietnode.thuki.provider-api-key"; + +// ─── Trait ─────────────────────────────────────────────────────────────────── + +pub trait SecretStore: Send + Sync + 'static { + fn set(&self, provider_id: &str, secret: &str) -> Result<(), String>; + fn get(&self, provider_id: &str) -> Result, String>; + /// Deleting a missing entry is `Ok`. + fn delete(&self, provider_id: &str) -> Result<(), String>; +} + +// ─── keyring-backed implementation ─────────────────────────────────────────── + +/// macOS Keychain backend via the `keyring` crate. Thin wrapper: every method +/// body is a direct `keyring::Entry` call plus error mapping. +/// +/// Not covered by the cargo coverage gate: this is a direct OS call with no +/// branching logic of its own; logic lives in callers tested with +/// [`FakeSecretStore`]. +pub struct KeyringStore; + +#[cfg_attr(coverage_nightly, coverage(off))] +impl SecretStore for KeyringStore { + fn set(&self, provider_id: &str, secret: &str) -> Result<(), String> { + keyring::Entry::new(KEYCHAIN_SERVICE, provider_id) + .map_err(|e| e.to_string())? + .set_password(secret) + .map_err(|e| e.to_string()) + } + + fn get(&self, provider_id: &str) -> Result, String> { + match keyring::Entry::new(KEYCHAIN_SERVICE, provider_id) + .map_err(|e| e.to_string())? + .get_password() + { + Ok(pw) => Ok(Some(pw)), + Err(keyring::Error::NoEntry) => Ok(None), + Err(e) => Err(e.to_string()), + } + } + + fn delete(&self, provider_id: &str) -> Result<(), String> { + match keyring::Entry::new(KEYCHAIN_SERVICE, provider_id) + .map_err(|e| e.to_string())? + .delete_credential() + { + Ok(()) => Ok(()), + Err(keyring::Error::NoEntry) => Ok(()), + Err(e) => Err(e.to_string()), + } + } +} + +// ─── In-memory fake (tests only) ───────────────────────────────────────────── + +/// In-memory [`SecretStore`] for unit tests. Available crate-wide during +/// `cargo test` so other modules' tests can construct it without touching the +/// real user Keychain. +#[cfg(test)] +pub(crate) struct FakeSecretStore { + map: std::sync::Mutex>, +} + +#[cfg(test)] +impl FakeSecretStore { + pub(crate) fn new() -> Self { + Self { + map: std::sync::Mutex::new(std::collections::HashMap::new()), + } + } +} + +#[cfg(test)] +impl SecretStore for FakeSecretStore { + fn set(&self, provider_id: &str, secret: &str) -> Result<(), String> { + self.map + .lock() + .unwrap() + .insert(provider_id.to_string(), secret.to_string()); + Ok(()) + } + + fn get(&self, provider_id: &str) -> Result, String> { + Ok(self.map.lock().unwrap().get(provider_id).cloned()) + } + + fn delete(&self, provider_id: &str) -> Result<(), String> { + self.map.lock().unwrap().remove(provider_id); + Ok(()) + } +} + +// ─── Newtype wrapper for Tauri managed state ───────────────────────────────── + +/// Newtype around `Arc` so Tauri's managed-state system can +/// hold the trait object. (`State>` fights the type system +/// because Tauri's `Manager::manage` requires `T: Any + Send + Sync`; wrapping +/// in a named newtype satisfies that bound cleanly.) +pub struct Secrets(pub Arc); + +// ─── Input validation ──────────────────────────────────────────────────────── + +/// Pure, tested validation for `set_provider_api_key` inputs. +/// +/// Returns `Err` when: +/// - `provider_id` is empty or longer than 128 bytes. +/// - `key` is empty or longer than 4096 bytes. +pub fn validate_key_input(provider_id: &str, key: &str) -> Result<(), String> { + if provider_id.is_empty() { + return Err("provider_id must not be empty".to_string()); + } + if provider_id.len() > 128 { + return Err("provider_id must be at most 128 bytes".to_string()); + } + if key.is_empty() { + return Err("key must not be empty".to_string()); + } + if key.len() > 4096 { + return Err("key must be at most 4096 bytes".to_string()); + } + Ok(()) +} + +// ─── Tauri commands ─────────────────────────────────────────────────────────── + +/// Stores an API key for a provider in the macOS Keychain. +/// +/// Validates inputs, then delegates to the managed [`SecretStore`]. +/// The secret value never crosses the IPC boundary in any response. +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg_attr(not(coverage), tauri::command)] +pub fn set_provider_api_key( + provider_id: String, + key: String, + store: tauri::State<'_, Secrets>, +) -> Result<(), String> { + validate_key_input(&provider_id, &key)?; + store.0.set(&provider_id, &key) +} + +/// Removes an API key for a provider from the macOS Keychain. +/// +/// Deleting a missing entry succeeds silently. +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg_attr(not(coverage), tauri::command)] +pub fn clear_provider_api_key( + provider_id: String, + store: tauri::State<'_, Secrets>, +) -> Result<(), String> { + store.0.delete(&provider_id) +} + +/// Returns `true` if an API key exists for the provider, `false` otherwise. +/// +/// The secret value is never included in the response. +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg_attr(not(coverage), tauri::command)] +pub fn has_provider_api_key( + provider_id: String, + store: tauri::State<'_, Secrets>, +) -> Result { + store.0.get(&provider_id).map(|o| o.is_some()) +} + +// ─── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + // Service name is load-bearing: changing it orphans existing Keychain entries. + // This test makes any accidental rename visible in review. + #[test] + fn service_name_is_stable() { + assert_eq!(KEYCHAIN_SERVICE, "com.quietnode.thuki.provider-api-key"); + } + + #[test] + fn validate_key_input_rejects_empty_and_oversize() { + // empty provider_id + assert!(validate_key_input("", "somekey").is_err()); + // empty key + assert!(validate_key_input("provider1", "").is_err()); + // provider_id exactly 128 bytes: ok + let id_128 = "a".repeat(128); + assert!(validate_key_input(&id_128, "somekey").is_ok()); + // provider_id 129 bytes: err + let id_129 = "a".repeat(129); + assert!(validate_key_input(&id_129, "somekey").is_err()); + // key exactly 4096 bytes: ok + let key_4096 = "k".repeat(4096); + assert!(validate_key_input("provider1", &key_4096).is_ok()); + // key 4097 bytes: err + let key_4097 = "k".repeat(4097); + assert!(validate_key_input("provider1", &key_4097).is_err()); + } + + #[test] + fn fake_store_set_get_delete_roundtrip() { + let store = FakeSecretStore::new(); + + // set then get returns the value + store.set("prov-a", "sk-secret123").unwrap(); + assert_eq!( + store.get("prov-a").unwrap(), + Some("sk-secret123".to_string()) + ); + + // overwrite works + store.set("prov-a", "sk-new").unwrap(); + assert_eq!(store.get("prov-a").unwrap(), Some("sk-new".to_string())); + + // delete removes the entry + store.delete("prov-a").unwrap(); + assert_eq!(store.get("prov-a").unwrap(), None); + + // has-key logic (mirrors the command body) + assert!(!store.get("prov-a").unwrap().is_some()); + store.set("prov-b", "key").unwrap(); + assert!(store.get("prov-b").unwrap().is_some()); + } + + #[test] + fn fake_delete_missing_is_ok() { + let store = FakeSecretStore::new(); + // deleting an entry that was never set must succeed + assert!(store.delete("nonexistent").is_ok()); + } +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 9a657166..eb7356db 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -25,6 +25,7 @@ pub mod images; pub mod models; pub mod ocr; pub mod onboarding; +pub mod openai; pub mod screenshot; pub mod search; pub mod settings_commands; @@ -37,6 +38,7 @@ mod activator; #[cfg(target_os = "macos")] mod cg_displays; pub mod context; +pub mod keychain; pub mod permissions; pub mod replace; @@ -254,43 +256,84 @@ fn show_overlay(app_handle: &tauri::AppHandle, ctx: crate::context::ActivationCo // Pre-load the active model so the user's first message does not pay // the cold-start penalty. Fires on all show paths: double-tap, tray, - // and first-launch auto-show. - let warmup_model = app_handle - .state::() - .0 - .lock() - .ok() - .and_then(|g| g.clone()); - if let Some(model) = warmup_model { - let warmup_config = app_handle - .state::>() - .read() - .clone(); - let endpoint = format!( - "{}/api/chat", - warmup_config - .inference - .active_provider_base_url() - .trim_end_matches('/') - ); - let system_prompt = warmup_config.prompt.resolved_system.clone(); - let keep_alive = if warmup_config.inference.keep_warm_inactivity_minutes == 0 { - None - } else { - Some(warmup::keep_alive_string( - warmup_config.inference.keep_warm_inactivity_minutes, - )) - }; - let num_ctx = warmup_config.inference.num_ctx; - let client = app_handle.state::().inner().clone(); - app_handle.state::().fire( - endpoint, - model, - system_prompt, - client, - keep_alive, - num_ctx, - ); + // and first-launch auto-show. Branches by the active provider's kind: + // Ollama keeps its native /api/chat warmup, the built-in engine gets a + // /v1 prime ONLY when it is already serving (summoning the overlay must + // never load a model implicitly), and openai providers get no warmup + // (nothing local to warm). + let warmup_kind = app_handle + .state::>() + .read() + .inference + .active_provider_kind() + .to_string(); + match warmup_kind.as_str() { + crate::config::defaults::PROVIDER_KIND_OLLAMA => { + let warmup_model = app_handle + .state::() + .0 + .lock() + .ok() + .and_then(|g| g.clone()); + if let Some(model) = warmup_model { + let warmup_config = app_handle + .state::>() + .read() + .clone(); + let endpoint = format!( + "{}/api/chat", + warmup_config + .inference + .active_provider_base_url() + .trim_end_matches('/') + ); + let system_prompt = warmup_config.prompt.resolved_system.clone(); + let keep_alive = if warmup_config.inference.keep_warm_inactivity_minutes == 0 { + None + } else { + Some(warmup::keep_alive_string( + warmup_config.inference.keep_warm_inactivity_minutes, + )) + }; + let num_ctx = warmup_config.inference.num_ctx; + let client = app_handle.state::().inner().clone(); + app_handle.state::().fire( + endpoint, + model, + system_prompt, + client, + keep_alive, + num_ctx, + ); + } + } + crate::config::defaults::PROVIDER_KIND_BUILTIN => { + let status = app_handle + .state::() + .status() + .borrow() + .clone(); + if let Some(port) = warmup::builtin_prime_port(&status) { + let (model, system_prompt) = { + let cfg = app_handle + .state::>() + .read() + .clone(); + ( + cfg.inference.active_provider_model().to_string(), + cfg.prompt.resolved_system.clone(), + ) + }; + let client = app_handle.state::().inner().clone(); + tauri::async_runtime::spawn(warmup::prime_builtin( + port, + model, + system_prompt, + client, + )); + } + } + _ => {} } // Extract before building local_ctx to avoid an extra clone. @@ -1473,6 +1516,28 @@ fn config_file_has_providers(path: &std::path::Path) -> bool { .unwrap_or(false) } +/// Path to the bundled `llama-server` sidecar binary. +/// +/// Debug builds run straight from the repo, so the target-triple-suffixed +/// binary in `src-tauri/binaries/` is used directly. Bundled builds rely on +/// Tauri's `externalBin` handling, which installs the sidecar next to the +/// app executable (`Contents/MacOS`) with the target-triple suffix stripped, +/// so it resolves relative to `current_exe()`. Verified manually against the +/// packaged app layout (see the release checklist). +#[cfg_attr(coverage_nightly, coverage(off))] +fn engine_sidecar_path() -> std::path::PathBuf { + if cfg!(debug_assertions) { + std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("binaries") + .join("llama-server-aarch64-apple-darwin") + } else { + std::env::current_exe() + .ok() + .and_then(|exe| exe.parent().map(|dir| dir.join("llama-server"))) + .unwrap_or_else(|| std::path::PathBuf::from("llama-server")) + } +} + #[cfg_attr(mobile, tauri::mobile_entry_point)] pub fn run() { let mut builder = tauri::Builder::default(); @@ -1820,6 +1885,48 @@ pub fn run() { app.manage(model_store); app.manage(models::DownloadState::default()); + // ── Keychain secret store ────────────────────────────── + app.manage(keychain::Secrets(std::sync::Arc::new( + keychain::KeyringStore, + ))); + + // ── Built-in inference engine runner ─────────────────── + // One actor owns the bundled llama-server lifecycle: at most one + // process, kill-then-start on model switch, idle unload. Spawned + // inside block_on so the actor task lands on Tauri's tokio + // runtime (setup itself runs outside a runtime context). + let engine_idle_minutes = app + .state::>() + .read() + .inference + .idle_unload_minutes; + let engine_client = app.state::().inner().clone(); + let engine = tauri::async_runtime::block_on(async move { + engine::runner::EngineHandle::spawn( + std::sync::Arc::new(engine::process::TokioEngineProcess { + binary: engine_sidecar_path(), + client: engine_client, + }), + engine_idle_minutes, + std::time::Duration::from_secs( + crate::config::defaults::ENGINE_IDLE_CHECK_INTERVAL_SECS, + ), + ) + }); + // Forward every engine lifecycle change to the frontend, + // mirroring how warmup events are emitted. + { + let status_handle = app.handle().clone(); + let mut status_rx = engine.status(); + tauri::async_runtime::spawn(async move { + while status_rx.changed().await.is_ok() { + let status = status_rx.borrow_and_update().clone(); + let _ = status_handle.emit("engine:status", status); + } + }); + } + app.manage(engine); + // ── Orphaned image cleanup (startup + periodic) ───────── run_image_cleanup(app.handle()); spawn_periodic_image_cleanup(app.handle().clone()); @@ -1952,17 +2059,22 @@ pub fn run() { #[cfg(not(coverage))] updater::commands::reset_and_relaunch_for_grant, #[cfg(not(coverage))] - updater::commands::consume_pending_grant_resume + updater::commands::consume_pending_grant_resume, + #[cfg(not(coverage))] + keychain::set_provider_api_key, + #[cfg(not(coverage))] + keychain::clear_provider_api_key, + #[cfg(not(coverage))] + keychain::has_provider_api_key ]) .build(tauri::generate_context!()) .expect("error while building tauri application") - .run(|app_handle, event| { - if let RunEvent::WindowEvent { + .run(|app_handle, event| match event { + RunEvent::WindowEvent { label, event: tauri::WindowEvent::CloseRequested { api, .. }, .. - } = event - { + } => { if label == "main" { api.prevent_close(); @@ -1984,6 +2096,17 @@ pub fn run() { } } } + RunEvent::Exit => { + // Kill the built-in engine sidecar and confirm its exit so + // no orphan llama-server survives quit. The actor runs on + // the tokio runtime, so block_on here cannot deadlock. + let engine = app_handle + .state::() + .inner() + .clone(); + tauri::async_runtime::block_on(async move { engine.shutdown().await }); + } + _ => {} }); } diff --git a/src-tauri/src/models/mod.rs b/src-tauri/src/models/mod.rs index e98860a5..43fe7530 100644 --- a/src-tauri/src/models/mod.rs +++ b/src-tauri/src/models/mod.rs @@ -31,6 +31,7 @@ use crate::config::defaults::{ DEFAULT_OLLAMA_SHOW_REQUEST_TIMEOUT_SECS, DEFAULT_OLLAMA_TAGS_REQUEST_TIMEOUT_SECS, HF_API_TIMEOUT_SECS, HF_BASE_URL, MAX_HF_API_BODY_BYTES, MAX_MODEL_SLUG_LEN, MAX_OLLAMA_SHOW_BODY_BYTES, MAX_OLLAMA_TAGS_BODY_BYTES, PROVIDER_ID_BUILTIN, + PROVIDER_KIND_BUILTIN, PROVIDER_KIND_OPENAI, }; use crate::config::AppConfig; @@ -308,7 +309,10 @@ fn read_provider_model_context( /// Writes `slug` onto the active provider's `model` field in config.toml and /// swaps the resolved result into the shared in-memory config. Replaces the -/// former SQLite `set_config(ACTIVE_MODEL_KEY, ...)` persistence. +/// former SQLite `set_config(ACTIVE_MODEL_KEY, ...)` persistence. When the +/// written provider is the active one, also refreshes the managed +/// [`ActiveModelState`] mirror so chat sees the new selection without a +/// restart (e.g. a builtin download finishing via `finalize_install`). #[cfg_attr(coverage_nightly, coverage(off))] fn persist_active_provider_model( app: &tauri::AppHandle, @@ -320,10 +324,37 @@ fn persist_active_provider_model( let resolved = crate::settings_commands::write_provider_field_to_disk(&path, provider_id, "model", slug) .map_err(|e| e.to_string())?; + let mirror = should_refresh_active_model(provider_id, &resolved); *config.write() = resolved; + if let Some(mirror) = mirror { + let active = app.state::(); + let mut guard = active.0.lock().map_err(|e| e.to_string())?; + *guard = mirror; + } Ok(()) } +/// Decides whether a provider-model write must be mirrored into the managed +/// [`ActiveModelState`]. Returns `Some(new_value)` only when `provider_id` is +/// the resolved config's active provider (the mirror tracks the active +/// provider only); the value is the resolved model with empty mapped to +/// `None` (the delete-model clear path writes ""). Pure so the decision is +/// unit-tested even though the persisting wrapper is coverage-off. +pub(crate) fn should_refresh_active_model( + provider_id: &str, + resolved: &AppConfig, +) -> Option> { + if resolved.inference.active_provider != provider_id { + return None; + } + Some( + resolved + .inference + .active_provider_model_opt() + .map(str::to_string), + ) +} + /// Pure helper that shapes the `get_model_picker_state` payload. Extracted so /// the three states (unreachable, reachable + empty, reachable + populated) /// can be unit-tested without spinning up a Tauri runtime or an HTTP server. @@ -715,22 +746,108 @@ pub struct ModelCapabilitiesCache(pub Mutex, cache: tauri::State<'_, ModelCapabilitiesCache>, config: tauri::State<'_, parking_lot::RwLock>, + db: tauri::State<'_, crate::history::Database>, ) -> Result, String> { - let (provider_id, base_url) = { + let (provider_id, base_url, kind, provider_model, provider_vision) = { let c = config.read(); ( c.inference.active_provider.clone(), c.inference.active_provider_base_url().to_string(), + c.inference.active_provider_kind().to_string(), + c.inference.active_provider_model().to_string(), + c.inference.active().map(|p| p.vision).unwrap_or(false), ) }; - let installed = fetch_installed_model_names(&client, &base_url).await?; - Ok(reconcile_capabilities(&client, &cache, &provider_id, &base_url, &installed).await) + match kind.as_str() { + PROVIDER_KIND_BUILTIN => { + let rows = { + let conn = db.0.lock().map_err(|e| e.to_string())?; + manifest::list(&conn).map_err(|e| e.to_string())? + }; + let caps = builtin_capabilities_from_manifest(&rows); + cache_capabilities(&cache, &provider_id, &caps); + Ok(caps) + } + PROVIDER_KIND_OPENAI => { + let caps = openai_capabilities(&provider_model, provider_vision); + cache_capabilities(&cache, &provider_id, &caps); + Ok(caps) + } + _ => { + let installed = fetch_installed_model_names(&client, &base_url).await?; + Ok(reconcile_capabilities(&client, &cache, &provider_id, &base_url, &installed).await) + } + } +} + +/// Capability map for the built-in provider, derived from the installed-model +/// manifest. Each row carries the curated vision/thinking flags recorded at +/// download time; `max_images` stays `None` because llama-server imposes no +/// fixed per-request image cap. +pub(crate) fn builtin_capabilities_from_manifest( + rows: &[manifest::InstalledModel], +) -> HashMap { + rows.iter() + .map(|row| { + ( + row.id.clone(), + Capabilities { + vision: row.vision, + thinking: row.thinking, + max_images: None, + }, + ) + }) + .collect() +} + +/// Capability map for an `openai`-kind provider: a single entry for the +/// configured model, driven by the provider's manual vision flag (generic +/// `/v1` servers expose no capability probe). Thinking stays `false`: there +/// is no declared reasoning-token contract to honor. An empty model (none +/// configured yet) yields an empty map. +pub(crate) fn openai_capabilities(model: &str, vision: bool) -> HashMap { + if model.is_empty() { + return HashMap::new(); + } + HashMap::from([( + model.to_string(), + Capabilities { + vision, + thinking: false, + max_images: None, + }, + )]) +} + +/// Writes a resolved capability map through to the cache under +/// `(provider_id, model)` keys, mirroring the Ollama reconcile path's +/// write-through so `ask_model`'s per-request filter finds the entries. +/// Best-effort: a poisoned lock skips the write (the map is still returned +/// to the caller). +pub(crate) fn cache_capabilities( + cache: &ModelCapabilitiesCache, + provider_id: &str, + caps: &HashMap, +) { + if let Ok(mut guard) = cache.0.lock() { + for (model, c) in caps { + guard.insert((provider_id.to_string(), model.clone()), c.clone()); + } + } } /// Pure-ish helper extracted so tests can drive the cache + fetch loop @@ -2710,6 +2827,93 @@ mod tests { assert!(result["x"].vision); } + // ── Non-Ollama capability resolution ───────────────────────────────────── + + /// Manifest row literal with the given capability flags. + fn manifest_row(id: &str, vision: bool, thinking: bool) -> manifest::InstalledModel { + manifest::InstalledModel { + id: id.to_string(), + display_name: format!("Model {id}"), + repo: "org/repo".to_string(), + revision: "a".repeat(40), + file_name: format!("{id}.gguf"), + sha256: "b".repeat(64), + size_bytes: 1_000_000, + quant: "Q4_K_M".to_string(), + vision, + thinking, + mmproj_file: None, + mmproj_sha256: None, + } + } + + #[test] + fn builtin_capabilities_come_from_manifest() { + // Round-trip through a real in-memory manifest so the rows carry + // exactly what the download recorded. + let conn = crate::database::open_in_memory().unwrap(); + manifest::insert(&conn, &manifest_row("org/repo:vis.gguf", true, false)).unwrap(); + manifest::insert(&conn, &manifest_row("org/repo:think.gguf", false, true)).unwrap(); + let rows = manifest::list(&conn).unwrap(); + + let caps = builtin_capabilities_from_manifest(&rows); + + assert_eq!(caps.len(), 2); + assert!(caps["org/repo:vis.gguf"].vision); + assert!(!caps["org/repo:vis.gguf"].thinking); + assert!(!caps["org/repo:think.gguf"].vision); + assert!(caps["org/repo:think.gguf"].thinking); + assert!(caps.values().all(|c| c.max_images.is_none())); + } + + #[test] + fn builtin_capabilities_empty_manifest_yields_empty_map() { + assert!(builtin_capabilities_from_manifest(&[]).is_empty()); + } + + #[test] + fn openai_capabilities_use_provider_vision_flag() { + let with_vision = openai_capabilities("gpt-4o", true); + assert_eq!(with_vision.len(), 1); + assert!(with_vision["gpt-4o"].vision); + assert!(!with_vision["gpt-4o"].thinking); + assert_eq!(with_vision["gpt-4o"].max_images, None); + + let without_vision = openai_capabilities("local-llm", false); + assert!(!without_vision["local-llm"].vision); + + assert!( + openai_capabilities("", true).is_empty(), + "no configured model means nothing to report" + ); + } + + #[test] + fn cache_capabilities_writes_through_under_provider_key() { + let cache = ModelCapabilitiesCache::default(); + let caps = + builtin_capabilities_from_manifest(&[manifest_row("org/repo:vis.gguf", true, true)]); + + cache_capabilities(&cache, "builtin", &caps); + + let guard = cache.0.lock().unwrap(); + let entry = &guard[&("builtin".to_string(), "org/repo:vis.gguf".to_string())]; + assert!(entry.vision); + assert!(entry.thinking); + } + + #[test] + fn cache_capabilities_poisoned_lock_is_best_effort() { + let cache = ModelCapabilitiesCache::default(); + let cache_ref = std::panic::AssertUnwindSafe(&cache.0); + let _ = std::panic::catch_unwind(|| { + let _guard = cache_ref.0.lock().unwrap(); + panic!("poison"); + }); + // Must not panic; the write is silently skipped. + cache_capabilities(&cache, "builtin", &openai_capabilities("m", true)); + } + // ── Model library: starter options ─────────────────────────────────────── /// Build a fresh store rooted at a temporary directory. @@ -3375,6 +3579,51 @@ mod tests { assert_eq!(builtin_provider_model(&cfg), ""); } + // ── should_refresh_active_model ────────────────────────────────────────── + + /// Helper: an `AppConfig` whose single provider `id` is active with `model`. + fn config_with_active_provider(id: &str, model: &str) -> AppConfig { + use crate::config::schema::Provider; + let mut cfg = AppConfig::default(); + cfg.inference.active_provider = id.to_string(); + cfg.inference.providers = vec![Provider { + id: id.to_string(), + kind: PROVIDER_KIND_BUILTIN.to_string(), + label: "Test".to_string(), + base_url: String::new(), + model: model.to_string(), + vision: false, + }]; + cfg + } + + #[test] + fn should_refresh_active_model_mirrors_active_provider_write() { + // Writing the active provider's model refreshes the mirror with the + // resolved slug (the download-finished path). + let cfg = config_with_active_provider("builtin", "o/r:w.gguf"); + assert_eq!( + should_refresh_active_model("builtin", &cfg), + Some(Some("o/r:w.gguf".to_string())) + ); + } + + #[test] + fn should_refresh_active_model_clears_mirror_on_empty_slug() { + // The delete-model path writes "": the mirror must clear, not keep a + // stale slug. + let cfg = config_with_active_provider("builtin", ""); + assert_eq!(should_refresh_active_model("builtin", &cfg), Some(None)); + } + + #[test] + fn should_refresh_active_model_ignores_non_active_provider() { + // A write to a provider that is not active never touches the mirror; + // it tracks the active provider only. + let cfg = config_with_active_provider("ollama", "gemma3:12b"); + assert_eq!(should_refresh_active_model("builtin", &cfg), None); + } + // ── Model library: system RAM probe ────────────────────────────────────── #[test] diff --git a/src-tauri/src/openai.rs b/src-tauri/src/openai.rs new file mode 100644 index 00000000..13607dd1 --- /dev/null +++ b/src-tauri/src/openai.rs @@ -0,0 +1,1206 @@ +//! Generic OpenAI-compatible `/v1` chat client. +//! +//! The twin of the native Ollama path in [`crate::commands`]: a streaming SSE +//! chat call ([`stream_openai_chat`]) that emits the exact same +//! [`StreamChunk`] channel contract as `stream_ollama_chat`, and a +//! non-streaming structured-output call ([`request_openai_json`]) that mirrors +//! the search pipeline's `request_json`. Used by the `builtin` (local +//! llama-server) and `openai` provider kinds. + +use futures_util::StreamExt; +use serde::Deserialize; +use tokio_util::sync::CancellationToken; + +use crate::commands::{ChatMessage, EngineError, EngineErrorKind, StreamChunk}; +use crate::config::defaults::MAX_SSE_LINE_BYTES; + +/// Groups the per-request parameters for [`stream_openai_chat`], mirroring +/// `OllamaChatParams` on the native path. +pub struct OpenAiChatParams { + /// Server origin without a trailing slash; the client appends + /// `/v1/chat/completions`. + pub base_url: String, + pub model: String, + pub messages: Vec, + /// Sent as a `Bearer` authorization header when `Some`. + pub api_key: Option, +} + +/// Error returned by [`request_openai_json`]. Mirrors the classification the +/// search pipeline's `request_json` applies to its `SearchError` variants: +/// transport failures (including the per-request timeout) map to +/// `Unreachable`, non-2xx statuses to `Http`, unusable bodies to `BadBody`, +/// and token cancellation to `Cancelled`. +#[derive(Debug, PartialEq)] +pub enum OpenAiError { + /// The server could not be reached (connect, transport, or timeout). + Unreachable(String), + /// The server answered with a non-2xx status; carries the response body. + Http(u16, String), + /// The response body could not be read or did not match the expected shape. + BadBody(String), + /// The caller's cancellation token fired before the response was read. + Cancelled, +} + +// ─── Wire types ────────────────────────────────────────────────────────────── + +/// `choices[i].delta` object in a `/v1/chat/completions` SSE event. Unknown +/// fields are ignored so vendor extensions never break parsing. +#[derive(Deserialize, Default)] +struct SseDelta { + #[serde(default)] + content: Option, + #[serde(default)] + reasoning_content: Option, +} + +/// A single entry of `choices` in an SSE event. +#[derive(Deserialize)] +struct SseChoice { + #[serde(default)] + delta: SseDelta, +} + +/// The JSON payload of one `data:` SSE line. +#[derive(Deserialize)] +struct SseEvent { + #[serde(default)] + choices: Vec, +} + +/// `choices[i].message` object in a non-streaming `/v1/chat/completions` +/// response. +#[derive(Deserialize)] +struct JsonChoiceMessage { + #[serde(default)] + content: String, +} + +/// A single entry of `choices` in a non-streaming response. +#[derive(Deserialize)] +struct JsonChoice { + message: JsonChoiceMessage, +} + +/// Top-level non-streaming `/v1/chat/completions` response body. +#[derive(Deserialize)] +struct JsonResponseBody { + #[serde(default)] + choices: Vec, +} + +// ─── Message conversion ────────────────────────────────────────────────────── + +/// Converts a [`ChatMessage`] into the OpenAI wire message shape. Text-only +/// messages keep `content` as a plain JSON string. Messages carrying images +/// switch `content` to the multipart form: a text part followed by one +/// `image_url` data-URI part per base64 image. +pub(crate) fn to_openai_message(msg: &ChatMessage) -> serde_json::Value { + match &msg.images { + Some(images) if !images.is_empty() => { + let mut parts = vec![serde_json::json!({"type": "text", "text": msg.content})]; + for b64 in images { + parts.push(serde_json::json!({ + "type": "image_url", + "image_url": {"url": format!("data:image/jpeg;base64,{b64}")}, + })); + } + serde_json::json!({"role": msg.role, "content": parts}) + } + _ => serde_json::json!({"role": msg.role, "content": msg.content}), + } +} + +// ─── Error classification ──────────────────────────────────────────────────── + +/// Maps a reqwest connection/transport error to a provider-neutral +/// [`EngineError`], mirroring `classify_stream_error` on the native path: +/// connect/timeout failures are `EngineUnreachable`, everything else +/// (e.g. a connection reset mid-stream) is `Other`. +fn classify_v1_transport_error(e: &reqwest::Error) -> EngineError { + if e.is_connect() || e.is_timeout() { + EngineError { + kind: EngineErrorKind::EngineUnreachable, + message: format!("The inference server could not be reached.\n{e}"), + } + } else { + EngineError { + kind: EngineErrorKind::Other, + message: + "Something went wrong\nThe connection to the inference server was interrupted." + .to_string(), + } + } +} + +/// Maps a non-2xx HTTP status from a `/v1` server to a provider-neutral +/// [`EngineError`], mirroring `classify_http_error` on the native path. +fn classify_v1_http_error(status: u16, model_name: &str) -> EngineError { + match status { + 404 => EngineError { + kind: EngineErrorKind::ModelNotFound, + message: format!("Model not found\nThe server has no model named '{model_name}'."), + }, + 401 | 403 => EngineError { + kind: EngineErrorKind::Other, + message: format!( + "Something went wrong\nAuthentication failed (HTTP {status}). Check the provider's API key." + ), + }, + _ => EngineError { + kind: EngineErrorKind::Other, + message: format!("Something went wrong\nHTTP {status}"), + }, + } +} + +/// Error emitted when the buffered unterminated SSE line exceeds +/// [`MAX_SSE_LINE_BYTES`]; the stream is aborted to bound memory. +fn oversize_sse_line_error() -> EngineError { + EngineError { + kind: EngineErrorKind::Other, + message: "Something went wrong\nThe inference server sent an oversized stream line." + .to_string(), + } +} + +// ─── Streaming chat ────────────────────────────────────────────────────────── + +/// Streams a `/v1/chat/completions` request (`stream: true`) and emits the +/// same [`StreamChunk`] contract as `stream_ollama_chat`: +/// `choices[0].delta.content` becomes [`StreamChunk::Token`], +/// `choices[0].delta.reasoning_content` becomes +/// [`StreamChunk::ThinkingToken`], and `data: [DONE]` (or the stream ending +/// without it) becomes [`StreamChunk::Done`]. Exactly one terminal chunk +/// (`Done`, `Cancelled`, or `Error`) is emitted per call. +/// +/// No sampling parameters are sent: the server and model defaults apply. +/// Returns the accumulated assistant content, mirroring `stream_ollama_chat`. +pub async fn stream_openai_chat( + params: OpenAiChatParams, + client: &reqwest::Client, + cancel_token: CancellationToken, + on_chunk: impl Fn(StreamChunk), +) -> String { + let OpenAiChatParams { + base_url, + model, + messages, + api_key, + } = params; + let body = serde_json::json!({ + "model": model, + "messages": messages.iter().map(to_openai_message).collect::>(), + "stream": true, + }); + let mut request = client + .post(format!("{base_url}/v1/chat/completions")) + .json(&body); + if let Some(ref key) = api_key { + request = request.bearer_auth(key); + } + + let mut accumulated = String::new(); + + let response = match request.send().await { + Ok(response) => response, + Err(e) => { + on_chunk(StreamChunk::Error(classify_v1_transport_error(&e))); + return accumulated; + } + }; + + if !response.status().is_success() { + let status = response.status().as_u16(); + on_chunk(StreamChunk::Error(classify_v1_http_error(status, &model))); + return accumulated; + } + + let mut stream = response.bytes_stream(); + let mut buffer: Vec = Vec::new(); + + loop { + tokio::select! { + biased; + _ = cancel_token.cancelled() => { + // Drop the stream - closes the HTTP connection, which + // signals the server to stop inference. + drop(stream); + on_chunk(StreamChunk::Cancelled); + return accumulated; + } + chunk_opt = stream.next() => { + match chunk_opt { + Some(Ok(bytes)) => { + buffer.extend_from_slice(&bytes); + + while let Some(idx) = buffer.iter().position(|&b| b == b'\n') { + let line_bytes = buffer.drain(..=idx).collect::>(); + // Mirror the native path: a non-UTF-8 line is + // silently skipped. + let Ok(line_text) = String::from_utf8(line_bytes) else { + continue; + }; + // trim handles the \r of \r\n line endings and + // collapses blank event-separator lines. + let trimmed = line_text.trim(); + // SSE comments, `event:` lines, and anything else + // that is not a data line are ignored. + let Some(payload) = trimmed.strip_prefix("data: ") else { + continue; + }; + if payload == "[DONE]" { + on_chunk(StreamChunk::Done); + return accumulated; + } + // Mirror the native path's tolerance: a data line + // that does not parse is silently skipped. + let Ok(event) = serde_json::from_str::(payload) else { + continue; + }; + if let Some(choice) = event.choices.first() { + if let Some(thinking) = choice + .delta + .reasoning_content + .as_deref() + .filter(|s| !s.is_empty()) + { + on_chunk(StreamChunk::ThinkingToken(thinking.to_string())); + } + if let Some(token) = + choice.delta.content.as_deref().filter(|s| !s.is_empty()) + { + accumulated.push_str(token); + on_chunk(StreamChunk::Token(token.to_string())); + } + } + } + + // Bound the unterminated line a malicious or broken + // server can make us buffer. + if buffer.len() > MAX_SSE_LINE_BYTES { + on_chunk(StreamChunk::Error(oversize_sse_line_error())); + return accumulated; + } + } + Some(Err(e)) => { + on_chunk(StreamChunk::Error(classify_v1_transport_error(&e))); + return accumulated; + } + None => { + // The server closed the stream without a [DONE] + // marker. Emit a terminal Done so the frontend always + // leaves its streaming state (mirrors the native + // path's missing-done-marker handling). + on_chunk(StreamChunk::Done); + return accumulated; + } + } + } + } + } +} + +// ─── Non-streaming structured output ───────────────────────────────────────── + +/// Builds the `/v1/chat/completions` request body for a structured-output +/// (non-streaming, temperature 0) call. Used by both [`request_openai_json`] +/// (the live wire call) and the search pipeline's trace helper so the logged +/// body always mirrors the wire exactly. +pub(crate) fn json_request_body( + model: &str, + messages: &[ChatMessage], + schema: serde_json::Value, + max_tokens: i32, +) -> serde_json::Value { + serde_json::json!({ + "model": model, + "messages": messages.iter().map(to_openai_message).collect::>(), + "stream": false, + "temperature": 0, + "max_tokens": max_tokens, + "response_format": { + "type": "json_schema", + "json_schema": {"name": "out", "strict": true, "schema": schema}, + }, + }) +} + +/// Sends a single non-streaming `/v1/chat/completions` request with a strict +/// json-schema `response_format` and returns `choices[0].message.content`. +/// The structured-output twin of the search pipeline's `request_json`: +/// temperature 0 for deterministic classification, a per-call wall-clock +/// `timeout_secs`, and the same cancellation discipline. +#[allow(clippy::too_many_arguments)] +pub async fn request_openai_json( + base_url: &str, + model: &str, + client: &reqwest::Client, + messages: Vec, + schema: serde_json::Value, + api_key: Option<&str>, + timeout_secs: u64, + max_tokens: i32, + cancel_token: &CancellationToken, +) -> Result { + let body = json_request_body(model, &messages, schema, max_tokens); + let mut request = client + .post(format!("{base_url}/v1/chat/completions")) + .json(&body) + .timeout(std::time::Duration::from_secs(timeout_secs)); + if let Some(key) = api_key { + request = request.bearer_auth(key); + } + + let response = tokio::select! { + biased; + _ = cancel_token.cancelled() => return Err(OpenAiError::Cancelled), + res = request.send() => res.map_err(|e| OpenAiError::Unreachable(e.to_string()))?, + }; + + if !response.status().is_success() { + let status = response.status().as_u16(); + let body_text = response.text().await.unwrap_or_default(); + return Err(OpenAiError::Http(status, body_text)); + } + + let raw_body = tokio::select! { + biased; + _ = cancel_token.cancelled() => return Err(OpenAiError::Cancelled), + body = response.text() => body.map_err(|e| OpenAiError::BadBody(e.to_string()))?, + }; + let parsed: JsonResponseBody = + serde_json::from_str(&raw_body).map_err(|e| OpenAiError::BadBody(e.to_string()))?; + parsed + .choices + .into_iter() + .next() + .map(|choice| choice.message.content) + .ok_or_else(|| OpenAiError::BadBody("response contained no choices".to_string())) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::{Arc, Mutex}; + use wiremock::matchers::{body_partial_json, header, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + fn collect_chunks() -> (Arc>>, impl Fn(StreamChunk)) { + let chunks: Arc>> = Arc::new(Mutex::new(Vec::new())); + let chunks_clone = chunks.clone(); + let callback = move |chunk: StreamChunk| { + chunks_clone.lock().unwrap().push(chunk); + }; + (chunks, callback) + } + + fn user_message(content: &str) -> ChatMessage { + ChatMessage { + role: "user".to_string(), + content: content.to_string(), + images: None, + } + } + + fn chat_params(base_url: String) -> OpenAiChatParams { + OpenAiChatParams { + base_url, + model: "test-model".to_string(), + messages: vec![user_message("hi")], + api_key: None, + } + } + + /// Helper: an SSE data line carrying a content delta. + fn sse_content_line(token: &str) -> String { + format!("data: {{\"choices\":[{{\"delta\":{{\"content\":\"{token}\"}}}}]}}\n\n") + } + + async fn mount_sse(server: &MockServer, body: impl Into>) { + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(200).set_body_raw(body.into(), "text/event-stream")) + .expect(1) + .mount(server) + .await; + } + + // ── stream_openai_chat ────────────────────────────────────────────────── + + #[tokio::test] + async fn streams_tokens_from_sse() { + let server = MockServer::start().await; + let body = format!( + "{}{}data: {{\"choices\":[{{\"delta\":{{}}}}]}}\n\ndata: [DONE]\n", + sse_content_line("Hello"), + sse_content_line(" world"), + ); + mount_sse(&server, body.into_bytes()).await; + + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + let accumulated = stream_openai_chat( + chat_params(server.uri()), + &client, + CancellationToken::new(), + callback, + ) + .await; + + let chunks = chunks.lock().unwrap(); + assert!(matches!(&chunks[0], StreamChunk::Token(t) if t == "Hello")); + assert!(matches!(&chunks[1], StreamChunk::Token(t) if t == " world")); + assert!(matches!(&chunks[2], StreamChunk::Done)); + assert_eq!(chunks.len(), 3, "exactly one terminal Done"); + assert_eq!(accumulated, "Hello world"); + + // Lock the wire contract: stream:true is sent and no sampling + // parameters override the server/model defaults. + let requests = server.received_requests().await.unwrap(); + let sent: serde_json::Value = serde_json::from_slice(&requests[0].body).unwrap(); + assert_eq!(sent["stream"], serde_json::json!(true)); + assert!(sent.get("temperature").is_none()); + assert!(sent.get("top_p").is_none()); + } + + /// SSE lines arriving split across TCP segments must be reassembled + /// through the line buffer before parsing. + #[tokio::test] + async fn streams_tokens_split_across_chunks() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut req_buf = [0u8; 8192]; + let _ = stream.read(&mut req_buf).await; + + let sse = format!("{}data: [DONE]\n", sse_content_line("Hello")); + let header = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\n\r\n", + sse.len() + ); + let _ = stream.write_all(header.as_bytes()).await; + // Split the first data line mid-JSON across two writes. + let (first, rest) = sse.split_at(20); + let _ = stream.write_all(first.as_bytes()).await; + let _ = stream.flush().await; + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + let _ = stream.write_all(rest.as_bytes()).await; + let _ = stream.shutdown().await; + }); + + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + let accumulated = stream_openai_chat( + chat_params(format!("http://127.0.0.1:{port}")), + &client, + CancellationToken::new(), + callback, + ) + .await; + + let chunks = chunks.lock().unwrap(); + assert!(matches!(&chunks[0], StreamChunk::Token(t) if t == "Hello")); + assert!(matches!(&chunks[1], StreamChunk::Done)); + assert_eq!(chunks.len(), 2); + assert_eq!(accumulated, "Hello"); + } + + #[tokio::test] + async fn reasoning_content_maps_to_thinking_token() { + let server = MockServer::start().await; + let body = format!( + "data: {{\"choices\":[{{\"delta\":{{\"reasoning_content\":\"hmm\"}}}}]}}\n\n{}data: [DONE]\n", + sse_content_line("answer"), + ); + mount_sse(&server, body.into_bytes()).await; + + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + let accumulated = stream_openai_chat( + chat_params(server.uri()), + &client, + CancellationToken::new(), + callback, + ) + .await; + + let chunks = chunks.lock().unwrap(); + assert!(matches!(&chunks[0], StreamChunk::ThinkingToken(t) if t == "hmm")); + assert!(matches!(&chunks[1], StreamChunk::Token(t) if t == "answer")); + assert!(matches!(&chunks[2], StreamChunk::Done)); + assert_eq!( + accumulated, "answer", + "thinking tokens must not be accumulated as content" + ); + } + + /// `data: [DONE]` terminates the stream immediately: anything the server + /// sends afterwards is never parsed and no second terminal chunk appears. + #[tokio::test] + async fn done_marker_ends_stream() { + let server = MockServer::start().await; + let body = format!( + "{}data: [DONE]\n{}", + sse_content_line("A"), + sse_content_line("ignored"), + ); + mount_sse(&server, body.into_bytes()).await; + + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + let accumulated = stream_openai_chat( + chat_params(server.uri()), + &client, + CancellationToken::new(), + callback, + ) + .await; + + let chunks = chunks.lock().unwrap(); + assert!(matches!(&chunks[0], StreamChunk::Token(t) if t == "A")); + assert!(matches!(&chunks[1], StreamChunk::Done)); + assert_eq!(chunks.len(), 2); + assert_eq!(accumulated, "A"); + } + + /// A server that closes the stream without `[DONE]` must still produce a + /// terminal Done (mirrors the native path's missing-done-marker fix). + #[tokio::test] + async fn stream_end_without_done_marker_emits_done() { + let server = MockServer::start().await; + let body = format!("{}{}", sse_content_line("A"), sse_content_line("B")); + mount_sse(&server, body.into_bytes()).await; + + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + let accumulated = stream_openai_chat( + chat_params(server.uri()), + &client, + CancellationToken::new(), + callback, + ) + .await; + + let chunks = chunks.lock().unwrap(); + assert!(matches!(&chunks[0], StreamChunk::Token(t) if t == "A")); + assert!(matches!(&chunks[1], StreamChunk::Token(t) if t == "B")); + assert!(matches!(&chunks[2], StreamChunk::Done)); + assert_eq!(chunks.len(), 3); + assert_eq!(accumulated, "AB"); + } + + /// Mirrors the native path's policy for unparseable lines: malformed data + /// lines, SSE comments, `event:` lines, and non-UTF-8 lines are all + /// silently skipped; the stream continues. + #[tokio::test] + async fn malformed_data_line_policy() { + let server = MockServer::start().await; + let mut body = Vec::new(); + body.extend_from_slice(b"data: this is not json\n"); + body.extend_from_slice(b": sse comment\n"); + body.extend_from_slice(b"event: ping\n"); + body.extend_from_slice(b"\xFF\xFE\n"); + body.extend_from_slice(b"data: {\"choices\":[]}\n"); + body.extend_from_slice(b"data: {\"choices\":[{}]}\n"); + body.extend_from_slice(sse_content_line("ok").as_bytes()); + body.extend_from_slice(b"data: [DONE]\n"); + mount_sse(&server, body).await; + + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + let accumulated = stream_openai_chat( + chat_params(server.uri()), + &client, + CancellationToken::new(), + callback, + ) + .await; + + let chunks = chunks.lock().unwrap(); + assert!(matches!(&chunks[0], StreamChunk::Token(t) if t == "ok")); + assert!(matches!(&chunks[1], StreamChunk::Done)); + assert_eq!(chunks.len(), 2, "skipped lines must emit nothing"); + assert_eq!(accumulated, "ok"); + } + + #[tokio::test] + async fn connect_refused_maps_engine_unreachable() { + // Bind then drop a listener so the port is closed. + let port = { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + listener.local_addr().unwrap().port() + }; + + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + let accumulated = stream_openai_chat( + chat_params(format!("http://127.0.0.1:{port}")), + &client, + CancellationToken::new(), + callback, + ) + .await; + + let chunks = chunks.lock().unwrap(); + assert_eq!(chunks.len(), 1); + assert!(matches!( + &chunks[0], + StreamChunk::Error(e) if e.kind == EngineErrorKind::EngineUnreachable + && e.message.starts_with("The inference server could not be reached.") + )); + assert_eq!(accumulated, ""); + } + + #[tokio::test] + async fn http_404_maps_model_not_found() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(404)) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + stream_openai_chat( + chat_params(server.uri()), + &client, + CancellationToken::new(), + callback, + ) + .await; + + let chunks = chunks.lock().unwrap(); + assert_eq!(chunks.len(), 1); + assert!(matches!( + &chunks[0], + StreamChunk::Error(e) if e.kind == EngineErrorKind::ModelNotFound + && e.message.contains("test-model") + )); + } + + #[tokio::test] + async fn http_401_maps_other_with_auth_message() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(401)) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + stream_openai_chat( + chat_params(server.uri()), + &client, + CancellationToken::new(), + callback, + ) + .await; + + let chunks = chunks.lock().unwrap(); + assert_eq!(chunks.len(), 1); + assert!(matches!( + &chunks[0], + StreamChunk::Error(e) if e.kind == EngineErrorKind::Other + && e.message.contains("Authentication failed (HTTP 401)") + )); + } + + #[tokio::test] + async fn http_500_maps_other_with_status() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(500)) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + stream_openai_chat( + chat_params(server.uri()), + &client, + CancellationToken::new(), + callback, + ) + .await; + + let chunks = chunks.lock().unwrap(); + assert_eq!(chunks.len(), 1); + assert!(matches!( + &chunks[0], + StreamChunk::Error(e) if e.kind == EngineErrorKind::Other + && e.message.contains("HTTP 500") + )); + } + + /// 403 takes the same auth branch as 401. + #[test] + fn http_403_classifies_with_auth_message() { + let error = classify_v1_http_error(403, "m"); + assert_eq!(error.kind, EngineErrorKind::Other); + assert!(error.message.contains("Authentication failed (HTTP 403)")); + } + + #[tokio::test] + async fn cancel_emits_cancelled() { + let server = MockServer::start().await; + let body = format!("{}data: [DONE]\n", sse_content_line("never")); + mount_sse(&server, body.into_bytes()).await; + + let client = reqwest::Client::new(); + let token = CancellationToken::new(); + token.cancel(); + let (chunks, callback) = collect_chunks(); + let accumulated = + stream_openai_chat(chat_params(server.uri()), &client, token, callback).await; + + let chunks = chunks.lock().unwrap(); + assert_eq!(chunks.len(), 1); + assert!(matches!(&chunks[0], StreamChunk::Cancelled)); + assert_eq!(accumulated, ""); + } + + #[tokio::test] + async fn oversize_sse_line_aborts_with_other() { + let server = MockServer::start().await; + // A single unterminated data line just over the cap; no newline ever + // arrives, so the buffered length check must abort the stream. + let mut body = b"data: ".to_vec(); + body.extend(std::iter::repeat_n(b'a', MAX_SSE_LINE_BYTES + 1)); + mount_sse(&server, body).await; + + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + let accumulated = stream_openai_chat( + chat_params(server.uri()), + &client, + CancellationToken::new(), + callback, + ) + .await; + + let chunks = chunks.lock().unwrap(); + assert_eq!(chunks.len(), 1); + assert!(matches!( + &chunks[0], + StreamChunk::Error(e) if e.kind == EngineErrorKind::Other + && e.message.contains("oversized stream line") + )); + assert_eq!(accumulated, ""); + } + + /// A connection reset mid-stream surfaces as an Error chunk with kind + /// Other (mirrors the native path: not a connect/timeout failure), and + /// no Done is emitted after it. + #[tokio::test] + async fn mid_stream_error_maps_other() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut req_buf = [0u8; 8192]; + let _ = stream.read(&mut req_buf).await; + + let first_line = sse_content_line("A"); + // Promise more bytes than are sent, then shut down: the client + // sees a truncated body as a mid-stream transport error. + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\n\r\n{}", + first_line.len() + 64, + first_line + ); + let _ = stream.write_all(response.as_bytes()).await; + let _ = stream.shutdown().await; + }); + + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + let accumulated = stream_openai_chat( + chat_params(format!("http://127.0.0.1:{port}")), + &client, + CancellationToken::new(), + callback, + ) + .await; + + let chunks = chunks.lock().unwrap(); + assert!(chunks + .iter() + .any(|chunk| matches!(chunk, StreamChunk::Token(t) if t == "A"))); + let error = chunks.iter().find_map(|chunk| match chunk { + StreamChunk::Error(error) => Some(error), + _ => None, + }); + assert_eq!(error.unwrap().kind, EngineErrorKind::Other); + assert!(chunks + .iter() + .all(|chunk| !matches!(chunk, StreamChunk::Done))); + assert_eq!(accumulated, "A"); + } + + #[tokio::test] + async fn api_key_sent_as_bearer() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .and(header("authorization", "Bearer sk-test")) + .respond_with( + ResponseTemplate::new(200).set_body_raw("data: [DONE]\n", "text/event-stream"), + ) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let (chunks, callback) = collect_chunks(); + let mut params = chat_params(server.uri()); + params.api_key = Some("sk-test".to_string()); + stream_openai_chat(params, &client, CancellationToken::new(), callback).await; + + let chunks = chunks.lock().unwrap(); + assert!(matches!(&chunks[0], StreamChunk::Done)); + } + + #[tokio::test] + async fn no_api_key_sends_no_authorization_header() { + let server = MockServer::start().await; + mount_sse(&server, b"data: [DONE]\n".to_vec()).await; + + let client = reqwest::Client::new(); + let (_, callback) = collect_chunks(); + stream_openai_chat( + chat_params(server.uri()), + &client, + CancellationToken::new(), + callback, + ) + .await; + + let requests = server.received_requests().await.unwrap(); + assert_eq!(requests.len(), 1); + assert!(!requests[0].headers.contains_key("authorization")); + } + + // ── to_openai_message ─────────────────────────────────────────────────── + + #[test] + fn text_only_message_keeps_plain_string_content() { + let msg = user_message("hello"); + assert_eq!( + to_openai_message(&msg), + serde_json::json!({"role": "user", "content": "hello"}) + ); + } + + #[test] + fn empty_images_vec_keeps_plain_string_content() { + let msg = ChatMessage { + role: "user".to_string(), + content: "hello".to_string(), + images: Some(vec![]), + }; + assert_eq!( + to_openai_message(&msg), + serde_json::json!({"role": "user", "content": "hello"}) + ); + } + + #[test] + fn images_serialize_as_content_parts() { + let msg = ChatMessage { + role: "user".to_string(), + content: "what is this?".to_string(), + images: Some(vec!["QUJD".to_string(), "REVG".to_string()]), + }; + assert_eq!( + to_openai_message(&msg), + serde_json::json!({ + "role": "user", + "content": [ + {"type": "text", "text": "what is this?"}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,QUJD"}}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,REVG"}}, + ], + }) + ); + } + + // ── request_openai_json ───────────────────────────────────────────────── + + #[tokio::test] + async fn json_request_uses_response_format_and_extracts_content() { + let server = MockServer::start().await; + let schema = serde_json::json!({ + "type": "object", + "properties": {"a": {"type": "integer"}}, + }); + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .and(header("authorization", "Bearer sk-json")) + .and(body_partial_json(serde_json::json!({ + "model": "test-model", + "stream": false, + "temperature": 0, + "max_tokens": 256, + "response_format": { + "type": "json_schema", + "json_schema": {"name": "out", "strict": true, "schema": schema}, + }, + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "choices": [{"message": {"content": "{\"a\":1}"}}], + }))) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let result = request_openai_json( + &server.uri(), + "test-model", + &client, + vec![user_message("classify")], + schema.clone(), + Some("sk-json"), + 5, + 256, + &CancellationToken::new(), + ) + .await; + + assert_eq!(result, Ok("{\"a\":1}".to_string())); + } + + #[tokio::test] + async fn json_request_http_error_maps() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(500).set_body_string("boom")) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let result = request_openai_json( + &server.uri(), + "m", + &client, + vec![user_message("q")], + serde_json::json!({}), + None, + 5, + 64, + &CancellationToken::new(), + ) + .await; + + assert_eq!(result, Err(OpenAiError::Http(500, "boom".to_string()))); + } + + #[tokio::test] + async fn json_request_cancel_maps_cancelled() { + // No mock mounted: the pre-cancelled token must win the biased + // select before any request is sent. + let server = MockServer::start().await; + + let token = CancellationToken::new(); + token.cancel(); + let client = reqwest::Client::new(); + let result = request_openai_json( + &server.uri(), + "m", + &client, + vec![user_message("q")], + serde_json::json!({}), + None, + 5, + 64, + &token, + ) + .await; + + assert_eq!(result, Err(OpenAiError::Cancelled)); + } + + /// The per-call timeout surfaces through reqwest's send error, which maps + /// to Unreachable (mirrors the native `request_json`, where a timeout is + /// a transport error and maps to `LlmUnavailable`). + #[tokio::test] + async fn json_request_timeout_maps_unreachable() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"choices": []})) + .set_delay(std::time::Duration::from_secs(5)), + ) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let result = request_openai_json( + &server.uri(), + "m", + &client, + vec![user_message("q")], + serde_json::json!({}), + None, + 1, + 64, + &CancellationToken::new(), + ) + .await; + + assert!(matches!(result, Err(OpenAiError::Unreachable(_)))); + } + + /// A 2xx response whose body dies mid-read (connection closed before the + /// promised Content-Length) maps to BadBody, mirroring the native + /// `request_json` where a body-read failure is `LlmBadJson`. + #[tokio::test] + async fn json_request_body_read_failure_maps_bad_body() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut req_buf = [0u8; 8192]; + let _ = stream.read(&mut req_buf).await; + // Promise more bytes than are sent, then shut down. + let response = + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 1000\r\n\r\n{\"choices\""; + let _ = stream.write_all(response.as_bytes()).await; + let _ = stream.shutdown().await; + }); + + let client = reqwest::Client::new(); + let result = request_openai_json( + &format!("http://127.0.0.1:{port}"), + "m", + &client, + vec![user_message("q")], + serde_json::json!({}), + None, + 5, + 64, + &CancellationToken::new(), + ) + .await; + + assert!(matches!(result, Err(OpenAiError::BadBody(_)))); + } + + #[tokio::test] + async fn json_request_malformed_body_maps_bad_body() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(200).set_body_string("not json")) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let result = request_openai_json( + &server.uri(), + "m", + &client, + vec![user_message("q")], + serde_json::json!({}), + None, + 5, + 64, + &CancellationToken::new(), + ) + .await; + + assert!(matches!(result, Err(OpenAiError::BadBody(_)))); + } + + #[tokio::test] + async fn json_request_empty_choices_maps_bad_body() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "choices": [], + }))) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let result = request_openai_json( + &server.uri(), + "m", + &client, + vec![user_message("q")], + serde_json::json!({}), + None, + 5, + 64, + &CancellationToken::new(), + ) + .await; + + assert_eq!( + result, + Err(OpenAiError::BadBody( + "response contained no choices".to_string() + )) + ); + } + + #[tokio::test] + async fn json_request_omits_authorization_without_key() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "choices": [{"message": {"content": "ok"}}], + }))) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let result = request_openai_json( + &server.uri(), + "m", + &client, + vec![user_message("q")], + serde_json::json!({}), + None, + 5, + 64, + &CancellationToken::new(), + ) + .await; + assert_eq!(result, Ok("ok".to_string())); + + let requests = server.received_requests().await.unwrap(); + assert_eq!(requests.len(), 1); + assert!(!requests[0].headers.contains_key("authorization")); + } +} diff --git a/src-tauri/src/search/llm.rs b/src-tauri/src/search/llm.rs index 8ad49aab..f030d1aa 100644 --- a/src-tauri/src/search/llm.rs +++ b/src-tauri/src/search/llm.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use serde::{Deserialize, Serialize}; use tokio_util::sync::CancellationToken; -use crate::commands::ChatMessage; +use crate::commands::{ChatMessage, LlmTransport}; use super::types::{ Action, JudgeVerdict, RouterJudgeOutput, SearchError, SearxResult, Sufficiency, @@ -336,6 +336,126 @@ fn transport_error( Err(SearchError::LlmUnavailable) } +/// Maps a [`crate::openai::OpenAiError`] from the `/v1` structured-output +/// client onto the search-pipeline error vocabulary, mirroring the +/// classification [`request_json`] applies on the native path. +fn map_openai_error(err: crate::openai::OpenAiError) -> SearchError { + match err { + crate::openai::OpenAiError::Cancelled => SearchError::Cancelled, + crate::openai::OpenAiError::Unreachable(_) => SearchError::LlmUnavailable, + crate::openai::OpenAiError::Http(status, _) => SearchError::LlmHttp(status), + crate::openai::OpenAiError::BadBody(_) => SearchError::LlmBadJson, + } +} + +/// `/v1` twin of [`request_json`]: sends the structured-output request via +/// [`crate::openai::request_openai_json`] and emits the same forensic +/// [`RecorderEvent::LlmCall`] record. +/// +/// `num_predict` translates to the wire's `max_tokens`. `num_ctx` is NOT +/// sent on `/v1`: for the built-in engine the context size is a launch +/// property of the llama-server process, and for `openai`-kind servers it is +/// informational only (spec 6.5). +#[allow(clippy::too_many_arguments)] +async fn request_json_v1( + base_url: &str, + api_key: Option<&str>, + model: &str, + client: &reqwest::Client, + messages: Vec, + format: serde_json::Value, + cancel_token: &CancellationToken, + timeout_secs: u64, + num_predict: i32, + recorder: &Arc, + stage: &str, +) -> Result { + let endpoint = format!("{base_url}/v1/chat/completions"); + // Build the trace body via the same helper request_openai_json uses so + // the recorded body always mirrors the actual wire shape. + let request_body_value = + crate::openai::json_request_body(model, &messages, format.clone(), num_predict); + let started = std::time::Instant::now(); + let result = crate::openai::request_openai_json( + base_url, + model, + client, + messages, + format, + api_key, + timeout_secs, + num_predict, + cancel_token, + ) + .await; + let (response_raw, error) = match &result { + Ok(content) => (Some(content.clone()), None), + Err(e) => (None, Some(format!("{e:?}"))), + }; + recorder.record(RecorderEvent::LlmCall { + stage: stage.to_string(), + endpoint, + request_body: request_body_value, + response_raw, + latency_ms: started.elapsed().as_millis() as u64, + error, + }); + result.map_err(map_openai_error) +} + +/// Dispatches a structured-output request to the active transport: the +/// native Ollama path keeps the exact [`request_json`] wire body; `/v1` +/// servers go through [`request_json_v1`]. +#[allow(clippy::too_many_arguments)] +async fn request_structured( + transport: &LlmTransport, + model: &str, + client: &reqwest::Client, + messages: Vec, + format: serde_json::Value, + cancel_token: &CancellationToken, + timeout_secs: u64, + num_ctx: u32, + num_predict: i32, + recorder: &Arc, + stage: &str, +) -> Result { + match transport { + LlmTransport::OllamaNative { endpoint } => { + request_json( + endpoint, + model, + client, + messages, + format, + cancel_token, + timeout_secs, + num_ctx, + num_predict, + recorder, + stage, + ) + .await + } + LlmTransport::V1 { base_url, api_key } => { + request_json_v1( + base_url, + api_key.as_deref(), + model, + client, + messages, + format, + cancel_token, + timeout_secs, + num_predict, + recorder, + stage, + ) + .await + } + } +} + // ─── Merged router+judge call ──────────────────────────────────────────────── /// Merged router+judge call that returns [`RouterJudgeOutput`] in a single @@ -361,7 +481,7 @@ fn transport_error( /// web search, because malformed router output should fail closed. #[allow(clippy::too_many_arguments)] pub async fn call_router_merged( - endpoint: &str, + transport: &LlmTransport, model: &str, client: &reqwest::Client, history: &[ChatMessage], @@ -380,8 +500,8 @@ pub async fn call_router_merged( // First attempt: standard prompt. let messages = build_router_messages(&system, history, query); - let raw = request_json( - endpoint, + let raw = request_structured( + transport, model, client, messages, @@ -406,8 +526,8 @@ pub async fn call_router_merged( "{query}\n\nReply with ONLY the JSON object described by the system prompt. No prose, no markdown fences, no explanation." ); let retry_messages = build_router_messages(&system, history, &strict_query); - let retry_raw = request_json( - endpoint, + let retry_raw = request_structured( + transport, model, client, retry_messages, @@ -544,7 +664,7 @@ fn parse_router_sufficiency(value: &str) -> Option { /// parse error. #[allow(clippy::too_many_arguments)] pub async fn call_judge( - endpoint: &str, + transport: &LlmTransport, model: &str, client: &reqwest::Client, query: &str, @@ -581,8 +701,8 @@ pub async fn call_judge( images: None, }, ]; - let raw = request_json( - endpoint, + let raw = request_structured( + transport, model, client, messages, @@ -628,8 +748,8 @@ pub async fn call_judge( images: None, }, ]; - let retry_raw = request_json( - endpoint, + let retry_raw = request_structured( + transport, model, client, retry_messages, @@ -1086,6 +1206,14 @@ mod router_judge_tests { ))) } + /// Call-shape helper: wraps a bare `/api/chat` endpoint into the native + /// transport the router/judge callers now take. + fn native(endpoint: impl Into) -> LlmTransport { + LlmTransport::OllamaNative { + endpoint: endpoint.into(), + } + } + // ── build_judge_user_message ───────────────────────────────────────────── #[test] @@ -1279,7 +1407,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let output = call_router_merged( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, &[], @@ -1314,7 +1442,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let output = call_router_merged( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, &[], @@ -1361,7 +1489,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let output = call_router_merged( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, &[], @@ -1407,7 +1535,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let output = call_router_merged( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, &[], @@ -1432,7 +1560,7 @@ mod router_judge_tests { let token = CancellationToken::new(); token.cancel(); let err = call_router_merged( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), "m", &client, &[], @@ -1463,7 +1591,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let err = call_router_merged( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, &[], @@ -1494,7 +1622,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let err = call_router_merged( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, &[], @@ -1534,7 +1662,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let err = call_router_merged( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, &[], @@ -1581,7 +1709,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let output = call_router_merged( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, &[], @@ -1630,7 +1758,7 @@ mod router_judge_tests { text: "s".into(), }]; let verdict = call_judge( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, "q", @@ -1674,7 +1802,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let verdict = call_judge( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, "q", @@ -1729,7 +1857,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let verdict = call_judge( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, "q", @@ -1760,7 +1888,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let err = call_judge( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), "m", &client, "q", @@ -1791,7 +1919,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let err = call_judge( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, "q", @@ -1825,7 +1953,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let verdict = call_judge( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, "q", @@ -1855,7 +1983,7 @@ mod router_judge_tests { let token = CancellationToken::new(); token.cancel(); let err = call_judge( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), "m", &client, "q", @@ -1886,7 +2014,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let verdict = call_judge( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, "q", @@ -1929,7 +2057,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let verdict = call_judge( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, "q", @@ -2001,7 +2129,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); call_judge( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, "q", @@ -2050,7 +2178,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let verdict = call_judge( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, "q", @@ -2092,7 +2220,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let err = call_judge( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, "q", @@ -2122,7 +2250,7 @@ mod router_judge_tests { // call_router_merged calls request_json internally; a 503 maps to // SearchError::LlmHttp(503). let err = call_router_merged( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, &[], @@ -2159,7 +2287,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let output = call_router_merged( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, &[], @@ -2199,7 +2327,7 @@ mod router_judge_tests { let client = reqwest::Client::new(); let token = CancellationToken::new(); let _ = call_judge( - &format!("{}/api/chat", server.uri()), + &native(format!("{}/api/chat", server.uri())), "m", &client, "q", @@ -2213,4 +2341,296 @@ mod router_judge_tests { .await .unwrap(); } + + // ── native body regression ─────────────────────────────────────────────── + + /// Locks the native structured-output wire body across the transport + /// change: `format` must carry the JSON schema and `options` must still + /// carry `num_predict` and `num_ctx` exactly as before. + #[tokio::test] + async fn native_request_body_carries_format_and_options() { + let server = MockServer::start().await; + let captured: std::sync::Arc>> = + std::sync::Arc::new(std::sync::Mutex::new(None)); + let captured_clone = captured.clone(); + Mock::given(method("POST")) + .and(path("/api/chat")) + .respond_with(move |req: &wiremock::Request| { + let body: serde_json::Value = serde_json::from_slice(&req.body).unwrap(); + *captured_clone.lock().unwrap() = Some(body); + ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "message": { + "role": "assistant", + "content": "{\"action\":\"proceed\",\"clarifying_question\":null,\"history_sufficiency\":\"sufficient\",\"optimized_query\":\"q\"}" + }, + "done": true + })) + }) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let token = CancellationToken::new(); + call_router_merged( + &native(format!("{}/api/chat", server.uri())), + "m", + &client, + &[], + "q", + "2026-04-18", + &token, + ROUTER_TIMEOUT_SECS, + crate::config::defaults::DEFAULT_NUM_CTX, + &noop_recorder(), + ) + .await + .unwrap(); + + let body = captured.lock().unwrap().clone().expect("body captured"); + assert_eq!(body["format"], router_output_schema()); + assert_eq!(body["stream"], serde_json::json!(false)); + assert_eq!( + body["options"]["num_predict"], + serde_json::json!(ROUTER_MAX_TOKENS) + ); + assert_eq!( + body["options"]["num_ctx"], + serde_json::json!(crate::config::defaults::DEFAULT_NUM_CTX) + ); + assert_eq!(body["options"]["temperature"], serde_json::json!(0.0)); + assert_eq!(body["options"]["top_p"], serde_json::json!(1.0)); + assert_eq!(body["options"]["top_k"], serde_json::json!(1)); + } + + // ── /v1 transport ──────────────────────────────────────────────────────── + + /// Call-shape helper: a `/v1` transport pointing at a wiremock server. + fn v1(base_url: impl Into, api_key: Option<&str>) -> LlmTransport { + LlmTransport::V1 { + base_url: base_url.into(), + api_key: api_key.map(str::to_string), + } + } + + #[tokio::test] + async fn router_on_v1_uses_response_format() { + let server = MockServer::start().await; + let captured: std::sync::Arc>> = + std::sync::Arc::new(std::sync::Mutex::new(None)); + let captured_clone = captured.clone(); + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(move |req: &wiremock::Request| { + let body: serde_json::Value = serde_json::from_slice(&req.body).unwrap(); + *captured_clone.lock().unwrap() = Some(body); + ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "choices": [{ + "message": { + "role": "assistant", + "content": "{\"action\":\"proceed\",\"clarifying_question\":null,\"history_sufficiency\":\"insufficient\",\"optimized_query\":\"curl CVE\"}" + } + }] + })) + }) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let token = CancellationToken::new(); + let output = call_router_merged( + &v1(server.uri(), None), + "m", + &client, + &[], + "tell me about curl CVE", + "2026-04-18", + &token, + ROUTER_TIMEOUT_SECS, + crate::config::defaults::DEFAULT_NUM_CTX, + &noop_recorder(), + ) + .await + .unwrap(); + assert!(matches!(output.action, Action::Proceed)); + assert_eq!(output.optimized_query.as_deref(), Some("curl CVE")); + + let body = captured.lock().unwrap().clone().expect("body captured"); + assert_eq!( + body["response_format"]["json_schema"]["schema"], + router_output_schema() + ); + assert_eq!(body["temperature"], serde_json::json!(0)); + // num_predict translates to max_tokens; num_ctx is NOT sent on /v1. + assert_eq!(body["max_tokens"], serde_json::json!(ROUTER_MAX_TOKENS)); + assert!(body.get("options").is_none()); + assert!(body.get("num_ctx").is_none()); + } + + #[tokio::test] + async fn judge_on_v1_sends_bearer() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .and(wiremock::matchers::header( + "authorization", + "Bearer sk-test", + )) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "choices": [{ + "message": { + "role": "assistant", + "content": "{\"sufficiency\":\"sufficient\",\"reasoning\":\"covered\",\"gap_queries\":[]}" + } + }] + }))) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let token = CancellationToken::new(); + let verdict = call_judge( + &v1(server.uri(), Some("sk-test")), + "m", + &client, + "q", + &[], + &token, + crate::config::defaults::DEFAULT_JUDGE_TIMEOUT_S, + crate::config::defaults::DEFAULT_NUM_CTX, + JudgeStage::Snippet, + &noop_recorder(), + ) + .await + .unwrap(); + assert!(matches!( + verdict.sufficiency, + crate::search::types::Sufficiency::Sufficient + )); + } + + #[tokio::test] + async fn v1_http_error_maps_to_llm_http() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(503)) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let token = CancellationToken::new(); + let err = call_router_merged( + &v1(server.uri(), None), + "m", + &client, + &[], + "q", + "2026-04-18", + &token, + ROUTER_TIMEOUT_SECS, + crate::config::defaults::DEFAULT_NUM_CTX, + &noop_recorder(), + ) + .await + .unwrap_err(); + assert_eq!(err, SearchError::LlmHttp(503)); + } + + #[test] + fn map_openai_error_covers_every_variant() { + use crate::openai::OpenAiError; + assert_eq!( + map_openai_error(OpenAiError::Cancelled), + SearchError::Cancelled + ); + assert_eq!( + map_openai_error(OpenAiError::Unreachable("refused".into())), + SearchError::LlmUnavailable + ); + assert_eq!( + map_openai_error(OpenAiError::Http(429, "slow down".into())), + SearchError::LlmHttp(429) + ); + assert_eq!( + map_openai_error(OpenAiError::BadBody("not json".into())), + SearchError::LlmBadJson + ); + } + + /// The trace body recorded by `request_json_v1` must mirror the actual + /// wire shape sent by `request_openai_json`: same keys, same structure, + /// no hand-built approximations (e.g. the old non-wire key + /// "response_format_schema" or the missing "temperature"). + #[tokio::test] + async fn v1_trace_body_mirrors_wire_shape() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "choices": [{ + "message": { + "role": "assistant", + "content": "{\"action\":\"proceed\",\"clarifying_question\":null,\"history_sufficiency\":\"insufficient\",\"optimized_query\":\"q\"}" + } + }] + }))) + .mount(&server) + .await; + + let mock = std::sync::Arc::new(crate::trace::recorder::MockRecorder::new()); + let inner: std::sync::Arc = mock.clone(); + let bound = std::sync::Arc::new(crate::trace::BoundRecorder::new( + inner, + ConversationId::new("test-v1-trace"), + )); + + let client = reqwest::Client::new(); + let token = CancellationToken::new(); + call_router_merged( + &v1(server.uri(), None), + "the-model", + &client, + &[], + "q", + "2026-04-18", + &token, + ROUTER_TIMEOUT_SECS, + crate::config::defaults::DEFAULT_NUM_CTX, + &bound, + ) + .await + .unwrap(); + + // The mock recorder captures exactly the events request_json_v1 emits. + // call_router_merged emits exactly one LlmCall (no retry needed). + let snapshot = mock.snapshot(); + assert_eq!(snapshot.len(), 1, "exactly one LlmCall event expected"); + let body = snapshot[0] + .1 + .llm_call_request_body() + .expect("snapshot[0] must be an LlmCall event") + .clone(); + + // Wire-shape keys that must be present. + assert_eq!(body["model"], serde_json::json!("the-model")); + assert_eq!(body["stream"], serde_json::json!(false)); + assert_eq!(body["temperature"], serde_json::json!(0)); + assert_eq!(body["max_tokens"], serde_json::json!(ROUTER_MAX_TOKENS)); + assert_eq!( + body["response_format"]["type"], + serde_json::json!("json_schema") + ); + assert_eq!( + body["response_format"]["json_schema"]["name"], + serde_json::json!("out") + ); + assert_eq!( + body["response_format"]["json_schema"]["schema"], + router_output_schema() + ); + + // Non-wire key from the old approximation must not be present. + assert!(body.get("response_format_schema").is_none()); + } } diff --git a/src-tauri/src/search/mod.rs b/src-tauri/src/search/mod.rs index ea21c1bb..2595cfec 100644 --- a/src-tauri/src/search/mod.rs +++ b/src-tauri/src/search/mod.rs @@ -69,29 +69,26 @@ pub async fn search_pipeline( app_config: State<'_, parking_lot::RwLock>, active_model_state: State<'_, ActiveModelState>, trace_recorder: State<'_, Arc>, + db: State<'_, crate::history::Database>, + model_store: State<'_, crate::models::storage::ModelStore>, + engine: State<'_, crate::engine::runner::EngineHandle>, + secrets: State<'_, crate::keychain::Secrets>, ) -> Result<(), String> { // Snapshot the config once so the entire pipeline sees a consistent view // even if the user edits Settings while a search is in flight. let app_config = app_config.read().clone(); - // Route by provider kind, mirroring `ask_model`. Phase 1 implements only - // the native Ollama path; a non-Ollama active provider cannot serve a - // search turn, so surface the same typed "not available yet" error the - // chat path emits instead of building a hostless `/api/chat` endpoint. - { - let kind = app_config.inference.active_provider_kind(); - let label = app_config - .inference - .active() - .map(|p| p.label.as_str()) - .unwrap_or(""); - if let Some(err) = crate::commands::unsupported_provider_error(kind, label) { - let _ = on_event.send(SearchEvent::Error { - message: err.message, - }); + // Route by the active provider's kind, mirroring `ask_model`. A builtin + // provider with no model picked surfaces as `NoModelSelected` here, so + // the frontend keeps `is_first_turn` armed exactly like the + // ActiveModelState bail below. + let route = match crate::commands::resolve_chat_route(&app_config.inference) { + Ok(route) => route, + Err(err) => { + let _ = on_event.send(route_failure_event(err)); return Ok(()); } - } + }; // Resolve the runtime search view from the loaded TOML. The single // source of truth lives in `config::defaults`; the loader has already @@ -101,12 +98,13 @@ pub async fn search_pipeline( // Snapshot the active model slug once from the picker-backed // ActiveModelState; drop the guard before any `.await` so we never - // hold a `MutexGuard` across an await point. - let model_name = { + // hold a `MutexGuard` across an await point. Builtin routes carry their + // model in the provider config instead (see `commands::model_for_route`). + let active_model = { let guard = active_model_state.0.lock().map_err(|e| e.to_string())?; guard.clone() }; - let Some(model_name) = model_name else { + let Some(model_name) = crate::commands::model_for_route(&route, active_model) else { // Mirrors the chat-path gate: refuse to dispatch with no active // model. The frontend strip already steers the user to the picker // before this point, so this branch is defense-in-depth for the @@ -133,13 +131,25 @@ pub async fn search_pipeline( return Ok(()); } - let ollama_endpoint = format!( - "{}/api/chat", - app_config - .inference - .active_provider_base_url() - .trim_end_matches('/') - ); + // Resolve the wire transport. For the builtin route this marks engine + // activity and ensures the sidecar serves the selected model before any + // pipeline stage issues an LLM call. + let transport = match crate::commands::resolve_llm_transport( + route, + &db, + &model_store, + &engine, + secrets.0.as_ref(), + app_config.inference.num_ctx, + ) + .await + { + Ok(transport) => transport, + Err(err) => { + let _ = on_event.send(transport_failure_event(err)); + return Ok(()); + } + }; let cancel_token = CancellationToken::new(); generation.set_token(cancel_token.clone()); @@ -197,7 +207,7 @@ pub async fn search_pipeline( let token_count = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)); let router = pipeline::DefaultRouterJudge::new( - ollama_endpoint.clone(), + transport.clone(), model_name.clone(), (*client).clone(), cancel_token.clone(), @@ -207,7 +217,7 @@ pub async fn search_pipeline( Arc::clone(&recorder), ); let judge = pipeline::DefaultJudge::new( - ollama_endpoint.clone(), + transport.clone(), model_name.clone(), (*client).clone(), cancel_token.clone(), @@ -219,7 +229,7 @@ pub async fn search_pipeline( let recorder_for_pump = Arc::clone(&recorder); let token_count_for_pump = Arc::clone(&token_count); let result = pipeline::run_agentic( - &ollama_endpoint, + &transport, &searxng_endpoint, &runtime_config.reader_url, &model_name, @@ -287,3 +297,76 @@ pub async fn search_pipeline( generation.clear_token(); Ok(()) } + +/// Maps a [`crate::commands::resolve_chat_route`] failure onto the search +/// event stream. A builtin provider with no model picked must surface as the +/// typed `NoModelSelected` event (keeping the frontend's `is_first_turn` +/// armed), not as a generic error bubble; every other route failure carries +/// its user-facing message. +fn route_failure_event(err: crate::commands::EngineError) -> SearchEvent { + if err.kind == crate::commands::EngineErrorKind::NoModelSelected { + SearchEvent::NoModelSelected + } else { + SearchEvent::Error { + message: err.message, + } + } +} + +/// Maps a [`crate::commands::resolve_llm_transport`] failure onto the search +/// event stream. `Superseded` means a newer settings change preempted the +/// engine ensure: a cancellation, never an error. Engine failures (start +/// failure, missing manifest row) carry their user-facing message. +fn transport_failure_event(err: crate::commands::TransportError) -> SearchEvent { + match err { + crate::commands::TransportError::Superseded => SearchEvent::Cancelled, + crate::commands::TransportError::Engine(e) => SearchEvent::Error { message: e.message }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::commands::{EngineError, EngineErrorKind, TransportError}; + + #[test] + fn route_failure_event_maps_no_model_to_typed_event() { + let event = route_failure_event(EngineError { + kind: EngineErrorKind::NoModelSelected, + message: "No model selected\nPick or download a model in Settings.".to_string(), + }); + assert!(matches!(event, SearchEvent::NoModelSelected)); + } + + #[test] + fn route_failure_event_maps_other_kinds_to_error_message() { + let event = route_failure_event(EngineError { + kind: EngineErrorKind::Other, + message: "Something went wrong\nThe active provider has an unknown kind.".to_string(), + }); + assert!(matches!( + event, + SearchEvent::Error { message } if message.contains("unknown kind") + )); + } + + #[test] + fn transport_failure_event_maps_superseded_to_cancelled() { + assert!(matches!( + transport_failure_event(TransportError::Superseded), + SearchEvent::Cancelled + )); + } + + #[test] + fn transport_failure_event_maps_engine_error_to_message() { + let event = transport_failure_event(TransportError::Engine(EngineError { + kind: EngineErrorKind::EngineStartFailed, + message: "Thuki's engine could not start.\nspawn boom".to_string(), + })); + assert!(matches!( + event, + SearchEvent::Error { message } if message.contains("could not start") + )); + } +} diff --git a/src-tauri/src/search/pipeline.rs b/src-tauri/src/search/pipeline.rs index 25f89d44..7e8c41fe 100644 --- a/src-tauri/src/search/pipeline.rs +++ b/src-tauri/src/search/pipeline.rs @@ -27,7 +27,8 @@ use async_trait::async_trait; use tokio_util::sync::CancellationToken; use crate::commands::{ - stream_ollama_chat, ChatMessage, ConversationHistory, OllamaChatParams, StreamChunk, + stream_ollama_chat, ChatMessage, ConversationHistory, LlmTransport, OllamaChatParams, + StreamChunk, }; use super::chunker; @@ -425,16 +426,22 @@ fn can_answer_from_history(history_snapshot: &[ChatMessage]) -> bool { !history_snapshot.is_empty() } -/// Runs a streaming Ollama call, translating `StreamChunk` events into -/// `SearchEvent` events and persisting the completed assistant turn on normal -/// completion (or partial completion via cancellation). +/// Runs a streaming chat call against the active transport, translating +/// `StreamChunk` events into `SearchEvent` events and persisting the +/// completed assistant turn on normal completion (or partial completion via +/// cancellation). +/// +/// The native Ollama path keeps its exact `stream_ollama_chat` parameters; +/// `/v1` transports stream through `openai::stream_openai_chat`, which emits +/// the same `StreamChunk` contract, so one callback pump (and its `saw_done` +/// terminal accounting) serves both arms. /// /// `warnings` and `metadata` are forwarded to `persist_turn`; the DB columns /// for these fields were added in Task 17. The frontend serializes and passes /// them back via `persist_message` when saving the turn. #[allow(clippy::too_many_arguments)] async fn run_streaming_branch( - endpoint: &str, + transport: &LlmTransport, model: &str, client: &reqwest::Client, cancel_token: CancellationToken, @@ -449,10 +456,11 @@ async fn run_streaming_branch( recorder: &Arc, stage: &str, ) { + let endpoint_label = transport.endpoint_label(); // Snapshot the request body before streaming starts so the trace can show // exactly what prompt the synthesis call was sent. let request_body = serde_json::json!({ - "endpoint": endpoint, + "endpoint": endpoint_label, "model": model, "messages": messages.iter().map(|m| serde_json::json!({ "role": m.role, @@ -465,35 +473,57 @@ async fn run_streaming_branch( let token_count_for_callback = token_count.clone(); let saw_done = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); let saw_done_for_callback = saw_done.clone(); - let accumulated = stream_ollama_chat( - OllamaChatParams { - endpoint: endpoint.to_string(), - model: model.to_string(), - messages, - think: false, - keep_alive: None, - num_ctx, - }, - client, - cancel_token, - |chunk| match chunk { - StreamChunk::Done => { - saw_done_for_callback.store(true, Ordering::SeqCst); - } - other => { - if matches!(other, StreamChunk::Token(_)) { - token_count_for_callback.fetch_add(1, Ordering::SeqCst); - } - on_event(translate_chunk(other)) + let pump = |chunk: StreamChunk| match chunk { + StreamChunk::Done => { + saw_done_for_callback.store(true, Ordering::SeqCst); + } + other => { + if matches!(other, StreamChunk::Token(_)) { + token_count_for_callback.fetch_add(1, Ordering::SeqCst); } - }, - ) - .await; + on_event(translate_chunk(other)) + } + }; + let accumulated = match transport { + LlmTransport::OllamaNative { endpoint } => { + stream_ollama_chat( + OllamaChatParams { + endpoint: endpoint.to_string(), + model: model.to_string(), + messages, + think: false, + keep_alive: None, + num_ctx, + }, + client, + cancel_token, + pump, + ) + .await + } + // num_ctx is NOT sent on /v1: for the builtin engine it is a launch + // property of the llama-server process, and for openai-kind servers + // it is informational only (spec 6.5). + LlmTransport::V1 { base_url, api_key } => { + crate::openai::stream_openai_chat( + crate::openai::OpenAiChatParams { + base_url: base_url.clone(), + model: model.to_string(), + messages, + api_key: api_key.clone(), + }, + client, + cancel_token, + pump, + ) + .await + } + }; record_streaming_llm_call( recorder, stage, - endpoint, + &endpoint_label, request_body, &accumulated, token_count.load(Ordering::SeqCst), @@ -616,7 +646,7 @@ pub trait JudgeCaller: Send + Sync { /// Tests must NOT use this struct directly as it would hit a real Ollama /// instance. Inject a mock [`RouterJudgeCaller`] instead. pub struct DefaultRouterJudge { - endpoint: String, + transport: LlmTransport, model: String, client: reqwest::Client, cancel: CancellationToken, @@ -630,9 +660,9 @@ impl DefaultRouterJudge { /// Constructs a `DefaultRouterJudge` that delegates to /// [`llm::call_router_merged`]. /// - /// - `endpoint`: fully-qualified `/api/chat` URL (e.g. - /// `http://127.0.0.1:11434/api/chat`). - /// - `model`: Ollama model identifier (e.g. `"mistral"`). + /// - `transport`: the resolved wire target (native `/api/chat` endpoint + /// or a `/v1` server). + /// - `model`: model identifier (e.g. `"mistral"`). /// - `client`: shared `reqwest::Client`; the Tauri command clones it from /// `AppState`. /// - `cancel`: the pipeline's cancellation token; races against the HTTP @@ -645,7 +675,7 @@ impl DefaultRouterJudge { #[cfg_attr(coverage_nightly, coverage(off))] #[allow(clippy::too_many_arguments)] pub fn new( - endpoint: String, + transport: LlmTransport, model: String, client: reqwest::Client, cancel: CancellationToken, @@ -655,7 +685,7 @@ impl DefaultRouterJudge { recorder: Arc, ) -> Self { Self { - endpoint, + transport, model, client, cancel, @@ -676,7 +706,7 @@ impl RouterJudgeCaller for DefaultRouterJudge { query: &str, ) -> Result { call_router_merged( - &self.endpoint, + &self.transport, &self.model, &self.client, history, @@ -699,7 +729,7 @@ impl RouterJudgeCaller for DefaultRouterJudge { /// /// Tests must NOT use this struct directly. Inject a mock [`JudgeCaller`]. pub struct DefaultJudge { - endpoint: String, + transport: LlmTransport, model: String, client: reqwest::Client, cancel: CancellationToken, @@ -711,8 +741,9 @@ pub struct DefaultJudge { impl DefaultJudge { /// Constructs a `DefaultJudge` that delegates to [`llm::call_judge`]. /// - /// - `endpoint`: fully-qualified `/api/chat` URL. - /// - `model`: Ollama model identifier. + /// - `transport`: the resolved wire target (native `/api/chat` endpoint + /// or a `/v1` server). + /// - `model`: model identifier. /// - `client`: shared `reqwest::Client`. /// - `cancel`: the pipeline's cancellation token; races against the HTTP /// call inside `call_judge`. @@ -722,7 +753,7 @@ impl DefaultJudge { #[cfg_attr(coverage_nightly, coverage(off))] #[allow(clippy::too_many_arguments)] pub fn new( - endpoint: String, + transport: LlmTransport, model: String, client: reqwest::Client, cancel: CancellationToken, @@ -731,7 +762,7 @@ impl DefaultJudge { recorder: Arc, ) -> Self { Self { - endpoint, + transport, model, client, cancel, @@ -752,7 +783,7 @@ impl JudgeCaller for DefaultJudge { stage: JudgeStage, ) -> Result { call_judge( - &self.endpoint, + &self.transport, &self.model, &self.client, query, @@ -800,7 +831,7 @@ fn lock_or_recover(mutex: &std::sync::Mutex) -> std::sync::MutexGuard<'_, /// Shared immutable inputs used by the extracted search-pipeline stages. struct SearchExecutionContext<'a> { - ollama_endpoint: &'a str, + transport: &'a LlmTransport, searxng_endpoint: &'a str, model: &'a str, client: &'a reqwest::Client, @@ -1058,7 +1089,7 @@ async fn stream_synthesis_from_sources( (shared.on_event)(SearchEvent::Composing); run_streaming_branch( - shared.ollama_endpoint, + shared.transport, shared.model, shared.client, shared.cancel_token.clone(), @@ -1165,7 +1196,7 @@ async fn run_history_answer_branch( ); run_streaming_branch( - shared.ollama_endpoint, + shared.transport, shared.model, shared.client, shared.cancel_token.clone(), @@ -1790,7 +1821,7 @@ async fn run_gap_refinement_loop( /// immediately on cancel rather than waiting for round-trips to complete. #[allow(clippy::too_many_arguments)] pub async fn run_agentic( - ollama_endpoint: &str, + transport: &LlmTransport, searxng_endpoint: &str, reader_base_url: &str, model: &str, @@ -1861,7 +1892,7 @@ pub async fn run_agentic( }; let shared = SearchExecutionContext { - ollama_endpoint, + transport, searxng_endpoint, model, client, @@ -2186,7 +2217,7 @@ pub async fn run_agentic( emit_trace(on_event, compose_step); on_event(SearchEvent::Composing); run_streaming_branch( - ollama_endpoint, + transport, model, client, cancel_token, @@ -2667,6 +2698,14 @@ mod tests { } } + /// Call-shape helper: wraps a bare `/api/chat` endpoint into the native + /// transport `run_agentic` and the streaming branch now take. + fn native(endpoint: impl Into) -> LlmTransport { + LlmTransport::OllamaNative { + endpoint: endpoint.into(), + } + } + // ── today_iso ─────────────────────────────────────────────────────────── #[test] @@ -2841,7 +2880,7 @@ mod tests { let (_, cb) = collect_events(); run_streaming_branch( - &format!("{}/api/chat", server.url()), + &native(format!("{}/api/chat", server.url())), "m", &client, token, @@ -2864,6 +2903,79 @@ mod tests { assert!(h.messages.lock().unwrap().is_empty()); } + // ── run_streaming_branch: /v1 transport ────────────────────────────────── + + /// A `/v1` transport must stream through `openai::stream_openai_chat` and + /// drive the exact same pump contract as the native arm: Token events, + /// `saw_done` accounting that yields a terminal `Done`, and persistence + /// of the accumulated turn. + #[tokio::test] + async fn synthesis_on_v1_streams() { + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let server = MockServer::start().await; + let body = "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n\ + data: {\"choices\":[{\"delta\":{\"content\":\" there\"}}]}\n\n\ + data: [DONE]\n"; + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with( + ResponseTemplate::new(200).set_body_raw(body.as_bytes(), "text/event-stream"), + ) + .expect(1) + .mount(&server) + .await; + + let client = reqwest::Client::new(); + let token = CancellationToken::new(); + let h = ConversationHistory::new(); + let (events, cb) = collect_events(); + let transport = LlmTransport::V1 { + base_url: server.uri(), + api_key: None, + }; + + run_streaming_branch( + &transport, + "m", + &client, + token, + vec![make_user_msg("q")], + &h, + 0, + make_user_msg("q"), + Vec::new(), + None, + &cb, + DEFAULT_NUM_CTX, + &(Arc::new(crate::trace::BoundRecorder::noop_for( + crate::trace::ConversationId::new("test-conv-pipeline"), + ))), + "synthesis", + ) + .await; + + let events = events.lock().unwrap(); + assert!(matches!( + &events[0], + SearchEvent::Token { content } if content == "Hi" + )); + assert!(matches!( + &events[1], + SearchEvent::Token { content } if content == " there" + )); + assert!( + matches!(&events[2], SearchEvent::Done { .. }), + "saw_done must surface as a terminal Done event" + ); + assert_eq!(events.len(), 3); + + let conv = h.messages.lock().unwrap(); + assert_eq!(conv.len(), 2, "user + assistant turn persisted"); + assert_eq!(conv[1].content, "Hi there"); + } + // ── DefaultRouterJudge / DefaultJudge construction ─────────────────────── #[test] @@ -2873,7 +2985,7 @@ mod tests { crate::trace::ConversationId::new("test-conv-pipeline"), )); let _judge = DefaultRouterJudge::new( - "http://127.0.0.1:11434/api/chat".into(), + native("http://127.0.0.1:11434/api/chat"), "mistral".into(), reqwest::Client::new(), cancel, @@ -2891,7 +3003,7 @@ mod tests { crate::trace::ConversationId::new("test-conv-pipeline"), )); let _judge = DefaultJudge::new( - "http://127.0.0.1:11434/api/chat".into(), + native("http://127.0.0.1:11434/api/chat"), "mistral".into(), reqwest::Client::new(), cancel, @@ -3187,6 +3299,14 @@ mod agentic_tests { Arc::new(BoundRecorder::noop_for(ConversationId::new(TEST_CONV_ID))) } + /// Call-shape helper: wraps a bare `/api/chat` endpoint into the native + /// transport `run_agentic` now takes. + fn native(endpoint: impl Into) -> LlmTransport { + LlmTransport::OllamaNative { + endpoint: endpoint.into(), + } + } + /// Constructs a mock recorder + an `Arc` wrapping it /// that the pipeline needs. Returns both so tests can pass the /// bound recorder to `run_agentic` while still introspecting the @@ -3442,7 +3562,7 @@ mod agentic_tests { let judge = QueueJudge(std::sync::Mutex::new(VecDeque::new())); let err = run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), "http://127.0.0.1:1/search", "http://127.0.0.1:1", "m", @@ -3484,7 +3604,7 @@ mod agentic_tests { let judge = QueueJudge(std::sync::Mutex::new(VecDeque::new())); let err = run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), "http://127.0.0.1:1/search", "http://127.0.0.1:1", "m", @@ -3527,7 +3647,7 @@ mod agentic_tests { let judge = QueueJudge(std::sync::Mutex::new(VecDeque::new())); run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), "http://127.0.0.1:1/search", "http://127.0.0.1:1", "m", @@ -3623,7 +3743,7 @@ mod agentic_tests { let (mock_recorder, recorder_view) = mock_recorder_pair(); run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), "http://127.0.0.1:1/search", "http://127.0.0.1:1", "m", @@ -3685,7 +3805,7 @@ mod agentic_tests { let (mock_recorder, recorder_view) = mock_recorder_pair(); let _ = run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), "http://127.0.0.1:1/search", "http://127.0.0.1:1", "m", @@ -3736,7 +3856,7 @@ mod agentic_tests { let judge = QueueJudge(std::sync::Mutex::new(VecDeque::new())); run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), "http://127.0.0.1:1/search", "http://127.0.0.1:1", "m", @@ -3805,7 +3925,7 @@ mod agentic_tests { let judge = QueueJudge(std::sync::Mutex::new(VecDeque::new())); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), "http://127.0.0.1:1/search", "http://127.0.0.1:1", "m", @@ -3886,7 +4006,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), "http://127.0.0.1:1", "m", @@ -3975,7 +4095,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), "http://127.0.0.1:1", "m", @@ -4033,7 +4153,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), "http://127.0.0.1:1", "m", @@ -4112,7 +4232,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx.url()), "http://127.0.0.1:1", "m", @@ -4212,7 +4332,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx.url()), &reader_server.uri(), "m", @@ -4304,7 +4424,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx.url()), &reader_server.uri(), "m", @@ -4397,7 +4517,7 @@ mod agentic_tests { runtime.pipeline_wall_clock_budget_s = 0; run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx.url()), &reader_server.uri(), "m", @@ -4507,7 +4627,7 @@ mod agentic_tests { runtime.pipeline_input_char_budget = 10; run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -4637,7 +4757,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -4758,7 +4878,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -4842,7 +4962,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx.url()), "http://127.0.0.1:1", "m", @@ -4894,7 +5014,7 @@ mod agentic_tests { let judge = QueueJudge(std::sync::Mutex::new(VecDeque::new())); let err = run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx.url()), "http://127.0.0.1:1", "m", @@ -4966,7 +5086,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx.url()), "http://127.0.0.1:1", "m", @@ -5041,7 +5161,7 @@ mod agentic_tests { let (events, cb) = collect_events(); run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), "http://127.0.0.1:1/search", "http://127.0.0.1:1", "m", @@ -5087,7 +5207,7 @@ mod agentic_tests { let (events, cb) = collect_events(); run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), "http://127.0.0.1:1/search", "http://127.0.0.1:1", "m", @@ -5134,7 +5254,7 @@ mod agentic_tests { let judge = QueueJudge(std::sync::Mutex::new(VecDeque::new())); let err = run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx.url()), "http://127.0.0.1:1", "m", @@ -5167,7 +5287,7 @@ mod agentic_tests { let judge = QueueJudge(std::sync::Mutex::new(VecDeque::new())); let err = run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), "http://127.0.0.1:1/search", "http://127.0.0.1:1", "m", @@ -5208,7 +5328,7 @@ mod agentic_tests { let judge = QueueJudge(std::sync::Mutex::new(VecDeque::new())); let err = run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx.url()), "http://127.0.0.1:1", "m", @@ -5266,7 +5386,7 @@ mod agentic_tests { )); let err = run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx.url()), &reader_server.uri(), "m", @@ -5331,7 +5451,7 @@ mod agentic_tests { }; run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx.url()), "http://127.0.0.1:1", "m", @@ -5406,7 +5526,7 @@ mod agentic_tests { }); run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx.url()), &reader_server.uri(), "m", @@ -5490,7 +5610,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx.url()), &reader_server.uri(), "m", @@ -5595,7 +5715,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx.url()), &reader_server.uri(), "m", @@ -5678,7 +5798,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx.url()), &reader_server.uri(), "m", @@ -5787,7 +5907,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -5941,7 +6061,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -6064,7 +6184,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -6217,7 +6337,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -6336,7 +6456,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -6469,7 +6589,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -6585,7 +6705,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), reader_base_url, "m", @@ -6663,7 +6783,7 @@ mod agentic_tests { }); run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx_server.uri()), "http://127.0.0.1:1", "m", @@ -6781,7 +6901,7 @@ mod agentic_tests { ); run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx.url()), &reader_server.uri(), "m", @@ -6873,7 +6993,7 @@ mod agentic_tests { )); let err = run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -6961,7 +7081,7 @@ mod agentic_tests { }); run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -7081,7 +7201,7 @@ mod agentic_tests { }; run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -7216,7 +7336,7 @@ mod agentic_tests { // The gap reader is slow; after CancelsAfterGapSearxng fires, the reader // call sees the cancelled token. run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx_server.uri()), &gap_reader_server.uri(), "m", @@ -7323,7 +7443,7 @@ mod agentic_tests { )); run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -7443,7 +7563,7 @@ mod agentic_tests { )); run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -7566,7 +7686,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), &reader_base, "m", @@ -7689,7 +7809,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -7825,7 +7945,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -7950,7 +8070,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -8058,7 +8178,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -8151,7 +8271,7 @@ mod agentic_tests { ); run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -8233,7 +8353,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx.url()), "http://127.0.0.1:1", "m", @@ -8382,7 +8502,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx.url()), &reader_server.uri(), "m", @@ -8455,7 +8575,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx.url()), &reader_server.uri(), "m", @@ -8574,7 +8694,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", @@ -8673,7 +8793,7 @@ mod agentic_tests { )); run_agentic( - &format!("{}/api/chat", ollama_server.uri()), + &native(format!("{}/api/chat", ollama_server.uri())), &format!("{}/search", searx_server.uri()), &reader_server.uri(), "m", diff --git a/src-tauri/src/settings_commands.rs b/src-tauri/src/settings_commands.rs index 64b2f502..20e1cdf8 100644 --- a/src-tauri/src/settings_commands.rs +++ b/src-tauri/src/settings_commands.rs @@ -102,6 +102,32 @@ pub(crate) fn trace_enabled_changed(prior_enabled: bool, resolved: &AppConfig) - resolved.debug.trace_enabled != prior_enabled } +/// Returns the new `[inference] idle_unload_minutes` value when the post-write +/// `AppConfig` changed it relative to the pre-write snapshot, `None` when it +/// is unchanged. Pulled out so the predicate is covered by tests instead of +/// riding inside the coverage-off Tauri command bodies that forward the new +/// value to the running engine actor. +pub(crate) fn idle_unload_minutes_changed(prior_minutes: u32, resolved: &AppConfig) -> Option { + let new_minutes = resolved.inference.idle_unload_minutes; + (new_minutes != prior_minutes).then_some(new_minutes) +} + +/// Forwards a changed `[inference] idle_unload_minutes` value to the engine +/// runner actor so the new idle-unload policy applies without a restart. +/// Spawned because the config commands are synchronous while the actor's +/// mailbox is async. Thin dispatch; the predicate and the actor's +/// `SetIdleMinutes` handling are both tested on their own. +#[cfg_attr(coverage_nightly, coverage(off))] +fn forward_idle_unload_minutes(app: &AppHandle, prior_minutes: u32, resolved: &AppConfig) { + if let Some(minutes) = idle_unload_minutes_changed(prior_minutes, resolved) { + let engine = app + .state::() + .inner() + .clone(); + tauri::async_runtime::spawn(async move { engine.set_idle_minutes(minutes).await }); + } +} + // ─── Tauri command surface ────────────────────────────────────────────────── /// Returns the current resolved `AppConfig` snapshot. @@ -131,7 +157,13 @@ pub fn set_config_field( trace_recorder: State<'_, std::sync::Arc>, ) -> Result { let path = config_path(&app)?; - let prior_trace_enabled = state.read().debug.trace_enabled; + let (prior_trace_enabled, prior_idle_unload_minutes) = { + let guard = state.read(); + ( + guard.debug.trace_enabled, + guard.inference.idle_unload_minutes, + ) + }; let resolved = { let mut guard = state.write(); let resolved = write_field_to_disk(&path, §ion, &key, value)?; @@ -149,6 +181,9 @@ pub fn set_config_field( let new_inner = crate::build_trace_inner(&app, resolved.debug.trace_enabled); trace_recorder.replace(new_inner); } + // Forward an `[inference] idle_unload_minutes` change to the engine + // runner so the new idle policy applies without restarting Thuki. + forward_idle_unload_minutes(&app, prior_idle_unload_minutes, &resolved); emit_config_updated(&app); Ok(resolved) } @@ -304,7 +339,13 @@ pub fn reset_config( trace_recorder: State<'_, std::sync::Arc>, ) -> Result { let path = config_path(&app)?; - let prior_trace_enabled = state.read().debug.trace_enabled; + let (prior_trace_enabled, prior_idle_unload_minutes) = { + let guard = state.read(); + ( + guard.debug.trace_enabled, + guard.inference.idle_unload_minutes, + ) + }; let resolved = { let mut guard = state.write(); let resolved = reset_section_on_disk(&path, section.as_deref())?; @@ -319,6 +360,9 @@ pub fn reset_config( let new_inner = crate::build_trace_inner(&app, resolved.debug.trace_enabled); trace_recorder.replace(new_inner); } + // A whole-file or `[inference]` reset restores the default idle-unload + // policy; forward it so the engine runner picks it up immediately. + forward_idle_unload_minutes(&app, prior_idle_unload_minutes, &resolved); emit_config_updated(&app); Ok(resolved) } @@ -385,7 +429,13 @@ pub fn reload_config_from_disk( trace_recorder: State<'_, std::sync::Arc>, ) -> Result { let path = config_path(&app)?; - let prior_trace_enabled = state.read().debug.trace_enabled; + let (prior_trace_enabled, prior_idle_unload_minutes) = { + let guard = state.read(); + ( + guard.debug.trace_enabled, + guard.inference.idle_unload_minutes, + ) + }; let resolved = { let mut guard = state.write(); let resolved = config::load_from_path(&path)?; @@ -399,6 +449,9 @@ pub fn reload_config_from_disk( let new_inner = crate::build_trace_inner(&app, resolved.debug.trace_enabled); trace_recorder.replace(new_inner); } + // Manual edits to `[inference] idle_unload_minutes` reach the engine + // runner through the same refresh path. + forward_idle_unload_minutes(&app, prior_idle_unload_minutes, &resolved); emit_config_updated(&app); Ok(resolved) } diff --git a/src-tauri/src/settings_commands/tests.rs b/src-tauri/src/settings_commands/tests.rs index f38045c8..aaf3b10c 100644 --- a/src-tauri/src/settings_commands/tests.rs +++ b/src-tauri/src/settings_commands/tests.rs @@ -12,8 +12,8 @@ use serde_json::json; use toml_edit::DocumentMut; use super::{ - coerce_json_to_toml, is_allowed_field, is_allowed_section, json_type_name, - json_value_to_toml_item, patch_document, read_document, reset_section_on_disk, + coerce_json_to_toml, idle_unload_minutes_changed, is_allowed_field, is_allowed_section, + json_type_name, json_value_to_toml_item, patch_document, read_document, reset_section_on_disk, trace_enabled_changed, write_field_to_disk, write_provider_field_to_disk, }; use crate::config::defaults::{ALLOWED_FIELDS, ALLOWED_SECTIONS}; @@ -1066,6 +1066,22 @@ fn trace_enabled_changed_returns_false_when_value_unchanged() { assert!(!trace_enabled_changed(false, &cfg)); } +// ─── idle_unload_minutes_changed ───────────────────────────────────────────── + +#[test] +fn idle_unload_minutes_changed_returns_new_value_on_change() { + let mut cfg = AppConfig::default(); + cfg.inference.idle_unload_minutes = 45; + assert_eq!(idle_unload_minutes_changed(0, &cfg), Some(45)); +} + +#[test] +fn idle_unload_minutes_changed_returns_none_when_unchanged() { + let mut cfg = AppConfig::default(); + cfg.inference.idle_unload_minutes = 45; + assert_eq!(idle_unload_minutes_changed(45, &cfg), None); +} + // ─── Helpers ───────────────────────────────────────────────────────────────── fn matches_type_mismatch(err: &ConfigError, section: &str, key: &str) { diff --git a/src-tauri/src/trace/recorder.rs b/src-tauri/src/trace/recorder.rs index 9d28e54d..fe98a21d 100644 --- a/src-tauri/src/trace/recorder.rs +++ b/src-tauri/src/trace/recorder.rs @@ -327,6 +327,17 @@ impl RecorderEvent { pub fn is_turn_end(&self) -> bool { matches!(self, RecorderEvent::TurnEnd { .. }) } + + /// Returns the `request_body` from an [`RecorderEvent::LlmCall`] event, + /// or `None` for any other variant. Provides a branch-free extraction + /// path for tests that assert on the recorded wire body. + #[cfg(test)] + pub(crate) fn llm_call_request_body(&self) -> Option<&serde_json::Value> { + match self { + RecorderEvent::LlmCall { request_body, .. } => Some(request_body), + _ => None, + } + } } /// Per-URL outcome inside a [`RecorderEvent::ReaderBatch`]. @@ -972,6 +983,22 @@ mod tests { assert!(!RecorderEvent::AssistantTokens { chunk: "x".into() }.is_turn_end()); } + #[test] + fn llm_call_request_body_returns_body_for_llm_call_and_none_for_others() { + let ev = RecorderEvent::LlmCall { + stage: "router".into(), + endpoint: "http://x/v1/chat/completions".into(), + request_body: json!({"model": "m"}), + response_raw: None, + latency_ms: 1, + error: None, + }; + assert_eq!(ev.llm_call_request_body(), Some(&json!({"model": "m"}))); + + let other = RecorderEvent::AssistantTokens { chunk: "hi".into() }; + assert!(other.llm_call_request_body().is_none()); + } + #[test] fn noop_recorder_swallows_every_event() { let r = NoopRecorder; diff --git a/src-tauri/src/warmup.rs b/src-tauri/src/warmup.rs index 7eea6050..c3dd148f 100644 --- a/src-tauri/src/warmup.rs +++ b/src-tauri/src/warmup.rs @@ -4,7 +4,9 @@ use std::sync::{ }; use tauri::{Emitter, Manager}; -use crate::config::defaults::VRAM_POLL_INTERVAL_SECS; +use crate::config::defaults::{ + PROVIDER_KIND_BUILTIN, PROVIDER_KIND_OLLAMA, VRAM_POLL_INTERVAL_SECS, +}; type InFlightSlot = Arc, String, u32)>>>; type OnLoaded = Arc; @@ -61,6 +63,83 @@ pub fn keep_alive_string(minutes: i32) -> String { } } +/// True when the VRAM poller should query Ollama's `/api/ps` on this tick. +/// The poller observes Ollama's VRAM only: the built-in engine publishes its +/// lifecycle through the engine status watch and an `openai` provider has no +/// local memory to observe, so any non-Ollama active provider skips the HTTP +/// call entirely. +pub(crate) fn vram_poll_active(kind: &str) -> bool { + kind == PROVIDER_KIND_OLLAMA +} + +/// The engine port to prime, when the built-in engine already serves a model. +/// `None` for every other lifecycle state: summoning the overlay must never +/// load a model implicitly (loads happen on explicit chat or download). +pub(crate) fn builtin_prime_port(status: &crate::engine::runner::EngineStatus) -> Option { + if status.state == "loaded" { + status.port + } else { + None + } +} + +/// Builds the prime request body for the built-in engine: a plain +/// `/v1/chat/completions` completion carrying the resolved system prompt and +/// a one-token budget. llama-server's prompt cache (on by default) keeps the +/// system prefix in KV so the first real message skips its prefill. +pub(crate) fn builtin_prime_body(model: &str, system_prompt: &str) -> serde_json::Value { + serde_json::json!({ + "model": model, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": "ok"} + ], + "max_tokens": 1, + "stream": false + }) +} + +/// Fires the built-in engine prime request at the serving port. Best-effort, +/// mirroring `run_warmup`'s error handling: every failure (transport or HTTP) +/// is silently ignored. Deliberately does NOT touch the engine's idle clock: +/// priming is app-summon activity, not user chat; if it touched, idle-unload +/// would never fire for a user who keeps summoning the overlay without +/// chatting. +pub(crate) async fn prime_builtin( + port: u16, + model: String, + system_prompt: String, + client: reqwest::Client, +) { + let body = builtin_prime_body(&model, &system_prompt); + let _ = client + .post(format!("http://127.0.0.1:{port}/v1/chat/completions")) + .json(&body) + .send() + .await; +} + +/// Built-in arm of `evict_model`: stops the engine sidecar and resolves once +/// the process exit is confirmed. The `warmup:model-evicted` emit stays in +/// the thin Tauri command because it needs an `AppHandle`. +pub(crate) async fn evict_builtin(engine: &crate::engine::runner::EngineHandle) { + engine.unload().await; +} + +/// Built-in arm of `get_loaded_model`: the provider's configured model id +/// when the engine status watch reports a loaded model, `None` otherwise +/// (including when no model has been picked yet). +pub(crate) fn builtin_loaded_model( + status: &crate::engine::runner::EngineStatus, + model_id: &str, +) -> Option { + if status.state == "loaded" && !model_id.is_empty() { + Some(model_id.to_string()) + } else { + None + } +} + impl Default for WarmupState { #[cfg_attr(coverage_nightly, coverage(off))] fn default() -> Self { @@ -151,34 +230,52 @@ pub fn warm_up_model( models: tauri::State, config: tauri::State>, client: tauri::State, + engine: tauri::State, ) { - let model = models.0.lock().ok().and_then(|g| g.clone()); - if let Some(model) = model { - let cfg = config.read(); - let endpoint = format!( - "{}/api/chat", - cfg.inference - .active_provider_base_url() - .trim_end_matches('/') - ); - let system_prompt = cfg.prompt.resolved_system.clone(); - let keep_alive = if cfg.inference.keep_warm_inactivity_minutes == 0 { - None - } else { - Some(keep_alive_string( - cfg.inference.keep_warm_inactivity_minutes, - )) - }; - let num_ctx = cfg.inference.num_ctx; - drop(cfg); - warmup.fire( - endpoint, - model, - system_prompt, - client.inner().clone(), - keep_alive, - num_ctx, - ); + let kind = config.read().inference.active_provider_kind().to_string(); + match kind.as_str() { + PROVIDER_KIND_OLLAMA => { + let model = models.0.lock().ok().and_then(|g| g.clone()); + if let Some(model) = model { + let cfg = config.read(); + let endpoint = format!( + "{}/api/chat", + cfg.inference + .active_provider_base_url() + .trim_end_matches('/') + ); + let system_prompt = cfg.prompt.resolved_system.clone(); + let keep_alive = if cfg.inference.keep_warm_inactivity_minutes == 0 { + None + } else { + Some(keep_alive_string( + cfg.inference.keep_warm_inactivity_minutes, + )) + }; + let num_ctx = cfg.inference.num_ctx; + drop(cfg); + warmup.fire( + endpoint, + model, + system_prompt, + client.inner().clone(), + keep_alive, + num_ctx, + ); + } + } + PROVIDER_KIND_BUILTIN => { + let status = engine.status().borrow().clone(); + if let Some(port) = builtin_prime_port(&status) { + let cfg = config.read(); + let model = cfg.inference.active_provider_model().to_string(); + let system_prompt = cfg.prompt.resolved_system.clone(); + drop(cfg); + let client = client.inner().clone(); + tauri::async_runtime::spawn(prime_builtin(port, model, system_prompt, client)); + } + } + _ => {} } } @@ -216,28 +313,43 @@ pub(crate) async fn get_loaded_model_request( Ok(if found { Some(model.to_string()) } else { None }) } -/// Returns the active model's name if it is currently loaded in Ollama's VRAM, -/// `None` if no model is selected or the selected model is not running. +/// Returns the active model's name if it is currently loaded, `None` if no +/// model is selected or nothing is running. Branches by the active provider's +/// kind: Ollama queries `/api/ps`, the built-in engine reads its own status +/// watch, and `openai` providers always report `None` (there is no local +/// memory to observe). #[tauri::command] #[cfg_attr(coverage_nightly, coverage(off))] pub async fn get_loaded_model( models: tauri::State<'_, crate::models::ActiveModelState>, config: tauri::State<'_, parking_lot::RwLock>, client: tauri::State<'_, reqwest::Client>, + engine: tauri::State<'_, crate::engine::runner::EngineHandle>, ) -> Result, String> { - let model = models.0.lock().ok().and_then(|g| g.clone()); - if let Some(model) = model { - let endpoint = format!( - "{}/api/ps", - config - .read() - .inference - .active_provider_base_url() - .trim_end_matches('/') - ); - get_loaded_model_request(&endpoint, &model, client.inner()).await - } else { - Ok(None) + let kind = config.read().inference.active_provider_kind().to_string(); + match kind.as_str() { + PROVIDER_KIND_BUILTIN => { + let model_id = config.read().inference.active_provider_model().to_string(); + let status = engine.status().borrow().clone(); + Ok(builtin_loaded_model(&status, &model_id)) + } + PROVIDER_KIND_OLLAMA => { + let model = models.0.lock().ok().and_then(|g| g.clone()); + if let Some(model) = model { + let endpoint = format!( + "{}/api/ps", + config + .read() + .inference + .active_provider_base_url() + .trim_end_matches('/') + ); + get_loaded_model_request(&endpoint, &model, client.inner()).await + } else { + Ok(None) + } + } + _ => Ok(None), } } @@ -264,11 +376,14 @@ pub(crate) async fn evict_model_request( .map_err(|e| e.to_string()) } -/// Unloads the active model from Ollama's VRAM immediately. +/// Unloads the active model from local memory immediately. Branches by the +/// active provider's kind: Ollama gets the `/api/generate keep_alive:"0"` +/// request, the built-in engine unloads its sidecar process, and `openai` +/// providers are a no-op (there is no local memory to release). /// -/// Delegates to `evict_model_request`; returns an error string on failure so -/// the frontend can react (e.g. reset the eject button state). Emits -/// `warmup:model-evicted` on success so the Settings panel updates live. +/// The Ollama arm delegates to `evict_model_request`; returns an error string +/// on failure so the frontend can react (e.g. reset the eject button state). +/// Emits `warmup:model-evicted` on success so the Settings panel updates live. #[tauri::command] #[cfg_attr(coverage_nightly, coverage(off))] pub async fn evict_model( @@ -277,22 +392,36 @@ pub async fn evict_model( models: tauri::State<'_, crate::models::ActiveModelState>, config: tauri::State<'_, parking_lot::RwLock>, client: tauri::State<'_, reqwest::Client>, + engine: tauri::State<'_, crate::engine::runner::EngineHandle>, ) -> Result<(), String> { - let model = models.0.lock().ok().and_then(|g| g.clone()); - if let Some(model) = model { - let endpoint = format!( - "{}/api/generate", - config - .read() - .inference - .active_provider_base_url() - .trim_end_matches('/') - ); - evict_model_request(&endpoint, &model, client.inner()).await?; - // Suppress any in-flight warmup callback so a slow warmup that - // completes after the eviction request does not re-announce the model. - warmup.mark_evicted(); - let _ = app_handle.emit("warmup:model-evicted", ()); + let kind = config.read().inference.active_provider_kind().to_string(); + match kind.as_str() { + PROVIDER_KIND_BUILTIN => { + // No mark_evicted() here: the WarmupState in-flight slot is only + // armed by fire(), which is never called for builtin providers. + // There is no Ollama-era warmup callback to suppress. + evict_builtin(&engine).await; + let _ = app_handle.emit("warmup:model-evicted", ()); + } + PROVIDER_KIND_OLLAMA => { + let model = models.0.lock().ok().and_then(|g| g.clone()); + if let Some(model) = model { + let endpoint = format!( + "{}/api/generate", + config + .read() + .inference + .active_provider_base_url() + .trim_end_matches('/') + ); + evict_model_request(&endpoint, &model, client.inner()).await?; + // Suppress any in-flight warmup callback so a slow warmup that + // completes after the eviction request does not re-announce the model. + warmup.mark_evicted(); + let _ = app_handle.emit("warmup:model-evicted", ()); + } + } + _ => {} } Ok(()) } @@ -316,6 +445,20 @@ pub fn spawn_vram_poller(app_handle: tauri::AppHandle) { loop { ticker.tick().await; + // The poller is Ollama-specific: skip the tick entirely (no HTTP + // call) while any other provider kind is active. `prev` is left + // untouched so a later switch back to Ollama resumes transition + // detection from the last observed Ollama state. + let kind = app_handle + .state::>() + .read() + .inference + .active_provider_kind() + .to_string(); + if !vram_poll_active(&kind) { + continue; + } + let model = app_handle .state::() .0 @@ -1249,4 +1392,145 @@ mod tests { "slot clears even when eviction suppresses the callback" ); } + + // ── Provider-kind branching ────────────────────────────────────────────── + + #[test] + fn vram_poller_tick_skips_non_ollama() { + assert!(vram_poll_active("ollama"), "ollama keeps polling /api/ps"); + assert!(!vram_poll_active("builtin"), "builtin must not hit Ollama"); + assert!(!vram_poll_active("openai"), "openai has no VRAM to observe"); + assert!(!vram_poll_active(""), "unresolved kind must not poll"); + } + + /// EngineStatus literal for the prime/loaded-model decision tests. + fn engine_status(state: &str, port: Option) -> crate::engine::runner::EngineStatus { + crate::engine::runner::EngineStatus { + state: state.to_string(), + model_path: String::new(), + port, + error: None, + } + } + + #[test] + fn prime_skipped_when_engine_not_loaded() { + assert_eq!(builtin_prime_port(&engine_status("stopped", None)), None); + assert_eq!(builtin_prime_port(&engine_status("starting", None)), None); + assert_eq!(builtin_prime_port(&engine_status("failed", None)), None); + assert_eq!( + builtin_prime_port(&engine_status("loaded", Some(40123))), + Some(40123) + ); + } + + #[tokio::test] + async fn builtin_prime_request_hits_v1_with_max_tokens_1() { + let mut server = Server::new_async().await; + let mock = server + .mock("POST", "/v1/chat/completions") + .match_body(mockito::Matcher::PartialJsonString( + r#"{"model":"org/repo:m.gguf","messages":[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"ok"}],"max_tokens":1,"stream":false}"#.to_string(), + )) + .with_status(200) + .with_body("{}") + .create_async() + .await; + + let port: u16 = server + .url() + .rsplit(':') + .next() + .unwrap() + .parse() + .expect("mockito url ends in a port"); + prime_builtin( + port, + "org/repo:m.gguf".to_string(), + SYS.to_string(), + reqwest::Client::new(), + ) + .await; + + mock.assert_async().await; + } + + #[test] + fn get_loaded_model_builtin_from_status() { + assert_eq!( + builtin_loaded_model(&engine_status("loaded", Some(40123)), "org/repo:m.gguf"), + Some("org/repo:m.gguf".to_string()) + ); + assert_eq!( + builtin_loaded_model(&engine_status("stopped", None), "org/repo:m.gguf"), + None + ); + assert_eq!( + builtin_loaded_model(&engine_status("loaded", Some(40123)), ""), + None, + "no picked model means nothing to report even while loaded" + ); + } + + // ── evict_builtin against a scripted engine ────────────────────────────── + + /// Minimal scriptable engine process: spawns instantly and answers every + /// health probe with 200, so `ensure_loaded` resolves without a real + /// llama-server. + struct InstantEngineProcess; + + struct InstantChild { + exit_tx: tokio::sync::watch::Sender, + exit_rx: tokio::sync::watch::Receiver, + } + + #[async_trait::async_trait] + impl crate::engine::process::EngineChild for InstantChild { + async fn wait_exit(&mut self) { + let _ = self.exit_rx.wait_for(|exited| *exited).await; + } + async fn kill(&mut self) { + let _ = self.exit_tx.send(true); + } + } + + #[async_trait::async_trait] + impl crate::engine::process::EngineProcess for InstantEngineProcess { + async fn spawn( + &self, + _args: &crate::engine::process::SpawnArgs, + ) -> Result, String> { + let (exit_tx, exit_rx) = tokio::sync::watch::channel(false); + Ok(Box::new(InstantChild { exit_tx, exit_rx })) + } + fn free_port(&self) -> Result { + Ok(40123) + } + async fn health_probe(&self, _port: u16) -> Result { + Ok(200) + } + } + + #[tokio::test] + async fn evict_on_builtin_calls_runner_unload() { + let engine = crate::engine::runner::EngineHandle::spawn( + Arc::new(InstantEngineProcess), + 0, + Duration::from_secs(3600), + ); + engine + .ensure_loaded(crate::engine::state::Target { + model_path: std::path::PathBuf::from("/tmp/m.gguf"), + mmproj_path: None, + num_ctx: DEFAULT_NUM_CTX, + }) + .await + .expect("scripted engine loads"); + assert_eq!(engine.status().borrow().state, "loaded"); + + evict_builtin(&engine).await; + + assert_eq!(engine.status().borrow().state, "stopped"); + engine.shutdown().await; + } } diff --git a/src-tauri/tests/search_pipeline_e2e.rs b/src-tauri/tests/search_pipeline_e2e.rs index 7ec19f4e..7c2f404e 100644 --- a/src-tauri/tests/search_pipeline_e2e.rs +++ b/src-tauri/tests/search_pipeline_e2e.rs @@ -16,7 +16,7 @@ use tokio_util::sync::CancellationToken; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; -use thuki_agent_lib::commands::ConversationHistory; +use thuki_agent_lib::commands::{ConversationHistory, LlmTransport}; use thuki_agent_lib::config::defaults::DEFAULT_NUM_CTX; use thuki_agent_lib::search::{ run_agentic, Action, JudgeCaller, JudgeSource, JudgeVerdict, RouterJudgeCaller, @@ -48,6 +48,14 @@ fn opt_trace_recorder(label: &str) -> Arc { // ── fixtures ────────────────────────────────────────────────────────────────── +/// Call-shape helper: wraps a bare `/api/chat` endpoint into the native +/// transport `run_agentic` now takes. +fn native(endpoint: impl Into) -> LlmTransport { + LlmTransport::OllamaNative { + endpoint: endpoint.into(), + } +} + /// Collects events emitted by the pipeline via a closure. fn collect_events() -> (Arc>>, impl Fn(SearchEvent)) { let events = Arc::new(Mutex::new(Vec::::new())); @@ -242,7 +250,7 @@ async fn happy_path_snippets_sufficient_streams_answer() { let judge = QueueJudge(Mutex::new(vec![verdict_sufficient()].into_iter().collect())); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx.url()), &reader_server.uri(), "m", @@ -379,7 +387,7 @@ async fn reader_escalation_with_chunks_sufficient() { )); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx.url()), &reader_server.uri(), "m", @@ -464,7 +472,7 @@ async fn reader_unavailable_degrades_to_snippets_and_warns() { // Deliberately pass a reader base URL that nothing is listening on. run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &format!("{}/search", searx.url()), "http://127.0.0.1:1", "m", @@ -567,7 +575,7 @@ async fn exhausted_gap_loop_warns_iteration_cap_and_streams_fallback() { let searx_endpoint = format!("{}/search", searx_server.uri()); run_agentic( - &format!("{}/api/chat", ollama.url()), + &native(format!("{}/api/chat", ollama.url())), &searx_endpoint, &reader_server.uri(), "m", @@ -663,7 +671,7 @@ async fn cancel_midloop_does_not_persist_and_emits_cancelled() { }); run_agentic( - "http://127.0.0.1:1/api/chat", + &native("http://127.0.0.1:1/api/chat"), &format!("{}/search", searx.url()), &reader_server.uri(), "m", diff --git a/src/components/ErrorCard.tsx b/src/components/ErrorCard.tsx index 29a7dd4d..9deba6ad 100644 --- a/src/components/ErrorCard.tsx +++ b/src/components/ErrorCard.tsx @@ -7,6 +7,8 @@ interface ErrorCardProps { const barColors: Record = { EngineUnreachable: '#ef4444', + // Same red as EngineUnreachable: a sidecar crash is equally severe. + EngineStartFailed: '#ef4444', ModelNotFound: '#f59e0b', // Same accent as ModelNotFound: this is a configuration/setup nudge, // not a daemon failure, so the warning hue (amber) is the right read. diff --git a/src/components/__tests__/ErrorCard.test.tsx b/src/components/__tests__/ErrorCard.test.tsx index 3b083f6f..bef3a993 100644 --- a/src/components/__tests__/ErrorCard.test.tsx +++ b/src/components/__tests__/ErrorCard.test.tsx @@ -40,6 +40,24 @@ describe('ErrorCard', () => { expect(bar?.getAttribute('data-kind')).toBe('EngineUnreachable'); }); + it('applies red accent bar for EngineStartFailed', () => { + const { container } = render( + , + ); + const bar = container.querySelector('[data-error-bar]'); + expect(bar).not.toBeNull(); + expect(bar?.getAttribute('data-kind')).toBe('EngineStartFailed'); + // JSDOM normalizes hex to rgb; assert the same red family as EngineUnreachable. + expect((bar as HTMLElement | null)?.style.background).toBe( + 'rgb(239, 68, 68)', + ); + }); + it('applies amber accent bar for ModelNotFound', () => { const { container } = render(