diff --git a/openless-all/app/src-tauri/Cargo.lock b/openless-all/app/src-tauri/Cargo.lock index 9e4ca497..d55399f2 100644 --- a/openless-all/app/src-tauri/Cargo.lock +++ b/openless-all/app/src-tauri/Cargo.lock @@ -4175,12 +4175,14 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-rustls", + "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams 0.4.2", "web-sys", "webpki-roots", ] @@ -4220,7 +4222,7 @@ dependencies = [ "url", "wasm-bindgen", "wasm-bindgen-futures", - "wasm-streams", + "wasm-streams 0.5.0", "web-sys", ] @@ -6147,6 +6149,19 @@ dependencies = [ "wasmparser", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "wasm-streams" version = "0.5.0" diff --git a/openless-all/app/src-tauri/Cargo.toml b/openless-all/app/src-tauri/Cargo.toml index 5983378e..537ba000 100644 --- a/openless-all/app/src-tauri/Cargo.toml +++ b/openless-all/app/src-tauri/Cargo.toml @@ -24,7 +24,7 @@ serde_json = "1" tokio = { version = "1", features = ["full"] } tokio-tungstenite = { version = "0.24", features = ["rustls-tls-native-roots"] } futures-util = "0.3" -reqwest = { version = "0.12", default-features = false, features = ["json", "multipart", "rustls-tls"] } +reqwest = { version = "0.12", default-features = false, features = ["json", "multipart", "rustls-tls", "stream"] } thiserror = "1" anyhow = "1" log = "0.4" diff --git a/openless-all/app/src-tauri/src/commands.rs b/openless-all/app/src-tauri/src/commands.rs index 9aaa6976..3fef6292 100644 --- a/openless-all/app/src-tauri/src/commands.rs +++ b/openless-all/app/src-tauri/src/commands.rs @@ -155,12 +155,9 @@ pub async fn validate_provider_credentials(kind: String) -> Result validate_llm_provider() .await .map(|()| ProviderCheckResult { ok: true }), - "asr" => { - let config = read_openai_provider_config(&kind)?; - fetch_provider_models(&config) - .await - .map(|_| ProviderCheckResult { ok: true }) - } + "asr" => validate_asr_provider() + .await + .map(|()| ProviderCheckResult { ok: true }), _ => Err(format!("unknown provider kind: {kind}")), } } @@ -232,6 +229,123 @@ async fn validate_llm_provider() -> Result<(), String> { }) } +async fn validate_asr_provider() -> Result<(), String> { + let config = read_openai_provider_config("asr")?; + let model = CredentialsVault::get(CredentialAccount::AsrModel) + .map_err(|e| e.to_string())? + .filter(|s| !s.trim().is_empty()) + .ok_or_else(|| "asrModelMissing".to_string())?; + validate_asr_transcription(&config, model.trim()).await +} + +async fn validate_asr_transcription(config: &ProviderConfig, model: &str) -> Result<(), String> { + const MAX_ASR_VALIDATE_BODY_BYTES: usize = 1024 * 1024; + let url = asr_transcriptions_url(&config.base_url)?; + let wav = encode_wav_16k_mono_silence(250); + let wav_part = reqwest::multipart::Part::bytes(wav) + .file_name("openless-asr-check.wav") + .mime_str("audio/wav") + .map_err(|e| format!("请求体构建失败: {e}"))?; + let form = reqwest::multipart::Form::new() + .part("file", wav_part) + .text("model", model.to_string()); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(20)) + .build() + .map_err(|_| "providerClientInitFailed".to_string())?; + let response = client + .post(&url) + .header("Authorization", format!("Bearer {}", config.api_key)) + .multipart(form) + .send() + .await + .map_err(|e| { + if e.is_timeout() { + "providerRequestTimeout".to_string() + } else { + "providerNetworkError".to_string() + } + })?; + let status = response.status(); + if !status.is_success() { + return Err(format!("providerHttpStatus:{}", status.as_u16())); + } + if let Some(len) = response.content_length() { + if len as usize > MAX_ASR_VALIDATE_BODY_BYTES { + return Err("providerResponseTooLarge".to_string()); + } + } + use futures_util::StreamExt; + let mut body = Vec::::new(); + let mut stream = response.bytes_stream(); + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(|_| "providerReadResponseFailed".to_string())?; + if body.len().saturating_add(chunk.len()) > MAX_ASR_VALIDATE_BODY_BYTES { + return Err("providerResponseTooLarge".to_string()); + } + body.extend_from_slice(&chunk); + } + let json: Value = serde_json::from_slice(&body).map_err(|_| "asrInvalidJson".to_string())?; + if !json.is_object() || json.get("text").is_none() { + return Err("asrMissingTextField".to_string()); + } + Ok(()) +} + +fn asr_transcriptions_url(base_url: &str) -> Result { + let parsed = reqwest::Url::parse(base_url.trim()).map_err(|_| "endpointInvalid".to_string())?; + let host = parsed.host_str().unwrap_or_default(); + let localhost = host.eq_ignore_ascii_case("localhost") || host == "127.0.0.1"; + if parsed.scheme() != "https" && !localhost { + return Err("endpointMustUseHttps".to_string()); + } + + // Work on the URL path only so we don't corrupt query parameters. + let mut url = parsed.clone(); + let path = parsed.path().trim_end_matches('/'); + let next_path = if path.ends_with("/audio/transcriptions") { + path.to_string() + } else if path.ends_with("/audio") { + format!("{path}/transcriptions") + } else if let Some(prefix) = path.strip_suffix("/chat/completions") { + format!("{prefix}/audio/transcriptions") + } else { + format!("{path}/audio/transcriptions") + }; + url.set_path(&next_path); + Ok(url.to_string()) +} + +fn encode_wav_16k_mono_silence(duration_ms: u32) -> Vec { + let sample_rate: u32 = 16_000; + let num_channels: u16 = 1; + let bits_per_sample: u16 = 16; + let bytes_per_sample = (bits_per_sample / 8) as usize; + let samples = (sample_rate as usize * duration_ms as usize) / 1000; + let pcm_len = samples * bytes_per_sample; + let data_size = pcm_len as u32; + let byte_rate = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8; + let block_align = num_channels * bits_per_sample / 8; + let chunk_size = 36 + data_size; + + let mut wav = Vec::with_capacity(44 + pcm_len); + wav.extend_from_slice(b"RIFF"); + wav.extend_from_slice(&chunk_size.to_le_bytes()); + wav.extend_from_slice(b"WAVE"); + wav.extend_from_slice(b"fmt "); + wav.extend_from_slice(&16u32.to_le_bytes()); + wav.extend_from_slice(&1u16.to_le_bytes()); + wav.extend_from_slice(&num_channels.to_le_bytes()); + wav.extend_from_slice(&sample_rate.to_le_bytes()); + wav.extend_from_slice(&byte_rate.to_le_bytes()); + wav.extend_from_slice(&block_align.to_le_bytes()); + wav.extend_from_slice(&bits_per_sample.to_le_bytes()); + wav.extend_from_slice(b"data"); + wav.extend_from_slice(&data_size.to_le_bytes()); + wav.resize(44 + pcm_len, 0); + wav +} + async fn fetch_provider_models(config: &ProviderConfig) -> Result, String> { let url = models_url(&config.base_url); log::info!("[provider-check] GET {url}"); @@ -582,8 +696,8 @@ fn _ensure_snapshot_used(_: CredentialsSnapshot) {} #[cfg(test)] mod tests { use super::{ - fetch_provider_models, models_url, parse_model_ids, persist_settings, ProviderConfig, - SettingsWriter, + asr_transcriptions_url, fetch_provider_models, models_url, parse_model_ids, + persist_settings, ProviderConfig, SettingsWriter, }; use crate::types::{ HotkeyBinding, HotkeyMode, HotkeyTrigger, QaHotkeyBinding, UserPreferences, @@ -627,6 +741,30 @@ mod tests { ); } + #[test] + fn asr_transcriptions_url_accepts_base_or_transcriptions_endpoint() { + assert_eq!( + asr_transcriptions_url("https://api.openai.com/v1").unwrap(), + "https://api.openai.com/v1/audio/transcriptions" + ); + assert_eq!( + asr_transcriptions_url("https://api.openai.com/v1/chat/completions").unwrap(), + "https://api.openai.com/v1/audio/transcriptions" + ); + assert_eq!( + asr_transcriptions_url("https://api.openai.com/v1/audio").unwrap(), + "https://api.openai.com/v1/audio/transcriptions" + ); + assert_eq!( + asr_transcriptions_url("https://api.openai.com/v1/audio/transcriptions").unwrap(), + "https://api.openai.com/v1/audio/transcriptions" + ); + assert_eq!( + asr_transcriptions_url("https://api.openai.com/v1?api-version=2024-12-01").unwrap(), + "https://api.openai.com/v1/audio/transcriptions?api-version=2024-12-01" + ); + } + #[test] fn parse_model_ids_sorts_and_deduplicates() { let models = diff --git a/openless-all/app/src/i18n/en.ts b/openless-all/app/src/i18n/en.ts index dda72c28..069b6f52 100644 --- a/openless-all/app/src/i18n/en.ts +++ b/openless-all/app/src/i18n/en.ts @@ -328,6 +328,11 @@ export const en: typeof zhCN = { modelSaved: 'Saved model {{model}}.', validateSuccess: 'Connection check passed.', providerHttpStatus: 'Provider returned HTTP {{status}}. Check the API key permissions or endpoint.', + endpointMustUseHttps: 'Endpoint must use HTTPS (localhost/127.0.0.1 are allowed for local testing).', + endpointInvalid: 'Endpoint format is invalid.', + responseTooLarge: 'Provider response is too large to validate safely.', + asrInvalidJson: 'ASR response is not valid JSON.', + asrMissingTextField: 'ASR response is missing the text field.', apiKeyMissing: 'API Key is empty.', endpointMissing: 'Endpoint is empty.', requestTimeout: 'Request timed out. Try again later.', diff --git a/openless-all/app/src/i18n/zh-CN.ts b/openless-all/app/src/i18n/zh-CN.ts index c5bd4220..d156ef1c 100644 --- a/openless-all/app/src/i18n/zh-CN.ts +++ b/openless-all/app/src/i18n/zh-CN.ts @@ -326,6 +326,11 @@ export const zhCN = { modelSaved: '已保存模型 {{model}}。', validateSuccess: '连接检查通过。', providerHttpStatus: '供应商接口返回 {{status}},请检查 API Key 权限或 Endpoint。', + endpointMustUseHttps: 'Endpoint 必须使用 HTTPS(本地 localhost/127.0.0.1 测试除外)。', + endpointInvalid: 'Endpoint 格式不合法。', + responseTooLarge: '供应商响应过大,已停止验证以保证安全。', + asrInvalidJson: 'ASR 响应不是有效 JSON。', + asrMissingTextField: 'ASR 响应缺少 text 字段。', apiKeyMissing: 'API Key 为空。', endpointMissing: 'Endpoint 为空。', requestTimeout: '请求超时,请稍后重试。', diff --git a/openless-all/app/src/pages/Settings.tsx b/openless-all/app/src/pages/Settings.tsx index 88cf7d0a..71c1abfd 100644 --- a/openless-all/app/src/pages/Settings.tsx +++ b/openless-all/app/src/pages/Settings.tsx @@ -593,10 +593,14 @@ function ProviderTools({ kind, modelAccount, onModelSelected }: { kind: 'llm' | setResult(result.ok ? 'success' : 'error', t('settings.providers.validateSuccess')); } catch (error) { const message = error instanceof Error ? error.message : String(error); - if (kind === 'llm' && message === 'llmModelMissing') { + if ((kind === 'llm' && message === 'llmModelMissing') || (kind === 'asr' && message === 'asrModelMissing')) { setResult('empty', t('settings.providers.modelMissing')); return; } + if (message === 'modelsEmpty') { + setResult('empty', t('settings.providers.modelsEmpty')); + return; + } setResult('error', providerErrorMessage(error, t)); } }; @@ -664,6 +668,14 @@ function providerErrorMessage(error: unknown, t: ReturnType