diff --git a/openless-all/app/src-tauri/src/asr/local/mod.rs b/openless-all/app/src-tauri/src/asr/local/mod.rs index 5412fb19..ec75ed62 100644 --- a/openless-all/app/src-tauri/src/asr/local/mod.rs +++ b/openless-all/app/src-tauri/src/asr/local/mod.rs @@ -3,7 +3,7 @@ //! 当前本地引擎: //! - **macOS**:`antirez/qwen-asr` 纯 C + Accelerate(`local_provider` / `qwen_engine`) //! - **Windows**:Foundry Local Whisper(`foundry_*`),以及 sherpa-onnx-local -//! 实验 provider(`sherpa*`,M1 仅骨架,详见 `docs/windows-sherpa-onnx-asr-plan.md`) +//! 实验 provider(`sherpa*`,offline batch + online streaming) pub mod cache; pub mod download; diff --git a/openless-all/app/src-tauri/src/asr/local/sherpa.rs b/openless-all/app/src-tauri/src/asr/local/sherpa.rs index 8354ea5d..8354a4fd 100644 --- a/openless-all/app/src-tauri/src/asr/local/sherpa.rs +++ b/openless-all/app/src-tauri/src/asr/local/sherpa.rs @@ -1,9 +1,7 @@ //! Windows sherpa-onnx 本地 ASR 的常量、catalog 与事件载荷。 //! -//! M1 阶段:纯描述层;不依赖 `sherpa-onnx` crate,不做实际推理。 -//! 与 `foundry.rs` 形状对齐,便于前端命令链路与 Foundry 同形复用。 -//! -//! 推理接入见 `sherpa_runtime.rs`(M2)。 +//! 当前 catalog 覆盖 Windows offline batch 模型和实验 online streaming 模型; +//! `sherpa_runtime.rs` 分别持有 `OfflineRecognizer` / `OnlineRecognizer`。 use std::path::{Path, PathBuf}; @@ -12,6 +10,7 @@ use serde::Serialize; pub const PROVIDER_ID: &str = "sherpa-onnx-local"; pub const DEFAULT_MODEL_ALIAS: &str = "sense-voice-small-zh"; +pub const DEFAULT_ONLINE_MODEL_ALIAS: &str = "zipformer-bilingual-zh-en-streaming"; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] #[serde(rename_all = "camelCase")] @@ -30,7 +29,7 @@ pub enum SherpaFamily { pub enum SherpaMode { /// 录音停止后整段 PCM 一次性识别。 Offline, - /// 边录边识别 partial / final segment。M5 才接。 + /// 边录边识别 partial / final segment。 Online, } @@ -47,8 +46,8 @@ pub struct SherpaModel { pub quality_tier: &'static str, } -/// M1 catalog 三档:默认 SenseVoice,中文专用 Paraformer,多语 Whisper 兜底。 -/// 文件清单 + 校验和会在 M3 模型管理阶段补全;M1 只暴露元数据驱动 UI。 +/// Catalog:默认 SenseVoice,中文专用 Paraformer,多语 Whisper 兜底, +/// Qwen3-ASR 实验档,以及 Zipformer online streaming 实验档。 #[allow(dead_code)] pub const MODELS: &[SherpaModel] = &[ SherpaModel { @@ -83,6 +82,14 @@ pub const MODELS: &[SherpaModel] = &[ languages: &["multi"], quality_tier: "qwen3-balanced", }, + SherpaModel { + alias: DEFAULT_ONLINE_MODEL_ALIAS, + display_name: "Zipformer Streaming bilingual (zh/en)", + family: SherpaFamily::Zipformer, + mode: SherpaMode::Online, + languages: &["zh", "en"], + quality_tier: "streaming-experimental", + }, ]; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -102,6 +109,18 @@ pub fn model_alias_is_known(alias: &str) -> bool { MODELS.iter().any(|model| model.alias == alias) } +pub fn mode_for_alias(alias: &str) -> Result { + MODELS + .iter() + .find(|model| model.alias == alias) + .map(|model| model.mode) + .ok_or_else(|| anyhow::anyhow!("unknown sherpa-onnx model alias: {alias}")) +} + +pub fn alias_is_online(alias: &str) -> bool { + matches!(mode_for_alias(alias), Ok(SherpaMode::Online)) +} + pub fn hf_repo_for_alias(alias: &str) -> Result<&'static str> { match alias { "sense-voice-small-zh" => { @@ -109,6 +128,9 @@ pub fn hf_repo_for_alias(alias: &str) -> Result<&'static str> { } "paraformer-zh" => Ok("csukuangfj/sherpa-onnx-paraformer-zh-2024-03-09"), "whisper-small-multi" => Ok("csukuangfj/sherpa-onnx-whisper-small"), + DEFAULT_ONLINE_MODEL_ALIAS => { + Ok("csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20") + } _ => anyhow::bail!("unknown sherpa-onnx model alias: {alias}"), } } @@ -124,6 +146,12 @@ pub fn required_files_for_alias(alias: &str) -> Result<&'static [&'static str]> "decoder.int8.onnx", "tokenizer", ]), + DEFAULT_ONLINE_MODEL_ALIAS => Ok(&[ + "encoder-epoch-99-avg-1.int8.onnx", + "decoder-epoch-99-avg-1.onnx", + "joiner-epoch-99-avg-1.int8.onnx", + "tokens.txt", + ]), _ => anyhow::bail!("unknown sherpa-onnx model alias: {alias}"), } } @@ -165,6 +193,18 @@ pub fn download_files_for_alias(alias: &str) -> Result<&'static [(&'static str, ("small-decoder.int8.onnx", "decoder.int8.onnx"), ("small-tokens.txt", "tokens.txt"), ]), + DEFAULT_ONLINE_MODEL_ALIAS => Ok(&[ + ( + "encoder-epoch-99-avg-1.int8.onnx", + "encoder-epoch-99-avg-1.int8.onnx", + ), + ("decoder-epoch-99-avg-1.onnx", "decoder-epoch-99-avg-1.onnx"), + ( + "joiner-epoch-99-avg-1.int8.onnx", + "joiner-epoch-99-avg-1.int8.onnx", + ), + ("tokens.txt", "tokens.txt"), + ]), _ => anyhow::bail!("unknown sherpa-onnx model alias: {alias}"), } } @@ -292,25 +332,39 @@ impl SherpaPrepareProgressPayload { #[allow(dead_code)] pub struct SherpaRuntimeStatus { pub provider_id: String, - /// M1 阶段恒为 false:sherpa-onnx crate 尚未接入。 + /// 当前平台是否具备 sherpa-onnx 推理能力。Windows 为 true;其他平台保留 + /// provider 元数据但不提供本地 sherpa 推理。 pub available: bool, /// 当前模型是否已加载到内存。 pub runtime_ready: bool, pub active_model: String, pub loaded_model_id: Option, pub error: Option, + /// 最近一次 prepare/load 耗时。缓存命中也会记录一次很小的耗时。 + pub last_prepare_ms: Option, + /// 最近一次 batch decode 耗时,不含录音时间。 + pub last_transcribe_ms: Option, + /// 最近一次送入 recognizer 的音频时长。 + pub last_audio_ms: Option, + /// 最近一次 prepare/transcribe 错误,方便 UI 和日志定位可恢复失败。 + pub last_error: Option, } impl SherpaRuntimeStatus { #[allow(dead_code)] pub fn unavailable(active_model: String, error: impl Into) -> Self { + let error = error.into(); Self { provider_id: PROVIDER_ID.into(), available: false, runtime_ready: false, active_model, loaded_model_id: None, - error: Some(error.into()), + error: Some(error.clone()), + last_prepare_ms: None, + last_transcribe_ms: None, + last_audio_ms: None, + last_error: Some(error), } } } @@ -342,9 +396,33 @@ mod tests { "paraformer-zh", "whisper-small-multi", "qwen3-asr-0.6b-int8", + "zipformer-bilingual-zh-en-streaming", ] ); assert!(catalog.iter().all(|m| !m.cached)); + assert_eq!(catalog.last().unwrap().mode, SherpaMode::Online); + } + + #[test] + fn online_zipformer_has_download_and_required_files() { + assert_eq!( + mode_for_alias(DEFAULT_ONLINE_MODEL_ALIAS).unwrap(), + SherpaMode::Online + ); + assert!(alias_is_online(DEFAULT_ONLINE_MODEL_ALIAS)); + assert_eq!( + hf_repo_for_alias(DEFAULT_ONLINE_MODEL_ALIAS).unwrap(), + "csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20" + ); + assert_eq!( + required_files_for_alias(DEFAULT_ONLINE_MODEL_ALIAS).unwrap(), + &[ + "encoder-epoch-99-avg-1.int8.onnx", + "decoder-epoch-99-avg-1.onnx", + "joiner-epoch-99-avg-1.int8.onnx", + "tokens.txt", + ] + ); } #[test] @@ -394,6 +472,7 @@ mod tests { assert!(!status.runtime_ready); assert_eq!(status.active_model, "paraformer-zh"); assert_eq!(status.error.as_deref(), Some("not ready")); + assert_eq!(status.last_error.as_deref(), Some("not ready")); } #[test] diff --git a/openless-all/app/src-tauri/src/asr/local/sherpa_download.rs b/openless-all/app/src-tauri/src/asr/local/sherpa_download.rs index cc882545..51f72c42 100644 --- a/openless-all/app/src-tauri/src/asr/local/sherpa_download.rs +++ b/openless-all/app/src-tauri/src/asr/local/sherpa_download.rs @@ -863,7 +863,193 @@ mod tests { impl Drop for TempModelDir { fn drop(&mut self) { let _ = fs::remove_dir_all(&self.0); + if let Ok(extract_dir) = archive_extract_dir(&self.0) { + let _ = fs::remove_dir_all(extract_dir); + } + } + } + + fn write_release_archive_fixture( + archive_path: &Path, + archive: sherpa::SherpaReleaseArchive, + files: &[(&str, &[u8])], + ) { + let src_root = std::env::temp_dir().join(format!( + "openless-sherpa-archive-src-{}", + uuid::Uuid::new_v4() + )); + fs::create_dir_all(&src_root).expect("create archive source root"); + for (relative, bytes) in files { + let path = src_root.join(relative); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).expect("create archive source parent"); + } + fs::write(path, bytes).expect("write archive source file"); } + + let file = fs::File::create(archive_path).expect("create archive file"); + let encoder = bzip2::write::BzEncoder::new(file, bzip2::Compression::default()); + let mut builder = tar::Builder::new(encoder); + builder + .append_dir_all(archive.root_dir, &src_root) + .expect("append archive root"); + let encoder = builder.into_inner().expect("finish tar"); + encoder.finish().expect("finish bzip2"); + let _ = fs::remove_dir_all(src_root); + } + + #[test] + fn verify_file_rejects_size_mismatch() { + let dir = TempModelDir::new("verify-size"); + let path = dir.path().join("model.bin"); + fs::write(&path, b"abc").expect("write test file"); + let file = SherpaRemoteFile { + path: "model.bin".into(), + local_path: "model.bin".into(), + size: 4, + sha256: None, + }; + + let message = format!("{:#}", verify_file(&path, &file).unwrap_err()); + + assert!(message.contains("文件大小不匹配")); + assert!(message.contains("actual=3")); + assert!(message.contains("expected=4")); + } + + #[test] + fn verify_file_rejects_sha256_mismatch() { + let dir = TempModelDir::new("verify-sha"); + let path = dir.path().join("model.bin"); + fs::write(&path, b"abc").expect("write test file"); + let file = SherpaRemoteFile { + path: "model.bin".into(), + local_path: "model.bin".into(), + size: 3, + sha256: Some("0000000000000000000000000000000000000000000000000000000000000000".into()), + }; + + let message = format!("{:#}", verify_file(&path, &file).unwrap_err()); + + assert!(message.contains("SHA-256 不匹配")); + assert!(message + .contains("expected=0000000000000000000000000000000000000000000000000000000000000000")); + } + + #[test] + fn verify_file_accepts_case_insensitive_sha256() { + let dir = TempModelDir::new("verify-sha-ok"); + let path = dir.path().join("model.bin"); + fs::write(&path, b"abc").expect("write test file"); + let file = SherpaRemoteFile { + path: "model.bin".into(), + local_path: "model.bin".into(), + size: 3, + sha256: Some(sha256_file(&path).unwrap().to_ascii_uppercase()), + }; + + verify_file(&path, &file).expect("sha should verify"); + } + + #[test] + fn archive_extract_dir_uses_sibling_path() { + let dir = TempModelDir::new("extract-dir"); + let name = dir.path().file_name().unwrap().to_string_lossy(); + + let extract_dir = archive_extract_dir(dir.path()).unwrap(); + + assert_eq!( + extract_dir, + dir.path().with_file_name(format!("{name}.extracting")) + ); + } + + #[test] + fn release_archive_bytes_uses_partial_archive_when_not_extracted() { + let alias = "qwen3-asr-0.6b-int8"; + let archive = sherpa::release_archive_for_alias(alias).expect("release archive"); + let dir = TempModelDir::new("release-archive-partial"); + let partial_path = dir.path().join(archive.file_name).with_extension("partial"); + fs::write(partial_path, b"partial").expect("write partial archive"); + + assert_eq!( + downloaded_release_archive_bytes(dir.path(), alias, archive), + 7 + ); + } + + #[test] + fn extract_release_archive_rejects_missing_required_file() { + let alias = "qwen3-asr-0.6b-int8"; + let archive = sherpa::release_archive_for_alias(alias).expect("release archive"); + let dir = TempModelDir::new("release-archive-missing"); + let archive_path = dir.path().join(archive.file_name); + write_release_archive_fixture( + &archive_path, + archive, + &[("conv_frontend.onnx", b"conv" as &[u8])], + ); + + let message = format!( + "{:#}", + extract_release_archive(&archive_path, dir.path(), archive, alias).unwrap_err() + ); + + assert!(message.contains("archive required path missing")); + assert!(message.contains("encoder.int8.onnx")); + } + + #[test] + fn extract_release_archive_moves_required_files_and_removes_work_paths() { + let alias = "qwen3-asr-0.6b-int8"; + let archive = sherpa::release_archive_for_alias(alias).expect("release archive"); + let dir = TempModelDir::new("release-archive-success"); + let archive_path = dir.path().join(archive.file_name); + write_release_archive_fixture( + &archive_path, + archive, + &[ + ("conv_frontend.onnx", b"conv" as &[u8]), + ("encoder.int8.onnx", b"encoder" as &[u8]), + ("decoder.int8.onnx", b"decoder" as &[u8]), + ("tokenizer/tokenizer.json", b"tok" as &[u8]), + ], + ); + + extract_release_archive(&archive_path, dir.path(), archive, alias).unwrap(); + + assert_eq!( + fs::read(dir.path().join("conv_frontend.onnx")).unwrap(), + b"conv" + ); + assert_eq!( + fs::read(dir.path().join("encoder.int8.onnx")).unwrap(), + b"encoder" + ); + assert_eq!( + fs::read(dir.path().join("decoder.int8.onnx")).unwrap(), + b"decoder" + ); + assert_eq!( + fs::read(dir.path().join("tokenizer").join("tokenizer.json")).unwrap(), + b"tok" + ); + assert!(!archive_path.exists()); + assert!(!archive_extract_dir(dir.path()).unwrap().exists()); + } + + #[test] + fn download_manager_cancel_sets_active_flag() { + let manager = SherpaDownloadManager::new(); + let flag = Arc::new(AtomicBool::new(false)); + manager + .cancel_flags + .lock() + .insert("sense-voice-small-zh".into(), Arc::clone(&flag)); + + manager.cancel("sense-voice-small-zh"); + + assert!(flag.load(Ordering::SeqCst)); } #[test] diff --git a/openless-all/app/src-tauri/src/asr/local/sherpa_provider.rs b/openless-all/app/src-tauri/src/asr/local/sherpa_provider.rs index 741d6159..7cd48e21 100644 --- a/openless-all/app/src-tauri/src/asr/local/sherpa_provider.rs +++ b/openless-all/app/src-tauri/src/asr/local/sherpa_provider.rs @@ -1,35 +1,58 @@ -//! sherpa-onnx 本地 ASR provider(M1 骨架)。 +//! sherpa-onnx 本地 ASR provider(Windows offline batch + online streaming)。 //! //! 形状与 `foundry_provider.rs` 对齐: //! - 作为 `Recorder::AudioConsumer` 持续吃 PCM //! - 录音结束后 `transcribe(timeout)` 返回 `RawTranscript` -//! - `cancel()` 让任何 in-flight transcription 提前结束(M1 桩,仅清 buffer) +//! - `cancel()` 让任何 in-flight transcription 提前结束,并清理已缓存 PCM //! -//! M1 阶段: -//! - `transcribe` 调 `SherpaOnnxRuntime::transcribe_pcm`(M1 返回空串) -//! - 让主链路在 Windows + `sherpa-onnx-local` provider 时能跑完 -//! begin_session → 录音 → end_session → polish → insert 的形态 -//! - M1 空 transcript 会走现有 emptyTranscript 护栏;M2 接真实推理后复用同一收尾路径 +//! Offline 模型停止录音后把整段 16kHz mono s16le PCM 交给 +//! `SherpaOnnxRuntime::transcribe_pcm`。Online 模型在独立 worker 中实时消费 PCM, +//! partial token 通过回调上抛,停止录音后返回 final `RawTranscript`。 -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::mpsc::{self, Receiver, Sender}; use std::sync::Arc; +use std::thread::JoinHandle; use std::time::Duration; +use std::time::Instant; use anyhow::Result; use parking_lot::Mutex; use crate::asr::RawTranscript; -use super::sherpa_runtime::SherpaOnnxRuntime; +use super::sherpa; +use super::sherpa_runtime::{SherpaOnlineSession, SherpaOnnxRuntime}; pub struct SherpaOnnxAsr { runtime: Arc, model_alias: String, language_hint: Option, - buffer: Mutex>, + mode: SherpaProviderMode, cancel_generation: AtomicU64, } +enum SherpaProviderMode { + Offline { buffer: Mutex> }, + Online { worker: Mutex> }, +} + +struct OnlineWorker { + tx: Sender, + result_rx: Mutex>>>, + join_handle: Mutex>>, + audio_bytes: AtomicU64, + cancelled: Arc, +} + +enum OnlineWorkerMessage { + Pcm(Vec), + Finish, + Cancel, +} + +pub type SherpaTokenHandler = Arc; + impl SherpaOnnxAsr { pub fn new( runtime: Arc, @@ -40,11 +63,35 @@ impl SherpaOnnxAsr { runtime, model_alias, language_hint: normalize_language_hint(language_hint), - buffer: Mutex::new(Vec::new()), + mode: SherpaProviderMode::Offline { + buffer: Mutex::new(Vec::new()), + }, cancel_generation: AtomicU64::new(0), } } + pub async fn new_for_model( + runtime: Arc, + model_alias: String, + language_hint: Option, + token_handler: Option, + ) -> Result { + if sherpa::alias_is_online(&model_alias) { + let session = runtime.create_online_session(&model_alias).await?; + Ok(Self { + runtime, + model_alias, + language_hint: normalize_language_hint(language_hint), + mode: SherpaProviderMode::Online { + worker: Mutex::new(Some(OnlineWorker::spawn(session, token_handler))), + }, + cancel_generation: AtomicU64::new(0), + }) + } else { + Ok(Self::new(runtime, model_alias, language_hint)) + } + } + #[allow(dead_code)] pub fn model_alias(&self) -> &str { &self.model_alias @@ -56,8 +103,23 @@ impl SherpaOnnxAsr { } pub async fn transcribe(&self, audio_timeout: Duration) -> Result { + match &self.mode { + SherpaProviderMode::Offline { buffer } => { + self.transcribe_offline(buffer, audio_timeout).await + } + SherpaProviderMode::Online { worker } => { + self.transcribe_online(worker, audio_timeout).await + } + } + } + + async fn transcribe_offline( + &self, + buffer: &Mutex>, + audio_timeout: Duration, + ) -> Result { let cancel_generation = self.cancel_generation.load(Ordering::SeqCst); - let pcm = self.buffer.lock().clone(); + let pcm = buffer.lock().clone(); if pcm.is_empty() { return Ok(RawTranscript { text: String::new(), @@ -76,8 +138,54 @@ impl SherpaOnnxAsr { } // 与 Foundry 行为对齐:进入推理后清 buffer,避免下一轮重复消费。 - self.buffer.lock().clear(); + buffer.lock().clear(); + + let text = result?; + Ok(RawTranscript { + text: trim_transcript_text(&text), + duration_ms, + }) + } + async fn transcribe_online( + &self, + worker_slot: &Mutex>, + audio_timeout: Duration, + ) -> Result { + let cancel_generation = self.cancel_generation.load(Ordering::SeqCst); + let Some(worker) = worker_slot.lock().take() else { + return Ok(RawTranscript { + text: String::new(), + duration_ms: 0, + }); + }; + let duration_ms = pcm_duration_ms_from_bytes(worker.audio_bytes.load(Ordering::SeqCst)); + let started = Instant::now(); + let result = worker.finish(audio_timeout).await; + let elapsed_ms = started.elapsed().as_millis().min(u128::from(u64::MAX)) as u64; + if self.cancel_generation.load(Ordering::SeqCst) != cancel_generation { + self.runtime.record_streaming_result( + &self.model_alias, + duration_ms, + elapsed_ms, + Some("sherpa-onnx streaming transcription cancelled".into()), + ); + anyhow::bail!("sherpa-onnx streaming transcription cancelled"); + } + match &result { + Ok(_) => self.runtime.record_streaming_result( + &self.model_alias, + duration_ms, + elapsed_ms, + None, + ), + Err(error) => self.runtime.record_streaming_result( + &self.model_alias, + duration_ms, + elapsed_ms, + Some(format!("{error:#}")), + ), + } let text = result?; Ok(RawTranscript { text: trim_transcript_text(&text), @@ -88,18 +196,36 @@ impl SherpaOnnxAsr { pub fn cancel(&self) { self.cancel_generation.fetch_add(1, Ordering::SeqCst); self.runtime.request_cancel_prepare(); - self.buffer.lock().clear(); + match &self.mode { + SherpaProviderMode::Offline { buffer } => buffer.lock().clear(), + SherpaProviderMode::Online { worker } => { + if let Some(worker) = worker.lock().take() { + worker.cancel(); + } + } + } } } impl crate::recorder::AudioConsumer for SherpaOnnxAsr { fn consume_pcm_chunk(&self, pcm: &[u8]) { - self.buffer.lock().extend_from_slice(pcm); + match &self.mode { + SherpaProviderMode::Offline { buffer } => buffer.lock().extend_from_slice(pcm), + SherpaProviderMode::Online { worker } => { + if let Some(worker) = worker.lock().as_ref() { + worker.send_pcm(pcm); + } + } + } } } fn pcm_duration_ms(pcm: &[u8]) -> u64 { - (pcm.len() as u64 / 2) * 1000 / 16_000 + pcm_duration_ms_from_bytes(pcm.len() as u64) +} + +fn pcm_duration_ms_from_bytes(bytes: u64) -> u64 { + (bytes / 2) * 1000 / 16_000 } fn trim_transcript_text(text: &str) -> String { @@ -111,6 +237,118 @@ fn normalize_language_hint(raw: Option) -> Option { .filter(|s| !s.is_empty()) } +impl OnlineWorker { + fn spawn(mut session: SherpaOnlineSession, token_handler: Option) -> Self { + let alias = session.alias().to_string(); + let (tx, rx) = mpsc::channel::(); + let (result_tx, result_rx) = mpsc::channel::>(); + let cancelled = Arc::new(AtomicBool::new(false)); + let worker_cancelled = Arc::clone(&cancelled); + let join_handle = std::thread::Builder::new() + .name(format!("openless-sherpa-online-{alias}")) + .spawn(move || { + let emit = |piece: &str| { + if piece.is_empty() || worker_cancelled.load(Ordering::SeqCst) { + return; + } + if let Some(handler) = token_handler.as_ref() { + handler(piece.to_string()); + } + }; + let result = loop { + match rx.recv() { + Ok(OnlineWorkerMessage::Pcm(pcm)) => { + if worker_cancelled.load(Ordering::SeqCst) { + break Err(anyhow::anyhow!("sherpa-onnx streaming cancelled")); + } + if let Err(error) = session.accept_pcm_chunk(&pcm, &emit) { + break Err(error); + } + if worker_cancelled.load(Ordering::SeqCst) { + break Err(anyhow::anyhow!("sherpa-onnx streaming cancelled")); + } + } + Ok(OnlineWorkerMessage::Finish) => { + if worker_cancelled.load(Ordering::SeqCst) { + break Err(anyhow::anyhow!("sherpa-onnx streaming cancelled")); + } + break session.finish(&emit); + } + Ok(OnlineWorkerMessage::Cancel) | Err(_) => { + worker_cancelled.store(true, Ordering::SeqCst); + break Err(anyhow::anyhow!("sherpa-onnx streaming cancelled")); + } + } + }; + let _ = result_tx.send(result); + }) + .expect("spawn sherpa online worker"); + + Self { + tx, + result_rx: Mutex::new(Some(result_rx)), + join_handle: Mutex::new(Some(join_handle)), + audio_bytes: AtomicU64::new(0), + cancelled, + } + } + + fn send_pcm(&self, pcm: &[u8]) { + if pcm.is_empty() || self.cancelled.load(Ordering::SeqCst) { + return; + } + self.audio_bytes + .fetch_add(pcm.len() as u64, Ordering::SeqCst); + if self + .tx + .send(OnlineWorkerMessage::Pcm(pcm.to_vec())) + .is_err() + { + log::warn!("[sherpa-asr] online worker is not accepting PCM"); + } + } + + async fn finish(self, audio_timeout: Duration) -> Result { + let _ = self.tx.send(OnlineWorkerMessage::Finish); + let result_rx = self + .result_rx + .lock() + .take() + .ok_or_else(|| anyhow::anyhow!("sherpa-onnx streaming result already taken"))?; + let join_handle = self.join_handle.lock().take(); + let result = tokio::time::timeout(audio_timeout, async move { + tokio::task::spawn_blocking(move || { + result_rx.recv().map_err(|error| { + anyhow::anyhow!("sherpa-onnx streaming worker closed: {error}") + })? + }) + .await + .map_err(|error| anyhow::anyhow!("sherpa-onnx streaming join failed: {error:#}"))? + }) + .await; + let result = match result { + Ok(result) => result, + Err(_) => { + self.cancelled.store(true, Ordering::SeqCst); + let _ = self.tx.send(OnlineWorkerMessage::Cancel); + anyhow::bail!("sherpa-onnx streaming transcribe timeout"); + } + }; + if let Some(join_handle) = join_handle { + let _ = join_handle.join(); + } + result + } + + fn cancel(self) { + self.cancelled.store(true, Ordering::SeqCst); + let _ = self.tx.send(OnlineWorkerMessage::Cancel); + if let Some(join_handle) = self.join_handle.lock().take() { + let _ = join_handle.join(); + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -145,7 +383,10 @@ mod tests { let provider = make_provider(); provider.consume_pcm_chunk(&[1, 2, 3, 4]); provider.consume_pcm_chunk(&[5, 6]); - assert_eq!(provider.buffer.lock().len(), 6); + match &provider.mode { + SherpaProviderMode::Offline { buffer } => assert_eq!(buffer.lock().len(), 6), + SherpaProviderMode::Online { .. } => panic!("expected offline provider"), + } } #[tokio::test] @@ -166,17 +407,29 @@ mod tests { provider.consume_pcm_chunk(&vec![0u8; 32_000]); let result = provider.transcribe(Duration::from_secs(5)).await; assert!(result.is_err()); - assert!(provider.buffer.lock().is_empty()); + match &provider.mode { + SherpaProviderMode::Offline { buffer } => assert!(buffer.lock().is_empty()), + SherpaProviderMode::Online { .. } => panic!("expected offline provider"), + } } #[test] fn cancel_clears_buffer_and_bumps_generation() { - let provider = make_provider(); + let runtime = Arc::new(SherpaOnnxRuntime::new()); + let provider = SherpaOnnxAsr::new( + Arc::clone(&runtime), + "sense-voice-small-zh".into(), + Some(" ZH ".into()), + ); provider.consume_pcm_chunk(&[1, 2, 3, 4]); let before = provider.cancel_generation.load(Ordering::SeqCst); provider.cancel(); let after = provider.cancel_generation.load(Ordering::SeqCst); assert!(after > before); - assert!(provider.buffer.lock().is_empty()); + assert!(runtime.cancel_prepare_requested_for_tests()); + match &provider.mode { + SherpaProviderMode::Offline { buffer } => assert!(buffer.lock().is_empty()), + SherpaProviderMode::Online { .. } => panic!("expected offline provider"), + } } } diff --git a/openless-all/app/src-tauri/src/asr/local/sherpa_runtime.rs b/openless-all/app/src-tauri/src/asr/local/sherpa_runtime.rs index 59662294..af24da3d 100644 --- a/openless-all/app/src-tauri/src/asr/local/sherpa_runtime.rs +++ b/openless-all/app/src-tauri/src/asr/local/sherpa_runtime.rs @@ -1,46 +1,64 @@ -//! sherpa-onnx 本地 ASR runtime(M1 骨架)。 +//! sherpa-onnx 本地 ASR runtime(Windows offline batch + online streaming)。 //! //! 设计与 `foundry_runtime.rs` 对齐:runtime 是模型/会话/生命周期的单一持有者, //! 不感知 `Coordinator` / `Recorder` / UI / Tauri 事件。失败统一通过 //! `anyhow::Error` 上抛,由上层翻译为用户可见文案。 //! -//! M1 阶段: -//! - 全平台编译通过(避免 macOS / Linux CI 红线) -//! - 不引入 `sherpa-onnx` crate(M2 才加 Windows-only 依赖) -//! - `ensure_loaded` / `transcribe_pcm` / `release_now` 全部桩实现 -//! - 仅维持 active_model / runtime_ready 这种「状态门面」,便于前端联调 +//! 当前 Windows 路径接入 `sherpa-onnx` 的 `OfflineRecognizer` 和 +//! `OnlineRecognizer`,支持模型加载、缓存、整段 PCM 转写、online 分块解码和释放。 +//! 非 Windows 仍只保留可编译的状态门面。 use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use std::time::Instant; use anyhow::{Context, Result}; use parking_lot::Mutex; use tokio::sync::Mutex as AsyncMutex; use crate::asr::local::sherpa::{ - self, SherpaCatalogModel, SherpaFamily, SherpaPreparePhase, SherpaPrepareProgressPayload, - SherpaRuntimeStatus, PROVIDER_ID, + self, SherpaCatalogModel, SherpaFamily, SherpaMode, SherpaPreparePhase, + SherpaPrepareProgressPayload, SherpaRuntimeStatus, PROVIDER_ID, }; #[cfg(target_os = "windows")] use sherpa_onnx::{ OfflineParaformerModelConfig, OfflineQwen3ASRModelConfig, OfflineRecognizer, OfflineRecognizerConfig, OfflineSenseVoiceModelConfig, OfflineWhisperModelConfig, + OnlineRecognizer, OnlineRecognizerConfig, }; -/// 模型加载状态。M1 阶段不持有任何 native handle; -/// M2 引入 sherpa-onnx crate 后再补 `recognizer: Arc` 之类的字段。 +/// Offline 模型加载状态。Windows 持有 native `OfflineRecognizer`;其他平台仅保留 alias +/// 以维持跨平台编译与状态查询形状。 #[derive(Clone)] -struct LoadedModel { +struct LoadedOfflineModel { alias: String, #[cfg(target_os = "windows")] recognizer: Arc, } +/// Online 模型加载状态。每次听写会话会从 recognizer 创建独立 `OnlineStream`。 +#[derive(Clone)] +struct LoadedOnlineModel { + alias: String, + #[cfg(target_os = "windows")] + recognizer: Arc, +} + #[derive(Default)] struct RuntimeState { - loaded: Option, + offline_loaded: Option, + online_loaded: Option, + diagnostics: RuntimeDiagnostics, +} + +#[derive(Clone, Default)] +struct RuntimeDiagnostics { + last_prepare_ms: Option, + last_transcribe_ms: Option, + last_audio_ms: Option, + last_error: Option, } /// 跨会话单例。生命周期由 `AsyncMutex` 串行化,确保 ensure_loaded / release 不会并发。 @@ -65,31 +83,32 @@ impl SherpaOnnxRuntime { } } - /// 返回当前 runtime 是否真的具备推理能力。M1 永远是 false; - /// M2 接入 sherpa-onnx 后改为编译期 `#[cfg(target_os = "windows")]` 真值。 + /// 返回当前 runtime 是否真的具备推理能力。当前仅 Windows 接入 + /// `sherpa-onnx` offline recognizer。 #[allow(dead_code)] pub fn is_available(&self) -> bool { cfg!(target_os = "windows") } pub async fn status_snapshot(&self, active_model: &str) -> SherpaRuntimeStatus { - let loaded_model_id = self - .state - .lock() - .loaded - .as_ref() - .map(|loaded| loaded.alias.clone()); + let state = self.state.lock(); + let loaded_model_id = active_loaded_model_id(&state, active_model); + let diagnostics = state.diagnostics.clone(); SherpaRuntimeStatus { provider_id: PROVIDER_ID.into(), available: self.is_available(), runtime_ready: loaded_model_id.is_some(), active_model: active_model.to_string(), loaded_model_id, - error: None, + error: diagnostics.last_error.clone(), + last_prepare_ms: diagnostics.last_prepare_ms, + last_transcribe_ms: diagnostics.last_transcribe_ms, + last_audio_ms: diagnostics.last_audio_ms, + last_error: diagnostics.last_error, } } - /// M1:返回静态 catalog。M3 接入下载管理后会合并本地缓存状态。 + /// 返回静态 catalog,并合并本地缓存状态与已下载字节数。 #[allow(dead_code)] pub async fn catalog_snapshot(&self) -> Result> { let mut catalog = sherpa::static_catalog_models(); @@ -115,13 +134,45 @@ impl SherpaOnnxRuntime { } pub async fn ensure_loaded_with_progress(&self, alias: &str, progress: F) -> Result + where + F: Fn(SherpaPrepareProgressPayload) + Send + Sync + 'static, + { + let started = Instant::now(); + let result = self + .ensure_loaded_with_progress_inner(alias, progress) + .await; + let elapsed_ms = started.elapsed().as_millis().min(u128::from(u64::MAX)) as u64; + match &result { + Ok(loaded_alias) => { + log::info!( + "[sherpa-asr] prepare finished model={} elapsed_ms={}", + loaded_alias, + elapsed_ms + ); + self.record_prepare_result(elapsed_ms, None); + } + Err(error) => { + let message = format!("{error:#}"); + log::warn!( + "[sherpa-asr] prepare failed model={} elapsed_ms={} error={}", + alias, + elapsed_ms, + message + ); + self.record_prepare_result(elapsed_ms, Some(message)); + } + } + result + } + + async fn ensure_loaded_with_progress_inner(&self, alias: &str, progress: F) -> Result where F: Fn(SherpaPrepareProgressPayload) + Send + Sync + 'static, { let _lifecycle = self.lifecycle.lock().await; self.cancel_prepare.store(false, Ordering::SeqCst); validate_alias(alias)?; - if let Some(loaded) = self.cached_loaded_model(alias) { + if let Some(loaded_alias) = self.cached_loaded_alias(alias) { progress(SherpaPrepareProgressPayload::new( SherpaPreparePhase::Finished, alias, @@ -129,7 +180,7 @@ impl SherpaOnnxRuntime { Some(100.0), None, )); - return Ok(loaded.alias); + return Ok(loaded_alias); } self.check_prepare_cancelled()?; let dir = sherpa::model_dir_for_alias(alias)?; @@ -158,7 +209,10 @@ impl SherpaOnnxRuntime { Some(100.0), None, )); - self.state.lock().loaded = Some(loaded.clone()); + match loaded { + LoadedModel::Offline(loaded) => self.state.lock().offline_loaded = Some(loaded), + LoadedModel::Online(loaded) => self.state.lock().online_loaded = Some(loaded), + } progress(SherpaPrepareProgressPayload::new( SherpaPreparePhase::Finished, alias, @@ -169,8 +223,8 @@ impl SherpaOnnxRuntime { Ok(alias.to_string()) } - /// M1:永远返回空串,配合 mock pipeline 让用户的话不被「丢失也不被乱写」。 - /// 真实接入见 M2 `OfflineRecognizer::decode`。 + /// Windows 下用已加载的 `OfflineRecognizer` 做整段 PCM batch 转写;非 Windows + /// 保持空实现,避免把 sherpa provider 暴露为可用推理能力。 #[allow(dead_code)] pub async fn transcribe_pcm( &self, @@ -182,21 +236,68 @@ impl SherpaOnnxRuntime { if pcm.is_empty() { return Ok(String::new()); } + if sherpa::mode_for_alias(alias)? != SherpaMode::Offline { + anyhow::bail!("sherpa-onnx model {alias} is online-only; use streaming API"); + } + let audio_ms = pcm_duration_ms(pcm); let loaded_alias = self.ensure_loaded(alias).await?; let loaded = self .state .lock() - .loaded + .offline_loaded .clone() .filter(|loaded| loaded.alias == loaded_alias) - .context("sherpa-onnx model not loaded")?; - transcribe_loaded_model( + .context("sherpa-onnx offline model not loaded")?; + let started = Instant::now(); + let result = transcribe_loaded_model( loaded, pcm.to_vec(), language_hint.map(str::to_string), audio_timeout, ) - .await + .await; + let elapsed_ms = started.elapsed().as_millis().min(u128::from(u64::MAX)) as u64; + match &result { + Ok(text) => { + log::info!( + "[sherpa-asr] transcribe finished model={} audio_ms={} elapsed_ms={} text_chars={}", + alias, + audio_ms, + elapsed_ms, + text.chars().count() + ); + self.record_transcribe_result(audio_ms, elapsed_ms, None); + } + Err(error) => { + let message = format!("{error:#}"); + log::warn!( + "[sherpa-asr] transcribe failed model={} audio_ms={} elapsed_ms={} error={}", + alias, + audio_ms, + elapsed_ms, + message + ); + self.record_transcribe_result(audio_ms, elapsed_ms, Some(message)); + } + } + result + } + + /// 创建独立 online 解码 session。调用者负责按 Recorder PCM chunk 喂入, + /// 并在停止录音时调用 `finish()` 刷出 final text。 + pub async fn create_online_session(&self, alias: &str) -> Result { + if sherpa::mode_for_alias(alias)? != SherpaMode::Online { + anyhow::bail!("sherpa-onnx model {alias} is not an online streaming model"); + } + let loaded_alias = self.ensure_loaded(alias).await?; + let loaded = self + .state + .lock() + .online_loaded + .clone() + .filter(|loaded| loaded.alias == loaded_alias) + .context("sherpa-onnx online model not loaded")?; + create_online_session_from_loaded(loaded) } pub fn request_cancel_prepare(&self) { @@ -210,7 +311,9 @@ impl SherpaOnnxRuntime { pub async fn release_now(&self) -> Result<()> { let _lifecycle = self.lifecycle.lock().await; - self.state.lock().loaded = None; + let mut state = self.state.lock(); + state.offline_loaded = None; + state.online_loaded = None; Ok(()) } @@ -223,8 +326,21 @@ impl SherpaOnnxRuntime { validate_alias(alias)?; { let mut state = self.state.lock(); - if state.loaded.as_ref().map(|loaded| loaded.alias.as_str()) == Some(alias) { - state.loaded = None; + if state + .offline_loaded + .as_ref() + .map(|loaded| loaded.alias.as_str()) + == Some(alias) + { + state.offline_loaded = None; + } + if state + .online_loaded + .as_ref() + .map(|loaded| loaded.alias.as_str()) + == Some(alias) + { + state.online_loaded = None; } } let dir = sherpa::model_dir_for_alias(alias)?; @@ -235,13 +351,25 @@ impl SherpaOnnxRuntime { Ok(()) } - fn cached_loaded_model(&self, alias: &str) -> Option { - self.state - .lock() - .loaded + fn cached_loaded_alias(&self, alias: &str) -> Option { + let state = self.state.lock(); + if state + .offline_loaded + .as_ref() + .filter(|loaded| loaded.alias == alias) + .is_some() + { + return Some(alias.to_string()); + } + if state + .online_loaded .as_ref() .filter(|loaded| loaded.alias == alias) - .cloned() + .is_some() + { + return Some(alias.to_string()); + } + None } fn check_prepare_cancelled(&self) -> Result<()> { @@ -250,6 +378,59 @@ impl SherpaOnnxRuntime { } Ok(()) } + + fn record_prepare_result(&self, elapsed_ms: u64, error: Option) { + let mut state = self.state.lock(); + state.diagnostics.last_prepare_ms = Some(elapsed_ms); + state.diagnostics.last_error = error; + } + + fn record_transcribe_result(&self, audio_ms: u64, elapsed_ms: u64, error: Option) { + let mut state = self.state.lock(); + state.diagnostics.last_audio_ms = Some(audio_ms); + state.diagnostics.last_transcribe_ms = Some(elapsed_ms); + state.diagnostics.last_error = error; + } + + pub fn record_streaming_result( + &self, + alias: &str, + audio_ms: u64, + elapsed_ms: u64, + error: Option, + ) { + match &error { + Some(message) => log::warn!( + "[sherpa-asr] streaming finished model={} audio_ms={} elapsed_ms={} error={}", + alias, + audio_ms, + elapsed_ms, + message + ), + None => log::info!( + "[sherpa-asr] streaming finished model={} audio_ms={} elapsed_ms={}", + alias, + audio_ms, + elapsed_ms + ), + } + self.record_transcribe_result(audio_ms, elapsed_ms, error); + } +} + +fn active_loaded_model_id(state: &RuntimeState, active_model: &str) -> Option { + state + .offline_loaded + .as_ref() + .filter(|loaded| loaded.alias == active_model) + .map(|loaded| loaded.alias.clone()) + .or_else(|| { + state + .online_loaded + .as_ref() + .filter(|loaded| loaded.alias == active_model) + .map(|loaded| loaded.alias.clone()) + }) } fn validate_alias(alias: &str) -> Result<()> { @@ -302,16 +483,34 @@ fn accumulate_dir_size(dir: &Path, bytes: &mut u64) { } } +fn pcm_duration_ms(pcm: &[u8]) -> u64 { + (pcm.len() as u64 / 2) * 1000 / 16_000 +} + +enum LoadedModel { + Offline(LoadedOfflineModel), + Online(LoadedOnlineModel), +} + #[cfg(target_os = "windows")] async fn load_model(alias: &str, dir: &Path) -> Result { let alias = alias.to_string(); let dir = dir.to_path_buf(); - tokio::task::spawn_blocking(move || { - let recognizer = create_offline_recognizer(&alias, &dir)?; - Ok(LoadedModel { - alias, - recognizer: Arc::new(recognizer), - }) + tokio::task::spawn_blocking(move || match sherpa::mode_for_alias(&alias)? { + SherpaMode::Offline => { + let recognizer = create_offline_recognizer(&alias, &dir)?; + Ok(LoadedModel::Offline(LoadedOfflineModel { + alias, + recognizer: Arc::new(recognizer), + })) + } + SherpaMode::Online => { + let recognizer = create_online_recognizer(&alias, &dir)?; + Ok(LoadedModel::Online(LoadedOnlineModel { + alias, + recognizer: Arc::new(recognizer), + })) + } }) .await .map_err(|e| anyhow::anyhow!("sherpa-onnx load join failed: {e:#}"))? @@ -319,9 +518,14 @@ async fn load_model(alias: &str, dir: &Path) -> Result { #[cfg(not(target_os = "windows"))] async fn load_model(alias: &str, _dir: &Path) -> Result { - Ok(LoadedModel { - alias: alias.to_string(), - }) + match sherpa::mode_for_alias(alias)? { + SherpaMode::Offline => Ok(LoadedModel::Offline(LoadedOfflineModel { + alias: alias.to_string(), + })), + SherpaMode::Online => Ok(LoadedModel::Online(LoadedOnlineModel { + alias: alias.to_string(), + })), + } } #[cfg(target_os = "windows")] @@ -374,6 +578,36 @@ fn create_offline_recognizer(alias: &str, dir: &Path) -> Result Result { + let mut config = OnlineRecognizerConfig::default(); + config.model_config.num_threads = std::thread::available_parallelism() + .map(|n| n.get().clamp(1, 4) as i32) + .unwrap_or(2); + config.model_config.provider = Some("cpu".into()); + config.enable_endpoint = true; + config.rule1_min_trailing_silence = 2.4; + config.rule2_min_trailing_silence = 1.2; + config.rule3_min_utterance_length = 20.0; + config.decoding_method = Some("greedy_search".into()); + match model_family(alias)? { + SherpaFamily::Zipformer => { + config.model_config.tokens = Some(path_to_string(&dir.join("tokens.txt"))?); + config.model_config.transducer.encoder = Some(path_to_string( + &dir.join("encoder-epoch-99-avg-1.int8.onnx"), + )?); + config.model_config.transducer.decoder = + Some(path_to_string(&dir.join("decoder-epoch-99-avg-1.onnx"))?); + config.model_config.transducer.joiner = Some(path_to_string( + &dir.join("joiner-epoch-99-avg-1.int8.onnx"), + )?); + } + family => anyhow::bail!("sherpa-onnx family {family:?} is not supported by online ASR"), + } + OnlineRecognizer::create(&config) + .ok_or_else(|| anyhow::anyhow!("create sherpa-onnx online recognizer failed")) +} + fn model_family(alias: &str) -> Result { sherpa::MODELS .iter() @@ -392,7 +626,7 @@ fn path_to_string(path: &Path) -> Result { #[cfg(target_os = "windows")] async fn transcribe_loaded_model( - loaded: LoadedModel, + loaded: LoadedOfflineModel, pcm: Vec, language_hint: Option, audio_timeout: std::time::Duration, @@ -422,7 +656,7 @@ async fn transcribe_loaded_model( #[cfg(not(target_os = "windows"))] async fn transcribe_loaded_model( - _loaded: LoadedModel, + _loaded: LoadedOfflineModel, _pcm: Vec, _language_hint: Option, _audio_timeout: std::time::Duration, @@ -430,6 +664,170 @@ async fn transcribe_loaded_model( Ok(String::new()) } +pub struct SherpaOnlineSession { + alias: String, + #[cfg(target_os = "windows")] + recognizer: Arc, + #[cfg(target_os = "windows")] + stream: sherpa_onnx::OnlineStream, + committed_text: String, + last_partial_text: String, + last_emitted_text: String, +} + +impl SherpaOnlineSession { + pub fn alias(&self) -> &str { + &self.alias + } + + pub fn accept_pcm_chunk(&mut self, pcm: &[u8], on_token: &F) -> Result<()> + where + F: Fn(&str), + { + if pcm.is_empty() { + return Ok(()); + } + let samples = pcm_s16le_to_f32(pcm)?; + self.accept_samples(&samples, on_token) + } + + pub fn finish(&mut self, on_token: &F) -> Result + where + F: Fn(&str), + { + self.finish_inner(on_token) + } + + #[cfg(target_os = "windows")] + fn accept_samples(&mut self, samples: &[f32], on_token: &F) -> Result<()> + where + F: Fn(&str), + { + self.stream.accept_waveform(16_000, samples); + self.drain_ready(on_token); + Ok(()) + } + + #[cfg(not(target_os = "windows"))] + fn accept_samples(&mut self, _samples: &[f32], _on_token: &F) -> Result<()> + where + F: Fn(&str), + { + Ok(()) + } + + #[cfg(target_os = "windows")] + fn finish_inner(&mut self, on_token: &F) -> Result + where + F: Fn(&str), + { + self.stream.input_finished(); + self.drain_ready(on_token); + self.capture_result(on_token, false); + Ok(self.final_text()) + } + + #[cfg(not(target_os = "windows"))] + fn finish_inner(&mut self, _on_token: &F) -> Result + where + F: Fn(&str), + { + Ok(self.final_text()) + } + + #[cfg(target_os = "windows")] + fn drain_ready(&mut self, on_token: &F) + where + F: Fn(&str), + { + while self.recognizer.is_ready(&self.stream) { + self.recognizer.decode(&self.stream); + self.capture_result(on_token, true); + } + } + + #[cfg(target_os = "windows")] + fn capture_result(&mut self, on_token: &F, allow_endpoint_reset: bool) + where + F: Fn(&str), + { + let Some(result) = self.recognizer.get_result(&self.stream) else { + return; + }; + let text = result.text.trim().to_string(); + if text.is_empty() { + return; + } + self.emit_delta(&text, on_token); + self.last_partial_text = text.clone(); + if allow_endpoint_reset && self.recognizer.is_endpoint(&self.stream) { + append_segment(&mut self.committed_text, &text); + self.last_partial_text.clear(); + self.last_emitted_text.clear(); + self.recognizer.reset(&self.stream); + } + } + + fn emit_delta(&mut self, text: &str, on_token: &F) + where + F: Fn(&str), + { + if text == self.last_emitted_text { + return; + } + if let Some(delta) = text.strip_prefix(&self.last_emitted_text) { + if !delta.is_empty() { + on_token(delta); + } + } else { + on_token(text); + } + self.last_emitted_text = text.to_string(); + } + + fn final_text(&self) -> String { + let mut text = self.committed_text.clone(); + if !self.last_partial_text.is_empty() { + append_segment(&mut text, &self.last_partial_text); + } + text.trim().to_string() + } +} + +fn append_segment(text: &mut String, segment: &str) { + let segment = segment.trim(); + if segment.is_empty() { + return; + } + if !text.is_empty() && !text.ends_with(char::is_whitespace) { + text.push(' '); + } + text.push_str(segment); +} + +#[cfg(target_os = "windows")] +fn create_online_session_from_loaded(loaded: LoadedOnlineModel) -> Result { + let stream = loaded.recognizer.create_stream(); + Ok(SherpaOnlineSession { + alias: loaded.alias, + recognizer: loaded.recognizer, + stream, + committed_text: String::new(), + last_partial_text: String::new(), + last_emitted_text: String::new(), + }) +} + +#[cfg(not(target_os = "windows"))] +fn create_online_session_from_loaded(loaded: LoadedOnlineModel) -> Result { + Ok(SherpaOnlineSession { + alias: loaded.alias, + committed_text: String::new(), + last_partial_text: String::new(), + last_emitted_text: String::new(), + }) +} + fn pcm_s16le_to_f32(pcm: &[u8]) -> Result> { if pcm.len() % 2 != 0 { anyhow::bail!("PCM buffer length is not aligned to i16 samples"); @@ -445,7 +843,7 @@ mod tests { use super::*; #[tokio::test] - async fn new_runtime_reports_skeleton_shape() { + async fn new_runtime_reports_offline_batch_status_shape() { let runtime = SherpaOnnxRuntime::new(); let status = runtime.status_snapshot("sense-voice-small-zh").await; @@ -455,6 +853,30 @@ mod tests { assert_eq!(status.active_model, "sense-voice-small-zh"); assert_eq!(status.loaded_model_id, None); assert_eq!(status.error, None); + assert_eq!(status.last_prepare_ms, None); + assert_eq!(status.last_transcribe_ms, None); + assert_eq!(status.last_audio_ms, None); + assert_eq!(status.last_error, None); + } + + #[cfg(not(target_os = "windows"))] + #[tokio::test] + async fn status_snapshot_only_marks_active_model_ready() { + let runtime = SherpaOnnxRuntime::new(); + runtime.state.lock().offline_loaded = Some(LoadedOfflineModel { + alias: "sense-voice-small-zh".into(), + }); + + let active = runtime.status_snapshot("sense-voice-small-zh").await; + let inactive = runtime.status_snapshot("paraformer-zh").await; + + assert!(active.runtime_ready); + assert_eq!( + active.loaded_model_id.as_deref(), + Some("sense-voice-small-zh") + ); + assert!(!inactive.runtime_ready); + assert_eq!(inactive.loaded_model_id, None); } #[tokio::test] @@ -462,6 +884,37 @@ mod tests { let runtime = SherpaOnnxRuntime::new(); let result = runtime.ensure_loaded("unknown-sherpa-model").await; assert!(result.is_err()); + let status = runtime.status_snapshot("sense-voice-small-zh").await; + assert!(status.last_prepare_ms.is_some()); + assert!(status + .last_error + .as_deref() + .unwrap_or_default() + .contains("unknown-sherpa-model")); + assert_eq!(status.error, status.last_error); + } + + #[tokio::test] + async fn status_snapshot_reports_runtime_diagnostics() { + let runtime = SherpaOnnxRuntime::new(); + runtime.record_prepare_result(12, None); + runtime.record_transcribe_result(1_250, 34, Some("decode failed".into())); + + let status = runtime.status_snapshot("paraformer-zh").await; + + assert_eq!(status.last_prepare_ms, Some(12)); + assert_eq!(status.last_audio_ms, Some(1_250)); + assert_eq!(status.last_transcribe_ms, Some(34)); + assert_eq!(status.error.as_deref(), Some("decode failed")); + assert_eq!(status.last_error.as_deref(), Some("decode failed")); + } + + #[test] + fn request_cancel_prepare_marks_runtime_cancelled() { + let runtime = SherpaOnnxRuntime::new(); + assert!(!runtime.cancel_prepare_requested_for_tests()); + runtime.request_cancel_prepare(); + assert!(runtime.cancel_prepare_requested_for_tests()); } #[test] @@ -474,16 +927,47 @@ mod tests { let result = ensure_required_files("paraformer-zh", &dir); std::fs::remove_dir_all(&dir).ok(); assert!(result.is_err()); + let message = format!("{:#}", result.unwrap_err()); + assert!(message.contains("model.int8.onnx")); + assert!(message.contains(&dir.display().to_string())); + } + + #[test] + fn model_dir_size_mb_counts_nested_files() { + let dir = std::env::temp_dir().join(format!( + "openless-sherpa-runtime-size-test-{}", + uuid::Uuid::new_v4() + )); + let nested = dir.join("nested"); + std::fs::create_dir_all(&nested).unwrap(); + std::fs::write(dir.join("top.bin"), vec![1u8; 1024 * 1024]).unwrap(); + std::fs::write(nested.join("child.bin"), vec![2u8; 1024 * 1024]).unwrap(); + + let size = model_dir_size_mb(&dir); + + std::fs::remove_dir_all(&dir).ok(); + assert_eq!(size, Some(2)); } #[tokio::test] async fn release_now_clears_loaded_model() { let runtime = SherpaOnnxRuntime::new(); + #[cfg(not(target_os = "windows"))] + { + runtime.state.lock().offline_loaded = Some(LoadedOfflineModel { + alias: "sense-voice-small-zh".into(), + }); + runtime.state.lock().online_loaded = Some(LoadedOnlineModel { + alias: sherpa::DEFAULT_ONLINE_MODEL_ALIAS.into(), + }); + } runtime.release_now().await.unwrap(); let status = runtime.status_snapshot("paraformer-zh").await; assert!(!status.runtime_ready); assert_eq!(status.loaded_model_id, None); + assert!(runtime.state.lock().offline_loaded.is_none()); + assert!(runtime.state.lock().online_loaded.is_none()); } #[tokio::test] @@ -501,6 +985,34 @@ mod tests { assert!(text.is_empty()); } + #[tokio::test] + async fn transcribe_pcm_rejects_online_model_alias() { + let runtime = SherpaOnnxRuntime::new(); + let result = runtime + .transcribe_pcm( + sherpa::DEFAULT_ONLINE_MODEL_ALIAS, + &[0, 0], + None, + std::time::Duration::from_secs(5), + ) + .await; + + assert!(result.is_err()); + assert!(format!("{:#}", result.unwrap_err()).contains("online-only")); + } + + #[tokio::test] + async fn create_online_session_rejects_offline_model_alias() { + let runtime = SherpaOnnxRuntime::new(); + let result = runtime + .create_online_session(sherpa::DEFAULT_MODEL_ALIAS) + .await; + + assert!(result.is_err()); + let error = result.err().unwrap(); + assert!(format!("{error:#}").contains("not an online streaming model")); + } + #[test] fn pcm_s16le_to_f32_converts_samples() { let samples = pcm_s16le_to_f32(&[0, 0, 0xff, 0x7f, 0x00, 0x80]).unwrap(); @@ -514,4 +1026,31 @@ mod tests { fn pcm_s16le_to_f32_rejects_odd_length() { assert!(pcm_s16le_to_f32(&[0]).is_err()); } + + #[test] + fn append_segment_inserts_space_between_segments() { + let mut text = String::new(); + append_segment(&mut text, "你好"); + append_segment(&mut text, "world"); + assert_eq!(text, "你好 world"); + } + + #[cfg(target_os = "windows")] + #[test] + fn sherpa_onnx_crate_exposes_online_recognizer_api_for_streaming_phase() { + let mut config = sherpa_onnx::OnlineRecognizerConfig::default(); + config.model_config.transducer.encoder = Some("encoder.onnx".into()); + config.model_config.transducer.decoder = Some("decoder.onnx".into()); + config.model_config.transducer.joiner = Some("joiner.onnx".into()); + config.model_config.tokens = Some("tokens.txt".into()); + config.enable_endpoint = true; + config.decoding_method = Some("greedy_search".into()); + + assert_eq!( + config.model_config.transducer.encoder.as_deref(), + Some("encoder.onnx") + ); + assert!(config.enable_endpoint); + assert_eq!(config.decoding_method.as_deref(), Some("greedy_search")); + } } diff --git a/openless-all/app/src-tauri/src/commands.rs b/openless-all/app/src-tauri/src/commands.rs index 8d12c8ab..fc750d4c 100644 --- a/openless-all/app/src-tauri/src/commands.rs +++ b/openless-all/app/src-tauri/src/commands.rs @@ -2513,11 +2513,11 @@ fn emit_foundry_prepare_progress(app: &AppHandle, payload: FoundryPrepareProgres } } -// ───────────────────── Windows local ASR (sherpa-onnx-local, M1 骨架) ───────────────────── +// ───────────── Windows local ASR (sherpa-onnx-local, offline batch + online) ───────────── // -// 命令形态与 Foundry 同形,让前端命令封装可以复用同一种 hook 模式;M1 阶段 -// 不做下载 / 不接 sherpa-onnx crate / 不做实际推理,详见 -// `docs/windows-sherpa-onnx-asr-plan.md`。 +// 命令形态与 Foundry 同形,让前端命令封装可以复用同一种 hook 模式;当前支持 +// catalog / 下载 / prepare / release / 删除 / 状态查询,推理由 coordinator 的 +// 听写链路触发。offline 模型停止录音后 batch decode;online 模型录音时输出 partial。 fn active_sherpa_model_from_prefs(prefs: &UserPreferences) -> String { if sherpa_model_alias_is_known(&prefs.sherpa_onnx_model) { @@ -3319,13 +3319,13 @@ pub async fn github_device_flow_poll( mod tests { use super::{ active_asr_is_keyless_for_validation, active_foundry_model_from_prefs, - asr_configured_for_provider, asr_transcriptions_url, fetch_provider_models, - is_gemini_base_url, is_valid_local_pack_id, is_valid_session_id, + active_sherpa_model_from_prefs, asr_configured_for_provider, asr_transcriptions_url, + fetch_provider_models, is_gemini_base_url, is_valid_local_pack_id, is_valid_session_id, llm_configured_for_provider, local_asr_release_plan_for_provider, models_url, - normalize_foundry_language_hint, parse_gemini_model_ids, parse_latest_beta_from_atom, - parse_model_ids, persist_settings, release_foundry_runtime_if_inactive, - release_sherpa_runtime_if_inactive, - validate_foundry_model_alias, ProviderConfig, SettingsWriter, + normalize_foundry_language_hint, normalize_sherpa_language_hint, parse_gemini_model_ids, + parse_latest_beta_from_atom, parse_model_ids, persist_settings, + release_foundry_runtime_if_inactive, release_sherpa_runtime_if_inactive, + validate_foundry_model_alias, validate_sherpa_model_alias, ProviderConfig, SettingsWriter, }; use crate::persistence::CredentialsSnapshot; use crate::types::{ @@ -3553,6 +3553,49 @@ mod tests { } } + #[test] + fn sherpa_language_hint_accepts_empty_and_supported_lowercase_tags() { + assert_eq!(normalize_sherpa_language_hint("").unwrap(), ""); + assert_eq!(normalize_sherpa_language_hint(" ").unwrap(), ""); + assert_eq!(normalize_sherpa_language_hint("zh").unwrap(), "zh"); + assert_eq!(normalize_sherpa_language_hint(" en ").unwrap(), "en"); + assert_eq!(normalize_sherpa_language_hint("zh-cn").unwrap(), "zh-cn"); + assert_eq!(normalize_sherpa_language_hint("yue").unwrap(), "yue"); + } + + #[test] + fn sherpa_language_hint_normalizes_uppercase_and_rejects_digits() { + assert_eq!(normalize_sherpa_language_hint("ZH").unwrap(), "zh"); + assert!(normalize_sherpa_language_hint("zh-1").is_err()); + assert!(normalize_sherpa_language_hint("zh_CN").is_err()); + } + + #[test] + fn sherpa_model_alias_validation_matches_catalog() { + assert!( + validate_sherpa_model_alias(crate::asr::local::sherpa::DEFAULT_MODEL_ALIAS).is_ok() + ); + assert!(validate_sherpa_model_alias("qwen3-asr-0.6b-int8").is_ok()); + assert!( + validate_sherpa_model_alias(crate::asr::local::sherpa::DEFAULT_ONLINE_MODEL_ALIAS) + .is_ok() + ); + assert!(validate_sherpa_model_alias("zipformer-zh-streaming").is_err()); + } + + #[test] + fn sherpa_active_model_pref_falls_back_to_default_for_unknown_alias() { + let prefs = UserPreferences { + sherpa_onnx_model: "zipformer-zh-streaming".to_string(), + ..Default::default() + }; + + assert_eq!( + active_sherpa_model_from_prefs(&prefs), + crate::asr::local::sherpa::DEFAULT_MODEL_ALIAS + ); + } + #[test] fn credentials_status_accepts_keyless_custom_llm_only() { let keyless_ready = CredentialsSnapshot { diff --git a/openless-all/app/src-tauri/src/coordinator.rs b/openless-all/app/src-tauri/src/coordinator.rs index 25d58480..21fab30b 100644 --- a/openless-all/app/src-tauri/src/coordinator.rs +++ b/openless-all/app/src-tauri/src/coordinator.rs @@ -142,8 +142,7 @@ enum ActiveAsr { Bailian(Arc), #[cfg(target_os = "windows")] FoundryLocalWhisper(Arc), - /// Windows sherpa-onnx 本地 ASR(M1 骨架,详见 - /// `docs/windows-sherpa-onnx-asr-plan.md`)。 + /// Windows sherpa-onnx 本地 ASR(offline batch + 实验 online streaming)。 #[cfg(target_os = "windows")] SherpaOnnxLocal(Arc), /// 本地 Qwen3-ASR;只在 macOS + 模型已下载时可达。 @@ -186,7 +185,7 @@ struct Inner { local_asr_cache: Arc, #[cfg(target_os = "windows")] foundry_local_runtime: Arc, - /// Windows sherpa-onnx 本地 ASR runtime(M1 骨架)。与 Foundry 同处一个 + /// Windows sherpa-onnx 本地 ASR runtime。与 Foundry 同处一个 /// 位置、同一 lifecycle 语义;上层通过 `ActiveAsr::SherpaOnnxLocal` 后只调 /// runtime,不会跨模块调。 #[cfg(target_os = "windows")] @@ -312,8 +311,8 @@ impl Coordinator { } } - /// 保留旧构造函数:现有调用点(含单元测试)只传 Foundry runtime, - /// sherpa-onnx runtime 采用默认骨架实例。入产后(lib.rs)请走 + /// 保留旧构造函数:现有调用点(含单元测试)只传 Foundry runtime。 + /// sherpa-onnx runtime 这里创建默认 offline batch 实例;入产后(lib.rs)请走 /// `new_with_local_runtimes`,确保 Tauri State 共享同一个 Arc。 #[cfg(target_os = "windows")] pub fn new_with_foundry_runtime(foundry_local_runtime: Arc) -> Self { @@ -2392,8 +2391,7 @@ fn sherpa_onnx_release_keep_secs(inner: &Arc) -> u32 { } /// 与 `schedule_foundry_local_asr_release` 同形:session_id 老旧则不释放, -/// 避免下一轮 session 重加载同一个模型。M1 阶段 runtime 是骨架,`release_now` -/// 只清 alias state,不会报错。 +/// 避免下一轮 session 立即重加载同一个 offline batch 模型。 #[cfg(target_os = "windows")] fn schedule_sherpa_onnx_release(inner: &Arc, session_id: SessionId) { let keep_secs = sherpa_onnx_release_keep_secs(inner); @@ -4176,8 +4174,8 @@ fn local_qwen_transcribe_timeout(audio_secs: f64) -> std::time::Duration { std::time::Duration::from_secs(secs) } -/// sherpa-onnx M1 阶段超时与 Foundry 同档。M2 接入真实推理后视 CPU 模型 -/// 实际耗时再调(中文 SenseVoice small int8 在 4 核 CPU 上一般 < 3s/30s 音频)。 +/// sherpa-onnx offline batch 暂与 Foundry 同档;后续按 Windows 真机 CPU/模型 +/// 实测结果再调整。 #[cfg(target_os = "windows")] fn sherpa_audio_transcribe_timeout_duration() -> std::time::Duration { std::time::Duration::from_secs(COORDINATOR_GLOBAL_TIMEOUT_SECS) diff --git a/openless-all/app/src-tauri/src/coordinator/dictation.rs b/openless-all/app/src-tauri/src/coordinator/dictation.rs index c7b33838..ff66b871 100644 --- a/openless-all/app/src-tauri/src/coordinator/dictation.rs +++ b/openless-all/app/src-tauri/src/coordinator/dictation.rs @@ -622,9 +622,10 @@ pub(super) async fn begin_session(inner: &Arc) -> Result<(), String> { return Ok(()); } - // Windows sherpa-onnx-local(M1 骨架):与 Foundry 同形分支,复用 Recorder / - // ActiveAsr / start_recorder_and_enter_listening。runtime/transcribe 都是 - // 桩——M1 这里只验证主链路能跑到 sherpa 这条路径;真实推理见 M2。 + // Windows sherpa-onnx-local:与 Foundry 同形分支,复用 Recorder / + // ActiveAsr / start_recorder_and_enter_listening。offline 模型走 batch; + // online 模型在 provider 内部 worker 中边录边解码,并通过 local-asr-token + // 推 partial 给前端胶囊。 #[cfg(target_os = "windows")] if sherpa::is_sherpa_onnx_local(&active_asr) { let prefs = inner.prefs.get(); @@ -639,11 +640,38 @@ pub(super) async fn begin_session(inner: &Arc) -> Result<(), String> { } else { Some(language_hint) }; - let local = Arc::new(SherpaOnnxAsr::new( + let token_handler = inner.app.lock().clone().map(|app| { + Arc::new(move |piece: String| { + if let Err(error) = app.emit("local-asr-token", piece) { + log::warn!("[sherpa-asr] emit token failed: {error}"); + } + }) as crate::asr::local::sherpa_provider::SherpaTokenHandler + }); + let local = match SherpaOnnxAsr::new_for_model( Arc::clone(&inner.sherpa_onnx_runtime), model_alias, language_hint, - )); + token_handler, + ) + .await + { + Ok(local) => Arc::new(local), + Err(e) => { + log::error!("[coord] sherpa-onnx init failed: {e:#}"); + emit_capsule( + inner, + CapsuleState::Error, + 0.0, + 0, + Some(format!("本地模型初始化失败: {e}")), + None, + ); + restore_prepared_windows_ime_session(inner, current_session_id); + inner.state.lock().phase = SessionPhase::Idle; + schedule_capsule_idle(inner, CAPSULE_AUTO_HIDE_DELAY_MS); + return Err(format!("sherpa-onnx init failed: {e}")); + } + }; store_asr_for_session( inner, current_session_id, @@ -1282,9 +1310,8 @@ pub(super) async fn end_session(inner: &Arc) -> Result<(), String> { } } } - // Windows sherpa-onnx(M1 骨架):transcribe 当前返回空 RawTranscript, - // 上层 empty-transcript guard 会写 emptyTranscript 历史并显示错误胶囊。 - // M2 接入推理后这里的行为就跟 Foundry 完全一致。 + // Windows sherpa-onnx offline batch:停止录音后整段转写,再复用现有 + // polish / insert / history 收尾路径。 #[cfg(target_os = "windows")] ActiveAsr::SherpaOnnxLocal(local) => { debug_assert!(!uses_global_timeout); diff --git a/openless-all/app/src-tauri/src/types.rs b/openless-all/app/src-tauri/src/types.rs index fe7137e6..1d9d925f 100644 --- a/openless-all/app/src-tauri/src/types.rs +++ b/openless-all/app/src-tauri/src/types.rs @@ -613,8 +613,7 @@ pub struct UserPreferences { /// Windows Foundry Local Whisper 模型在 runtime 中保持加载多久。 #[serde(default = "default_local_asr_keep_loaded_secs")] pub foundry_local_asr_keep_loaded_secs: u32, - /// Windows sherpa-onnx 本地 ASR(M1 实验 provider,详见 - /// `docs/windows-sherpa-onnx-asr-plan.md`)当前激活的模型 alias。 + /// Windows sherpa-onnx 本地 ASR 当前激活的模型 alias。 #[serde(default = "default_sherpa_onnx_model")] pub sherpa_onnx_model: String, /// Windows sherpa-onnx 语言 hint(BCP-47 / ISO 639-1 小写)。空 = 自动。