Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion openless-all/app/src-tauri/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion openless-all/app/src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ serde_json = "1"
tokio = { version = "1", features = ["full"] }
tokio-tungstenite = { version = "0.24", features = ["rustls-tls-native-roots"] }
futures-util = "0.3"
reqwest = { version = "0.12", default-features = false, features = ["json", "multipart", "rustls-tls"] }
reqwest = { version = "0.12", default-features = false, features = ["json", "multipart", "rustls-tls", "stream"] }
thiserror = "1"
anyhow = "1"
log = "0.4"
Expand Down
154 changes: 146 additions & 8 deletions openless-all/app/src-tauri/src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,9 @@ pub async fn validate_provider_credentials(kind: String) -> Result<ProviderCheck
"llm" => validate_llm_provider()
.await
.map(|()| ProviderCheckResult { ok: true }),
"asr" => {
let config = read_openai_provider_config(&kind)?;
fetch_provider_models(&config)
.await
.map(|_| ProviderCheckResult { ok: true })
}
"asr" => validate_asr_provider()
.await
.map(|()| ProviderCheckResult { ok: true }),
_ => Err(format!("unknown provider kind: {kind}")),
}
}
Expand Down Expand Up @@ -232,6 +229,123 @@ async fn validate_llm_provider() -> Result<(), String> {
})
}

async fn validate_asr_provider() -> Result<(), String> {
let config = read_openai_provider_config("asr")?;
let model = CredentialsVault::get(CredentialAccount::AsrModel)
.map_err(|e| e.to_string())?
.filter(|s| !s.trim().is_empty())
.ok_or_else(|| "asrModelMissing".to_string())?;
validate_asr_transcription(&config, model.trim()).await
}

async fn validate_asr_transcription(config: &ProviderConfig, model: &str) -> Result<(), String> {
const MAX_ASR_VALIDATE_BODY_BYTES: usize = 1024 * 1024;
let url = asr_transcriptions_url(&config.base_url)?;
let wav = encode_wav_16k_mono_silence(250);
let wav_part = reqwest::multipart::Part::bytes(wav)
.file_name("openless-asr-check.wav")
.mime_str("audio/wav")
.map_err(|e| format!("请求体构建失败: {e}"))?;
let form = reqwest::multipart::Form::new()
.part("file", wav_part)
.text("model", model.to_string());
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(20))
.build()
.map_err(|_| "providerClientInitFailed".to_string())?;
let response = client
.post(&url)
.header("Authorization", format!("Bearer {}", config.api_key))
.multipart(form)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
"providerRequestTimeout".to_string()
} else {
"providerNetworkError".to_string()
}
})?;
let status = response.status();
if !status.is_success() {
return Err(format!("providerHttpStatus:{}", status.as_u16()));
}
if let Some(len) = response.content_length() {
if len as usize > MAX_ASR_VALIDATE_BODY_BYTES {
return Err("providerResponseTooLarge".to_string());
}
}
use futures_util::StreamExt;
let mut body = Vec::<u8>::new();
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|_| "providerReadResponseFailed".to_string())?;
if body.len().saturating_add(chunk.len()) > MAX_ASR_VALIDATE_BODY_BYTES {
return Err("providerResponseTooLarge".to_string());
}
body.extend_from_slice(&chunk);
}
let json: Value = serde_json::from_slice(&body).map_err(|_| "asrInvalidJson".to_string())?;
if !json.is_object() || json.get("text").is_none() {
return Err("asrMissingTextField".to_string());
}
Ok(())
}

fn asr_transcriptions_url(base_url: &str) -> Result<String, String> {
let parsed = reqwest::Url::parse(base_url.trim()).map_err(|_| "endpointInvalid".to_string())?;
let host = parsed.host_str().unwrap_or_default();
let localhost = host.eq_ignore_ascii_case("localhost") || host == "127.0.0.1";
if parsed.scheme() != "https" && !localhost {
return Err("endpointMustUseHttps".to_string());
}

// Work on the URL path only so we don't corrupt query parameters.
let mut url = parsed.clone();
let path = parsed.path().trim_end_matches('/');
let next_path = if path.ends_with("/audio/transcriptions") {
path.to_string()
} else if path.ends_with("/audio") {
format!("{path}/transcriptions")
} else if let Some(prefix) = path.strip_suffix("/chat/completions") {
format!("{prefix}/audio/transcriptions")
} else {
format!("{path}/audio/transcriptions")
};
url.set_path(&next_path);
Ok(url.to_string())
}

fn encode_wav_16k_mono_silence(duration_ms: u32) -> Vec<u8> {
let sample_rate: u32 = 16_000;
let num_channels: u16 = 1;
let bits_per_sample: u16 = 16;
let bytes_per_sample = (bits_per_sample / 8) as usize;
let samples = (sample_rate as usize * duration_ms as usize) / 1000;
let pcm_len = samples * bytes_per_sample;
let data_size = pcm_len as u32;
let byte_rate = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8;
let block_align = num_channels * bits_per_sample / 8;
let chunk_size = 36 + data_size;

let mut wav = Vec::with_capacity(44 + pcm_len);
wav.extend_from_slice(b"RIFF");
wav.extend_from_slice(&chunk_size.to_le_bytes());
wav.extend_from_slice(b"WAVE");
wav.extend_from_slice(b"fmt ");
wav.extend_from_slice(&16u32.to_le_bytes());
wav.extend_from_slice(&1u16.to_le_bytes());
wav.extend_from_slice(&num_channels.to_le_bytes());
wav.extend_from_slice(&sample_rate.to_le_bytes());
wav.extend_from_slice(&byte_rate.to_le_bytes());
wav.extend_from_slice(&block_align.to_le_bytes());
wav.extend_from_slice(&bits_per_sample.to_le_bytes());
wav.extend_from_slice(b"data");
wav.extend_from_slice(&data_size.to_le_bytes());
wav.resize(44 + pcm_len, 0);
wav
}

async fn fetch_provider_models(config: &ProviderConfig) -> Result<Vec<String>, String> {
let url = models_url(&config.base_url);
log::info!("[provider-check] GET {url}");
Expand Down Expand Up @@ -582,8 +696,8 @@ fn _ensure_snapshot_used(_: CredentialsSnapshot) {}
#[cfg(test)]
mod tests {
use super::{
fetch_provider_models, models_url, parse_model_ids, persist_settings, ProviderConfig,
SettingsWriter,
asr_transcriptions_url, fetch_provider_models, models_url, parse_model_ids,
persist_settings, ProviderConfig, SettingsWriter,
};
use crate::types::{
HotkeyBinding, HotkeyMode, HotkeyTrigger, QaHotkeyBinding, UserPreferences,
Expand Down Expand Up @@ -627,6 +741,30 @@ mod tests {
);
}

#[test]
fn asr_transcriptions_url_accepts_base_or_transcriptions_endpoint() {
assert_eq!(
asr_transcriptions_url("https://api.openai.com/v1").unwrap(),
"https://api.openai.com/v1/audio/transcriptions"
);
assert_eq!(
asr_transcriptions_url("https://api.openai.com/v1/chat/completions").unwrap(),
"https://api.openai.com/v1/audio/transcriptions"
);
assert_eq!(
asr_transcriptions_url("https://api.openai.com/v1/audio").unwrap(),
"https://api.openai.com/v1/audio/transcriptions"
);
assert_eq!(
asr_transcriptions_url("https://api.openai.com/v1/audio/transcriptions").unwrap(),
"https://api.openai.com/v1/audio/transcriptions"
);
assert_eq!(
asr_transcriptions_url("https://api.openai.com/v1?api-version=2024-12-01").unwrap(),
"https://api.openai.com/v1/audio/transcriptions?api-version=2024-12-01"
);
}

#[test]
fn parse_model_ids_sorts_and_deduplicates() {
let models =
Expand Down
5 changes: 5 additions & 0 deletions openless-all/app/src/i18n/en.ts
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,11 @@ export const en: typeof zhCN = {
modelSaved: 'Saved model {{model}}.',
validateSuccess: 'Connection check passed.',
providerHttpStatus: 'Provider returned HTTP {{status}}. Check the API key permissions or endpoint.',
endpointMustUseHttps: 'Endpoint must use HTTPS (localhost/127.0.0.1 are allowed for local testing).',
endpointInvalid: 'Endpoint format is invalid.',
responseTooLarge: 'Provider response is too large to validate safely.',
asrInvalidJson: 'ASR response is not valid JSON.',
asrMissingTextField: 'ASR response is missing the text field.',
apiKeyMissing: 'API Key is empty.',
endpointMissing: 'Endpoint is empty.',
requestTimeout: 'Request timed out. Try again later.',
Expand Down
5 changes: 5 additions & 0 deletions openless-all/app/src/i18n/zh-CN.ts
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,11 @@ export const zhCN = {
modelSaved: '已保存模型 {{model}}。',
validateSuccess: '连接检查通过。',
providerHttpStatus: '供应商接口返回 {{status}},请检查 API Key 权限或 Endpoint。',
endpointMustUseHttps: 'Endpoint 必须使用 HTTPS(本地 localhost/127.0.0.1 测试除外)。',
endpointInvalid: 'Endpoint 格式不合法。',
responseTooLarge: '供应商响应过大,已停止验证以保证安全。',
asrInvalidJson: 'ASR 响应不是有效 JSON。',
asrMissingTextField: 'ASR 响应缺少 text 字段。',
apiKeyMissing: 'API Key 为空。',
endpointMissing: 'Endpoint 为空。',
requestTimeout: '请求超时,请稍后重试。',
Expand Down
14 changes: 13 additions & 1 deletion openless-all/app/src/pages/Settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -593,10 +593,14 @@ function ProviderTools({ kind, modelAccount, onModelSelected }: { kind: 'llm' |
setResult(result.ok ? 'success' : 'error', t('settings.providers.validateSuccess'));
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
if (kind === 'llm' && message === 'llmModelMissing') {
if ((kind === 'llm' && message === 'llmModelMissing') || (kind === 'asr' && message === 'asrModelMissing')) {
setResult('empty', t('settings.providers.modelMissing'));
return;
}
if (message === 'modelsEmpty') {
setResult('empty', t('settings.providers.modelsEmpty'));
return;
}
setResult('error', providerErrorMessage(error, t));
}
};
Expand Down Expand Up @@ -664,6 +668,14 @@ function providerErrorMessage(error: unknown, t: ReturnType<typeof useTranslatio
if (message.startsWith('providerHttpStatus:')) {
return t('settings.providers.providerHttpStatus', { status: message.split(':')[1] || '?' });
}
if (message === 'endpointMustUseHttps') return t('settings.providers.endpointMustUseHttps');
if (message === 'endpointInvalid') return t('settings.providers.endpointInvalid');
if (message === 'providerResponseTooLarge') return t('settings.providers.responseTooLarge');
if (message === 'asrInvalidJson') return t('settings.providers.asrInvalidJson');
if (message === 'asrMissingTextField') return t('settings.providers.asrMissingTextField');
if (message === 'providerNetworkError') return t('common.networkError');
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Define translation for mapped network validation error

providerErrorMessage maps providerNetworkError to t('common.networkError'), but neither en.ts nor zh-CN.ts defines common.networkError. When ASR validation hits a network failure, the UI will show a raw missing-key string instead of a user-friendly message, so the new error-code mapping degrades the settings UX in a common failure path.

Useful? React with 👍 / 👎.

if (message === 'providerReadResponseFailed' || message === 'providerClientInitFailed') return t('common.operationFailed');
if (message === 'providerRequestTimeout') return t('settings.providers.requestTimeout');
if (message.includes('API Key')) return t('settings.providers.apiKeyMissing');
if (message.includes('Endpoint')) return t('settings.providers.endpointMissing');
if (message.includes('timeout') || message.includes('超时')) return t('settings.providers.requestTimeout');
Expand Down
Loading