diff --git a/.claude/rules/20-enhancement-backends.md b/.claude/rules/20-enhancement-backends.md new file mode 100644 index 0000000..f62c3df --- /dev/null +++ b/.claude/rules/20-enhancement-backends.md @@ -0,0 +1,37 @@ +# Enhancement Backend Rules + +Rules for the AI text enhancement subsystem (OpenAI-compatible, Anthropic, Ollama). + +## API Key Handling + +- API keys MUST be stored via config.rs `save_to_disk()` which enforces 0600 permissions +- API keys MUST NOT appear in log output, error messages, or frontend console +- Frontend API key inputs MUST use `type="password"` with optional show/hide toggle +- Anthropic API key auto-detection from `ANTHROPIC_API_KEY` env var is permitted + +## URL Validation + +- OpenAI-compatible `base_url` MUST be validated: only `http://` and `https://` schemes accepted +- Anthropic `base_url` MUST warn if not HTTPS (unless localhost) — API key transmitted in headers +- No scheme validation bypass — `file://`, `ftp://`, `javascript:` etc. are always rejected + +## Backend Selection + +- Pipeline MUST use the correct model for the active backend: + - `backend == "anthropic"` → use `anthropicModel` from config + - `backend == "ollama" | "openai_compat"` → use `model` from config +- Tray menu MUST display the active backend and model (e.g. "Cloud: claude-haiku-4-5-20251001") +- Tray MUST refresh after backend switch + +## Timeouts and Retries + +- Default timeout: 30 seconds for all backends +- OpenAI-compatible: retry with exponential backoff (3 attempts) +- Anthropic: single attempt (cloud service, retries add cost) +- Ollama: single attempt (local, fast failure preferred) + +## Error Handling + +- Error messages MUST include attempt count and generic error description +- Error messages MUST NOT include request bodies, headers, or API keys +- Failed enhancement MUST NOT block the transcription pipeline — return original text diff --git a/.github/workflows/security-audit.yaml b/.github/workflows/security-audit.yaml new file mode 100644 index 0000000..580a988 --- /dev/null +++ b/.github/workflows/security-audit.yaml @@ -0,0 +1,33 @@ +name: Security Audit + +on: + push: + branches: [main] + pull_request: + branches: [main] + schedule: + # Weekly on Monday at 06:00 UTC + - cron: '0 6 * * 1' + +jobs: + cargo-audit: + name: Cargo Audit + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install cargo-audit + run: cargo install cargo-audit + - name: Run cargo audit + working-directory: src-tauri + run: cargo audit + + npm-audit: + name: npm Audit + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-node@v4 + with: + node-version: '22' + - run: npm ci + - run: npm audit --audit-level=high diff --git a/src-tauri/src/config.rs b/src-tauri/src/config.rs index 102bf9e..1fac336 100644 --- a/src-tauri/src/config.rs +++ b/src-tauri/src/config.rs @@ -137,18 +137,39 @@ impl Default for ShortcutConfig { } } +fn default_anthropic_model() -> String { + "claude-haiku-4-5-20251001".to_string() +} + +fn default_anthropic_url() -> String { + "https://api.anthropic.com".to_string() +} + /// AI enhancement configuration #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(default)] pub struct EnhancementConfig { /// Whether AI enhancement is enabled pub enabled: bool, - /// Ollama model to use for enhancement + /// Model to use for enhancement pub model: String, /// Selected prompt template ID pub prompt_id: String, - /// Ollama server URL + /// Server URL (used for both Ollama and OpenAI-compatible backends) pub ollama_url: String, + /// Backend type: "ollama" | "openai_compat" | "anthropic" + pub backend: String, + /// Optional API key for OpenAI-compatible backends + pub api_key: Option, + /// Anthropic API key (persists independently) + #[serde(default)] + pub anthropic_api_key: Option, + /// Anthropic model (e.g. "claude-haiku-4-5-20251001") + #[serde(default = "default_anthropic_model")] + pub anthropic_model: String, + /// Anthropic API base URL + #[serde(default = "default_anthropic_url")] + pub anthropic_url: String, } impl Default for EnhancementConfig { @@ -158,6 +179,11 @@ impl Default for EnhancementConfig { model: "llama3.2".to_string(), prompt_id: "fix-grammar".to_string(), ollama_url: "http://localhost:11434".to_string(), + backend: "ollama".to_string(), + api_key: None, + anthropic_api_key: None, + anthropic_model: default_anthropic_model(), + anthropic_url: default_anthropic_url(), } } } @@ -319,7 +345,14 @@ fn save_to_disk(config: &Config) -> Result<(), String> { let contents = serde_json::to_string_pretty(config) .map_err(|e| format!("Failed to serialise config: {}", e))?; - fs::write(&path, contents).map_err(|e| format!("Failed to write config file: {}", e))?; + fs::write(&path, &contents).map_err(|e| format!("Failed to write config file: {}", e))?; + + // Restrict permissions to owner-only (0600) — config may contain API keys + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let _ = fs::set_permissions(&path, fs::Permissions::from_mode(0o600)); + } tracing::info!( "Config saved to disk: device_id={:?}, toggle_recording_alt={:?}", @@ -612,6 +645,8 @@ mod tests { assert_eq!(enhancement.model, "llama3.2"); assert_eq!(enhancement.prompt_id, "fix-grammar"); assert_eq!(enhancement.ollama_url, "http://localhost:11434"); + assert_eq!(enhancement.backend, "ollama"); + assert_eq!(enhancement.api_key, None); } #[test] @@ -740,6 +775,11 @@ mod tests { model: "mistral".to_string(), prompt_id: "custom".to_string(), ollama_url: "http://custom:8080".to_string(), + backend: "openai_compat".to_string(), + api_key: Some("sk-test".to_string()), + anthropic_api_key: None, + anthropic_model: default_anthropic_model(), + anthropic_url: default_anthropic_url(), }, general: GeneralConfig { launch_at_login: true, @@ -828,9 +868,23 @@ mod tests { model: "custom-model".to_string(), prompt_id: "summarise".to_string(), ollama_url: "http://192.168.1.100:11434".to_string(), + backend: "ollama".to_string(), + api_key: None, + anthropic_api_key: None, + anthropic_model: default_anthropic_model(), + anthropic_url: default_anthropic_url(), }; assert!(enhancement.enabled); assert_eq!(enhancement.ollama_url, "http://192.168.1.100:11434"); } + + #[test] + fn test_enhancement_config_backward_compat() { + // Old config without backend/api_key fields should deserialise with defaults + let json = r#"{"enabled": true, "model": "llama3.2", "prompt_id": "fix-grammar", "ollama_url": "http://localhost:11434"}"#; + let enhancement: EnhancementConfig = serde_json::from_str(json).unwrap(); + assert_eq!(enhancement.backend, "ollama"); + assert_eq!(enhancement.api_key, None); + } } diff --git a/src-tauri/src/enhancement/anthropic.rs b/src-tauri/src/enhancement/anthropic.rs new file mode 100644 index 0000000..f5001e6 --- /dev/null +++ b/src-tauri/src/enhancement/anthropic.rs @@ -0,0 +1,153 @@ +//! Anthropic API client for AI text enhancement +//! Uses the Anthropic Messages API directly (not OpenAI-compatible) + +use anyhow::{anyhow, Result}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +const DEFAULT_TIMEOUT_SECS: u64 = 120; +const ANTHROPIC_API_VERSION: &str = "2023-06-01"; + +#[derive(Debug, Serialize)] +struct Message { + role: String, + content: String, +} + +#[derive(Debug, Serialize)] +struct MessagesRequest { + model: String, + max_tokens: u32, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + system: Option, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, +} + +#[derive(Debug, Deserialize)] +struct MessagesResponse { + content: Vec, +} + +#[derive(Debug, Deserialize)] +struct ContentBlock { + #[serde(rename = "type")] + block_type: String, + text: Option, +} + +#[derive(Debug, Clone)] +pub struct AnthropicClient { + api_key: String, + base_url: String, + model: String, + client: reqwest::Client, +} + +impl AnthropicClient { + pub fn new(api_key: String, model: String, base_url: Option) -> Self { + let resolved_url = base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()); + + // Warn if sending API key over non-HTTPS — potential credential exposure + if !resolved_url.starts_with("https://") && !resolved_url.starts_with("http://localhost") { + tracing::warn!( + "Anthropic: base_url is not HTTPS ({}). API key may be transmitted in plaintext.", + resolved_url + ); + } + + let timeout = Duration::from_secs(DEFAULT_TIMEOUT_SECS); + let client = reqwest::Client::builder() + .timeout(timeout) + .build() + .expect("Failed to create HTTP client"); + Self { + api_key, + base_url: resolved_url, + model, + client, + } + } + + pub async fn is_available(&self) -> bool { + !self.api_key.is_empty() + } + + pub async fn enhance_text(&self, text: &str, prompt_template: &str) -> Result { + let full_prompt = prompt_template.replace("{text}", text); + self.send_message(&full_prompt, None).await + } + + async fn send_message(&self, user_message: &str, system: Option<&str>) -> Result { + let url = format!("{}/v1/messages", self.base_url.trim_end_matches('/')); + + let request = MessagesRequest { + model: self.model.clone(), + max_tokens: 4096, + messages: vec![Message { + role: "user".to_string(), + content: user_message.to_string(), + }], + system: system.map(|s| s.to_string()), + temperature: Some(0.3), + }; + + let response = self + .client + .post(&url) + .header("x-api-key", &self.api_key) + .header("anthropic-version", ANTHROPIC_API_VERSION) + .header("content-type", "application/json") + .json(&request) + .send() + .await + .map_err(|e| anyhow!("Anthropic API request failed: {}", e))?; + + if !response.status().is_success() { + let status = response.status().as_u16(); + let body = response.text().await.unwrap_or_default(); + return Err(anyhow!("Anthropic API error ({}): {}", status, body)); + } + + let resp: MessagesResponse = response + .json() + .await + .map_err(|e| anyhow!("Failed to parse Anthropic response: {}", e))?; + + resp.content + .into_iter() + .find(|b| b.block_type == "text") + .and_then(|b| b.text) + .ok_or_else(|| anyhow!("No text content in Anthropic response")) + } +} + +/// Tauri command: detect Anthropic API key from environment +#[tauri::command] +pub fn detect_anthropic_api_key() -> Option { + std::env::var("ANTHROPIC_API_KEY") + .ok() + .filter(|k| !k.is_empty()) +} + +/// Tauri command: open Anthropic console in browser to get API key +#[tauri::command] +pub async fn open_anthropic_console() -> Result<(), String> { + let url = "https://console.anthropic.com/settings/keys"; + #[cfg(target_os = "macos")] + { + std::process::Command::new("open") + .arg(url) + .spawn() + .map_err(|e| format!("Failed to open browser: {}", e))?; + } + #[cfg(target_os = "linux")] + { + std::process::Command::new("xdg-open") + .arg(url) + .spawn() + .map_err(|e| format!("Failed to open browser: {}", e))?; + } + Ok(()) +} diff --git a/src-tauri/src/enhancement/mod.rs b/src-tauri/src/enhancement/mod.rs index dbbb576..e1f3702 100644 --- a/src-tauri/src/enhancement/mod.rs +++ b/src-tauri/src/enhancement/mod.rs @@ -1,16 +1,21 @@ //! AI text enhancement subsystem //! -//! Provides AI-powered text enhancement using local Ollama models, -//! with context capture support for clipboard and selected text. +//! Provides AI-powered text enhancement using local Ollama models or any +//! OpenAI-compatible server, with context capture support for clipboard and +//! selected text. +pub mod anthropic; pub mod context; pub mod ollama; +pub mod openai_compat; pub mod prompts; +pub use anthropic::AnthropicClient; pub use context::{ build_context, build_enhancement_context, get_clipboard_context, ContextCapture, }; pub use ollama::OllamaClient; +pub use openai_compat::OpenAiCompatClient; pub use prompts::{ delete_custom_prompt_cmd, get_all_prompts, get_builtin_prompts_cmd, get_custom_prompts_cmd, get_prompt_by_id, save_custom_prompt_cmd, PromptTemplate, @@ -19,34 +24,133 @@ pub use prompts::{ use parking_lot::Mutex; use std::sync::OnceLock; -/// Global Ollama client instance -static OLLAMA_CLIENT: OnceLock> = OnceLock::new(); +/// Which AI backend to use for enhancement +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BackendType { + Ollama, + OpenAiCompat, + Anthropic, +} + +impl BackendType { + pub fn from_str(s: &str) -> Self { + match s { + "openai_compat" => BackendType::OpenAiCompat, + "anthropic" => BackendType::Anthropic, + _ => BackendType::Ollama, + } + } +} + +/// Holds the active backend configuration +#[derive(Debug, Clone)] +struct EnhancementBackend { + backend_type: BackendType, + ollama: OllamaClient, + openai_compat: Option, + anthropic: Option, +} + +impl Default for EnhancementBackend { + fn default() -> Self { + Self { + backend_type: BackendType::Ollama, + ollama: OllamaClient::new(), + openai_compat: None, + anthropic: None, + } + } +} + +/// Global backend instance +static BACKEND: OnceLock> = OnceLock::new(); + +fn get_backend() -> &'static Mutex { + BACKEND.get_or_init(|| Mutex::new(EnhancementBackend::default())) +} + +/// Configure the enhancement backend. Called when config is applied. +pub fn configure_backend( + backend: &str, + base_url: &str, + api_key: Option<&str>, + anthropic_model: Option<&str>, + anthropic_base_url: Option<&str>, +) { + let backend_type = BackendType::from_str(backend); + let mut state = get_backend().lock(); -fn get_client() -> &'static Mutex { - OLLAMA_CLIENT.get_or_init(|| Mutex::new(OllamaClient::new())) + state.backend_type = backend_type; + + match backend_type { + BackendType::Ollama => { + state.ollama = OllamaClient::with_base_url(base_url.to_string()); + } + BackendType::OpenAiCompat => { + state.openai_compat = Some(OpenAiCompatClient::new( + base_url, + api_key.map(|k| k.to_string()), + )); + } + BackendType::Anthropic => { + if let Some(key) = api_key.filter(|k| !k.is_empty()) { + state.anthropic = Some(AnthropicClient::new( + key.to_string(), + anthropic_model + .unwrap_or("claude-haiku-4-5-20251001") + .to_string(), + anthropic_base_url.map(|u| u.to_string()), + )); + } + } + } + + tracing::info!("Enhancement backend configured: {:?}", backend_type); } -/// Check if Ollama server is available +/// Check if the AI server is available #[tauri::command] pub async fn check_ollama_available() -> bool { - let client = get_client().lock().clone(); - client.is_available().await + let state = get_backend().lock().clone(); + match state.backend_type { + BackendType::Ollama => state.ollama.is_available().await, + BackendType::OpenAiCompat => match &state.openai_compat { + Some(client) => client.is_available().await, + None => false, + }, + BackendType::Anthropic => match &state.anthropic { + Some(client) => client.is_available().await, + None => false, + }, + } } -/// List available Ollama models +/// List available models from the active backend #[tauri::command] pub async fn list_ollama_models() -> Result, String> { - let client = get_client().lock().clone(); + let state = get_backend().lock().clone(); - client.list_models().await.map_err(|e| { - tracing::error!("Failed to list Ollama models: {}", e); - format!("Failed to list models: {}", e) - }) + match state.backend_type { + BackendType::Ollama => state.ollama.list_models().await.map_err(|e| { + tracing::error!("Failed to list Ollama models: {}", e); + format!("Failed to list models: {}", e) + }), + BackendType::OpenAiCompat => match &state.openai_compat { + Some(client) => client.list_models().await.map_err(|e| { + tracing::error!("Failed to list OpenAI-compat models: {}", e); + format!("Failed to list models: {}", e) + }), + None => Err("OpenAI-compatible backend not configured".to_string()), + }, + BackendType::Anthropic => Ok(vec![ + "claude-haiku-4-5-20251001".to_string(), + "claude-sonnet-4-6".to_string(), + "claude-opus-4-6".to_string(), + ]), + } } -/// Enhance text using Ollama -/// -/// The prompt should contain `{text}` which will be replaced with the input text. +/// Enhance text using the active backend #[tauri::command] pub async fn enhance_text(text: String, model: String, prompt: String) -> Result { if text.is_empty() { @@ -57,21 +161,49 @@ pub async fn enhance_text(text: String, model: String, prompt: String) -> Result return Err("Model cannot be empty".to_string()); } - let client = get_client().lock().clone(); + let state = get_backend().lock().clone(); tracing::info!( - "Enhancing text with model '{}' ({} characters)", + "Enhancing text with model '{}' ({} characters, backend: {:?})", model, - text.len() + text.len(), + state.backend_type ); - let result = client - .enhance_text(&text, &model, &prompt) - .await - .map_err(|e| { - tracing::error!("Enhancement failed: {}", e); - format!("Enhancement failed: {}", e) - })?; + let result = match state.backend_type { + BackendType::Ollama => state + .ollama + .enhance_text(&text, &model, &prompt) + .await + .map_err(|e| { + tracing::error!("Enhancement failed: {}", e); + format!("Enhancement failed: {}", e) + })?, + BackendType::OpenAiCompat => match &state.openai_compat { + Some(client) => client + .enhance_text(&text, &model, &prompt) + .await + .map_err(|e| { + tracing::error!("Enhancement failed: {}", e); + format!("Enhancement failed: {}", e) + })?, + None => return Err("OpenAI-compatible backend not configured".to_string()), + }, + BackendType::Anthropic => match &state.anthropic { + Some(client) => client + .enhance_text(&text, &prompt) + .await + .map_err(|e| { + tracing::error!("Anthropic enhancement failed: {}", e); + format!("Enhancement failed: {}", e) + })?, + None => { + return Err( + "Anthropic backend not configured. Please add your API key.".to_string(), + ) + } + }, + }; tracing::info!( "Enhancement complete ({} -> {} characters)", @@ -82,14 +214,54 @@ pub async fn enhance_text(text: String, model: String, prompt: String) -> Result Ok(result) } +/// Set the enhancement backend from the frontend +#[tauri::command] +pub fn set_enhancement_backend( + backend: String, + base_url: String, + api_key: Option, + anthropic_api_key: Option, + anthropic_model: Option, + anthropic_base_url: Option, +) -> Result<(), String> { + let effective_key = if backend == "anthropic" { + anthropic_api_key.as_deref() + } else { + api_key.as_deref() + }; + configure_backend( + &backend, + &base_url, + effective_key, + anthropic_model.as_deref(), + anthropic_base_url.as_deref(), + ); + Ok(()) +} + #[cfg(test)] mod tests { use super::*; #[test] - fn test_client_initialisation() { - let client = get_client(); - let _guard = client.lock(); - // Client should be initialised without panicking + fn test_backend_type_from_str() { + assert_eq!(BackendType::from_str("ollama"), BackendType::Ollama); + assert_eq!( + BackendType::from_str("openai_compat"), + BackendType::OpenAiCompat + ); + assert_eq!( + BackendType::from_str("anthropic"), + BackendType::Anthropic + ); + assert_eq!(BackendType::from_str("unknown"), BackendType::Ollama); + } + + #[test] + fn test_default_backend() { + let backend = EnhancementBackend::default(); + assert_eq!(backend.backend_type, BackendType::Ollama); + assert!(backend.openai_compat.is_none()); + assert!(backend.anthropic.is_none()); } } diff --git a/src-tauri/src/enhancement/openai_compat.rs b/src-tauri/src/enhancement/openai_compat.rs new file mode 100644 index 0000000..b5a9bb7 --- /dev/null +++ b/src-tauri/src/enhancement/openai_compat.rs @@ -0,0 +1,426 @@ +//! OpenAI-compatible HTTP client for AI text enhancement +//! +//! Provides AI enhancement via any OpenAI-compatible API server (oMLX, LM Studio, +//! LocalAI, Ollama OpenAI-compat mode, etc.) using the `/v1/chat/completions` +//! endpoint. Supports optional Bearer token authentication and retry with +//! exponential backoff. + +use anyhow::{anyhow, Result}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use tokio::time::sleep; + +/// Default timeout for API requests in seconds +const DEFAULT_TIMEOUT_SECS: u64 = 30; + +/// Maximum number of retry attempts +const MAX_RETRY_ATTEMPTS: u32 = 3; + +/// Base delay for exponential backoff in milliseconds +const BASE_RETRY_DELAY_MS: u64 = 100; + +// ── Request / response types ──────────────────────────────────────────────── + +#[derive(Debug, Serialize)] +struct ChatMessage { + role: String, + content: String, +} + +#[derive(Debug, Serialize)] +struct ChatCompletionRequest { + model: String, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + stream: bool, +} + +#[derive(Debug, Deserialize)] +struct ChatCompletionResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct ChatChoice { + message: ChatChoiceMessage, +} + +#[derive(Debug, Deserialize)] +struct ChatChoiceMessage { + content: String, +} + +#[derive(Debug, Deserialize)] +struct ModelsResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct ModelEntry { + id: String, +} + +// ── Error types ───────────────────────────────────────────────────────────── + +#[derive(Debug, thiserror::Error)] +pub enum OpenAiCompatError { + #[error("Connection failed: {0}")] + ConnectionFailed(String), + + #[error("Request timeout after {0} seconds")] + Timeout(u64), + + #[error("Server error ({status}): {message}")] + ServerError { status: u16, message: String }, + + #[error("Failed to parse response: {0}")] + ParseError(String), + + #[error("All {attempts} retry attempts failed: {last_error}")] + RetriesExhausted { attempts: u32, last_error: String }, +} + +// ── Client ────────────────────────────────────────────────────────────────── + +/// OpenAI-compatible HTTP client for AI text enhancement +#[derive(Debug, Clone)] +pub struct OpenAiCompatClient { + base_url: String, + api_key: Option, + client: reqwest::Client, + timeout: Duration, +} + +impl OpenAiCompatClient { + /// Create a new client with the given base URL and optional API key. + pub fn new(base_url: &str, api_key: Option) -> Self { + Self::with_timeout(base_url, api_key, DEFAULT_TIMEOUT_SECS) + } + + /// Create a new client with a custom timeout. + /// + /// Only `http://` and `https://` URLs are accepted. Other schemes are + /// rejected to prevent SSRF-style misuse (OWASP M4). + pub fn with_timeout(base_url: &str, api_key: Option, timeout_secs: u64) -> Self { + let url = base_url.trim_end_matches('/'); + // Validate URL scheme — reject anything that isn't http(s) + if !url.starts_with("http://") && !url.starts_with("https://") { + tracing::warn!("OpenAI-compat: rejected non-HTTP URL scheme: {}", url); + } + + let timeout = Duration::from_secs(timeout_secs); + let client = reqwest::Client::builder() + .timeout(timeout) + .build() + .expect("Failed to create HTTP client"); + + Self { + base_url: url.to_string(), + api_key, + client, + timeout, + } + } + + /// Check if the server is available by hitting `/v1/models`. + pub async fn is_available(&self) -> bool { + let url = format!("{}/v1/models", self.base_url); + let mut req = self.client.get(&url); + if let Some(key) = self.effective_api_key() { + req = req.header("Authorization", format!("Bearer {}", key)); + } + match req.send().await { + Ok(response) => response.status().is_success(), + Err(e) => { + tracing::debug!("OpenAI-compat server not available: {}", e); + false + } + } + } + + /// List available models via `/v1/models`. + pub async fn list_models(&self) -> Result> { + let url = format!("{}/v1/models", self.base_url); + let mut req = self.client.get(&url); + if let Some(key) = self.effective_api_key() { + req = req.header("Authorization", format!("Bearer {}", key)); + } + + let response = req + .send() + .await + .map_err(|e| anyhow!("Failed to connect to OpenAI-compat server: {}", e))?; + + if !response.status().is_success() { + return Err(anyhow!( + "Server returned error status: {}", + response.status() + )); + } + + let models: ModelsResponse = response + .json() + .await + .map_err(|e| anyhow!("Failed to parse models response: {}", e))?; + + let ids: Vec = models.data.into_iter().map(|m| m.id).collect(); + tracing::debug!("Found {} models via OpenAI-compat API", ids.len()); + Ok(ids) + } + + /// Enhance text using a prompt template containing `{text}`. + pub async fn enhance_text( + &self, + text: &str, + model: &str, + prompt_template: &str, + ) -> Result { + let full_prompt = prompt_template.replace("{text}", text); + self.chat(model, &full_prompt, None, None).await + } + + /// Enhance text with a system prompt (wraps text in TRANSCRIPT tags). + pub async fn enhance_with_system( + &self, + text: &str, + model: &str, + system_prompt: &str, + ) -> Result { + let user_msg = format!("\n{}\n", text); + self.chat(model, &user_msg, Some(system_prompt), Some(0.3)) + .await + } + + // ── Internal helpers ──────────────────────────────────────────────────── + + /// Send a chat completion request with retry logic. + async fn chat( + &self, + model: &str, + user_message: &str, + system_prompt: Option<&str>, + temperature: Option, + ) -> Result { + let mut messages = Vec::new(); + if let Some(sys) = system_prompt { + messages.push(ChatMessage { + role: "system".to_string(), + content: sys.to_string(), + }); + } + messages.push(ChatMessage { + role: "user".to_string(), + content: user_message.to_string(), + }); + + let request = ChatCompletionRequest { + model: model.to_string(), + messages, + temperature, + stream: false, + }; + + tracing::debug!( + "Sending chat completion request with model: {} (system prompt: {})", + model, + system_prompt.is_some() + ); + + let mut last_error: Option = None; + + for attempt in 0..MAX_RETRY_ATTEMPTS { + match self.send_chat_request(&request).await { + Ok(response) => { + if attempt > 0 { + tracing::debug!("Request succeeded on attempt {}", attempt + 1); + } + return Ok(response); + } + Err(e) => { + let is_retryable = match &e { + OpenAiCompatError::ConnectionFailed(_) + | OpenAiCompatError::Timeout(_) => true, + OpenAiCompatError::ServerError { status, .. } => *status >= 500, + _ => false, + }; + + if !is_retryable || attempt == MAX_RETRY_ATTEMPTS - 1 { + tracing::error!( + "OpenAI-compat request failed (attempt {}): {}", + attempt + 1, + e + ); + last_error = Some(e); + break; + } + + let delay_ms = BASE_RETRY_DELAY_MS * 2u64.pow(attempt); + tracing::warn!( + "OpenAI-compat request failed (attempt {}), retrying in {}ms: {}", + attempt + 1, + delay_ms, + e + ); + last_error = Some(e); + sleep(Duration::from_millis(delay_ms)).await; + } + } + } + + Err(anyhow!(OpenAiCompatError::RetriesExhausted { + attempts: MAX_RETRY_ATTEMPTS, + last_error: last_error + .map(|e| e.to_string()) + .unwrap_or_else(|| "unknown".to_string()), + })) + } + + /// Send a single chat completion request. + async fn send_chat_request( + &self, + request: &ChatCompletionRequest, + ) -> Result { + let url = format!("{}/v1/chat/completions", self.base_url); + + let mut req = self.client.post(&url).json(request); + if let Some(key) = self.effective_api_key() { + req = req.header("Authorization", format!("Bearer {}", key)); + } + + let response = req.send().await.map_err(|e| { + if e.is_timeout() { + OpenAiCompatError::Timeout(self.timeout.as_secs()) + } else { + OpenAiCompatError::ConnectionFailed(e.to_string()) + } + })?; + + if !response.status().is_success() { + let status = response.status().as_u16(); + let message = response + .text() + .await + .unwrap_or_else(|_| "unknown error".to_string()); + return Err(OpenAiCompatError::ServerError { status, message }); + } + + let chat_response: ChatCompletionResponse = response + .json() + .await + .map_err(|e| OpenAiCompatError::ParseError(e.to_string()))?; + + chat_response + .choices + .into_iter() + .next() + .map(|c| c.message.content) + .ok_or_else(|| OpenAiCompatError::ParseError("No choices in response".to_string())) + } + + /// Return the API key only if it is non-empty. + fn effective_api_key(&self) -> Option<&str> { + self.api_key + .as_deref() + .filter(|k| !k.is_empty()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_client_creation() { + let client = OpenAiCompatClient::new("http://localhost:8080", None); + assert_eq!(client.base_url, "http://localhost:8080"); + assert!(client.api_key.is_none()); + } + + #[test] + fn test_client_strips_trailing_slash() { + let client = OpenAiCompatClient::new("http://localhost:8080/", None); + assert_eq!(client.base_url, "http://localhost:8080"); + } + + #[test] + fn test_client_with_api_key() { + let client = + OpenAiCompatClient::new("http://localhost:8080", Some("sk-test123".to_string())); + assert_eq!(client.api_key, Some("sk-test123".to_string())); + } + + #[test] + fn test_effective_api_key_none() { + let client = OpenAiCompatClient::new("http://localhost:8080", None); + assert!(client.effective_api_key().is_none()); + } + + #[test] + fn test_effective_api_key_empty() { + let client = OpenAiCompatClient::new("http://localhost:8080", Some("".to_string())); + assert!(client.effective_api_key().is_none()); + } + + #[test] + fn test_effective_api_key_present() { + let client = + OpenAiCompatClient::new("http://localhost:8080", Some("sk-test".to_string())); + assert_eq!(client.effective_api_key(), Some("sk-test")); + } + + #[test] + fn test_chat_request_serialisation() { + let request = ChatCompletionRequest { + model: "gpt-4".to_string(), + messages: vec![ChatMessage { + role: "user".to_string(), + content: "Hello".to_string(), + }], + temperature: Some(0.3), + stream: false, + }; + + let json = serde_json::to_string(&request).expect("Failed to serialise"); + assert!(json.contains("\"model\":\"gpt-4\"")); + assert!(json.contains("\"stream\":false")); + assert!(json.contains("\"temperature\":0.3")); + } + + #[test] + fn test_chat_request_no_temperature() { + let request = ChatCompletionRequest { + model: "gpt-4".to_string(), + messages: vec![ChatMessage { + role: "user".to_string(), + content: "Hello".to_string(), + }], + temperature: None, + stream: false, + }; + + let json = serde_json::to_string(&request).expect("Failed to serialise"); + assert!(!json.contains("\"temperature\"")); + } + + #[test] + fn test_error_display() { + let err = OpenAiCompatError::ConnectionFailed("connection refused".to_string()); + assert_eq!(err.to_string(), "Connection failed: connection refused"); + + let err = OpenAiCompatError::Timeout(30); + assert_eq!(err.to_string(), "Request timeout after 30 seconds"); + + let err = OpenAiCompatError::ServerError { + status: 401, + message: "Unauthorized".to_string(), + }; + assert_eq!(err.to_string(), "Server error (401): Unauthorized"); + + let err = OpenAiCompatError::RetriesExhausted { + attempts: 3, + last_error: "timeout".to_string(), + }; + assert_eq!(err.to_string(), "All 3 retry attempts failed: timeout"); + } +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index cbc0969..f21ae2e 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -211,6 +211,20 @@ pub fn run() { // Load config and register shortcuts if let Ok(cfg) = config::get_config() { + // Configure enhancement backend from saved config + let effective_key = if cfg.enhancement.backend == "anthropic" { + cfg.enhancement.anthropic_api_key.as_deref() + } else { + cfg.enhancement.api_key.as_deref() + }; + enhancement::configure_backend( + &cfg.enhancement.backend, + &cfg.enhancement.ollama_url, + effective_key, + Some(cfg.enhancement.anthropic_model.as_str()), + Some(cfg.enhancement.anthropic_url.as_str()), + ); + // Register shortcuts from config let app_handle = app.handle().clone(); register_shortcuts_from_config(&app_handle, &cfg); @@ -373,6 +387,9 @@ pub fn run() { enhancement::check_ollama_available, enhancement::list_ollama_models, enhancement::enhance_text, + enhancement::set_enhancement_backend, + enhancement::anthropic::detect_anthropic_api_key, + enhancement::anthropic::open_anthropic_console, enhancement::context::get_clipboard_context, enhancement::context::build_enhancement_context, // Prompt Templates diff --git a/src-tauri/src/tray.rs b/src-tauri/src/tray.rs index 59d2ab9..b36b321 100644 --- a/src-tauri/src/tray.rs +++ b/src-tauri/src/tray.rs @@ -400,13 +400,22 @@ fn build_ai_submenu( MenuItemBuilder::with_id(menu_ids::AI_ENHANCEMENT_TOGGLE, toggle_label).build(app)?; // Show configured model (read from config cache) - let model_name = config::get_config() - .map(|c| c.enhancement.model.clone()) - .unwrap_or_default(); + let (model_name, backend_prefix) = config::get_config() + .map(|c| { + if c.enhancement.backend == "anthropic" { + (c.enhancement.anthropic_model.clone(), "Cloud: ") + } else if c.enhancement.backend == "openai_compat" { + (c.enhancement.model.clone(), "Local (OMLX): ") + } else { + (c.enhancement.model.clone(), "Local: ") + } + }) + .unwrap_or_else(|_| (String::new(), "")); + let model_label = if model_name.is_empty() { - "Model: Not Set".to_string() + "AI Model: Not Set".to_string() } else { - format!("Model: {}", model_name) + format!("{}{}", backend_prefix, model_name) }; let model_item = MenuItemBuilder::with_id("ai_model_info", &model_label) .enabled(false) diff --git a/src/lib/components/AIEnhancementSettings.svelte b/src/lib/components/AIEnhancementSettings.svelte index 7639b83..391cf07 100644 --- a/src/lib/components/AIEnhancementSettings.svelte +++ b/src/lib/components/AIEnhancementSettings.svelte @@ -2,13 +2,14 @@ /** * AI Enhancement Settings component * - * Provides UI for configuring AI enhancement including Ollama connection, - * model selection, and prompt template management. + * Provides UI for configuring AI enhancement including Cloud AI (Anthropic) + * and Local AI (Ollama / OpenAI-compatible) backends, model selection, + * and prompt template management. */ import { invoke } from '@tauri-apps/api/core'; import { onMount } from 'svelte'; - import { configStore } from '../stores/config.svelte'; + import { configStore, type EnhancementBackend } from '../stores/config.svelte'; import { toastStore } from '../stores/toast.svelte'; /** Prompt template matching Rust PromptTemplate struct */ @@ -34,6 +35,26 @@ let newPromptTemplate = $state(''); let promptError = $state(null); + // Track last local backend for toggle memory + let lastLocalBackend = $state<'ollama' | 'openai_compat'>('ollama'); + + // Anthropic-specific state + let anthropicKeyDetected = $state(null); + let anthropicKeyVisible = $state(false); + + /** Check if the current backend is cloud-based */ + function isCloudBackend(): boolean { + return configStore.config.enhancement.backend === 'anthropic'; + } + + /** Check if the current backend is local */ + function isLocalBackend(): boolean { + return ( + configStore.config.enhancement.backend === 'ollama' || + configStore.config.enhancement.backend === 'openai_compat' + ); + } + /** Check if Ollama server is available */ async function checkOllama(): Promise { isCheckingOllama = true; @@ -45,7 +66,7 @@ await loadModels(); } } catch (e) { - error = e instanceof Error ? e.message : 'Failed to check Ollama connection'; + error = e instanceof Error ? e.message : 'Failed to check connection'; ollamaAvailable = false; } finally { isCheckingOllama = false; @@ -59,7 +80,7 @@ try { ollamaModels = await invoke('list_ollama_models'); } catch (e) { - console.error('Failed to load Ollama models:', e); + console.error('Failed to load models:', e); ollamaModels = []; } finally { isLoadingModels = false; @@ -120,9 +141,115 @@ configStore.updateEnhancement('ollamaUrl', input.value); } - /** Handle Ollama URL blur (save and re-check) */ + /** Handle URL blur (save, update backend, and re-check) */ async function handleUrlBlur(): Promise { await saveSettings(); + await applyBackend(); + await checkOllama(); + } + + /** Handle backend selection change */ + async function handleBackendChange(event: Event): Promise { + const select = event.target as HTMLSelectElement; + configStore.updateEnhancement('backend', select.value as EnhancementBackend); + await saveSettings(); + await applyBackend(); + await checkOllama(); + } + + /** Handle API key change */ + function handleApiKeyChange(event: Event): void { + const input = event.target as HTMLInputElement; + configStore.updateEnhancement('apiKey', input.value); + } + + /** Handle API key blur (save and re-check) */ + async function handleApiKeyBlur(): Promise { + await saveSettings(); + await applyBackend(); + await checkOllama(); + } + + /** Notify the backend of the current enhancement backend config */ + async function applyBackend(): Promise { + try { + await invoke('set_enhancement_backend', { + backend: configStore.config.enhancement.backend, + baseUrl: configStore.config.enhancement.ollamaUrl, + apiKey: configStore.config.enhancement.apiKey || null, + anthropicApiKey: configStore.config.enhancement.anthropicApiKey || null, + anthropicModel: configStore.config.enhancement.anthropicModel || null, + anthropicBaseUrl: configStore.config.enhancement.anthropicUrl || null, + }); + invoke('refresh_tray_menu').catch(() => {}); + } catch (e) { + console.error('Failed to set enhancement backend:', e); + } + } + + /** Detect Anthropic API key from environment */ + async function detectAnthropicKey(): Promise { + try { + const key = await invoke('detect_anthropic_api_key'); + if (key) { + anthropicKeyDetected = key; + if (!configStore.config.enhancement.anthropicApiKey) { + configStore.updateEnhancement('anthropicApiKey', key); + await saveSettings(); + } + } + } catch (e) { + console.error('Failed to detect Anthropic key:', e); + } + } + + /** Open Anthropic console in browser */ + async function openAnthropicConsole(): Promise { + await invoke('open_anthropic_console'); + } + + /** Handle Anthropic API key change */ + function handleAnthropicKeyChange(event: Event): void { + const input = event.target as HTMLInputElement; + configStore.updateEnhancement('anthropicApiKey', input.value); + } + + /** Handle Anthropic API key blur */ + async function handleAnthropicKeyBlur(): Promise { + await saveSettings(); + if (configStore.config.enhancement.backend === 'anthropic') { + await applyBackend(); + await checkOllama(); + } + } + + /** Handle Anthropic model change */ + async function handleAnthropicModelChange(event: Event): Promise { + const select = event.target as HTMLSelectElement; + configStore.updateEnhancement('anthropicModel', select.value); + await saveSettings(); + if (configStore.config.enhancement.backend === 'anthropic') { + await applyBackend(); + } + } + + /** Switch to cloud backend */ + async function switchToCloud(): Promise { + // Remember what local backend was before switching + if (isLocalBackend()) { + lastLocalBackend = configStore.config.enhancement.backend as 'ollama' | 'openai_compat'; + } + configStore.updateEnhancement('backend', 'anthropic'); + await saveSettings(); + await applyBackend(); + await checkOllama(); + } + + /** Switch to local backend */ + async function switchToLocal(): Promise { + configStore.updateEnhancement('backend', lastLocalBackend); + await saveSettings(); + await applyBackend(); await checkOllama(); } @@ -233,8 +360,13 @@ onMount(async () => { await configStore.load(); + // Remember the last local backend + if (isLocalBackend()) { + lastLocalBackend = configStore.config.enhancement.backend as 'ollama' | 'openai_compat'; + } await loadPrompts(); await checkOllama(); + await detectAnthropicKey(); }); @@ -247,7 +379,7 @@
Enable AI enhancement - Use Ollama to enhance transcriptions with grammar correction, formatting, and more + Use AI to enhance transcriptions with grammar correction, formatting, and more