Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src-tauri/src/app_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@ pub struct AppInner {
pub pending_audio_stop: Mutex<Option<tokio::sync::oneshot::Sender<()>>>,
pub pending_audio_warmup: Mutex<Option<tokio::sync::oneshot::Sender<()>>>,
pub latest_transcript: Mutex<(String, String)>, // (final_text, partial_text)
/// Audio sample chunks captured before the ASR session is ready (background
/// connect in progress, or during a reconnect gap). Drained into the session
/// once it attaches. Always accessed while holding `asr_session` to stay
/// ordered against the drain.
pub pending_audio: Mutex<Vec<Vec<f32>>>,
/// Resolves when the background ASR connect finishes (Ok) or fails (Err).
/// `stop_recording` awaits this when the user stops before the session is ready.
pub connect_rx: Mutex<Option<tokio::sync::oneshot::Receiver<Result<(), String>>>>,
/// Recording-session generation. Bumped on each start and on cancel so a
/// stale background connect task (from a cancelled/superseded session) can
/// detect it is obsolete and discard its result.
pub session_epoch: std::sync::atomic::AtomicU64,
/// Finalized text carried across ASR reconnects within a single recording.
/// Each reconnect starts a fresh server-side session with no memory of prior
/// audio, so already-recognized text is accumulated here and prepended to the
/// new session's output. Reset at the start of every recording.
pub accumulated_text: Mutex<String>,
}

pub type AppHandle = Arc<AppInner>;
Expand All @@ -49,5 +66,9 @@ pub fn create_app_state(
pending_audio_stop: Mutex::new(None),
pending_audio_warmup: Mutex::new(None),
latest_transcript: Mutex::new((String::new(), String::new())),
pending_audio: Mutex::new(Vec::new()),
connect_rx: Mutex::new(None),
session_epoch: std::sync::atomic::AtomicU64::new(0),
accumulated_text: Mutex::new(String::new()),
})
}
150 changes: 113 additions & 37 deletions src-tauri/src/asr/doubao.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,30 @@ fn normalize_error_message(error: &str) -> String {
error.to_string()
}

/// Classify a Doubao ASR error code as fatal (unrecoverable by reconnect) or
/// transient. Parameter-invalid errors won't recover by reconnecting; server-side
/// timeouts / busy errors and network drops are transient.
fn is_fatal_asr_code(code: u64) -> bool {
matches!(code, 45000001 | 45000002)
}

// ---------------------------------------------------------------------------
// WebSocket sink type alias
// ---------------------------------------------------------------------------

type WsSink = futures_util::stream::SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;

/// A frame to write to the WebSocket, serialized through a single writer task.
enum WsWrite {
/// A non-last audio frame.
Audio(Vec<u8>),
/// The last-packet (commit) frame. The writer drops any audio enqueued after
/// it, so the server never sees a packet past the final one.
Last(Vec<u8>),
/// Close the connection.
Close,
}

// ---------------------------------------------------------------------------
// DoubaoEngine — AsrEngine implementation
// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -491,10 +509,16 @@ impl AsrEngine for DoubaoEngine {
);
headers.insert("X-Api-Connect-Id", connect_id.parse().unwrap());

// Connect
let (ws_stream, _) = connect_async(request)
.await
.map_err(|e| format!("ASR WebSocket connection failed: {}", e))?;
// Connect with a bounded timeout. Without it a stalled handshake relies on
// the OS-level TCP timeout (tens of seconds); the caller retries instead.
let (ws_stream, _) =
match tokio::time::timeout(std::time::Duration::from_secs(5), connect_async(request))
.await
{
Ok(Ok(pair)) => pair,
Ok(Err(e)) => return Err(format!("ASR WebSocket connection failed: {}", e)),
Err(_) => return Err("ASR WebSocket 连接超时".to_string()),
};

let (sink, mut stream) = ws_stream.split();

Expand All @@ -518,24 +542,54 @@ impl AsrEngine for DoubaoEngine {
let final_text: Arc<Mutex<String>> = Arc::new(Mutex::new(String::new()));
let partial_text: Arc<Mutex<String>> = Arc::new(Mutex::new(String::new()));
let latest_result_text: Arc<Mutex<String>> = Arc::new(Mutex::new(String::new()));
let sink = Arc::new(Mutex::new(Some(sink)));
let commit_tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<String>>>> =
Arc::new(Mutex::new(None));

// Dedicated writer task: a single FIFO consumer of the sink. Keeps frames
// ordered and drops any audio enqueued after the last packet, so the server
// never sees a packet past the final one (which it rejects).
let (writer_tx, mut writer_rx) = mpsc::unbounded_channel::<WsWrite>();
tokio::spawn(async move {
let mut sink: WsSink = sink;
let mut last_sent = false;
while let Some(msg) = writer_rx.recv().await {
match msg {
WsWrite::Audio(bytes) => {
if last_sent {
continue;
}
if sink.send(Message::Binary(bytes.into())).await.is_err() {
break;
}
}
WsWrite::Last(bytes) => {
if last_sent {
continue;
}
last_sent = true;
let _ = sink.send(Message::Binary(bytes.into())).await;
}
WsWrite::Close => {
let _ = sink.send(Message::Close(None)).await;
break;
}
}
}
});

let (event_tx, event_rx) = mpsc::unbounded_channel();

let session = DoubaoSession {
is_ready: is_ready.clone(),
is_committed: is_committed.clone(),
final_text: final_text.clone(),
latest_result_text: latest_result_text.clone(),
sender: sink.clone(),
writer_tx,
commit_tx: commit_tx.clone(),
};

// Spawn message handler task
let event_tx_clone = event_tx.clone();
let sink_for_handler = sink.clone();
tokio::spawn(async move {
while let Some(msg) = stream.next().await {
match msg {
Expand Down Expand Up @@ -604,8 +658,24 @@ impl AsrEngine for DoubaoEngine {
.and_then(|v| v.as_str())
.map(|m| format!("ASR error {}: {}", code, m))
.unwrap_or_else(|| format!("ASR error code {}", code));
let _ =
event_tx_clone.send(AsrEvent::Error(message.to_string()));
let _ = event_tx_clone.send(AsrEvent::Error {
message: message.to_string(),
fatal: is_fatal_asr_code(code),
});
// If a commit is waiting, resolve it now with the
// best text we have instead of blocking until the
// 5s timeout (the socket is about to be reset).
if is_committed.load(Ordering::SeqCst) {
if let Some(tx) = commit_tx.lock().await.take() {
let latest = latest_result_text.lock().await.clone();
let ft = final_text.lock().await.clone();
let _ = tx.send(if latest.is_empty() {
ft
} else {
latest
});
}
}
continue;
}
}
Expand Down Expand Up @@ -742,14 +812,19 @@ impl AsrEngine for DoubaoEngine {
.or_else(|| payload.get("error").and_then(|e| e.get("message")))
.and_then(|v| v.as_str())
.unwrap_or("ASR 服务异常");
let _ = event_tx_clone.send(AsrEvent::Error(message.to_string()));
// Unknown text-protocol error: attempt reconnect before giving up.
let _ = event_tx_clone.send(AsrEvent::Error {
message: message.to_string(),
fatal: false,
});
}
}
}
Ok(Message::Close(frame)) => {
is_ready.store(false, Ordering::SeqCst);
// Prevent lingering audio sends after connection closes
*sink_for_handler.lock().await = None;
// Lingering audio sends are already prevented: is_ready=false
// gates append_audio, and the FIFO writer task drops frames
// after the last packet / exits when its send fails.
let code: Option<u16> = frame.as_ref().map(|f| f.code.into());
let reason = frame
.as_ref()
Expand All @@ -771,9 +846,21 @@ impl AsrEngine for DoubaoEngine {
}
Err(e) => {
is_ready.store(false, Ordering::SeqCst);
*sink_for_handler.lock().await = None;
let msg = normalize_error_message(&e.to_string());
let _ = event_tx_clone.send(AsrEvent::Error(msg));
// Connection reset without a clean Close: resolve any pending
// commit so the caller doesn't wait the full 5s timeout.
if is_committed.load(Ordering::SeqCst) {
if let Some(tx) = commit_tx.lock().await.take() {
let latest = latest_result_text.lock().await.clone();
let ft = final_text.lock().await.clone();
let _ = tx.send(if latest.is_empty() { ft } else { latest });
}
}
// Transport-level failure (network drop): recoverable by reconnect.
let _ = event_tx_clone.send(AsrEvent::Error {
message: msg,
fatal: false,
});
break;
}
_ => {}
Expand All @@ -796,7 +883,9 @@ struct DoubaoSession {
is_committed: Arc<AtomicBool>,
final_text: Arc<Mutex<String>>,
latest_result_text: Arc<Mutex<String>>,
sender: Arc<Mutex<Option<WsSink>>>,
/// Sends frames to the dedicated writer task. A single FIFO consumer keeps
/// frames ordered and guarantees the last packet is written after all audio.
writer_tx: mpsc::UnboundedSender<WsWrite>,
commit_tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<String>>>>,
}

Expand All @@ -817,14 +906,9 @@ impl AsrSession for DoubaoSession {
.flat_map(|s| s.to_le_bytes())
.collect();
let frame = encode_audio_only_request(&audio, false);
let sender = self.sender.clone();
tokio::spawn(async move {
if let Some(ref mut sink) = *sender.lock().await {
if let Err(e) = sink.send(Message::Binary(frame.into())).await {
log_asr!(debug, "Audio send skipped (connection closing): {}", e);
}
}
});
// Hand the frame to the writer task; FIFO order is preserved and the
// writer drops anything enqueued after the last packet.
let _ = self.writer_tx.send(WsWrite::Audio(frame));
}

async fn commit_and_await_final(&self) -> Result<String, String> {
Expand All @@ -834,15 +918,13 @@ impl AsrSession for DoubaoSession {
if self.is_committed.load(Ordering::SeqCst) {
return Err("录音已结束".to_string());
}
// Mark committed (stops further appends) and enqueue the last packet.
// Because all prior audio was enqueued before this call (the renderer
// flushes and acks before stop proceeds) and the writer is FIFO, the
// last packet is guaranteed to be written after every audio frame.
self.is_committed.store(true, Ordering::SeqCst);

// Send last-audio frame
let frame = encode_audio_only_request(&[], true);
{
if let Some(ref mut sink) = *self.sender.lock().await {
let _ = sink.send(Message::Binary(frame.into())).await;
}
}
let _ = self.writer_tx.send(WsWrite::Last(frame));

// Wait for final result with timeout
let (tx, rx) = tokio::sync::oneshot::channel();
Expand All @@ -861,13 +943,7 @@ impl AsrSession for DoubaoSession {

fn close(&self) {
self.is_ready.store(false, Ordering::SeqCst);
let sender = self.sender.clone();
tokio::spawn(async move {
let mut guard = sender.lock().await;
if let Some(ref mut sink) = guard.take() {
let _ = sink.send(Message::Close(None)).await;
}
});
let _ = self.writer_tx.send(WsWrite::Close);
}
}

Expand Down
8 changes: 7 additions & 1 deletion src-tauri/src/asr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ pub enum AsrEvent {
final_text: String,
partial_text: String,
},
Error(String),
Error {
message: String,
/// Whether the error is unrecoverable (reconnecting cannot help). Fatal
/// errors finalize the recording with whatever text was already
/// recognized; non-fatal errors trigger an auto-reconnect attempt.
fatal: bool,
},
Close {
code: Option<u16>,
reason: String,
Expand Down
72 changes: 46 additions & 26 deletions src-tauri/src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ use tauri::{utils::Theme, AppHandle, Emitter, Manager, State};
// Re-export paste::PasteResult for use in commands
use paste::PasteResult;

/// Cap on audio chunks buffered before the ASR session is ready (~100ms per
/// chunk, so ~30s). Bounds memory if the connect keeps failing.
const MAX_PENDING_CHUNKS: usize = 300;

/// Detect the actual OS-level light/dark theme preference.
fn detect_system_theme() -> &'static str {
#[cfg(target_os = "macos")]
Expand Down Expand Up @@ -300,39 +304,55 @@ pub async fn send_audio_chunk(
);
}

// Decode base64 → i16 PCM bytes → f32 samples
let bytes = match base64::engine::general_purpose::STANDARD.decode(&base64_chunk) {
Ok(data) => data,
Err(_) => {
log_audio!(warn, "Chunk #{} base64 decode failed", n);
return Ok(serde_json::json!({ "ok": false, "message": "音频数据解码失败" }));
}
};
let samples: Vec<f32> = bytes
.chunks_exact(2)
.map(|chunk| {
let sample = i16::from_le_bytes([chunk[0], chunk[1]]);
sample as f32 / 32768.0
})
.collect();

// Drive the native waveform (macOS only) from the same PCM the ASR receives,
// whether the chunk is sent immediately or buffered.
#[cfg(target_os = "macos")]
if let Some(level) = compute_audio_level(&samples) {
crate::overlay::set_audio_level(&_app, level);
}

// Hold the `asr_session` lock across the decision so buffering stays ordered
// against the background connect task's drain (same lock), guaranteeing no
// buffered chunk is silently dropped between drain and session-attach.
let session = state.asr_session.lock().await;
if let Some(ref session) = *session {
if session.is_ready() {
// Decode base64 → i16 PCM bytes → f32 samples
let bytes = match base64::engine::general_purpose::STANDARD.decode(&base64_chunk) {
Ok(data) => data,
Err(_) => {
log_audio!(warn, "Chunk #{} base64 decode failed", n);
return Ok(serde_json::json!({ "ok": false, "message": "音频数据解码失败" }));
}
};
let samples: Vec<f32> = bytes
.chunks_exact(2)
.map(|chunk| {
let sample = i16::from_le_bytes([chunk[0], chunk[1]]);
sample as f32 / 32768.0
})
.collect();
// Drive the native waveform (macOS only) from the same PCM the ASR receives.
#[cfg(target_os = "macos")]
if let Some(level) = compute_audio_level(&samples) {
crate::overlay::set_audio_level(&_app, level);
}
session.append_audio(&samples);
return Ok(serde_json::json!({ "ok": true }));
}
log_audio!(warn, "Chunk #{} dropped: session not ready", n);
} else {
if n == 0 {
log_audio!(warn, "Chunk #{} dropped: no session", n);
}
}
Ok(serde_json::json!({ "ok": false, "message": "ASR 会话未建立" }))

// Session not ready yet (background connect in progress, or reconnect gap):
// buffer the samples so nothing the user says before the session attaches is
// lost. Drained into the session once it attaches.
let mut pending = state.pending_audio.lock().await;
if pending.len() < MAX_PENDING_CHUNKS {
pending.push(samples);
} else if n.is_multiple_of(50) {
log_audio!(
warn,
"Pending audio buffer full ({} chunks), dropping chunk #{}",
MAX_PENDING_CHUNKS,
n
);
}
Ok(serde_json::json!({ "ok": true, "buffered": true }))
}

/// Notify that audio has stopped in the renderer.
Expand Down
Loading