From 45d9bbaa2b69c1e0c324c3e82425c185fec34a93 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 07:53:42 +0700 Subject: [PATCH 01/44] Add provider-neutral transcript speaker metadata --- .../backend/schema/transcript_segment.dart | 30 +++++ app/test/widgets/transcript_test.dart | 31 ++++++ backend/models/transcript_segment.py | 51 ++++++++- backend/tests/unit/test_transcript_segment.py | 104 ++++++++++++++++++ backend/utils/stt/pre_recorded.py | 50 ++++++++- .../Backend-Rust/src/models/conversation.rs | 16 +++ .../Backend-Rust/src/services/firestore.rs | 56 ++++++++++ desktop/Desktop/Sources/APIClient.swift | 44 +++++++- .../Sources/TranscriptionService.swift | 8 ++ .../Desktop/Tests/ListenProtocolTests.swift | 17 +++ .../TranscriptSpeakerAssignmentTests.swift | 34 ++++++ 11 files changed, 432 insertions(+), 9 deletions(-) diff --git a/app/lib/backend/schema/transcript_segment.dart b/app/lib/backend/schema/transcript_segment.dart index 56091c48a68..dce020ba3e3 100644 --- a/app/lib/backend/schema/transcript_segment.dart +++ b/app/lib/backend/schema/transcript_segment.dart @@ -33,6 +33,13 @@ class TranscriptSegment { List translations = []; bool speechProfileProcessed; String? sttProvider; + String? sttModel; + String? providerClusterId; + String? providerSpeakerLabel; + String speakerIdentityState; + double? speakerIdentityConfidence; + String? speakerIdentitySource; + String? speakerIdentityVersion; TranscriptSegment({ required this.id, @@ -45,6 +52,13 @@ class TranscriptSegment { required this.translations, this.speechProfileProcessed = true, this.sttProvider, + this.sttModel, + this.providerClusterId, + this.providerSpeakerLabel, + this.speakerIdentityState = 'legacy_ambiguous', + this.speakerIdentityConfidence, + this.speakerIdentitySource, + this.speakerIdentityVersion, }) { final parts = speaker?.split('_') ?? []; speakerId = parts.length > 1 ? (int.tryParse(parts[1]) ?? 0) : 0; @@ -74,6 +88,15 @@ class TranscriptSegment { translations: json['translations'] != null ? Translation.fromJsonList(json['translations'] as List) : [], speechProfileProcessed: (json['speech_profile_processed'] ?? true) as bool, sttProvider: json['stt_provider'] as String?, + sttModel: json['stt_model'] as String?, + providerClusterId: json['provider_cluster_id'] as String?, + providerSpeakerLabel: json['provider_speaker_label'] as String?, + speakerIdentityState: (json['speaker_identity_state'] ?? 'legacy_ambiguous') as String, + speakerIdentityConfidence: json['speaker_identity_confidence'] != null + ? double.tryParse(json['speaker_identity_confidence'].toString()) + : null, + speakerIdentitySource: json['speaker_identity_source'] as String?, + speakerIdentityVersion: json['speaker_identity_version'] as String?, ); } @@ -88,6 +111,13 @@ class TranscriptSegment { 'end': end, 'translations': translations.map((t) => t.toJson()).toList(), if (sttProvider != null) 'stt_provider': sttProvider, + if (sttModel != null) 'stt_model': sttModel, + if (providerClusterId != null) 'provider_cluster_id': providerClusterId, + if (providerSpeakerLabel != null) 'provider_speaker_label': providerSpeakerLabel, + 'speaker_identity_state': speakerIdentityState, + if (speakerIdentityConfidence != null) 'speaker_identity_confidence': speakerIdentityConfidence, + if (speakerIdentitySource != null) 'speaker_identity_source': speakerIdentitySource, + if (speakerIdentityVersion != null) 'speaker_identity_version': speakerIdentityVersion, }; } diff --git a/app/test/widgets/transcript_test.dart b/app/test/widgets/transcript_test.dart index 9ed35e9fc09..11e37519802 100644 --- a/app/test/widgets/transcript_test.dart +++ b/app/test/widgets/transcript_test.dart @@ -42,6 +42,37 @@ void main() { ); } + group('TranscriptSegment compatibility metadata', () { + test('decodes provider speaker metadata without changing legacy fields', () { + final segment = TranscriptSegment.fromJson({ + 'id': 'seg-provider', + 'text': 'Hello', + 'speaker': null, + 'speaker_id': 0, + 'is_user': false, + 'person_id': null, + 'start': 0.0, + 'end': 1.0, + 'stt_provider': 'provider-a', + 'stt_model': 'async-large', + 'provider_cluster_id': 'speaker-alpha', + 'provider_speaker_label': null, + 'speaker_identity_state': 'unknown', + 'speaker_identity_confidence': null, + 'speaker_identity_source': null, + 'speaker_identity_version': 'v1', + }); + + expect(segment.speaker, isNull); + expect(segment.speakerId, 0); + expect(segment.sttProvider, 'provider-a'); + expect(segment.sttModel, 'async-large'); + expect(segment.providerClusterId, 'speaker-alpha'); + expect(segment.speakerIdentityState, 'unknown'); + expect(segment.speakerIdentityVersion, 'v1'); + }); + }); + group('Speaker label display', () { testWidgets('shows person name when personId is set and in cache', (tester) async { final now = DateTime.now(); diff --git a/backend/models/transcript_segment.py b/backend/models/transcript_segment.py index 44647ce7aca..9f8af3b0414 100644 --- a/backend/models/transcript_segment.py +++ b/backend/models/transcript_segment.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Optional, List, Tuple +from typing import Literal, Optional, List, Tuple import uuid import re from pydantic import BaseModel, Field @@ -23,6 +23,38 @@ class Translation(BaseModel): text: str +SpeakerIdentityState = Literal['unknown', 'legacy_ambiguous', 'unassigned', 'identified', 'user'] + + +class ProviderTranscriptWord(BaseModel): + text: str + start: float + end: float + provider_cluster_id: Optional[str] = None + speaker_label: Optional[str] = None + confidence: Optional[float] = None + + +class ProviderTranscriptUtterance(BaseModel): + text: str + start: float + end: float + provider_cluster_id: Optional[str] = None + speaker_label: Optional[str] = None + confidence: Optional[float] = None + words: Optional[List[ProviderTranscriptWord]] = None + + +class ProviderTranscriptResult(BaseModel): + provider: str + model: Optional[str] = None + language: Optional[str] = None + duration: Optional[float] = None + words: List[ProviderTranscriptWord] = [] + utterances: List[ProviderTranscriptUtterance] = [] + raw_provider_result_id: Optional[str] = None + + class TranscriptSegment(BaseModel): id: Optional[str] = None text: str @@ -35,6 +67,13 @@ class TranscriptSegment(BaseModel): translations: Optional[List[Translation]] = [] speech_profile_processed: bool = True stt_provider: Optional[str] = None + stt_model: Optional[str] = None + provider_cluster_id: Optional[str] = None + provider_speaker_label: Optional[str] = None + speaker_identity_state: SpeakerIdentityState = 'legacy_ambiguous' + speaker_identity_confidence: Optional[float] = None + speaker_identity_source: Optional[str] = None + speaker_identity_version: Optional[str] = None def __init__(self, **data): super().__init__(**data) @@ -45,10 +84,16 @@ def __init__(self, **data): try: self.speaker_id = int(self.speaker.split('_', 1)[1]) except (ValueError, IndexError): - self.speaker_id = 0 - else: + if self.speaker_id is None: + self.speaker_id = 0 + elif self.speaker_id is None: self.speaker_id = 0 + if self.person_id and self.speaker_identity_state in ('legacy_ambiguous', 'unassigned', 'unknown'): + self.speaker_identity_state = 'user' if self.is_user else 'identified' + elif self.is_user and self.speaker_identity_state in ('legacy_ambiguous', 'unassigned', 'unknown'): + self.speaker_identity_state = 'user' + def get_timestamp_string(self): start_duration = timedelta(seconds=int(self.start)) end_duration = timedelta(seconds=int(self.end)) diff --git a/backend/tests/unit/test_transcript_segment.py b/backend/tests/unit/test_transcript_segment.py index 1503791fe83..6b6b9b1150d 100644 --- a/backend/tests/unit/test_transcript_segment.py +++ b/backend/tests/unit/test_transcript_segment.py @@ -1,6 +1,9 @@ import pytest +import sys +import types from models.transcript_segment import TranscriptSegment +from models.transcript_segment import ProviderTranscriptResult, ProviderTranscriptWord def _segment(text, speaker="SPEAKER_00", is_user=False, start=0.0, end=1.0): @@ -19,6 +22,107 @@ def _normalize_punctuation(text): return text.strip().replace(" ", " ").replace(" ,", ",").replace(" .", ".").replace(" ?", "?") +def test_provider_transcript_result_preserves_opaque_cluster_ids(): + result = ProviderTranscriptResult( + provider="test-provider", + model="async-large", + words=[ + ProviderTranscriptWord( + text="hello", + start=0.0, + end=0.5, + provider_cluster_id="speaker-alpha", + speaker_label="SPEAKER_00", + ) + ], + ) + + payload = result.model_dump() + + assert payload["provider"] == "test-provider" + assert payload["model"] == "async-large" + assert payload["words"][0]["provider_cluster_id"] == "speaker-alpha" + + +def test_transcript_segment_serializes_canonical_identity_metadata(): + segment = TranscriptSegment( + text="Hello", + speaker="SPEAKER_07", + is_user=False, + person_id="person-1", + start=0.0, + end=1.0, + stt_provider="provider-a", + stt_model="model-b", + provider_cluster_id="cluster-x", + provider_speaker_label="SPEAKER_07", + speaker_identity_confidence=0.91, + speaker_identity_source="omi_speaker_embedding", + speaker_identity_version="v1", + ) + + payload = segment.model_dump() + + assert payload["speaker_id"] == 7 + assert payload["speaker"] == "SPEAKER_07" + assert payload["provider_cluster_id"] == "cluster-x" + assert payload["speaker_identity_state"] == "identified" + assert payload["speaker_identity_confidence"] == 0.91 + assert payload["speaker_identity_source"] == "omi_speaker_embedding" + + +def test_unknown_identity_is_explicit_and_legacy_zero_remains_ambiguous(): + unknown = TranscriptSegment( + text="Hello", + speaker=None, + is_user=False, + person_id=None, + start=0.0, + end=1.0, + provider_cluster_id=None, + speaker_identity_state="unknown", + ) + legacy = TranscriptSegment(text="Hi", speaker="SPEAKER_00", is_user=False, start=1.0, end=2.0) + + assert unknown.speaker_id == 0 + assert unknown.speaker_identity_state == "unknown" + assert legacy.speaker_id == 0 + assert legacy.speaker_identity_state == "legacy_ambiguous" + + +def test_postprocess_words_does_not_promote_malformed_speaker_to_speaker_zero(): + deepgram = types.ModuleType("deepgram") + deepgram.DeepgramClient = lambda *args, **kwargs: object() + deepgram.DeepgramClientOptions = lambda *args, **kwargs: object() + byok = types.ModuleType("utils.byok") + byok.get_byok_key = lambda *args, **kwargs: None + endpoints = types.ModuleType("utils.other.endpoints") + endpoints.timeit = lambda fn: fn + sys.modules.setdefault("deepgram", deepgram) + sys.modules.setdefault("fal_client", types.ModuleType("fal_client")) + sys.modules.setdefault("utils.byok", byok) + sys.modules.setdefault("utils.other.endpoints", endpoints) + + from utils.stt import pre_recorded + + segments = pre_recorded.postprocess_words( + [ + { + "timestamp": [0.0, 0.5], + "speaker": "provider-speaker-alpha", + "text": "hello", + } + ], + duration=1, + ) + + assert len(segments) == 1 + assert segments[0].speaker is None + assert segments[0].speaker_id == 0 + assert segments[0].provider_cluster_id == "provider-speaker-alpha" + assert segments[0].speaker_identity_state == "unknown" + + def test_forward_merge_on_short_incomplete_last_sentence(): a = _segment("Hello there. and then", speaker="SPEAKER_00", start=0.0, end=4.0) b = _segment("we continue speaking.", speaker="SPEAKER_01", start=4.0, end=7.0) diff --git a/backend/utils/stt/pre_recorded.py b/backend/utils/stt/pre_recorded.py index 59eaa9dd718..767b5d36f1f 100644 --- a/backend/utils/stt/pre_recorded.py +++ b/backend/utils/stt/pre_recorded.py @@ -30,6 +30,23 @@ def _deepgram_client_for_request() -> DeepgramClient: return _deepgram_client +def _deepgram_speaker_fields(speaker_id) -> dict: + if speaker_id is None: + return {'speaker': None, 'provider_cluster_id': None, 'provider_speaker_label': None} + + provider_cluster_id = str(speaker_id) + try: + speaker_label = f"SPEAKER_{int(speaker_id):02d}" + except (TypeError, ValueError): + speaker_label = None + + return { + 'speaker': speaker_label, + 'provider_cluster_id': provider_cluster_id, + 'provider_speaker_label': speaker_label, + } + + # Languages supported by nova-3 _deepgram_nova3_languages = { "ar", @@ -224,10 +241,15 @@ def deepgram_prerecorded( words = [] for w in dg_words: speaker_id = w.get('speaker', 0) + speaker_fields = _deepgram_speaker_fields(speaker_id) words.append( { 'timestamp': [w['start'], w['end']], - 'speaker': f"SPEAKER_{speaker_id:02d}" if speaker_id is not None else None, + 'speaker': speaker_fields['speaker'], + 'provider_cluster_id': speaker_fields['provider_cluster_id'], + 'provider_speaker_label': speaker_fields['provider_speaker_label'], + 'stt_provider': 'deepgram', + 'stt_model': model, 'text': w.get('punctuated_word', w['word']), } ) @@ -358,10 +380,15 @@ def deepgram_prerecorded_from_bytes( words = [] for w in dg_words: speaker_id = w.get('speaker', 0) + speaker_fields = _deepgram_speaker_fields(speaker_id) words.append( { 'timestamp': [w['start'], w['end']], - 'speaker': f"SPEAKER_{speaker_id:02d}" if speaker_id is not None else None, + 'speaker': speaker_fields['speaker'], + 'provider_cluster_id': speaker_fields['provider_cluster_id'], + 'provider_speaker_label': speaker_fields['provider_speaker_label'], + 'stt_provider': 'deepgram', + 'stt_model': model, 'text': w.get('punctuated_word', w['word']), } ) @@ -440,14 +467,21 @@ def _words_cleaning(words: List[dict]): for i, w in enumerate(words): # if w['timestamp'][0] == w['timestamp'][1]: # continue + raw_speaker = w.get('speaker') + speaker = raw_speaker if isinstance(raw_speaker, str) and raw_speaker.startswith('SPEAKER_') else None words_cleaned.append( { 'start': round(w['timestamp'][0], 2), 'end': round(w['timestamp'][1] or w['timestamp'][0] + 1, 2), - 'speaker': w['speaker'], + 'speaker': speaker, + 'provider_cluster_id': w.get('provider_cluster_id') or raw_speaker, + 'provider_speaker_label': w.get('provider_speaker_label') or speaker, + 'stt_provider': w.get('stt_provider'), + 'stt_model': w.get('stt_model'), 'text': str(w['text']).strip(), 'is_user': False, 'person_id': None, + 'speaker_identity_state': 'unassigned' if speaker else 'unknown', } ) @@ -470,10 +504,11 @@ def _words_cleaning(words: List[dict]): speaker = prev_speaker elif next_speaker: speaker = next_speaker - else: - speaker = 'SPEAKER_00' words_cleaned[i]['speaker'] = speaker + words_cleaned[i]['provider_speaker_label'] = speaker + if speaker: + words_cleaned[i]['speaker_identity_state'] = 'unassigned' # for chunk in words_cleaned: # print(chunk) @@ -527,6 +562,11 @@ def _segments_as_objects(segments: List[dict]) -> List[TranscriptSegment]: person_id=None, start=round(segment['start'] - starts_at, 2), end=round(segment['end'] - starts_at, 2), + stt_provider=segment.get('stt_provider'), + stt_model=segment.get('stt_model'), + provider_cluster_id=segment.get('provider_cluster_id'), + provider_speaker_label=segment.get('provider_speaker_label'), + speaker_identity_state='user' if segment['is_user'] else segment.get('speaker_identity_state', 'unknown'), ) for segment in segments ] diff --git a/desktop/Backend-Rust/src/models/conversation.rs b/desktop/Backend-Rust/src/models/conversation.rs index 511da1ae8d2..4e26c048f97 100644 --- a/desktop/Backend-Rust/src/models/conversation.rs +++ b/desktop/Backend-Rust/src/models/conversation.rs @@ -24,6 +24,22 @@ pub struct TranscriptSegment { pub start: f64, #[serde(default)] pub end: f64, + #[serde(default)] + pub stt_provider: Option, + #[serde(default)] + pub stt_model: Option, + #[serde(default)] + pub provider_cluster_id: Option, + #[serde(default)] + pub provider_speaker_label: Option, + #[serde(default)] + pub speaker_identity_state: Option, + #[serde(default)] + pub speaker_identity_confidence: Option, + #[serde(default)] + pub speaker_identity_source: Option, + #[serde(default)] + pub speaker_identity_version: Option, } fn default_speaker() -> String { diff --git a/desktop/Backend-Rust/src/services/firestore.rs b/desktop/Backend-Rust/src/services/firestore.rs index 40d29caed40..426e6273296 100644 --- a/desktop/Backend-Rust/src/services/firestore.rs +++ b/desktop/Backend-Rust/src/services/firestore.rs @@ -4271,6 +4271,14 @@ impl FirestoreService { end: seg.get("end") .and_then(|s| s.as_f64()) .unwrap_or(0.0), + stt_provider: seg.get("stt_provider").and_then(|s| s.as_str()).map(|s| s.to_string()), + stt_model: seg.get("stt_model").and_then(|s| s.as_str()).map(|s| s.to_string()), + provider_cluster_id: seg.get("provider_cluster_id").and_then(|s| s.as_str()).map(|s| s.to_string()), + provider_speaker_label: seg.get("provider_speaker_label").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_state: seg.get("speaker_identity_state").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_confidence: seg.get("speaker_identity_confidence").and_then(|s| s.as_f64()), + speaker_identity_source: seg.get("speaker_identity_source").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_version: seg.get("speaker_identity_version").and_then(|s| s.as_str()).map(|s| s.to_string()), }) }) .collect(); @@ -4319,6 +4327,14 @@ impl FirestoreService { end: seg.get("end") .and_then(|s| s.as_f64()) .unwrap_or(0.0), + stt_provider: seg.get("stt_provider").and_then(|s| s.as_str()).map(|s| s.to_string()), + stt_model: seg.get("stt_model").and_then(|s| s.as_str()).map(|s| s.to_string()), + provider_cluster_id: seg.get("provider_cluster_id").and_then(|s| s.as_str()).map(|s| s.to_string()), + provider_speaker_label: seg.get("provider_speaker_label").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_state: seg.get("speaker_identity_state").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_confidence: seg.get("speaker_identity_confidence").and_then(|s| s.as_f64()), + speaker_identity_source: seg.get("speaker_identity_source").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_version: seg.get("speaker_identity_version").and_then(|s| s.as_str()).map(|s| s.to_string()), }) }) .collect(); @@ -4380,6 +4396,14 @@ impl FirestoreService { person_id: self.parse_string(seg_fields, "person_id"), start: self.parse_float(seg_fields, "start").unwrap_or(0.0), end: self.parse_float(seg_fields, "end").unwrap_or(0.0), + stt_provider: self.parse_string(seg_fields, "stt_provider"), + stt_model: self.parse_string(seg_fields, "stt_model"), + provider_cluster_id: self.parse_string(seg_fields, "provider_cluster_id"), + provider_speaker_label: self.parse_string(seg_fields, "provider_speaker_label"), + speaker_identity_state: self.parse_string(seg_fields, "speaker_identity_state"), + speaker_identity_confidence: self.parse_float(seg_fields, "speaker_identity_confidence"), + speaker_identity_source: self.parse_string(seg_fields, "speaker_identity_source"), + speaker_identity_version: self.parse_string(seg_fields, "speaker_identity_version"), }) }) .collect()) @@ -4444,6 +4468,14 @@ impl FirestoreService { .get("end") .and_then(|s| s.as_f64()) .unwrap_or(0.0), + stt_provider: seg.get("stt_provider").and_then(|s| s.as_str()).map(|s| s.to_string()), + stt_model: seg.get("stt_model").and_then(|s| s.as_str()).map(|s| s.to_string()), + provider_cluster_id: seg.get("provider_cluster_id").and_then(|s| s.as_str()).map(|s| s.to_string()), + provider_speaker_label: seg.get("provider_speaker_label").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_state: seg.get("speaker_identity_state").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_confidence: seg.get("speaker_identity_confidence").and_then(|s| s.as_f64()), + speaker_identity_source: seg.get("speaker_identity_source").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_version: seg.get("speaker_identity_version").and_then(|s| s.as_str()).map(|s| s.to_string()), }) }) .collect()) @@ -4577,6 +4609,30 @@ impl FirestoreService { if let Some(person_id) = &seg.person_id { segment["person_id"] = json!(person_id); } + if let Some(stt_provider) = &seg.stt_provider { + segment["stt_provider"] = json!(stt_provider); + } + if let Some(stt_model) = &seg.stt_model { + segment["stt_model"] = json!(stt_model); + } + if let Some(provider_cluster_id) = &seg.provider_cluster_id { + segment["provider_cluster_id"] = json!(provider_cluster_id); + } + if let Some(provider_speaker_label) = &seg.provider_speaker_label { + segment["provider_speaker_label"] = json!(provider_speaker_label); + } + if let Some(speaker_identity_state) = &seg.speaker_identity_state { + segment["speaker_identity_state"] = json!(speaker_identity_state); + } + if let Some(speaker_identity_confidence) = seg.speaker_identity_confidence { + segment["speaker_identity_confidence"] = json!(speaker_identity_confidence); + } + if let Some(speaker_identity_source) = &seg.speaker_identity_source { + segment["speaker_identity_source"] = json!(speaker_identity_source); + } + if let Some(speaker_identity_version) = &seg.speaker_identity_version { + segment["speaker_identity_version"] = json!(speaker_identity_version); + } segment }).collect(); let json_str = serde_json::to_string(&segments_json).unwrap_or_else(|_| "[]".to_string()); diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 5f24f348c28..21e06f4d01b 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -826,6 +826,14 @@ struct TranscriptSegment: Codable, Identifiable { let start: Double let end: Double let translations: [TranscriptTranslation] + let sttProvider: String? + let sttModel: String? + let providerClusterId: String? + let providerSpeakerLabel: String? + let speakerIdentityState: String? + let speakerIdentityConfidence: Double? + let speakerIdentitySource: String? + let speakerIdentityVersion: String? var speakerId: Int { guard let speaker = speaker else { return 0 } @@ -840,6 +848,14 @@ struct TranscriptSegment: Codable, Identifiable { case id, text, speaker case isUser = "is_user" case personId = "person_id" + case sttProvider = "stt_provider" + case sttModel = "stt_model" + case providerClusterId = "provider_cluster_id" + case providerSpeakerLabel = "provider_speaker_label" + case speakerIdentityState = "speaker_identity_state" + case speakerIdentityConfidence = "speaker_identity_confidence" + case speakerIdentitySource = "speaker_identity_source" + case speakerIdentityVersion = "speaker_identity_version" case start, end, translations } @@ -856,6 +872,16 @@ struct TranscriptSegment: Codable, Identifiable { end = try container.decodeIfPresent(Double.self, forKey: .end) ?? 0 translations = try container.decodeIfPresent([TranscriptTranslation].self, forKey: .translations) ?? [] + sttProvider = try container.decodeIfPresent(String.self, forKey: .sttProvider) + sttModel = try container.decodeIfPresent(String.self, forKey: .sttModel) + providerClusterId = try container.decodeIfPresent(String.self, forKey: .providerClusterId) + providerSpeakerLabel = try container.decodeIfPresent(String.self, forKey: .providerSpeakerLabel) + speakerIdentityState = + try container.decodeIfPresent(String.self, forKey: .speakerIdentityState) + speakerIdentityConfidence = + try container.decodeIfPresent(Double.self, forKey: .speakerIdentityConfidence) + speakerIdentitySource = try container.decodeIfPresent(String.self, forKey: .speakerIdentitySource) + speakerIdentityVersion = try container.decodeIfPresent(String.self, forKey: .speakerIdentityVersion) } /// Memberwise initializer for creating from local storage @@ -868,7 +894,15 @@ struct TranscriptSegment: Codable, Identifiable { personId: String?, start: Double, end: Double, - translations: [TranscriptTranslation] = [] + translations: [TranscriptTranslation] = [], + sttProvider: String? = nil, + sttModel: String? = nil, + providerClusterId: String? = nil, + providerSpeakerLabel: String? = nil, + speakerIdentityState: String? = nil, + speakerIdentityConfidence: Double? = nil, + speakerIdentitySource: String? = nil, + speakerIdentityVersion: String? = nil ) { self.id = id self.backendId = backendId @@ -879,6 +913,14 @@ struct TranscriptSegment: Codable, Identifiable { self.start = start self.end = end self.translations = translations + self.sttProvider = sttProvider + self.sttModel = sttModel + self.providerClusterId = providerClusterId + self.providerSpeakerLabel = providerSpeakerLabel + self.speakerIdentityState = speakerIdentityState + self.speakerIdentityConfidence = speakerIdentityConfidence + self.speakerIdentitySource = speakerIdentitySource + self.speakerIdentityVersion = speakerIdentityVersion } /// Formatted timestamp string (e.g., "00:01:30 - 00:01:45") diff --git a/desktop/Desktop/Sources/TranscriptionService.swift b/desktop/Desktop/Sources/TranscriptionService.swift index 33e30c91220..0e88c1308f0 100644 --- a/desktop/Desktop/Sources/TranscriptionService.swift +++ b/desktop/Desktop/Sources/TranscriptionService.swift @@ -37,6 +37,14 @@ class TranscriptionService { let start: Double let end: Double let translations: [BackendTranslation]? + let stt_provider: String? + let stt_model: String? + let provider_cluster_id: String? + let provider_speaker_label: String? + let speaker_identity_state: String? + let speaker_identity_confidence: Double? + let speaker_identity_source: String? + let speaker_identity_version: String? } /// Message event (from `/v4/listen` only — not used by PTT transcribe-stream) diff --git a/desktop/Desktop/Tests/ListenProtocolTests.swift b/desktop/Desktop/Tests/ListenProtocolTests.swift index 6ef3cddaa2a..406f5c2f197 100644 --- a/desktop/Desktop/Tests/ListenProtocolTests.swift +++ b/desktop/Desktop/Tests/ListenProtocolTests.swift @@ -42,6 +42,23 @@ final class ListenProtocolTests: XCTestCase { XCTAssertFalse(seg.is_user) } + func testDecodeSegmentWithProviderIdentityMetadata() throws { + let json = """ + [{"id":"seg-1","text":"hello","speaker":null,"speaker_id":null,"is_user":false,"person_id":null,"start":0.0,"end":1.0,"provider_cluster_id":"speaker-alpha","speaker_identity_state":"unknown","stt_provider":"provider-a","stt_model":"async-large"}] + """ + let data = json.data(using: .utf8)! + let segments = try JSONDecoder().decode([TranscriptionService.BackendSegment].self, from: data) + + XCTAssertEqual(segments.count, 1) + let seg = segments[0] + XCTAssertNil(seg.speaker) + XCTAssertNil(seg.speaker_id) + XCTAssertEqual(seg.provider_cluster_id, "speaker-alpha") + XCTAssertEqual(seg.speaker_identity_state, "unknown") + XCTAssertEqual(seg.stt_provider, "provider-a") + XCTAssertEqual(seg.stt_model, "async-large") + } + func testDecodeMultipleSegments() throws { let json = """ [ diff --git a/desktop/Desktop/Tests/TranscriptSpeakerAssignmentTests.swift b/desktop/Desktop/Tests/TranscriptSpeakerAssignmentTests.swift index 3b733688c71..618686ee5fe 100644 --- a/desktop/Desktop/Tests/TranscriptSpeakerAssignmentTests.swift +++ b/desktop/Desktop/Tests/TranscriptSpeakerAssignmentTests.swift @@ -205,6 +205,40 @@ final class TranscriptSpeakerAssignmentTests: XCTestCase { XCTAssertTrue(segment.translations.isEmpty, "Translations should default to empty array when not present in JSON") } + func testTranscriptSegmentDecodesProviderIdentityMetadata() throws { + let json = """ + { + "id": "seg_provider", + "text": "Hello", + "speaker": null, + "is_user": false, + "person_id": null, + "start": 0.0, + "end": 1.0, + "stt_provider": "provider-a", + "stt_model": "async-large", + "provider_cluster_id": "speaker-alpha", + "provider_speaker_label": null, + "speaker_identity_state": "unknown", + "speaker_identity_confidence": 0.42, + "speaker_identity_source": "omi_speaker_embedding", + "speaker_identity_version": "v1" + } + """.data(using: .utf8)! + + let segment = try JSONDecoder().decode(TranscriptSegment.self, from: json) + + XCTAssertNil(segment.speaker) + XCTAssertEqual(segment.speakerId, 0) + XCTAssertEqual(segment.sttProvider, "provider-a") + XCTAssertEqual(segment.sttModel, "async-large") + XCTAssertEqual(segment.providerClusterId, "speaker-alpha") + XCTAssertEqual(segment.speakerIdentityState, "unknown") + XCTAssertEqual(segment.speakerIdentityConfidence, 0.42) + XCTAssertEqual(segment.speakerIdentitySource, "omi_speaker_embedding") + XCTAssertEqual(segment.speakerIdentityVersion, "v1") + } + func testSpeakerSegmentTranslationsPreserved() { let translations = [ SegmentTranslation(lang: "en", text: "Hello"), From 487d7ba1c3b9b952991e9a45cabfdf02c7d99be5 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 08:01:28 +0700 Subject: [PATCH 02/44] Introduce Deepgram STT provider facade --- .../unit/test_streaming_deepgram_backoff.py | 18 ++ .../tests/unit/test_stt_provider_facade.py | 160 +++++++++++++ backend/utils/stt/deepgram_adapter.py | 222 ++++++++++++++++++ backend/utils/stt/pre_recorded.py | 200 ++++------------ backend/utils/stt/providers.py | 91 +++++++ backend/utils/stt/streaming.py | 27 +++ 6 files changed, 558 insertions(+), 160 deletions(-) create mode 100644 backend/tests/unit/test_stt_provider_facade.py create mode 100644 backend/utils/stt/deepgram_adapter.py create mode 100644 backend/utils/stt/providers.py diff --git a/backend/tests/unit/test_streaming_deepgram_backoff.py b/backend/tests/unit/test_streaming_deepgram_backoff.py index d5457612fff..93399787e57 100644 --- a/backend/tests/unit/test_streaming_deepgram_backoff.py +++ b/backend/tests/unit/test_streaming_deepgram_backoff.py @@ -15,14 +15,18 @@ # Mock heavy dependencies before importing streaming module _mock_modules = {} for mod_name in [ + 'cachetools', 'database', 'database._client', 'database.users', + 'utils.http_client', 'utils.other.storage', 'deepgram', 'deepgram.clients', 'deepgram.clients.live', 'deepgram.clients.live.v1', + 'onnxruntime', + 'pydub', 'websockets', 'websockets.exceptions', ]: @@ -39,6 +43,20 @@ sys.modules['deepgram'].LiveTranscriptionEvents = MagicMock() sys.modules['deepgram.clients.live.v1'].LiveOptions = MagicMock +if 'cachetools' in _mock_modules: + + class FakeTTLCache(dict): + def __init__(self, maxsize=None, ttl=None): + super().__init__() + + sys.modules['cachetools'].TTLCache = FakeTTLCache + +if 'pydub' in _mock_modules: + sys.modules['pydub'].AudioSegment = MagicMock + +if 'utils.http_client' in _mock_modules: + sys.modules['utils.http_client'].get_stt_client = MagicMock() + from utils.stt.streaming import connect_to_deepgram_with_backoff, process_audio_dg # noqa: E402 from utils.stt.streaming import deepgram_options, deepgram_cloud_options # noqa: E402 from utils.stt.streaming import get_stt_service_for_language, STTService, should_preserve_filler_words # noqa: E402 diff --git a/backend/tests/unit/test_stt_provider_facade.py b/backend/tests/unit/test_stt_provider_facade.py new file mode 100644 index 00000000000..87eb60abd98 --- /dev/null +++ b/backend/tests/unit/test_stt_provider_facade.py @@ -0,0 +1,160 @@ +import sys +from unittest.mock import MagicMock + + +for mod_name in ['deepgram', 'deepgram.clients', 'deepgram.clients.live', 'deepgram.clients.live.v1']: + if mod_name not in sys.modules: + sys.modules[mod_name] = MagicMock() + +sys.modules['deepgram'].DeepgramClient = MagicMock +sys.modules['deepgram'].DeepgramClientOptions = MagicMock +sys.modules['deepgram'].LiveTranscriptionEvents = MagicMock() +sys.modules['deepgram.clients.live.v1'].LiveOptions = MagicMock + +from utils.stt.deepgram_adapter import ( # noqa: E402 + DeepgramPrerecordedTranscriptionProvider, + normalize_deepgram_prerecorded_result, + provider_result_to_legacy_words, +) +from utils.stt.providers import ( # noqa: E402 + STTProviderName, + STTWorkload, + get_prerecorded_provider_name, + get_streaming_provider_name, +) + + +def _deepgram_fixture(words=None, utterances=None): + return { + 'metadata': {'request_id': 'dg-request-1', 'duration': 3.25}, + 'results': { + 'channels': [ + { + 'detected_language': 'en-US', + 'alternatives': [ + { + 'words': ( + words + if words is not None + else [ + { + 'word': 'hello', + 'punctuated_word': 'Hello', + 'start': 0.0, + 'end': 0.4, + 'confidence': 0.91, + 'speaker': 2, + }, + { + 'word': 'world', + 'punctuated_word': 'world.', + 'start': 0.5, + 'end': 1.1, + 'confidence': 0.88, + 'speaker': 2, + }, + ] + ), + } + ], + } + ], + 'utterances': ( + utterances + if utterances is not None + else [ + { + 'transcript': 'Hello world.', + 'start': 0.0, + 'end': 1.1, + 'confidence': 0.9, + 'speaker': 2, + } + ] + ), + }, + } + + +def test_deepgram_fixture_normalizes_to_provider_transcript_result(): + result = normalize_deepgram_prerecorded_result(_deepgram_fixture(), model='nova-3') + + assert result.provider == STTProviderName.deepgram.value + assert result.model == 'nova-3' + assert result.language == 'en' + assert result.duration == 3.25 + assert result.raw_provider_result_id == 'dg-request-1' + assert len(result.words) == 2 + assert result.words[0].text == 'Hello' + assert result.words[0].provider_cluster_id == '2' + assert result.words[0].speaker_label == 'SPEAKER_02' + assert result.utterances[0].provider_cluster_id == '2' + + +def test_deepgram_result_converts_to_legacy_words_for_existing_callers(): + result = normalize_deepgram_prerecorded_result(_deepgram_fixture(), model='nova-3') + + words = provider_result_to_legacy_words(result) + + assert words == [ + { + 'timestamp': [0.0, 0.4], + 'speaker': 'SPEAKER_02', + 'provider_cluster_id': '2', + 'provider_speaker_label': 'SPEAKER_02', + 'stt_provider': 'deepgram', + 'stt_model': 'nova-3', + 'text': 'Hello', + }, + { + 'timestamp': [0.5, 1.1], + 'speaker': 'SPEAKER_02', + 'provider_cluster_id': '2', + 'provider_speaker_label': 'SPEAKER_02', + 'stt_provider': 'deepgram', + 'stt_model': 'nova-3', + 'text': 'world.', + }, + ] + + +def test_deepgram_adapter_preserves_prerecorded_request_options(): + fake_response = MagicMock() + fake_response.to_dict.return_value = _deepgram_fixture() + fake_rest = MagicMock() + fake_rest.transcribe_url.return_value = fake_response + fake_client = MagicMock() + fake_client.listen.rest.v.return_value = fake_rest + provider = DeepgramPrerecordedTranscriptionProvider(lambda: fake_client, timeout=MagicMock()) + + result, detected_language = provider.transcribe_url( + 'https://example.test/audio.wav', + return_language=True, + diarize=False, + language='multi', + model='nova-3', + keywords=['Omi', 'custom'], + ) + + call_args = fake_rest.transcribe_url.call_args + assert call_args.args[0] == {'url': 'https://example.test/audio.wav'} + assert call_args.args[1]['diarize'] is False + assert call_args.args[1]['detect_language'] is True + assert call_args.args[1]['keyterm'] == ['Omi', 'custom'] + assert 'language' not in call_args.args[1] + assert result.provider == 'deepgram' + assert detected_language == 'en' + + +def test_provider_routing_keeps_all_current_workloads_on_deepgram(): + assert get_streaming_provider_name(STTWorkload.ptt) == STTProviderName.deepgram + assert get_streaming_provider_name(STTWorkload.realtime) == STTProviderName.deepgram + + for workload in [ + STTWorkload.background, + STTWorkload.postprocess, + STTWorkload.ptt, + STTWorkload.sync, + STTWorkload.voice_message, + ]: + assert get_prerecorded_provider_name(workload) == STTProviderName.deepgram diff --git a/backend/utils/stt/deepgram_adapter.py b/backend/utils/stt/deepgram_adapter.py new file mode 100644 index 00000000000..bdf6237be08 --- /dev/null +++ b/backend/utils/stt/deepgram_adapter.py @@ -0,0 +1,222 @@ +from dataclasses import dataclass +from io import BytesIO +from typing import Callable, List, Optional, Sequence, Tuple, Union + +import httpx +from deepgram import DeepgramClient + +from models.transcript_segment import ProviderTranscriptResult, ProviderTranscriptUtterance, ProviderTranscriptWord +from utils.stt.providers import STTProviderName + + +@dataclass(frozen=True) +class DeepgramPrerecordedOptions: + model: str = 'nova-3' + diarize: bool = True + language: Optional[str] = None + return_language: bool = False + keywords: Optional[Sequence[str]] = None + sample_rate: int = 16000 + encoding: Optional[str] = None + channels: int = 1 + nova3_keyword_prefix_match: bool = False + + +def deepgram_speaker_fields(speaker_id) -> dict: + if speaker_id is None: + return {'speaker': None, 'provider_cluster_id': None, 'provider_speaker_label': None} + + provider_cluster_id = str(speaker_id) + try: + speaker_label = f'SPEAKER_{int(speaker_id):02d}' + except (TypeError, ValueError): + speaker_label = None + + return { + 'speaker': speaker_label, + 'provider_cluster_id': provider_cluster_id, + 'provider_speaker_label': speaker_label, + } + + +def normalize_deepgram_prerecorded_result( + result: dict, model: str, language: Optional[str] = None +) -> ProviderTranscriptResult: + channels = result.get('results', {}).get('channels', []) + if not channels: + raise Exception('No channels found in response') + + alternatives = channels[0].get('alternatives', []) + if not alternatives: + raise Exception('No alternatives found in response') + + alternative = alternatives[0] + detected_language = _normalize_language(channels[0].get('detected_language') or language) + words = [_normalize_deepgram_word(word) for word in alternative.get('words', [])] + utterances = [ + _normalize_deepgram_utterance(utterance) for utterance in result.get('results', {}).get('utterances', []) + ] + duration = result.get('metadata', {}).get('duration') + + return ProviderTranscriptResult( + provider=STTProviderName.deepgram.value, + model=model, + language=detected_language, + duration=duration, + words=words, + utterances=utterances, + raw_provider_result_id=result.get('metadata', {}).get('request_id'), + ) + + +def provider_result_to_legacy_words(result: ProviderTranscriptResult) -> List[dict]: + words = [] + for word in result.words: + words.append( + { + 'timestamp': [word.start, word.end], + 'speaker': word.speaker_label, + 'provider_cluster_id': word.provider_cluster_id, + 'provider_speaker_label': word.speaker_label, + 'stt_provider': result.provider, + 'stt_model': result.model, + 'text': word.text, + } + ) + return words + + +def _normalize_language(language: Optional[str]) -> Optional[str]: + if language and '-' in language: + return language.split('-', 1)[0] + return language + + +def _normalize_deepgram_word(word: dict) -> ProviderTranscriptWord: + speaker_fields = deepgram_speaker_fields(word.get('speaker')) + return ProviderTranscriptWord( + text=word.get('punctuated_word', word.get('word', '')), + start=word['start'], + end=word['end'], + provider_cluster_id=speaker_fields['provider_cluster_id'], + speaker_label=speaker_fields['speaker'], + confidence=word.get('confidence'), + ) + + +def _normalize_deepgram_utterance(utterance: dict) -> ProviderTranscriptUtterance: + speaker_fields = deepgram_speaker_fields(utterance.get('speaker')) + utterance_words = utterance.get('words') + return ProviderTranscriptUtterance( + text=utterance.get('transcript', ''), + start=utterance['start'], + end=utterance['end'], + provider_cluster_id=speaker_fields['provider_cluster_id'], + speaker_label=speaker_fields['speaker'], + confidence=utterance.get('confidence'), + words=[_normalize_deepgram_word(word) for word in utterance_words] if utterance_words else None, + ) + + +class DeepgramPrerecordedTranscriptionProvider: + provider_name = STTProviderName.deepgram + + def __init__( + self, + client_factory: Callable[[], DeepgramClient], + timeout: httpx.Timeout, + ): + self._client_factory = client_factory + self._timeout = timeout + + def transcribe_url( + self, + audio_url: str, + speakers_count: int = None, + return_language: bool = False, + diarize: bool = True, + language: Optional[str] = None, + model: str = 'nova-3', + keywords: Optional[Sequence[str]] = None, + ) -> Union[ProviderTranscriptResult, Tuple[ProviderTranscriptResult, str]]: + options = DeepgramPrerecordedOptions( + model=model, + diarize=diarize, + language=language, + return_language=return_language, + keywords=keywords, + ) + request_options = self._request_options(options) + response = ( + self._client_factory() + .listen.rest.v('1') + .transcribe_url({'url': audio_url}, request_options, timeout=self._timeout) + ) + result = normalize_deepgram_prerecorded_result(response.to_dict(), model=model, language=language) + if return_language: + return result, result.language or 'en' + return result + + def transcribe_bytes( + self, + audio_bytes: bytes, + sample_rate: int = 16000, + diarize: bool = True, + encoding: Optional[str] = None, + channels: int = 1, + language: Optional[str] = None, + model: str = 'nova-3', + return_language: bool = False, + keywords: Optional[Sequence[str]] = None, + ) -> Union[ProviderTranscriptResult, Tuple[ProviderTranscriptResult, str]]: + options = DeepgramPrerecordedOptions( + model=model, + diarize=diarize, + language=language, + return_language=return_language, + keywords=keywords, + sample_rate=sample_rate, + encoding=encoding, + channels=channels, + nova3_keyword_prefix_match=True, + ) + request_options = self._request_options(options) + audio_buffer = BytesIO(audio_bytes) + mimetype = 'audio/raw' if encoding else 'audio/wav' + source = {'buffer': audio_buffer, 'mimetype': mimetype} + response = ( + self._client_factory().listen.rest.v('1').transcribe_file(source, request_options, timeout=self._timeout) + ) + result = normalize_deepgram_prerecorded_result(response.to_dict(), model=model, language=language) + if return_language: + return result, result.language or 'en' + return result + + def _request_options(self, options: DeepgramPrerecordedOptions) -> dict: + is_multi = options.language == 'multi' + request_options = { + 'model': options.model, + 'smart_format': True, + 'punctuate': True, + 'diarize': options.diarize, + 'utterances': True, + 'detect_language': options.return_language or is_multi, + } + if options.language and not is_multi: + request_options['language'] = options.language + + if options.keywords: + is_nova3_keyword_model = options.model in ('nova-3',) or ( + options.nova3_keyword_prefix_match and str(options.model).startswith('nova-3') + ) + if is_nova3_keyword_model: + request_options['keyterm'] = list(options.keywords) + else: + request_options['keywords'] = list(options.keywords) + + if options.encoding: + request_options['encoding'] = options.encoding + request_options['sample_rate'] = options.sample_rate + request_options['channels'] = options.channels + + return request_options diff --git a/backend/utils/stt/pre_recorded.py b/backend/utils/stt/pre_recorded.py index 767b5d36f1f..ce75fc9e6e9 100644 --- a/backend/utils/stt/pre_recorded.py +++ b/backend/utils/stt/pre_recorded.py @@ -1,6 +1,5 @@ import os from collections import defaultdict -from io import BytesIO from typing import List, Optional, Sequence, Tuple, Union import fal_client @@ -10,6 +9,11 @@ from models.transcript_segment import TranscriptSegment from utils.byok import get_byok_key from utils.other.endpoints import timeit +from utils.stt.deepgram_adapter import ( + DeepgramPrerecordedTranscriptionProvider, + deepgram_speaker_fields, + provider_result_to_legacy_words, +) import logging _DG_TIMEOUT = httpx.Timeout(connect=10.0, read=120.0, write=30.0, pool=10.0) @@ -31,20 +35,11 @@ def _deepgram_client_for_request() -> DeepgramClient: def _deepgram_speaker_fields(speaker_id) -> dict: - if speaker_id is None: - return {'speaker': None, 'provider_cluster_id': None, 'provider_speaker_label': None} + return deepgram_speaker_fields(speaker_id) - provider_cluster_id = str(speaker_id) - try: - speaker_label = f"SPEAKER_{int(speaker_id):02d}" - except (TypeError, ValueError): - speaker_label = None - return { - 'speaker': speaker_label, - 'provider_cluster_id': provider_cluster_id, - 'provider_speaker_label': speaker_label, - } +def _deepgram_prerecorded_provider() -> DeepgramPrerecordedTranscriptionProvider: + return DeepgramPrerecordedTranscriptionProvider(_deepgram_client_for_request, _DG_TIMEOUT) # Languages supported by nova-3 @@ -190,79 +185,20 @@ def deepgram_prerecorded( logger.info(f'deepgram_prerecorded {audio_url} {speakers_count} {attempts}') try: - # 'multi' language means auto-detection - is_multi = language == 'multi' - should_detect_language = return_language or is_multi - options = { - "model": model, - "smart_format": True, - "punctuate": True, - "diarize": diarize, - "detect_language": should_detect_language, - "utterances": True, - } - if language and not is_multi: - options["language"] = language - - if keywords: - if model in ('nova-3',): - options["keyterm"] = list(keywords) - else: - options["keywords"] = list(keywords) - - response = ( - _deepgram_client_for_request() - .listen.rest.v("1") - .transcribe_url({"url": audio_url}, options, timeout=_DG_TIMEOUT) + result = _deepgram_prerecorded_provider().transcribe_url( + audio_url, + speakers_count=speakers_count, + return_language=return_language, + diarize=diarize, + language=language, + model=model, + keywords=keywords, ) - - # Extract words from response - result = response.to_dict() - channels = result.get('results', {}).get('channels', []) - if not channels: - raise Exception('No channels found in response') - - alternatives = channels[0].get('alternatives', []) - if not alternatives: - raise Exception('No alternatives found in response') - - dg_words = alternatives[0].get('words', []) - if not dg_words: - if return_language: - detected_lang = channels[0].get('detected_language', 'en') - if detected_lang and '-' in detected_lang: - detected_lang = detected_lang.split('-')[0] - return [], detected_lang or 'en' - return [] - - # Convert Deepgram format to fal_whisperx compatible format - # Deepgram: {word, start, end, confidence, punctuated_word, speaker (int)} - # Expected: {timestamp: [start, end], speaker: 'SPEAKER_XX', text: 'word'} - words = [] - for w in dg_words: - speaker_id = w.get('speaker', 0) - speaker_fields = _deepgram_speaker_fields(speaker_id) - words.append( - { - 'timestamp': [w['start'], w['end']], - 'speaker': speaker_fields['speaker'], - 'provider_cluster_id': speaker_fields['provider_cluster_id'], - 'provider_speaker_label': speaker_fields['provider_speaker_label'], - 'stt_provider': 'deepgram', - 'stt_model': model, - 'text': w.get('punctuated_word', w['word']), - } - ) - if return_language: - # Deepgram returns detected_language in the channel - detected_lang = channels[0].get('detected_language', 'en') - # Normalize language code (Deepgram might return 'en-US', we want 'en') - if detected_lang and '-' in detected_lang: - detected_lang = detected_lang.split('-')[0] - return words, detected_lang or 'en' + transcript_result, detected_language = result + return provider_result_to_legacy_words(transcript_result), detected_language - return words + return provider_result_to_legacy_words(result) except Exception as e: logger.error(f'Deepgram prerecorded error: {e}') @@ -317,89 +253,33 @@ def deepgram_prerecorded_from_bytes( Or tuple of (words, language) if return_language=True """ logger.info( - f'deepgram_prerecorded_from_bytes bytes_len={len(audio_bytes)} {sample_rate} {diarize} {attempts} encoding={encoding} language={language} model={model}' + 'deepgram_prerecorded_from_bytes bytes_len=%s %s %s %s encoding=%s language=%s model=%s', + len(audio_bytes), + sample_rate, + diarize, + attempts, + encoding, + language, + model, ) try: - is_multi = language == 'multi' - should_detect_language = return_language or is_multi - options = { - "model": model, - "smart_format": True, - "punctuate": True, - "diarize": diarize, - "utterances": True, - "detect_language": should_detect_language, - } - if language and not is_multi: - options["language"] = language - - if keywords: - if str(model).startswith("nova-3"): - options["keyterm"] = list(keywords) - else: - options["keywords"] = list(keywords) - - # For raw PCM, Deepgram needs encoding + sample_rate to interpret the bytes - if encoding: - options["encoding"] = encoding - options["sample_rate"] = sample_rate - options["channels"] = channels - - # Wrap bytes in BytesIO for Deepgram client - audio_buffer = BytesIO(audio_bytes) - mimetype = "audio/raw" if encoding else "audio/wav" - source = {"buffer": audio_buffer, "mimetype": mimetype} - - response = ( - _deepgram_client_for_request().listen.rest.v("1").transcribe_file(source, options, timeout=_DG_TIMEOUT) + result = _deepgram_prerecorded_provider().transcribe_bytes( + audio_bytes, + sample_rate=sample_rate, + diarize=diarize, + encoding=encoding, + channels=channels, + language=language, + model=model, + return_language=return_language, + keywords=keywords, ) - - # Extract words from response - result = response.to_dict() - result_channels = result.get('results', {}).get('channels', []) - if not result_channels: - raise Exception('No channels found in response') - - alternatives = result_channels[0].get('alternatives', []) - if not alternatives: - raise Exception('No alternatives found in response') - - dg_words = alternatives[0].get('words', []) - if not dg_words: - if return_language: - detected_lang = result_channels[0].get('detected_language', 'en') - if detected_lang and '-' in detected_lang: - detected_lang = detected_lang.split('-')[0] - return [], detected_lang or 'en' - return [] - - # Convert Deepgram format to standard format - # Deepgram: {word, start, end, confidence, punctuated_word, speaker (int)} - # Expected: {timestamp: [start, end], speaker: 'SPEAKER_XX', text: 'word'} - words = [] - for w in dg_words: - speaker_id = w.get('speaker', 0) - speaker_fields = _deepgram_speaker_fields(speaker_id) - words.append( - { - 'timestamp': [w['start'], w['end']], - 'speaker': speaker_fields['speaker'], - 'provider_cluster_id': speaker_fields['provider_cluster_id'], - 'provider_speaker_label': speaker_fields['provider_speaker_label'], - 'stt_provider': 'deepgram', - 'stt_model': model, - 'text': w.get('punctuated_word', w['word']), - } - ) - if return_language: - detected_lang = result_channels[0].get('detected_language', 'en') - if detected_lang and '-' in detected_lang: - detected_lang = detected_lang.split('-')[0] - return words, detected_lang or 'en' + transcript_result, detected_language = result + return provider_result_to_legacy_words(transcript_result), detected_language - return words + return provider_result_to_legacy_words(result) except Exception as e: logger.error(f'Deepgram prerecorded from bytes error: {e}') diff --git a/backend/utils/stt/providers.py b/backend/utils/stt/providers.py new file mode 100644 index 00000000000..cd041a87fda --- /dev/null +++ b/backend/utils/stt/providers.py @@ -0,0 +1,91 @@ +from enum import Enum +from typing import Callable, List, Optional, Protocol, Sequence, Tuple, Union + +from models.transcript_segment import ProviderTranscriptResult + + +class STTProviderName(str, Enum): + deepgram = 'deepgram' + + +class STTWorkload(str, Enum): + background = 'background' + postprocess = 'postprocess' + ptt = 'ptt' + realtime = 'realtime' + sync = 'sync' + voice_message = 'voice_message' + + +class PrerecordedSTTProvider(Protocol): + provider_name: STTProviderName + + def transcribe_url( + self, + audio_url: str, + speakers_count: int = None, + return_language: bool = False, + diarize: bool = True, + language: Optional[str] = None, + model: str = 'nova-3', + keywords: Optional[Sequence[str]] = None, + ) -> Union[ProviderTranscriptResult, Tuple[ProviderTranscriptResult, str]]: ... + + def transcribe_bytes( + self, + audio_bytes: bytes, + sample_rate: int = 16000, + diarize: bool = True, + encoding: Optional[str] = None, + channels: int = 1, + language: Optional[str] = None, + model: str = 'nova-3', + return_language: bool = False, + keywords: Optional[Sequence[str]] = None, + ) -> Union[ProviderTranscriptResult, Tuple[ProviderTranscriptResult, str]]: ... + + +class StreamingSTTProvider(Protocol): + provider_name: STTProviderName + + async def connect_stream( + self, + stream_transcript, + language: str, + sample_rate: int, + channels: int, + model: str = 'nova-3', + keywords: List[str] = [], + vad_gate=None, + is_active: Optional[Callable[[], bool]] = None, + ): ... + + +class DiarizationProvider(Protocol): + provider_name: STTProviderName + + +class SpeakerIdentityProvider(Protocol): + provider_name: str + + +_PRERECORDED_WORKLOAD_PROVIDERS = { + STTWorkload.background: STTProviderName.deepgram, + STTWorkload.postprocess: STTProviderName.deepgram, + STTWorkload.ptt: STTProviderName.deepgram, + STTWorkload.sync: STTProviderName.deepgram, + STTWorkload.voice_message: STTProviderName.deepgram, +} + +_STREAMING_WORKLOAD_PROVIDERS = { + STTWorkload.ptt: STTProviderName.deepgram, + STTWorkload.realtime: STTProviderName.deepgram, +} + + +def get_prerecorded_provider_name(workload: STTWorkload) -> STTProviderName: + return _PRERECORDED_WORKLOAD_PROVIDERS[STTWorkload(workload)] + + +def get_streaming_provider_name(workload: STTWorkload) -> STTProviderName: + return _STREAMING_WORKLOAD_PROVIDERS[STTWorkload(workload)] diff --git a/backend/utils/stt/streaming.py b/backend/utils/stt/streaming.py index 6fa923e7c39..e354895f4d9 100644 --- a/backend/utils/stt/streaming.py +++ b/backend/utils/stt/streaming.py @@ -11,6 +11,7 @@ from utils.byok import get_byok_key from utils.executors import sync_executor, run_blocking from utils.stt.safe_socket import KeepaliveConfig, SafeDeepgramSocket # noqa: F401 — re-exported for backward compat +from utils.stt.providers import STTProviderName from utils.stt.vad_gate import GatedDeepgramSocket import logging @@ -278,6 +279,32 @@ def on_dg_error(self, error, **kwargs): return safe_conn +class DeepgramStreamingTranscriptionProvider: + provider_name = STTProviderName.deepgram + + async def connect_stream( + self, + stream_transcript, + language: str, + sample_rate: int, + channels: int, + model: str = 'nova-3', + keywords: List[str] = [], + vad_gate=None, + is_active: Optional[Callable[[], bool]] = None, + ): + return await process_audio_dg( + stream_transcript, + language, + sample_rate, + channels, + model=model, + keywords=keywords, + vad_gate=vad_gate, + is_active=is_active, + ) + + # Calculate backoff with jitter def calculate_backoff_with_jitter(attempt, base_delay=1000, max_delay=32000): jitter = random.random() * base_delay From 9bb10888d226dc3292cefbde551d72e8ef86bacf Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 08:07:48 +0700 Subject: [PATCH 03/44] Add conversation reconstructor for STT results --- .../unit/test_conversation_reconstructor.py | 206 ++++++++++++++ .../utils/stt/conversation_reconstructor.py | 263 ++++++++++++++++++ backend/utils/stt/pre_recorded.py | 134 ++------- 3 files changed, 491 insertions(+), 112 deletions(-) create mode 100644 backend/tests/unit/test_conversation_reconstructor.py create mode 100644 backend/utils/stt/conversation_reconstructor.py diff --git a/backend/tests/unit/test_conversation_reconstructor.py b/backend/tests/unit/test_conversation_reconstructor.py new file mode 100644 index 00000000000..4edecf686c2 --- /dev/null +++ b/backend/tests/unit/test_conversation_reconstructor.py @@ -0,0 +1,206 @@ +import sys +from unittest.mock import MagicMock + +from models.transcript_segment import ( + ProviderTranscriptResult, + ProviderTranscriptUtterance, + ProviderTranscriptWord, +) +from utils.stt.conversation_reconstructor import ConversationReconstructor, reconstruct_conversation + +if 'deepgram' not in sys.modules: + sys.modules['deepgram'] = MagicMock() + +from utils.stt.deepgram_adapter import normalize_deepgram_prerecorded_result + + +def _word(text, start, end, cluster=None, label=None): + return ProviderTranscriptWord( + text=text, + start=start, + end=end, + provider_cluster_id=cluster, + speaker_label=label, + ) + + +def test_reconstructs_word_only_provider_result_with_stable_ordering_and_cluster_metadata(): + result = ProviderTranscriptResult( + provider='test-provider', + model='async-model', + words=[ + _word('later', 2.0, 2.5, cluster='cluster-b', label='SPEAKER_01'), + _word('hello', 0.0, 0.4, cluster='cluster-a', label='SPEAKER_00'), + _word('world.', 0.5, 1.0, cluster='cluster-a', label='SPEAKER_00'), + ], + ) + + segments = reconstruct_conversation(result) + + assert [segment.text for segment in segments] == ['Hello world.', 'Later'] + assert segments[0].start == 0.0 + assert segments[0].end == 1.0 + assert segments[0].provider_cluster_id == 'cluster-a' + assert segments[0].speaker == 'SPEAKER_00' + assert segments[0].speaker_identity_state == 'unassigned' + assert segments[0].stt_provider == 'test-provider' + assert segments[0].stt_model == 'async-model' + assert segments[1].provider_cluster_id == 'cluster-b' + + +def test_reconstructs_utterance_only_provider_result_with_explicit_unknown_identity(): + result = ProviderTranscriptResult( + provider='assemblyai', + utterances=[ + ProviderTranscriptUtterance( + text='opaque speaker label', + start=4.0, + end=5.0, + provider_cluster_id='speaker-a', + speaker_label='A', + ), + ProviderTranscriptUtterance(text='no cluster', start=5.2, end=6.0), + ], + ) + + segments = reconstruct_conversation(result) + + assert [segment.text for segment in segments] == ['Opaque speaker label', 'No cluster'] + assert segments[0].provider_cluster_id == 'speaker-a' + assert segments[0].provider_speaker_label == 'A' + assert segments[0].speaker is None + assert segments[0].speaker_identity_state == 'unknown' + assert segments[1].provider_cluster_id is None + assert segments[1].speaker_identity_state == 'unknown' + + +def test_mixed_utterances_and_words_do_not_duplicate_words_covered_by_utterances(): + result = ProviderTranscriptResult( + provider='mixed', + utterances=[ + ProviderTranscriptUtterance( + text='Hello world.', + start=0.0, + end=1.1, + provider_cluster_id='0', + speaker_label='SPEAKER_00', + ) + ], + words=[ + _word('Hello', 0.0, 0.4, cluster='0', label='SPEAKER_00'), + _word('world.', 0.5, 1.1, cluster='0', label='SPEAKER_00'), + _word('Outside.', 2.0, 2.6, cluster='1', label='SPEAKER_01'), + ], + ) + + segments = reconstruct_conversation(result) + + assert [segment.text for segment in segments] == ['Hello world.', 'Outside.'] + assert segments[0].provider_cluster_id == '0' + assert segments[1].provider_cluster_id == '1' + + +def test_reconstructor_preserves_legacy_deepgram_prerecorded_parity(): + deepgram_result = { + 'metadata': {'request_id': 'dg-request-1', 'duration': 3.25}, + 'results': { + 'channels': [ + { + 'detected_language': 'en-US', + 'alternatives': [ + { + 'words': [ + { + 'word': 'hello', + 'punctuated_word': 'Hello', + 'start': 0.0, + 'end': 0.4, + 'confidence': 0.91, + 'speaker': 2, + }, + { + 'word': 'world', + 'punctuated_word': 'world.', + 'start': 0.5, + 'end': 1.1, + 'confidence': 0.88, + 'speaker': 2, + }, + ] + } + ], + } + ], + 'utterances': [ + { + 'transcript': 'Hello world.', + 'start': 0.0, + 'end': 1.1, + 'confidence': 0.9, + 'speaker': 2, + } + ], + }, + } + + result = normalize_deepgram_prerecorded_result(deepgram_result, model='nova-3') + segments = reconstruct_conversation(result) + + assert len(segments) == 1 + assert segments[0].text == 'Hello world.' + assert segments[0].speaker == 'SPEAKER_02' + assert segments[0].speaker_id == 2 + assert segments[0].provider_cluster_id == '2' + assert segments[0].provider_speaker_label == 'SPEAKER_02' + assert segments[0].stt_provider == 'deepgram' + assert segments[0].stt_model == 'nova-3' + + +def test_overlap_duplicate_candidates_keep_longer_text_once(): + reconstructor = ConversationReconstructor() + result = ProviderTranscriptResult( + provider='test-provider', + utterances=[ + ProviderTranscriptUtterance( + text='hello', + start=0.0, + end=1.0, + provider_cluster_id='0', + speaker_label='SPEAKER_00', + ), + ProviderTranscriptUtterance( + text='hello there', + start=0.5, + end=1.5, + provider_cluster_id='0', + speaker_label='SPEAKER_00', + ), + ], + ) + + segments = reconstructor.reconstruct(result) + + assert len(segments) == 1 + assert segments[0].text == 'Hello there' + assert segments[0].start == 0.0 + assert segments[0].end == 1.5 + + +def test_skip_window_marks_dominant_preskip_cluster_as_user(): + result = ProviderTranscriptResult( + provider='test-provider', + words=[ + _word('my', 0.0, 0.2, cluster='me', label='SPEAKER_00'), + _word('voice', 0.3, 0.6, cluster='me', label='SPEAKER_00'), + _word('after', 3.0, 3.4, cluster='me', label='SPEAKER_00'), + _word('guest', 3.5, 4.0, cluster='guest', label='SPEAKER_01'), + ], + ) + + segments = reconstruct_conversation(result, skip_n_seconds=2) + + assert segments[0].text == 'After' + assert segments[0].is_user is True + assert segments[0].speaker_identity_state == 'user' + assert segments[0].start == 0.0 + assert segments[1].is_user is False diff --git a/backend/utils/stt/conversation_reconstructor.py b/backend/utils/stt/conversation_reconstructor.py new file mode 100644 index 00000000000..c9f8c8ac0db --- /dev/null +++ b/backend/utils/stt/conversation_reconstructor.py @@ -0,0 +1,263 @@ +from dataclasses import dataclass +from typing import Iterable, List, Optional, Sequence, Tuple + +from models.transcript_segment import ( + ProviderTranscriptResult, + ProviderTranscriptUtterance, + ProviderTranscriptWord, + TranscriptSegment, +) + + +@dataclass +class _SegmentCandidate: + text: str + start: float + end: float + provider_cluster_id: Optional[str] + speaker_label: Optional[str] + + +class ConversationReconstructor: + def __init__(self, max_same_cluster_gap_seconds: float = 30.0): + self.max_same_cluster_gap_seconds = max_same_cluster_gap_seconds + + def reconstruct( + self, + result: ProviderTranscriptResult, + skip_n_seconds: int = 0, + user_provider_cluster_id: Optional[str] = None, + ) -> List[TranscriptSegment]: + user_cluster_id = user_provider_cluster_id or self._retrieve_user_cluster_id(result, skip_n_seconds) + candidates = self._build_candidates(result, skip_n_seconds) + candidates = [candidate for candidate in candidates if candidate.start >= skip_n_seconds and candidate.text] + candidates = self._sort_and_dedupe_candidates(candidates) + candidates = self._merge_adjacent_candidates(candidates) + + if not candidates: + return [] + + starts_at = candidates[0].start + segments = [] + for candidate in candidates: + is_user = bool(user_cluster_id and candidate.provider_cluster_id == user_cluster_id) + speaker_label = self._legacy_speaker_label(candidate.speaker_label) + identity_state = self._identity_state(is_user, candidate.provider_cluster_id, speaker_label) + segments.append( + TranscriptSegment( + text=candidate.text.strip().capitalize(), + speaker=speaker_label, + is_user=is_user, + person_id=None, + start=round(candidate.start - starts_at, 2), + end=round(candidate.end - starts_at, 2), + stt_provider=result.provider, + stt_model=result.model, + provider_cluster_id=candidate.provider_cluster_id, + provider_speaker_label=candidate.speaker_label, + speaker_identity_state=identity_state, + ) + ) + return segments + + def _build_candidates(self, result: ProviderTranscriptResult, skip_n_seconds: int) -> List[_SegmentCandidate]: + candidates = self._utterance_candidates(result.utterances) + if not result.words: + return candidates + + candidate_words = [word for word in result.words if word.start >= skip_n_seconds] + if not candidates: + return self._word_candidates(candidate_words) + + covered_intervals = [(candidate.start, candidate.end) for candidate in candidates] + uncovered_words = [word for word in candidate_words if not self._word_is_covered(word, covered_intervals)] + return candidates + self._word_candidates(uncovered_words) + + def _utterance_candidates(self, utterances: Sequence[ProviderTranscriptUtterance]) -> List[_SegmentCandidate]: + candidates = [] + for utterance in utterances: + text = utterance.text.strip() + if not text and utterance.words: + text = ' '.join(word.text.strip() for word in utterance.words if word.text.strip()) + if not text: + continue + provider_cluster_id = utterance.provider_cluster_id + speaker_label = utterance.speaker_label + if utterance.words and (not provider_cluster_id or not speaker_label): + provider_cluster_id, speaker_label = self._dominant_speaker(utterance.words) + candidates.append( + _SegmentCandidate( + text=text, + start=utterance.start, + end=utterance.end, + provider_cluster_id=provider_cluster_id, + speaker_label=speaker_label, + ) + ) + return candidates + + def _word_candidates(self, words: Sequence[ProviderTranscriptWord]) -> List[_SegmentCandidate]: + candidates = [] + normalized_words = self._interpolate_missing_word_speakers( + sorted(words, key=lambda item: (item.start, item.end)) + ) + for word in normalized_words: + text = word.text.strip() + if not text: + continue + if candidates and self._should_merge_word(candidates[-1], word): + candidates[-1].text = f'{candidates[-1].text} {text}' + candidates[-1].end = word.end + continue + candidates.append( + _SegmentCandidate( + text=text, + start=word.start, + end=word.end, + provider_cluster_id=word.provider_cluster_id, + speaker_label=word.speaker_label, + ) + ) + return candidates + + def _interpolate_missing_word_speakers( + self, words: Sequence[ProviderTranscriptWord] + ) -> List[ProviderTranscriptWord]: + normalized_words = [word.model_copy() for word in words] + for index, word in enumerate(normalized_words): + if word.provider_cluster_id or word.speaker_label: + continue + + previous_word = normalized_words[index - 1] if index > 0 else None + next_word = normalized_words[index + 1] if index < len(normalized_words) - 1 else None + previous_has_speaker = previous_word and (previous_word.provider_cluster_id or previous_word.speaker_label) + next_has_speaker = next_word and (next_word.provider_cluster_id or next_word.speaker_label) + if previous_has_speaker and next_has_speaker: + if previous_word.provider_cluster_id == next_word.provider_cluster_id: + word.provider_cluster_id = previous_word.provider_cluster_id + word.speaker_label = previous_word.speaker_label + else: + secs_from_previous = word.start - previous_word.end + secs_to_next = next_word.start - word.end + source = previous_word if secs_from_previous < secs_to_next else next_word + word.provider_cluster_id = source.provider_cluster_id + word.speaker_label = source.speaker_label + elif previous_has_speaker: + word.provider_cluster_id = previous_word.provider_cluster_id + word.speaker_label = previous_word.speaker_label + elif next_has_speaker: + word.provider_cluster_id = next_word.provider_cluster_id + word.speaker_label = next_word.speaker_label + return normalized_words + + def _sort_and_dedupe_candidates(self, candidates: Sequence[_SegmentCandidate]) -> List[_SegmentCandidate]: + ordered = sorted(candidates, key=lambda item: (item.start, item.end)) + deduped = [] + for candidate in ordered: + if not deduped: + deduped.append(candidate) + continue + + previous = deduped[-1] + if self._is_duplicate_overlap(previous, candidate): + if len(candidate.text) > len(previous.text): + candidate.start = min(previous.start, candidate.start) + candidate.end = max(previous.end, candidate.end) + deduped[-1] = candidate + continue + deduped.append(candidate) + return deduped + + def _merge_adjacent_candidates(self, candidates: Sequence[_SegmentCandidate]) -> List[_SegmentCandidate]: + merged = [] + for candidate in candidates: + if merged and self._should_merge_candidate(merged[-1], candidate): + merged[-1].text = f'{merged[-1].text} {candidate.text}' + merged[-1].end = candidate.end + continue + merged.append(candidate) + return merged + + def _retrieve_user_cluster_id(self, result: ProviderTranscriptResult, skip_n_seconds: int) -> Optional[str]: + if not skip_n_seconds: + return None + + speaker_counts = {} + speaker_sources: Iterable[Tuple[float, Optional[str]]] = ( + [(word.start, word.provider_cluster_id) for word in result.words] + if result.words + else [(utterance.start, utterance.provider_cluster_id) for utterance in result.utterances] + ) + for start, provider_cluster_id in sorted(speaker_sources, key=lambda item: item[0]): + if start >= skip_n_seconds: + break + if not provider_cluster_id: + continue + speaker_counts[provider_cluster_id] = speaker_counts.get(provider_cluster_id, 0) + 1 + return max(speaker_counts, key=speaker_counts.get) if speaker_counts else None + + def _dominant_speaker(self, words: Sequence[ProviderTranscriptWord]) -> Tuple[Optional[str], Optional[str]]: + speaker_counts = {} + labels_by_cluster = {} + for word in words: + if not word.provider_cluster_id: + continue + speaker_counts[word.provider_cluster_id] = speaker_counts.get(word.provider_cluster_id, 0) + 1 + if word.speaker_label: + labels_by_cluster[word.provider_cluster_id] = word.speaker_label + + if not speaker_counts: + return None, None + + cluster_id = max(speaker_counts, key=speaker_counts.get) + return cluster_id, labels_by_cluster.get(cluster_id) + + def _word_is_covered(self, word: ProviderTranscriptWord, intervals: Sequence[Tuple[float, float]]) -> bool: + return any(start <= word.start and word.end <= end for start, end in intervals) + + def _should_merge_word(self, previous: _SegmentCandidate, word: ProviderTranscriptWord) -> bool: + return ( + previous.provider_cluster_id == word.provider_cluster_id + and word.start - previous.end < self.max_same_cluster_gap_seconds + ) + + def _should_merge_candidate(self, previous: _SegmentCandidate, candidate: _SegmentCandidate) -> bool: + return ( + previous.provider_cluster_id == candidate.provider_cluster_id + and candidate.start - previous.end < self.max_same_cluster_gap_seconds + ) + + def _is_duplicate_overlap(self, previous: _SegmentCandidate, candidate: _SegmentCandidate) -> bool: + if candidate.start >= previous.end: + return False + + previous_text = self._normalize_text(previous.text) + candidate_text = self._normalize_text(candidate.text) + return ( + bool(previous_text and candidate_text) + and previous.provider_cluster_id == candidate.provider_cluster_id + and (previous_text == candidate_text or previous_text in candidate_text or candidate_text in previous_text) + ) + + def _legacy_speaker_label(self, speaker_label: Optional[str]) -> Optional[str]: + if isinstance(speaker_label, str) and speaker_label.startswith('SPEAKER_'): + return speaker_label + return None + + def _identity_state(self, is_user: bool, provider_cluster_id: Optional[str], speaker_label: Optional[str]) -> str: + if is_user: + return 'user' + if speaker_label: + return 'unassigned' + return 'unknown' + + def _normalize_text(self, text: str) -> str: + return ' '.join(text.lower().split()) + + +def reconstruct_conversation( + result: ProviderTranscriptResult, + skip_n_seconds: int = 0, + user_provider_cluster_id: Optional[str] = None, +) -> List[TranscriptSegment]: + return ConversationReconstructor().reconstruct(result, skip_n_seconds, user_provider_cluster_id) diff --git a/backend/utils/stt/pre_recorded.py b/backend/utils/stt/pre_recorded.py index ce75fc9e6e9..2907684c493 100644 --- a/backend/utils/stt/pre_recorded.py +++ b/backend/utils/stt/pre_recorded.py @@ -1,14 +1,14 @@ import os -from collections import defaultdict from typing import List, Optional, Sequence, Tuple, Union import fal_client import httpx from deepgram import DeepgramClient, DeepgramClientOptions -from models.transcript_segment import TranscriptSegment +from models.transcript_segment import ProviderTranscriptResult, ProviderTranscriptWord, TranscriptSegment from utils.byok import get_byok_key from utils.other.endpoints import timeit +from utils.stt.conversation_reconstructor import reconstruct_conversation from utils.stt.deepgram_adapter import ( DeepgramPrerecordedTranscriptionProvider, deepgram_speaker_fields, @@ -342,121 +342,31 @@ def fal_whisperx( return [] -def _words_cleaning(words: List[dict]): - words_cleaned: List[dict] = [] - for i, w in enumerate(words): - # if w['timestamp'][0] == w['timestamp'][1]: - # continue - raw_speaker = w.get('speaker') +def legacy_words_to_provider_result(words: List[dict]) -> ProviderTranscriptResult: + provider_words = [] + provider = None + model = None + for word in words: + raw_speaker = word.get('speaker') speaker = raw_speaker if isinstance(raw_speaker, str) and raw_speaker.startswith('SPEAKER_') else None - words_cleaned.append( - { - 'start': round(w['timestamp'][0], 2), - 'end': round(w['timestamp'][1] or w['timestamp'][0] + 1, 2), - 'speaker': speaker, - 'provider_cluster_id': w.get('provider_cluster_id') or raw_speaker, - 'provider_speaker_label': w.get('provider_speaker_label') or speaker, - 'stt_provider': w.get('stt_provider'), - 'stt_model': w.get('stt_model'), - 'text': str(w['text']).strip(), - 'is_user': False, - 'person_id': None, - 'speaker_identity_state': 'unassigned' if speaker else 'unknown', - } + timestamp = word['timestamp'] + provider = provider or word.get('stt_provider') + model = model or word.get('stt_model') + provider_words.append( + ProviderTranscriptWord( + text=str(word['text']).strip(), + start=round(timestamp[0], 2), + end=round(timestamp[1] or timestamp[0] + 1, 2), + provider_cluster_id=word.get('provider_cluster_id') or raw_speaker, + speaker_label=word.get('provider_speaker_label') or speaker, + confidence=word.get('confidence'), + ) ) - for i, word in enumerate(words_cleaned): - speaker = word['speaker'] - if not speaker: - prev_chunk = words_cleaned[i - 1] if i > 0 else None - next_chunk = words_cleaned[i + 1] if i < len(words_cleaned) - 1 else None - prev_speaker = prev_chunk['speaker'] if prev_chunk else None - next_speaker = next_chunk['speaker'] if next_chunk else None - - if prev_speaker and next_speaker: - if prev_speaker == next_speaker: - speaker = prev_chunk['speaker'] - else: - secs_from_prev = word['start'] - prev_chunk['end'] if prev_chunk else 0 - secs_to_next = next_chunk['start'] - word['end'] if next_chunk else 0 - speaker = prev_speaker if secs_from_prev < secs_to_next else next_speaker - elif prev_speaker: - speaker = prev_speaker - elif next_speaker: - speaker = next_speaker - - words_cleaned[i]['speaker'] = speaker - words_cleaned[i]['provider_speaker_label'] = speaker - if speaker: - words_cleaned[i]['speaker_identity_state'] = 'unassigned' - - # for chunk in words_cleaned: - # print(chunk) - return words_cleaned - - -def _retrieve_user_speaker_id(words: list, skip_n_seconds: int): - if not skip_n_seconds: - return None - - user_speaker_id = defaultdict(int) - for word in words: - if word['start'] >= skip_n_seconds: - break - if not word['speaker']: - continue - user_speaker_id[word['speaker']] += 1 - - user_speaker_id = max(user_speaker_id, key=user_speaker_id.get) if user_speaker_id else None - return user_speaker_id - - -def _merge_segments(words: List[dict], skip_n_seconds: int, user_speaker_id: str): - segments = [] - for word in words: - if word['start'] < skip_n_seconds: - continue - word['is_user'] = word['speaker'] == user_speaker_id if word['speaker'] else False - - same_prev_speaker = word['speaker'] == segments[-1]['speaker'] if segments else False - seconds_from_prev = word['start'] - segments[-1]['end'] if segments else 0 - - # TODO: consider having a max segment size too - if segments and same_prev_speaker and seconds_from_prev < 30: - segments[-1]['end'] = word['end'] - segments[-1]['text'] += ' ' + word['text'] - else: - segments.append(word) - return segments - - -def _segments_as_objects(segments: List[dict]) -> List[TranscriptSegment]: - if not segments: - return [] - starts_at = segments[0]['start'] - return [ - TranscriptSegment( - text=str(segment['text']).strip().capitalize(), - speaker=segment['speaker'], - is_user=segment['is_user'], - person_id=None, - start=round(segment['start'] - starts_at, 2), - end=round(segment['end'] - starts_at, 2), - stt_provider=segment.get('stt_provider'), - stt_model=segment.get('stt_model'), - provider_cluster_id=segment.get('provider_cluster_id'), - provider_speaker_label=segment.get('provider_speaker_label'), - speaker_identity_state='user' if segment['is_user'] else segment.get('speaker_identity_state', 'unknown'), - ) - for segment in segments - ] + return ProviderTranscriptResult(provider=provider or 'unknown', model=model, words=provider_words) def postprocess_words( words: List[dict], duration: int, skip_n_seconds: int = 0 # , merge_segments: bool = True ) -> List[TranscriptSegment]: - words: List[dict] = _words_cleaning(words) - user_speaker_id = _retrieve_user_speaker_id(words, skip_n_seconds) - segments = _merge_segments(words, skip_n_seconds, user_speaker_id) - segments = _segments_as_objects(segments) - return segments + return reconstruct_conversation(legacy_words_to_provider_result(words), skip_n_seconds=skip_n_seconds) From 790ad0318625f7ae46c4987be99827d6444131be Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 08:14:13 +0700 Subject: [PATCH 04/44] Add transcription provider usage ledger --- .../database/transcription_provider_usage.py | 411 ++++++++++++++++++ .../unit/test_transcription_provider_usage.py | 241 ++++++++++ backend/utils/metrics.py | 186 +++++++- 3 files changed, 837 insertions(+), 1 deletion(-) create mode 100644 backend/database/transcription_provider_usage.py create mode 100644 backend/tests/unit/test_transcription_provider_usage.py diff --git a/backend/database/transcription_provider_usage.py b/backend/database/transcription_provider_usage.py new file mode 100644 index 00000000000..2db80e48212 --- /dev/null +++ b/backend/database/transcription_provider_usage.py @@ -0,0 +1,411 @@ +from datetime import datetime, timedelta, timezone +from typing import Any, List, Optional +from uuid import uuid4 + +from google.cloud import firestore +from google.cloud.firestore_v1 import FieldFilter + +from ._client import db +from utils.metrics import ( + identity_confidence_bucket, + observe_transcription_provider_audio_seconds, + observe_transcription_provider_fallback, + observe_transcription_provider_identity_confidence, + observe_transcription_provider_request, + observe_transcription_provider_retry, + observe_transcription_provider_speaker_clusters, +) + +RUNS_COLLECTION = 'transcription_provider_runs' +DAILY_USAGE_COLLECTION = 'transcription_provider_usage_daily' +RUN_TTL_DAYS = 180 + +FORBIDDEN_LEDGER_KEYS = { + 'audio_bytes', + 'audio', + 'raw_audio_bytes', + 'text', + 'transcript', + 'transcript_text', + 'words', + 'word_records', + 'chunks', + 'utterances', +} + + +def utc_day_bucket(value: Optional[datetime] = None) -> str: + value = value or datetime.now(timezone.utc) + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc).strftime('%Y-%m-%d') + + +def _safe_doc_id_part(value: str) -> str: + return str(value or 'unknown').replace('/', '_') + + +def daily_rollup_doc_id(day: str, provider: str, model: str, workload: str) -> str: + return ':'.join( + [ + _safe_doc_id_part(day), + _safe_doc_id_part(provider), + _safe_doc_id_part(model), + _safe_doc_id_part(workload), + ] + ) + + +def _run_ref(run_id: str): + return db.collection(RUNS_COLLECTION).document(run_id) + + +def _rollup_ref(day: str, provider: str, model: str, workload: str): + return db.collection(DAILY_USAGE_COLLECTION).document(daily_rollup_doc_id(day, provider, model, workload)) + + +def _utc_now() -> datetime: + return datetime.now(timezone.utc) + + +def _ttl_expires_at(now: datetime) -> datetime: + return now + timedelta(days=RUN_TTL_DAYS) + + +def _reject_forbidden_payload_keys(payload: dict[str, Any]) -> None: + forbidden = _find_forbidden_payload_keys(payload) + if forbidden: + raise ValueError(f'transcription provider ledger payload contains forbidden keys: {sorted(forbidden)}') + + +def _find_forbidden_payload_keys(value: Any) -> set[str]: + if isinstance(value, dict): + forbidden = FORBIDDEN_LEDGER_KEYS & set(value) + for nested in value.values(): + forbidden.update(_find_forbidden_payload_keys(nested)) + return forbidden + if isinstance(value, list): + forbidden = set() + for nested in value: + forbidden.update(_find_forbidden_payload_keys(nested)) + return forbidden + return set() + + +def create_provider_run( + uid: str, + provider: str, + model: str, + workload: str, + run_id: Optional[str] = None, + conversation_id: Optional[str] = None, + provider_job_ref: Optional[str] = None, + artifact_refs: Optional[dict[str, str]] = None, + started_at: Optional[datetime] = None, +) -> str: + run_id = run_id or str(uuid4()) + now = _utc_now() + started_at = started_at or now + payload = { + 'run_id': run_id, + 'uid': uid, + 'conversation_id': conversation_id, + 'provider': provider, + 'model': model, + 'workload': workload, + 'status': 'started', + 'provider_job_ref': provider_job_ref, + 'artifact_refs': artifact_refs or {}, + 'timing': { + 'started_at': started_at, + 'completed_at': None, + 'latency_ms': None, + }, + 'raw_audio_seconds': 0.0, + 'speech_active_seconds': 0.0, + 'billable_seconds': 0.0, + 'estimated_cost_usd': 0.0, + 'retry_count': 0, + 'fallback_count': 0, + 'transcript_segment_count': 0, + 'transcript_word_count': 0, + 'speaker_cluster_count': 0, + 'identified_speaker_cluster_count': 0, + 'identity_confidence_summary': {}, + 'error_class': None, + 'created_at': now, + 'updated_at': now, + 'expires_at': _ttl_expires_at(now), + } + _reject_forbidden_payload_keys(payload) + _run_ref(run_id).set(payload, merge=False) + return run_id + + +def finalize_provider_run( + run_id: str, + provider: str, + model: str, + workload: str, + status: str, + started_at: datetime, + completed_at: Optional[datetime] = None, + raw_audio_seconds: float = 0.0, + speech_active_seconds: float = 0.0, + billable_seconds: float = 0.0, + estimated_cost_usd: float = 0.0, + retry_count: int = 0, + fallback_count: int = 0, + transcript_segment_count: int = 0, + transcript_word_count: int = 0, + speaker_cluster_count: int = 0, + identified_speaker_cluster_count: int = 0, + identity_confidence_summary: Optional[dict[str, Any]] = None, + error_class: Optional[str] = None, + artifact_refs: Optional[dict[str, str]] = None, + fallback_provider: Optional[str] = None, + fallback_reason: str = 'provider_failure', +) -> None: + completed_at = completed_at or _utc_now() + if started_at.tzinfo is None: + started_at = started_at.replace(tzinfo=timezone.utc) + if completed_at.tzinfo is None: + completed_at = completed_at.replace(tzinfo=timezone.utc) + latency_seconds = max((completed_at - started_at).total_seconds(), 0.0) + summary = identity_confidence_summary or {} + payload = { + 'status': status, + 'timing': { + 'started_at': started_at, + 'completed_at': completed_at, + 'latency_ms': int(latency_seconds * 1000), + }, + 'raw_audio_seconds': raw_audio_seconds, + 'speech_active_seconds': speech_active_seconds, + 'billable_seconds': billable_seconds, + 'estimated_cost_usd': estimated_cost_usd, + 'retry_count': retry_count, + 'fallback_count': fallback_count, + 'transcript_segment_count': transcript_segment_count, + 'transcript_word_count': transcript_word_count, + 'speaker_cluster_count': speaker_cluster_count, + 'identified_speaker_cluster_count': identified_speaker_cluster_count, + 'identity_confidence_summary': summary, + 'error_class': error_class, + 'artifact_refs': artifact_refs or {}, + 'updated_at': completed_at, + } + _reject_forbidden_payload_keys(payload) + _run_ref(run_id).set(payload, merge=True) + increment_daily_rollup( + day=utc_day_bucket(completed_at), + provider=provider, + model=model, + workload=workload, + status=status, + raw_audio_seconds=raw_audio_seconds, + speech_active_seconds=speech_active_seconds, + billable_seconds=billable_seconds, + estimated_cost_usd=estimated_cost_usd, + retry_count=retry_count, + fallback_count=fallback_count, + transcript_segment_count=transcript_segment_count, + transcript_word_count=transcript_word_count, + speaker_cluster_count=speaker_cluster_count, + identified_speaker_cluster_count=identified_speaker_cluster_count, + identity_confidence_summary=summary, + ) + emit_provider_run_metrics( + provider=provider, + model=model, + workload=workload, + status=status, + latency_seconds=latency_seconds, + raw_audio_seconds=raw_audio_seconds, + speech_active_seconds=speech_active_seconds, + billable_seconds=billable_seconds, + retry_count=retry_count, + fallback_count=fallback_count, + fallback_provider=fallback_provider, + fallback_reason=fallback_reason, + speaker_cluster_count=speaker_cluster_count, + identified_speaker_cluster_count=identified_speaker_cluster_count, + identity_confidence_summary=summary, + ) + + +def increment_daily_rollup( + day: str, + provider: str, + model: str, + workload: str, + status: str, + raw_audio_seconds: float = 0.0, + speech_active_seconds: float = 0.0, + billable_seconds: float = 0.0, + estimated_cost_usd: float = 0.0, + retry_count: int = 0, + fallback_count: int = 0, + transcript_segment_count: int = 0, + transcript_word_count: int = 0, + speaker_cluster_count: int = 0, + identified_speaker_cluster_count: int = 0, + identity_confidence_summary: Optional[dict[str, Any]] = None, +) -> None: + update = { + 'day': day, + 'provider': provider, + 'model': model, + 'workload': workload, + 'run_count': firestore.Increment(1), + f'status_counts.{status}': firestore.Increment(1), + 'raw_audio_seconds': firestore.Increment(raw_audio_seconds), + 'speech_active_seconds': firestore.Increment(speech_active_seconds), + 'billable_seconds': firestore.Increment(billable_seconds), + 'estimated_cost_usd': firestore.Increment(estimated_cost_usd), + 'retry_count': firestore.Increment(retry_count), + 'fallback_count': firestore.Increment(fallback_count), + 'transcript_segment_count': firestore.Increment(transcript_segment_count), + 'transcript_word_count': firestore.Increment(transcript_word_count), + 'speaker_cluster_count': firestore.Increment(speaker_cluster_count), + 'identified_speaker_cluster_count': firestore.Increment(identified_speaker_cluster_count), + 'last_updated': _utc_now(), + } + for bucket, count in (identity_confidence_summary or {}).items(): + if isinstance(count, (int, float)) and count > 0: + update[f'identity_confidence_counts.{bucket}'] = firestore.Increment(count) + _rollup_ref(day, provider, model, workload).set(update, merge=True) + + +def rebuild_daily_rollup_from_runs(day: str, provider: str, model: str, workload: str) -> dict[str, Any]: + query = ( + db.collection(RUNS_COLLECTION) + .where(filter=FieldFilter('provider', '==', provider)) + .where(filter=FieldFilter('model', '==', model)) + .where(filter=FieldFilter('workload', '==', workload)) + ) + rollup = _empty_rollup(day, provider, model, workload) + for doc in query.stream(): + data = doc.to_dict() or {} + completed_at = (data.get('timing') or {}).get('completed_at') + if not completed_at or utc_day_bucket(completed_at) != day: + continue + _add_run_to_rollup(rollup, data) + _rollup_ref(day, provider, model, workload).set(rollup, merge=False) + return rollup + + +def purge_provider_runs_for_user(uid: str, batch_size: int = 400) -> int: + deleted = 0 + query = db.collection(RUNS_COLLECTION).where(filter=FieldFilter('uid', '==', uid)) + batch = db.batch() + pending = 0 + for doc in query.stream(): + batch.delete(doc.reference) + deleted += 1 + pending += 1 + if pending >= batch_size: + batch.commit() + batch = db.batch() + pending = 0 + if pending: + batch.commit() + return deleted + + +def emit_provider_run_metrics( + provider: str, + model: str, + workload: str, + status: str, + latency_seconds: float, + raw_audio_seconds: float, + speech_active_seconds: float, + billable_seconds: float, + retry_count: int, + fallback_count: int, + speaker_cluster_count: int, + identified_speaker_cluster_count: int, + identity_confidence_summary: Optional[dict[str, Any]] = None, + fallback_provider: Optional[str] = None, + fallback_reason: str = 'provider_failure', +) -> None: + observe_transcription_provider_request(provider, model, workload, status, latency_seconds) + observe_transcription_provider_audio_seconds( + provider, + model, + workload, + raw_audio_seconds=raw_audio_seconds, + speech_active_seconds=speech_active_seconds, + billable_seconds=billable_seconds, + ) + observe_transcription_provider_retry(provider, model, workload, 'provider_retry', retry_count) + if fallback_count > 0: + observe_transcription_provider_fallback( + provider, + fallback_provider or 'unknown', + workload, + fallback_reason, + fallback_count, + ) + observe_transcription_provider_speaker_clusters( + provider, + model, + workload, + speaker_cluster_count=speaker_cluster_count, + identified_speaker_cluster_count=identified_speaker_cluster_count, + ) + for bucket, count in (identity_confidence_summary or {}).items(): + observe_transcription_provider_identity_confidence(provider, model, workload, bucket, count) + + +def summarize_identity_confidences(confidences: List[Optional[float]]) -> dict[str, int]: + summary: dict[str, int] = {} + for confidence in confidences: + bucket = identity_confidence_bucket(confidence) + summary[bucket] = summary.get(bucket, 0) + 1 + return summary + + +def _empty_rollup(day: str, provider: str, model: str, workload: str) -> dict[str, Any]: + return { + 'day': day, + 'provider': provider, + 'model': model, + 'workload': workload, + 'run_count': 0, + 'status_counts': {}, + 'raw_audio_seconds': 0.0, + 'speech_active_seconds': 0.0, + 'billable_seconds': 0.0, + 'estimated_cost_usd': 0.0, + 'retry_count': 0, + 'fallback_count': 0, + 'transcript_segment_count': 0, + 'transcript_word_count': 0, + 'speaker_cluster_count': 0, + 'identified_speaker_cluster_count': 0, + 'identity_confidence_counts': {}, + 'last_updated': _utc_now(), + } + + +def _add_run_to_rollup(rollup: dict[str, Any], data: dict[str, Any]) -> None: + rollup['run_count'] += 1 + status = data.get('status') or 'unknown' + rollup['status_counts'][status] = rollup['status_counts'].get(status, 0) + 1 + for field in ( + 'raw_audio_seconds', + 'speech_active_seconds', + 'billable_seconds', + 'estimated_cost_usd', + 'retry_count', + 'fallback_count', + 'transcript_segment_count', + 'transcript_word_count', + 'speaker_cluster_count', + 'identified_speaker_cluster_count', + ): + rollup[field] += data.get(field, 0) or 0 + for bucket, count in (data.get('identity_confidence_summary') or {}).items(): + rollup['identity_confidence_counts'][bucket] = rollup['identity_confidence_counts'].get(bucket, 0) + count diff --git a/backend/tests/unit/test_transcription_provider_usage.py b/backend/tests/unit/test_transcription_provider_usage.py new file mode 100644 index 00000000000..3e655cb8fac --- /dev/null +++ b/backend/tests/unit/test_transcription_provider_usage.py @@ -0,0 +1,241 @@ +from datetime import datetime, timezone +import sys +import types +from unittest.mock import MagicMock + +import pytest + +_google_module = sys.modules.setdefault('google', types.ModuleType('google')) +_google_cloud_module = sys.modules.setdefault('google.cloud', types.ModuleType('google.cloud')) +_google_firestore_module = types.ModuleType('google.cloud.firestore') +_google_firestore_module.Increment = lambda value: {'__increment': value} +_google_firestore_v1_module = types.ModuleType('google.cloud.firestore_v1') +_google_firestore_v1_module.FieldFilter = lambda field, op, value: (field, op, value) +sys.modules['google.cloud.firestore'] = _google_firestore_module +sys.modules['google.cloud.firestore_v1'] = _google_firestore_v1_module +setattr(_google_module, 'cloud', _google_cloud_module) +setattr(_google_cloud_module, 'firestore', _google_firestore_module) +_mock_client_module = MagicMock() +_mock_client_module.db = MagicMock() +sys.modules['database._client'] = _mock_client_module +_prometheus_module = types.ModuleType('prometheus_client') +_metric_factory = lambda *args, **kwargs: MagicMock(labels=MagicMock(return_value=MagicMock())) +_prometheus_module.Counter = _metric_factory +_prometheus_module.Gauge = _metric_factory +_prometheus_module.Histogram = _metric_factory +_prometheus_module.generate_latest = lambda: b'' +_prometheus_module.CONTENT_TYPE_LATEST = 'text/plain' +sys.modules['prometheus_client'] = _prometheus_module +_fastapi_module = types.ModuleType('fastapi') +_fastapi_module.Response = lambda content=None, media_type=None: {'content': content, 'media_type': media_type} +sys.modules['fastapi'] = _fastapi_module + +from database import transcription_provider_usage as usage +from utils import metrics + + +class _FakeSnapshot: + def __init__(self, data): + self._data = data + self.reference = MagicMock() + + def to_dict(self): + return self._data + + +class _FakeDoc: + def __init__(self, doc_id): + self.id = doc_id + self.set_calls = [] + + def set(self, data, merge=False): + self.set_calls.append({'data': data, 'merge': merge}) + + +class _FakeCollection: + def __init__(self, name, docs): + self.name = name + self.docs = docs + self.filters = [] + + def document(self, doc_id): + return self.docs.setdefault((self.name, doc_id), _FakeDoc(doc_id)) + + def where(self, filter=None, *args, **kwargs): + self.filters.append(filter) + return self + + def stream(self): + return iter([]) + + +class _FakeDb: + def __init__(self): + self.docs = {} + self.collections = {} + self.batch_ref = MagicMock() + + def collection(self, name): + return self.collections.setdefault(name, _FakeCollection(name, self.docs)) + + def batch(self): + return self.batch_ref + + +def _inc(value): + return {'__increment': value} + + +def test_create_and_finalize_provider_run_writes_ledger_rollup_and_metrics(monkeypatch): + fake_db = _FakeDb() + monkeypatch.setattr(usage, 'db', fake_db) + monkeypatch.setattr(usage.firestore, 'Increment', _inc) + emitted = [] + monkeypatch.setattr( + usage, + 'emit_provider_run_metrics', + lambda **kwargs: emitted.append(kwargs), + ) + + started_at = datetime(2026, 5, 20, 23, 59, 58, tzinfo=timezone.utc) + completed_at = datetime(2026, 5, 21, 0, 0, 3, tzinfo=timezone.utc) + run_id = usage.create_provider_run( + uid='user-1', + provider='assemblyai', + model='universal-2', + workload='background', + run_id='run-1', + conversation_id='conv-1', + artifact_refs={'provider_result': 'gs://bucket/result.json'}, + started_at=started_at, + ) + usage.finalize_provider_run( + run_id=run_id, + provider='assemblyai', + model='universal-2', + workload='background', + status='success', + started_at=started_at, + completed_at=completed_at, + raw_audio_seconds=60.0, + speech_active_seconds=42.0, + billable_seconds=60.0, + estimated_cost_usd=0.37, + retry_count=1, + fallback_count=0, + transcript_segment_count=12, + transcript_word_count=140, + speaker_cluster_count=3, + identified_speaker_cluster_count=2, + identity_confidence_summary={'high': 2, 'unknown': 1}, + artifact_refs={'provider_result': 'gs://bucket/result.json'}, + ) + + run_doc = fake_db.docs[(usage.RUNS_COLLECTION, 'run-1')] + assert run_doc.set_calls[0]['merge'] is False + assert run_doc.set_calls[0]['data']['status'] == 'started' + assert run_doc.set_calls[0]['data']['expires_at'] is not None + assert run_doc.set_calls[1]['merge'] is True + finalized = run_doc.set_calls[1]['data'] + assert finalized['status'] == 'success' + assert finalized['timing']['latency_ms'] == 5000 + assert 'transcript_text' not in finalized + assert 'words' not in finalized + + rollup_doc = fake_db.docs[(usage.DAILY_USAGE_COLLECTION, '2026-05-21:assemblyai:universal-2:background')] + rollup = rollup_doc.set_calls[0]['data'] + assert rollup['run_count'] == {'__increment': 1} + assert rollup['raw_audio_seconds'] == {'__increment': 60.0} + assert rollup['identity_confidence_counts.high'] == {'__increment': 2} + assert emitted[0]['latency_seconds'] == 5.0 + assert emitted[0]['billable_seconds'] == 60.0 + + +def test_rejects_transcript_text_and_chunk_payloads(): + with pytest.raises(ValueError): + usage._reject_forbidden_payload_keys({'transcript_text': 'hello'}) + with pytest.raises(ValueError): + usage._reject_forbidden_payload_keys({'chunks': [{'start': 0}]}) + with pytest.raises(ValueError): + usage._reject_forbidden_payload_keys({'artifact_refs': {'transcript': 'gs://bucket/transcript.txt'}}) + + +def test_utc_daily_bucket_and_rollup_rebuild(monkeypatch): + fake_db = _FakeDb() + collection = fake_db.collection(usage.RUNS_COLLECTION) + included = _FakeSnapshot( + { + 'provider': 'assemblyai', + 'model': 'universal-2', + 'workload': 'background', + 'status': 'success', + 'timing': {'completed_at': datetime(2026, 5, 21, 0, 1, tzinfo=timezone.utc)}, + 'raw_audio_seconds': 10, + 'speech_active_seconds': 6, + 'billable_seconds': 10, + 'estimated_cost_usd': 0.1, + 'retry_count': 1, + 'fallback_count': 0, + 'transcript_segment_count': 4, + 'transcript_word_count': 40, + 'speaker_cluster_count': 2, + 'identified_speaker_cluster_count': 1, + 'identity_confidence_summary': {'high': 1}, + } + ) + excluded = _FakeSnapshot( + { + 'timing': {'completed_at': datetime(2026, 5, 22, 0, 1, tzinfo=timezone.utc)}, + 'status': 'success', + 'raw_audio_seconds': 999, + } + ) + collection.stream = lambda: iter([included, excluded]) + monkeypatch.setattr(usage, 'db', fake_db) + + assert usage.utc_day_bucket(datetime(2026, 5, 21, 7, 1)) == '2026-05-21' + assert ( + usage.daily_rollup_doc_id('2026-05-21', 'provider', 'model/v1', 'sync') == '2026-05-21:provider:model_v1:sync' + ) + + rollup = usage.rebuild_daily_rollup_from_runs('2026-05-21', 'assemblyai', 'universal-2', 'background') + + assert rollup['run_count'] == 1 + assert rollup['raw_audio_seconds'] == 10.0 + assert rollup['status_counts'] == {'success': 1} + assert rollup['identity_confidence_counts'] == {'high': 1} + + +def test_purge_provider_runs_for_user_deletes_top_level_run_records(monkeypatch): + fake_db = _FakeDb() + collection = fake_db.collection(usage.RUNS_COLLECTION) + docs = [_FakeSnapshot({'uid': 'user-1'}), _FakeSnapshot({'uid': 'user-1'})] + collection.stream = lambda: iter(docs) + monkeypatch.setattr(usage, 'db', fake_db) + + deleted = usage.purge_provider_runs_for_user('user-1') + + assert deleted == 2 + assert fake_db.batch_ref.delete.call_count == 2 + fake_db.batch_ref.commit.assert_called_once() + + +def test_metrics_reject_high_cardinality_labels(): + assert metrics.identity_confidence_bucket(None) == 'unknown' + assert metrics.identity_confidence_bucket(0.91) == 'very_high' + with pytest.raises(ValueError): + metrics._provider_metric_labels(provider='assemblyai', user_id='user-1') + with pytest.raises(ValueError): + metrics._provider_metric_labels(provider='assemblyai', transcript_text='hello world') + + +def test_provider_metrics_source_does_not_define_forbidden_label_names(): + forbidden_labels = { + "['provider', 'model', 'workload', 'user_id']", + "['provider', 'model', 'workload', 'conversation_id']", + "['provider', 'model', 'workload', 'provider_job_id']", + "['provider', 'model', 'workload', 'transcript_text']", + } + source = metrics.__loader__.get_source(metrics.__name__) + for label_list in forbidden_labels: + assert label_list not in source diff --git a/backend/utils/metrics.py b/backend/utils/metrics.py index 30d06192d02..45990ffe1be 100644 --- a/backend/utils/metrics.py +++ b/backend/utils/metrics.py @@ -1,4 +1,6 @@ -from prometheus_client import Counter, Gauge, generate_latest, CONTENT_TYPE_LATEST +from typing import Optional + +from prometheus_client import Counter, Gauge, Histogram, generate_latest, CONTENT_TYPE_LATEST from fastapi import Response BACKEND_LISTEN_ACTIVE_WS_CONNECTIONS = Gauge( @@ -26,6 +28,188 @@ 'Number of sessions currently in degraded mode (pusher unavailable)', ) +TRANSCRIPTION_PROVIDER_REQUESTS = Counter( + 'transcription_provider_requests_total', + 'Total transcription provider requests by provider, model, workload, and status', + ['provider', 'model', 'workload', 'status'], +) + +TRANSCRIPTION_PROVIDER_LATENCY_SECONDS = Histogram( + 'transcription_provider_latency_seconds', + 'Transcription provider request latency by provider, model, workload, and status', + ['provider', 'model', 'workload', 'status'], + buckets=(0.5, 1, 2.5, 5, 10, 30, 60, 120, 300, 600, float('inf')), +) + +TRANSCRIPTION_PROVIDER_RETRIES = Counter( + 'transcription_provider_retries_total', + 'Total transcription provider retry attempts by provider, model, workload, and reason', + ['provider', 'model', 'workload', 'reason'], +) + +TRANSCRIPTION_PROVIDER_FALLBACKS = Counter( + 'transcription_provider_fallbacks_total', + 'Total transcription provider fallbacks by from provider, to provider, and workload', + ['from_provider', 'to_provider', 'workload', 'reason'], +) + +TRANSCRIPTION_PROVIDER_AUDIO_SECONDS = Counter( + 'transcription_provider_audio_seconds_total', + 'Total transcription provider audio seconds by provider, model, workload, and kind', + ['provider', 'model', 'workload', 'kind'], +) + +TRANSCRIPTION_PROVIDER_BILLABLE_SECONDS = Counter( + 'transcription_provider_billable_seconds_total', + 'Total transcription provider billable seconds by provider, model, and workload', + ['provider', 'model', 'workload'], +) + +TRANSCRIPTION_PROVIDER_SPEAKER_CLUSTERS = Counter( + 'transcription_provider_speaker_clusters_total', + 'Total transcription provider speaker clusters by provider, model, workload, and kind', + ['provider', 'model', 'workload', 'kind'], +) + +TRANSCRIPTION_PROVIDER_IDENTITY_CONFIDENCE = Counter( + 'transcription_provider_identity_confidence_total', + 'Total speaker identity assignments by provider, model, workload, and confidence bucket', + ['provider', 'model', 'workload', 'bucket'], +) + +TRANSCRIPTION_PROVIDER_ALLOWED_LABELS = { + 'provider', + 'model', + 'workload', + 'status', + 'reason', + 'from_provider', + 'to_provider', + 'kind', + 'bucket', +} + +TRANSCRIPTION_PROVIDER_FORBIDDEN_LABELS = { + 'uid', + 'user_id', + 'conversation_id', + 'provider_job_id', + 'transcript', + 'transcript_text', + 'text', + 'run_id', +} + + +def _provider_metric_labels(**labels: str) -> dict: + unexpected = set(labels) - TRANSCRIPTION_PROVIDER_ALLOWED_LABELS + forbidden = set(labels) & TRANSCRIPTION_PROVIDER_FORBIDDEN_LABELS + if unexpected or forbidden: + raise ValueError(f'Unsafe transcription provider metric labels: {sorted(unexpected | forbidden)}') + return {key: str(value or 'unknown') for key, value in labels.items()} + + +def observe_transcription_provider_request( + provider: str, + model: str, + workload: str, + status: str, + latency_seconds: float, +) -> None: + labels = _provider_metric_labels(provider=provider, model=model, workload=workload, status=status) + TRANSCRIPTION_PROVIDER_REQUESTS.labels(**labels).inc() + if latency_seconds >= 0: + TRANSCRIPTION_PROVIDER_LATENCY_SECONDS.labels(**labels).observe(latency_seconds) + + +def observe_transcription_provider_retry(provider: str, model: str, workload: str, reason: str, count: int = 1) -> None: + if count <= 0: + return + labels = _provider_metric_labels(provider=provider, model=model, workload=workload, reason=reason) + TRANSCRIPTION_PROVIDER_RETRIES.labels(**labels).inc(count) + + +def observe_transcription_provider_fallback( + from_provider: str, + to_provider: str, + workload: str, + reason: str, + count: int = 1, +) -> None: + if count <= 0: + return + labels = _provider_metric_labels( + from_provider=from_provider, + to_provider=to_provider, + workload=workload, + reason=reason, + ) + TRANSCRIPTION_PROVIDER_FALLBACKS.labels(**labels).inc(count) + + +def observe_transcription_provider_audio_seconds( + provider: str, + model: str, + workload: str, + raw_audio_seconds: float = 0.0, + speech_active_seconds: float = 0.0, + billable_seconds: float = 0.0, +) -> None: + for kind, seconds in ( + ('raw', raw_audio_seconds), + ('speech_active', speech_active_seconds), + ('billable', billable_seconds), + ): + if seconds <= 0: + continue + labels = _provider_metric_labels(provider=provider, model=model, workload=workload, kind=kind) + TRANSCRIPTION_PROVIDER_AUDIO_SECONDS.labels(**labels).inc(seconds) + if billable_seconds > 0: + labels = _provider_metric_labels(provider=provider, model=model, workload=workload) + TRANSCRIPTION_PROVIDER_BILLABLE_SECONDS.labels(**labels).inc(billable_seconds) + + +def observe_transcription_provider_speaker_clusters( + provider: str, + model: str, + workload: str, + speaker_cluster_count: int = 0, + identified_speaker_cluster_count: int = 0, +) -> None: + for kind, count in ( + ('provider', speaker_cluster_count), + ('identified', identified_speaker_cluster_count), + ): + if count <= 0: + continue + labels = _provider_metric_labels(provider=provider, model=model, workload=workload, kind=kind) + TRANSCRIPTION_PROVIDER_SPEAKER_CLUSTERS.labels(**labels).inc(count) + + +def identity_confidence_bucket(confidence: Optional[float]) -> str: + if confidence is None: + return 'unknown' + if confidence >= 0.90: + return 'very_high' + if confidence >= 0.75: + return 'high' + if confidence >= 0.50: + return 'medium' + return 'low' + + +def observe_transcription_provider_identity_confidence( + provider: str, + model: str, + workload: str, + bucket: str, + count: int = 1, +) -> None: + if count <= 0: + return + labels = _provider_metric_labels(provider=provider, model=model, workload=workload, bucket=bucket) + TRANSCRIPTION_PROVIDER_IDENTITY_CONFIDENCE.labels(**labels).inc(count) + def metrics_response() -> Response: return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST) From 1117eb89325336d77397bedca3c5fdd614faf3a5 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 08:22:02 +0700 Subject: [PATCH 05/44] Route background transcription through provider service --- backend/routers/sync.py | 23 +- .../unit/test_background_provider_service.py | 126 +++++++ backend/utils/chat.py | 107 +++--- .../conversations/postprocess_conversation.py | 14 +- backend/utils/stt/provider_service.py | 314 ++++++++++++++++++ 5 files changed, 519 insertions(+), 65 deletions(-) create mode 100644 backend/tests/unit/test_background_provider_service.py create mode 100644 backend/utils/stt/provider_service.py diff --git a/backend/routers/sync.py b/backend/routers/sync.py index 4cd3e72fe51..230abc43aa9 100644 --- a/backend/routers/sync.py +++ b/backend/routers/sync.py @@ -61,7 +61,8 @@ from utils.byok import get_byok_keys, set_byok_keys from utils.http_client import _get_semaphore from utils.log_sanitizer import sanitize -from utils.stt.pre_recorded import deepgram_prerecorded, get_deepgram_model_for_language, postprocess_words +from utils.stt import provider_service as stt_provider_service +from utils.stt.providers import STTWorkload from utils.stt.vad import vad_is_empty from utils.fair_use import ( record_speech_ms, @@ -957,29 +958,33 @@ def delete_file(): single_language_mode = prefs.get('single_language_mode', False) if single_language_mode and user_language: - dg_language, dg_model = get_deepgram_model_for_language(user_language) + stt_language, stt_model = stt_provider_service.resolve_prerecorded_language_model(user_language) else: - dg_language, dg_model = get_deepgram_model_for_language('multi') + stt_language, stt_model = stt_provider_service.resolve_prerecorded_language_model('multi') # When single-language mode is active, trust the user's language choice # rather than Deepgram's detection (avoids overriding explicit selection). use_return_language = not (single_language_mode and user_language) - words, detected_language = deepgram_prerecorded( + transcription = stt_provider_service.transcribe_url( url, + workload=STTWorkload.sync, + uid=uid, + conversation_id=target_conversation_id, speakers_count=3, - attempts=0, - return_language=True, - language=dg_language, - model=dg_model, + return_language=use_return_language, + language=stt_language, + model=stt_model, keywords=vocabulary if vocabulary else None, ) + words = transcription.words + detected_language = transcription.detected_language or stt_language language = user_language if (single_language_mode and user_language) else detected_language if not words: # DG processed audio successfully but found no speech (silence/noise). # Real DG failures now raise RuntimeError and are caught by the except block. logger.info(f'No transcript words for segment {path} (silence or noise-only audio)') return - transcript_segments: List[TranscriptSegment] = postprocess_words(words, 0) + transcript_segments: List[TranscriptSegment] = transcription.segments if not transcript_segments: logger.warning(f'Postprocessing returned empty for segment {path} (words present but no segments)') return diff --git a/backend/tests/unit/test_background_provider_service.py b/backend/tests/unit/test_background_provider_service.py new file mode 100644 index 00000000000..342d6fd80dc --- /dev/null +++ b/backend/tests/unit/test_background_provider_service.py @@ -0,0 +1,126 @@ +import os +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +for mod_name in ['deepgram', 'deepgram.clients', 'deepgram.clients.live', 'deepgram.clients.live.v1']: + if mod_name not in sys.modules: + sys.modules[mod_name] = MagicMock() + +sys.modules['deepgram'].DeepgramClient = MagicMock +sys.modules['deepgram'].DeepgramClientOptions = MagicMock + +os.environ.setdefault('DEEPGRAM_API_KEY', 'fake-for-test') + +from models.transcript_segment import ProviderTranscriptResult, ProviderTranscriptWord # noqa: E402 +from utils.stt import provider_service # noqa: E402 +from utils.stt.providers import STTProviderName, STTWorkload, get_prerecorded_provider_name # noqa: E402 + + +def _provider_result(provider='deepgram', model='nova-3'): + return ProviderTranscriptResult( + provider=provider, + model=model, + language='en', + duration=2.0, + words=[ + ProviderTranscriptWord( + text='hello', + start=0.0, + end=0.4, + provider_cluster_id='0', + speaker_label='SPEAKER_00', + ), + ProviderTranscriptWord( + text='world', + start=0.5, + end=1.0, + provider_cluster_id='0', + speaker_label='SPEAKER_00', + ), + ], + ) + + +def test_provider_service_transcribes_sync_upload_and_finalizes_deepgram_run(): + fake_provider = MagicMock() + fake_provider.provider_name = STTProviderName.deepgram + fake_provider.transcribe_url.return_value = (_provider_result(), 'en') + + with patch.object(provider_service, '_deepgram_prerecorded_provider', return_value=fake_provider), patch.object( + provider_service, 'create_provider_run', return_value='run-sync' + ) as create_run, patch.object(provider_service, 'finalize_provider_run') as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.sync, + uid='uid-1', + conversation_id='conversation-1', + return_language=True, + language='multi', + model='nova-3', + keywords=['Omi'], + ) + + assert response.detected_language == 'en' + assert [segment.text for segment in response.segments] == ['Hello world'] + assert response.words[0]['stt_provider'] == 'deepgram' + fake_provider.transcribe_url.assert_called_once() + assert fake_provider.transcribe_url.call_args.kwargs['keywords'] == ['Omi'] + create_run.assert_called_once() + assert create_run.call_args.kwargs['workload'] == 'sync' + finalize_run.assert_called_once() + assert finalize_run.call_args.kwargs['run_id'] == 'run-sync' + assert finalize_run.call_args.kwargs['provider'] == 'deepgram' + assert finalize_run.call_args.kwargs['workload'] == 'sync' + assert finalize_run.call_args.kwargs['status'] == 'succeeded' + assert finalize_run.call_args.kwargs['transcript_segment_count'] == 1 + + +def test_provider_service_finalizes_background_run_on_deepgram_default(): + fake_provider = MagicMock() + fake_provider.provider_name = STTProviderName.deepgram + fake_provider.transcribe_url.return_value = _provider_result() + + with patch.object(provider_service, '_deepgram_prerecorded_provider', return_value=fake_provider), patch.object( + provider_service, 'create_provider_run', return_value='run-background' + ), patch.object(provider_service, 'finalize_provider_run') as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/background.wav', + workload=STTWorkload.background, + uid='uid-1', + model='nova-3', + raw_audio_seconds=9.5, + ) + + assert response.result.provider == 'deepgram' + finalize_run.assert_called_once() + assert finalize_run.call_args.kwargs['workload'] == 'background' + assert finalize_run.call_args.kwargs['provider'] == 'deepgram' + assert finalize_run.call_args.kwargs['raw_audio_seconds'] == 9.5 + + +def test_prerecorded_ptt_and_realtime_related_workloads_stay_deepgram(): + assert get_prerecorded_provider_name(STTWorkload.ptt) == STTProviderName.deepgram + assert get_prerecorded_provider_name(STTWorkload.voice_message) == STTProviderName.deepgram + + +def test_background_call_sites_use_provider_service_layer(): + backend_root = Path(__file__).resolve().parents[2] + with open(backend_root / 'routers/sync.py') as f: + sync_source = f.read() + with open(backend_root / 'utils/conversations/postprocess_conversation.py') as f: + postprocess_source = f.read() + with open(backend_root / 'utils/chat.py') as f: + chat_source = f.read() + + assert 'from utils.stt.pre_recorded import' not in sync_source + assert 'stt_provider_service.transcribe_url' in sync_source + assert 'workload=STTWorkload.sync' in sync_source + + assert 'from utils.stt.pre_recorded import' not in postprocess_source + assert 'stt_provider_service.transcribe_url' in postprocess_source + assert 'workload=STTWorkload.postprocess' in postprocess_source + + assert 'from utils.stt.pre_recorded import' not in chat_source + assert 'workload=STTWorkload.voice_message' in chat_source + assert 'workload=STTWorkload.ptt' in chat_source diff --git a/backend/utils/chat.py b/backend/utils/chat.py index d58acd85394..f6ad86f7637 100644 --- a/backend/utils/chat.py +++ b/backend/utils/chat.py @@ -19,12 +19,8 @@ from utils.notifications import send_notification from utils.other.storage import get_syncing_file_temporal_signed_url, delete_syncing_temporal_file from utils.retrieval.graph import execute_graph_chat, execute_graph_chat_stream -from utils.stt.pre_recorded import ( - deepgram_prerecorded, - deepgram_prerecorded_from_bytes, - postprocess_words, - get_deepgram_model_for_language, -) +from utils.stt import provider_service as stt_provider_service +from utils.stt.providers import STTWorkload from utils.llm.usage_tracker import track_usage, set_usage_context, reset_usage_context, Features import logging @@ -73,27 +69,28 @@ def delete_file(): if not language: language = resolve_voice_message_language(uid, None) - # Get the appropriate Deepgram model for this language - stt_language, stt_model = get_deepgram_model_for_language(language) + stt_language, stt_model = stt_provider_service.resolve_prerecorded_language_model(language) is_multi = stt_language == 'multi' try: - if is_multi: - words, detected_language = deepgram_prerecorded( - url, diarize=False, language=stt_language, return_language=True, model=stt_model - ) - else: - words = deepgram_prerecorded( - url, diarize=False, language=stt_language, return_language=False, model=stt_model - ) - detected_language = stt_language + transcription = stt_provider_service.transcribe_url( + url, + workload=STTWorkload.voice_message, + uid=uid, + diarize=False, + language=stt_language, + return_language=is_multi, + model=stt_model, + ) + words = transcription.words + detected_language = transcription.detected_language if is_multi else stt_language except RuntimeError as e: logger.error(f'Voice message transcription failed for {path}: {e}') return None, stt_language if not is_multi else 'en' if not words: logger.info('no words') return None, detected_language - transcript_segments: List[TranscriptSegment] = postprocess_words(words, 0) + transcript_segments: List[TranscriptSegment] = transcription.segments del words if not transcript_segments: logger.error('failed to get deepgram segments') @@ -125,41 +122,31 @@ def transcribe_pcm_bytes( if not language: language = resolve_voice_message_language(uid, None) - stt_language, stt_model = get_deepgram_model_for_language(language) + stt_language, stt_model = stt_provider_service.resolve_prerecorded_language_model(language) is_multi = stt_language == 'multi' # Let RuntimeError propagate so the router can distinguish backend failure from no-speech - if is_multi: - result = deepgram_prerecorded_from_bytes( - audio_bytes, - sample_rate=sample_rate, - diarize=False, - encoding=encoding, - channels=channels, - language=stt_language, - model=stt_model, - return_language=True, - keywords=keywords, - ) - words, detected_language = result - else: - words = deepgram_prerecorded_from_bytes( - audio_bytes, - sample_rate=sample_rate, - diarize=False, - encoding=encoding, - channels=channels, - language=stt_language, - model=stt_model, - keywords=keywords, - ) - detected_language = stt_language + transcription = stt_provider_service.transcribe_bytes( + audio_bytes, + workload=STTWorkload.ptt, + uid=uid, + sample_rate=sample_rate, + diarize=False, + encoding=encoding, + channels=channels, + language=stt_language, + model=stt_model, + return_language=is_multi, + keywords=keywords, + ) + words = transcription.words + detected_language = transcription.detected_language if is_multi else stt_language if not words: logger.info('transcribe_pcm_bytes: no words') return None, detected_language - transcript_segments: List[TranscriptSegment] = postprocess_words(words, 0) + transcript_segments: List[TranscriptSegment] = transcription.segments del words if not transcript_segments: logger.error('transcribe_pcm_bytes: failed to get segments') @@ -190,15 +177,22 @@ def delete_file(): if not language: language = resolve_voice_message_language(uid, None) - # Get the appropriate Deepgram model for this language - stt_language, stt_model = get_deepgram_model_for_language(language) + stt_language, stt_model = stt_provider_service.resolve_prerecorded_language_model(language) try: - words = deepgram_prerecorded(url, diarize=False, language=stt_language, model=stt_model) + transcription = stt_provider_service.transcribe_url( + url, + workload=STTWorkload.voice_message, + uid=uid, + diarize=False, + language=stt_language, + model=stt_model, + ) + words = transcription.words except RuntimeError as e: logger.error(f'Voice message transcription failed for {path}: {e}') return [] - transcript_segments: List[TranscriptSegment] = postprocess_words(words, 0) + transcript_segments: List[TranscriptSegment] = transcription.segments del words if not transcript_segments: logger.error('failed to get deepgram segments') @@ -264,15 +258,22 @@ def delete_file(): if not language: language = resolve_voice_message_language(uid, None) - # Get the appropriate Deepgram model for this language - stt_language, stt_model = get_deepgram_model_for_language(language) + stt_language, stt_model = stt_provider_service.resolve_prerecorded_language_model(language) try: - words = deepgram_prerecorded(url, diarize=False, language=stt_language, model=stt_model) + transcription = stt_provider_service.transcribe_url( + url, + workload=STTWorkload.voice_message, + uid=uid, + diarize=False, + language=stt_language, + model=stt_model, + ) + words = transcription.words except RuntimeError as e: logger.error(f'Voice message transcription failed for {path}: {e}') return - transcript_segments: List[TranscriptSegment] = postprocess_words(words, 0) + transcript_segments: List[TranscriptSegment] = transcription.segments del words if not transcript_segments: logger.error('failed to get deepgram segments') diff --git a/backend/utils/conversations/postprocess_conversation.py b/backend/utils/conversations/postprocess_conversation.py index a2cc66375a3..a4ea6a94882 100644 --- a/backend/utils/conversations/postprocess_conversation.py +++ b/backend/utils/conversations/postprocess_conversation.py @@ -15,7 +15,8 @@ from models.transcript_segment import TranscriptSegment from utils.conversations.process_conversation import process_conversation, process_user_emotion from utils.other.storage import upload_postprocessing_audio, delete_postprocessing_audio, upload_conversation_recording -from utils.stt.pre_recorded import deepgram_prerecorded, postprocess_words +from utils.stt import provider_service as stt_provider_service +from utils.stt.providers import STTWorkload from utils.stt.speech_profile import get_speech_profile_matching_predictions from utils.stt.vad import vad_is_empty import logging @@ -79,8 +80,15 @@ def postprocess_conversation( upload_conversation_recording(file_path, uid, conversation_id) speakers_count = len(set([segment.speaker for segment in conversation.transcript_segments])) - words = deepgram_prerecorded(signed_url, speakers_count=speakers_count) - fal_segments = postprocess_words(words, aseg.duration_seconds) + transcription = stt_provider_service.transcribe_url( + signed_url, + workload=STTWorkload.postprocess, + uid=uid, + conversation_id=conversation_id, + speakers_count=speakers_count, + raw_audio_seconds=aseg.duration_seconds, + ) + fal_segments = transcription.segments # if new transcript is 90% shorter than the original, cancel post-processing, smth wrong with audio or FAL count = len(''.join([segment.text.strip() for segment in conversation.transcript_segments])) diff --git a/backend/utils/stt/provider_service.py b/backend/utils/stt/provider_service.py new file mode 100644 index 00000000000..a05fb73a659 --- /dev/null +++ b/backend/utils/stt/provider_service.py @@ -0,0 +1,314 @@ +import logging +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import List, Optional, Sequence, Tuple + +from models.transcript_segment import ProviderTranscriptResult, TranscriptSegment +from utils.stt.conversation_reconstructor import reconstruct_conversation +from utils.stt.deepgram_adapter import provider_result_to_legacy_words +from utils.stt.providers import STTProviderName, STTWorkload, get_prerecorded_provider_name + +logger = logging.getLogger(__name__) + + +def create_provider_run(**kwargs) -> str: + from database.transcription_provider_usage import create_provider_run as _create_provider_run + + return _create_provider_run(**kwargs) + + +def finalize_provider_run(**kwargs) -> None: + from database.transcription_provider_usage import finalize_provider_run as _finalize_provider_run + + _finalize_provider_run(**kwargs) + + +def summarize_identity_confidences(confidences): + from database.transcription_provider_usage import summarize_identity_confidences as _summarize_identity_confidences + + return _summarize_identity_confidences(confidences) + + +def _deepgram_prerecorded_provider(): + from utils.stt.pre_recorded import _deepgram_prerecorded_provider as _provider + + return _provider() + + +def get_deepgram_model_for_language(language: str) -> Tuple[str, str]: + from utils.stt.pre_recorded import get_deepgram_model_for_language as _get_deepgram_model_for_language + + return _get_deepgram_model_for_language(language) + + +@dataclass +class PrerecordedTranscriptionResponse: + result: ProviderTranscriptResult + detected_language: Optional[str] + segments: List[TranscriptSegment] + words: List[dict] + run_id: Optional[str] + + +def resolve_prerecorded_language_model(language: Optional[str]) -> Tuple[str, str]: + return get_deepgram_model_for_language(language or 'multi') + + +def transcribe_url( + audio_url: str, + workload: STTWorkload, + uid: Optional[str] = None, + conversation_id: Optional[str] = None, + speakers_count: int = None, + return_language: bool = False, + diarize: bool = True, + language: Optional[str] = None, + model: str = 'nova-3', + keywords: Optional[Sequence[str]] = None, + skip_n_seconds: int = 0, + raw_audio_seconds: float = 0.0, +) -> PrerecordedTranscriptionResponse: + workload = STTWorkload(workload) + provider_name = get_prerecorded_provider_name(workload) + provider = _get_prerecorded_provider(provider_name) + started_at = datetime.now(timezone.utc) + run_id = _create_run(uid, provider_name.value, model, workload.value, conversation_id, started_at) + + try: + provider_result, detected_language, retry_count = _transcribe_url_with_retry( + provider, + audio_url, + speakers_count=speakers_count, + return_language=return_language, + diarize=diarize, + language=language, + model=model, + keywords=keywords, + ) + segments = reconstruct_conversation(provider_result, skip_n_seconds=skip_n_seconds) + words = provider_result_to_legacy_words(provider_result) + _finalize_run( + run_id, + provider_result, + workload, + started_at, + 'succeeded', + retry_count=retry_count, + raw_audio_seconds=raw_audio_seconds or provider_result.duration or 0.0, + segments=segments, + ) + return PrerecordedTranscriptionResponse( + result=provider_result, + detected_language=detected_language, + segments=segments, + words=words, + run_id=run_id, + ) + except Exception as e: + _finalize_failed_run(run_id, provider_name.value, model, workload.value, started_at, e, raw_audio_seconds) + raise RuntimeError(f'{provider_name.value} transcription failed after 2 attempts: {e}') + + +def transcribe_bytes( + audio_bytes: bytes, + workload: STTWorkload, + uid: Optional[str] = None, + conversation_id: Optional[str] = None, + sample_rate: int = 16000, + diarize: bool = True, + encoding: Optional[str] = None, + channels: int = 1, + language: Optional[str] = None, + model: str = 'nova-3', + return_language: bool = False, + keywords: Optional[Sequence[str]] = None, + skip_n_seconds: int = 0, + raw_audio_seconds: float = 0.0, +) -> PrerecordedTranscriptionResponse: + workload = STTWorkload(workload) + provider_name = get_prerecorded_provider_name(workload) + provider = _get_prerecorded_provider(provider_name) + started_at = datetime.now(timezone.utc) + run_id = _create_run(uid, provider_name.value, model, workload.value, conversation_id, started_at) + + try: + provider_result, detected_language, retry_count = _transcribe_bytes_with_retry( + provider, + audio_bytes, + sample_rate=sample_rate, + diarize=diarize, + encoding=encoding, + channels=channels, + language=language, + model=model, + return_language=return_language, + keywords=keywords, + ) + segments = reconstruct_conversation(provider_result, skip_n_seconds=skip_n_seconds) + words = provider_result_to_legacy_words(provider_result) + _finalize_run( + run_id, + provider_result, + workload, + started_at, + 'succeeded', + retry_count=retry_count, + raw_audio_seconds=raw_audio_seconds or provider_result.duration or 0.0, + segments=segments, + ) + return PrerecordedTranscriptionResponse( + result=provider_result, + detected_language=detected_language, + segments=segments, + words=words, + run_id=run_id, + ) + except Exception as e: + _finalize_failed_run(run_id, provider_name.value, model, workload.value, started_at, e, raw_audio_seconds) + raise RuntimeError(f'{provider_name.value} transcription failed after 2 attempts: {e}') + + +def _get_prerecorded_provider(provider_name: STTProviderName): + if provider_name == STTProviderName.deepgram: + return _deepgram_prerecorded_provider() + raise ValueError(f'Unsupported prerecorded STT provider: {provider_name}') + + +def _transcribe_url_with_retry( + provider, audio_url: str, **kwargs +) -> Tuple[ProviderTranscriptResult, Optional[str], int]: + last_error = None + for attempt in range(2): + try: + result = provider.transcribe_url(audio_url, **kwargs) + transcript_result, detected_language = _unpack_provider_result(result, kwargs.get('return_language')) + return transcript_result, detected_language, attempt + except Exception as e: + last_error = e + logger.error( + 'provider prerecorded url transcription error attempt=%s provider=%s: %s', + attempt, + provider.provider_name, + e, + ) + raise last_error + + +def _transcribe_bytes_with_retry( + provider, audio_bytes: bytes, **kwargs +) -> Tuple[ProviderTranscriptResult, Optional[str], int]: + last_error = None + for attempt in range(2): + try: + result = provider.transcribe_bytes(audio_bytes, **kwargs) + transcript_result, detected_language = _unpack_provider_result(result, kwargs.get('return_language')) + return transcript_result, detected_language, attempt + except Exception as e: + last_error = e + logger.error( + 'provider prerecorded bytes transcription error attempt=%s provider=%s: %s', + attempt, + provider.provider_name, + e, + ) + raise last_error + + +def _unpack_provider_result(result, return_language: bool) -> Tuple[ProviderTranscriptResult, Optional[str]]: + if return_language: + transcript_result, detected_language = result + return transcript_result, detected_language + transcript_result = result + return transcript_result, transcript_result.language + + +def _create_run( + uid: Optional[str], + provider: str, + model: str, + workload: str, + conversation_id: Optional[str], + started_at: datetime, +) -> Optional[str]: + if not uid: + return None + try: + return create_provider_run( + uid=uid, + provider=provider, + model=model, + workload=workload, + conversation_id=conversation_id, + started_at=started_at, + ) + except Exception as e: + logger.warning('failed to create transcription provider run ledger uid=%s workload=%s: %s', uid, workload, e) + return None + + +def _finalize_run( + run_id: Optional[str], + result: ProviderTranscriptResult, + workload: STTWorkload, + started_at: datetime, + status: str, + retry_count: int, + raw_audio_seconds: float, + segments: List[TranscriptSegment], +) -> None: + if not run_id: + return + clusters = { + item.provider_cluster_id for item in list(result.words) + list(result.utterances) if item.provider_cluster_id + } + confidences = [segment.speaker_identity_confidence for segment in segments] + try: + finalize_provider_run( + run_id=run_id, + provider=result.provider, + model=result.model or 'unknown', + workload=workload.value, + status=status, + started_at=started_at, + raw_audio_seconds=raw_audio_seconds, + speech_active_seconds=raw_audio_seconds, + billable_seconds=raw_audio_seconds, + retry_count=retry_count, + transcript_segment_count=len(segments), + transcript_word_count=len(result.words), + speaker_cluster_count=len(clusters), + identified_speaker_cluster_count=len( + {segment.provider_cluster_id for segment in segments if segment.person_id} + ), + identity_confidence_summary=summarize_identity_confidences(confidences), + ) + except Exception as e: + logger.warning('failed to finalize transcription provider run ledger run_id=%s: %s', run_id, e) + + +def _finalize_failed_run( + run_id: Optional[str], + provider: str, + model: str, + workload: str, + started_at: datetime, + error: Exception, + raw_audio_seconds: float, +) -> None: + if not run_id: + return + try: + finalize_provider_run( + run_id=run_id, + provider=provider, + model=model, + workload=workload, + status='failed', + started_at=started_at, + raw_audio_seconds=raw_audio_seconds, + error_class=error.__class__.__name__, + ) + except Exception as finalize_error: + logger.warning( + 'failed to finalize failed transcription provider run ledger run_id=%s: %s', run_id, finalize_error + ) From 79bfafcdf95c05a566acbc0f5779f0d6eaf2f38b Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 08:30:21 +0700 Subject: [PATCH 06/44] Implement cluster-scoped speaker identity --- backend/models/message_event.py | 7 + backend/models/transcript_segment.py | 5 +- backend/routers/sync.py | 125 +----- .../unit/test_background_speaker_identity.py | 167 ++++++++ .../utils/stt/background_speaker_identity.py | 392 ++++++++++++++++++ backend/utils/stt/provider_service.py | 20 +- 6 files changed, 602 insertions(+), 114 deletions(-) create mode 100644 backend/tests/unit/test_background_speaker_identity.py create mode 100644 backend/utils/stt/background_speaker_identity.py diff --git a/backend/models/message_event.py b/backend/models/message_event.py index 0bd386689f5..cca36fd3146 100644 --- a/backend/models/message_event.py +++ b/backend/models/message_event.py @@ -154,6 +154,13 @@ class SpeakerLabelSuggestionEvent(MessageEvent): person_id: str person_name: str segment_id: str + version: int = 1 + provider_cluster_id: Optional[str] = None + speaker_identity_state: Optional[str] = None + confidence: Optional[float] = None + source: Optional[str] = None + provenance: Optional[dict[str, Any]] = None + candidates: Optional[List[dict[str, Any]]] = None def to_json(self): j = self.model_dump(mode="json") diff --git a/backend/models/transcript_segment.py b/backend/models/transcript_segment.py index 9f8af3b0414..3dcbbb35c94 100644 --- a/backend/models/transcript_segment.py +++ b/backend/models/transcript_segment.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Literal, Optional, List, Tuple +from typing import Any, Literal, Optional, List, Tuple import uuid import re from pydantic import BaseModel, Field @@ -74,6 +74,9 @@ class TranscriptSegment(BaseModel): speaker_identity_confidence: Optional[float] = None speaker_identity_source: Optional[str] = None speaker_identity_version: Optional[str] = None + speaker_identity_provenance: Optional[dict[str, Any]] = None + speaker_identity_candidates: Optional[List[dict[str, Any]]] = None + speaker_identity_text_hints: Optional[List[dict[str, Any]]] = None def __init__(self, **data): super().__init__(**data) diff --git a/backend/routers/sync.py b/backend/routers/sync.py index 230abc43aa9..e9b5f698b7f 100644 --- a/backend/routers/sync.py +++ b/backend/routers/sync.py @@ -76,13 +76,8 @@ FAIR_USE_ENABLED, FAIR_USE_RESTRICT_DAILY_DG_MS, ) -from utils.speaker_assignment import process_speaker_assigned_segments -from utils.speaker_identification import detect_speaker_from_text -from utils.stt.speaker_embedding import ( - extract_embedding_from_bytes, - compare_embeddings, - SPEAKER_MATCH_THRESHOLD, -) +from utils.stt.background_speaker_identity import identify_background_speaker_clusters +from utils.stt.speaker_embedding import extract_embedding_from_bytes from utils.subscription import has_transcription_credits logger = logging.getLogger(__name__) @@ -822,111 +817,21 @@ def identify_speakers_for_segments( person_embeddings_cache: Dict[str, dict], uid: str, ) -> None: - """Identify speakers in transcript segments using voice embeddings and text detection. - - Modifies segments in-place by assigning person_id and is_user fields. + """Identify background speakers once per provider cluster. - Steps: - 1. Voice embedding matching (requires audio_bytes and non-empty cache): - For each unique speaker_id, find the longest segment (>=1s), extract audio clip, - get embedding, match against person_embeddings_cache. - 2. Text-based detection ("I am X") runs independently for all unmatched speakers. - 3. Apply assignments via process_speaker_assigned_segments. + Text self-introduction is retained as hint metadata only. It must not create + or apply a durable speaker identity without voice evidence. """ - speaker_to_person_map: Dict[int, Tuple[str, str]] = {} - segment_person_assignment_map: Dict[str, str] = {} - - # Group segments by speaker_id, find best (longest) segment per speaker for embedding - speaker_segments: Dict[int, List[TranscriptSegment]] = {} - for seg in transcript_segments: - sid = seg.speaker_id if seg.speaker_id is not None else 0 - speaker_segments.setdefault(sid, []).append(seg) - - # Voice embedding matching (only when audio and cached embeddings are available) - # Track matched person_ids so each person is only assigned to one speaker - # (diarization tells us speakers are distinct — no person can be two speakers). - matched_person_ids: set = set() - - if audio_bytes and person_embeddings_cache: - # Sort speakers by best single segment duration (longest first) — this is the clip - # actually used for embedding, so it determines match quality. - # Note: matched_person_ids assumes diarization is correct (one person = one speaker). - # If diarization fragments one person across speaker IDs, only the best match wins. - sorted_speakers = sorted( - speaker_segments.items(), - key=lambda kv: max(s.end - s.start for s in kv[1]), - reverse=True, - ) - - for speaker_id, segments in sorted_speakers: - best_seg = max(segments, key=lambda s: s.end - s.start) - seg_duration = best_seg.end - best_seg.start - - if seg_duration < SPEAKER_ID_MIN_AUDIO: - continue - - clip_wav = _extract_speaker_clip_wav(audio_bytes, best_seg.start, best_seg.end) - if not clip_wav: - continue - - try: - query_embedding = extract_embedding_from_bytes(clip_wav, "sync_speaker.wav") - except (ValueError, Exception) as e: - logger.info(f'Speaker ID: embedding extraction failed for speaker {speaker_id}: {e} uid={uid}') - continue - - # Compare only against unmatched candidates (each person can be one speaker) - best_match = None - best_distance = float('inf') - for person_id, data in person_embeddings_cache.items(): - if person_id in matched_person_ids: - continue - distance = compare_embeddings(query_embedding, data['embedding']) - if distance < best_distance: - best_distance = distance - best_match = (person_id, data['name']) - - if best_match and best_distance < SPEAKER_MATCH_THRESHOLD: - person_id, person_name = best_match - speaker_to_person_map[speaker_id] = (person_id, person_name) - segment_person_assignment_map[best_seg.id] = person_id - matched_person_ids.add(person_id) - logger.info( - f'Speaker ID (sync): speaker {speaker_id} -> {person_id} ' - f'(distance={best_distance:.3f}) uid={uid}' - ) - - # Text-based detection runs independently for all unmatched speakers. - # For speaker_id > 0 (diarized): update both speaker_to_person_map and per-segment map. - # For speaker_id <= 0 (undiarized): only assign per-segment (avoid mapping all speaker_id=0 - # segments to one person when diarization is inactive). - for speaker_id, segments in speaker_segments.items(): - if speaker_id in speaker_to_person_map: - continue - for seg in segments: - detected_name = detect_speaker_from_text(seg.text) - if detected_name: - person = users_db.get_person_by_name(uid, detected_name) - if person: - # Per-segment assignment always applies - segment_person_assignment_map[seg.id] = person['id'] - # Update speaker map only when diarization is active - if speaker_id > 0: - speaker_to_person_map[speaker_id] = (person['id'], person['name']) - logger.info( - f'Speaker ID (sync): text detection speaker {speaker_id} -> ' - f'{person["id"]} via "{detected_name}" uid={uid}' - ) - if speaker_id > 0: - break # One match per diarized speaker is enough - - # Apply all assignments to segments - if speaker_to_person_map or segment_person_assignment_map: - process_speaker_assigned_segments( - transcript_segments, - segment_person_assignment_map, - speaker_to_person_map, - ) + assignments = identify_background_speaker_clusters( + transcript_segments, + audio_bytes, + person_embeddings_cache or {}, + embedding_extractor=extract_embedding_from_bytes, + ) + identified_count = sum(1 for assignment in assignments.values() if assignment.state in ('identified', 'user')) + logger.info( + f'Speaker ID (sync): cluster identity assignments={len(assignments)} identified={identified_count} uid={uid}' + ) def process_segment( diff --git a/backend/tests/unit/test_background_speaker_identity.py b/backend/tests/unit/test_background_speaker_identity.py new file mode 100644 index 00000000000..f49744637f2 --- /dev/null +++ b/backend/tests/unit/test_background_speaker_identity.py @@ -0,0 +1,167 @@ +import io +import wave + +import numpy as np + +from models.message_event import SpeakerLabelSuggestionEvent +from models.transcript_segment import TranscriptSegment +from utils.stt.background_speaker_identity import ( + SPEAKER_IDENTITY_SOURCE, + SPEAKER_IDENTITY_VERSION, + identify_background_speaker_clusters, + select_representative_cluster_spans, +) + + +def _segment(segment_id, start, end, cluster='cluster-a', speaker='SPEAKER_01', text='hello from speaker'): + return TranscriptSegment( + id=segment_id, + text=text, + speaker=speaker, + is_user=False, + start=start, + end=end, + provider_cluster_id=cluster, + provider_speaker_label=speaker, + speaker_identity_state='unassigned', + ) + + +def _wav_bytes(duration_seconds=30, sample_rate=16000): + samples = np.zeros(int(duration_seconds * sample_rate), dtype=np.int16) + out = io.BytesIO() + with wave.open(out, 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(samples.tobytes()) + return out.getvalue() + + +def test_cluster_sampling_prefers_clean_spans_and_caps_total_duration(): + cluster_segments = [ + _segment('short', 0.0, 0.5), + _segment('overlap', 1.0, 8.0), + _segment('clean-long', 9.0, 24.0), + _segment('clean-second', 24.5, 33.0), + ] + all_segments = cluster_segments + [_segment('other-overlap', 2.0, 3.0, cluster='cluster-b')] + + spans = select_representative_cluster_spans(cluster_segments, all_segments) + + assert [span.segment_id for span in spans] == ['clean-long'] + assert sum(span.duration for span in spans) == 10.0 + assert all(span.segment_id != 'overlap' for span in spans) + assert all(span.duration >= 1.0 for span in spans) + + +def test_cluster_identity_applies_one_voice_assignment_to_all_cluster_segments(): + alice_embedding = np.array([[1.0, 0.0]], dtype=np.float32) + bob_embedding = np.array([[0.0, 1.0]], dtype=np.float32) + segments = [ + _segment('a1', 0.0, 3.0, cluster='cluster-a', speaker='SPEAKER_01'), + _segment('a2', 3.5, 6.5, cluster='cluster-a', speaker='SPEAKER_01'), + _segment('b1', 7.0, 10.0, cluster='cluster-b', speaker='SPEAKER_02'), + ] + cache = { + 'person-alice': {'embedding': alice_embedding, 'name': 'Alice'}, + 'person-bob': {'embedding': bob_embedding, 'name': 'Bob'}, + } + + def fake_extract(_audio, _filename): + return alice_embedding + + assignments = identify_background_speaker_clusters(segments, _wav_bytes(), cache, embedding_extractor=fake_extract) + + assert assignments['cluster-a'].person_id == 'person-alice' + assert assignments['cluster-a'].confidence == 1.0 + assert segments[0].person_id == 'person-alice' + assert segments[1].person_id == 'person-alice' + assert segments[0].speaker_identity_source == SPEAKER_IDENTITY_SOURCE + assert segments[0].speaker_identity_version == SPEAKER_IDENTITY_VERSION + assert segments[0].speaker_identity_provenance['provider_cluster_id'] == 'cluster-a' + assert segments[0].speaker_identity_candidates[0]['person_id'] == 'person-alice' + + +def test_low_confidence_cluster_remains_explicitly_unknown_with_candidate_metadata(): + query_embedding = np.array([[1.0, 0.0]], dtype=np.float32) + distant_embedding = np.array([[0.0, 1.0]], dtype=np.float32) + segments = [_segment('a1', 0.0, 3.0)] + cache = {'person-distant': {'embedding': distant_embedding, 'name': 'Distant'}} + + assignments = identify_background_speaker_clusters( + segments, + _wav_bytes(), + cache, + embedding_extractor=lambda _audio, _filename: query_embedding, + ) + + assert assignments['cluster-a'].state == 'unknown' + assert assignments['cluster-a'].reason == 'below_threshold' + assert segments[0].speaker_identity_state == 'unknown' + assert segments[0].person_id is None + assert segments[0].speaker_identity_candidates[0]['person_id'] == 'person-distant' + assert segments[0].speaker_identity_confidence is None + + +def test_text_self_introduction_is_hint_only_without_voice_assignment(): + segments = [_segment('intro', 0.0, 3.0, text='I am Alice and I joined the call.')] + + assignments = identify_background_speaker_clusters(segments, audio_bytes=None, person_embeddings_cache={}) + + assert assignments['cluster-a'].state == 'unknown' + assert assignments['cluster-a'].text_hints[0]['detected_name'] == 'Alice' + assert segments[0].person_id is None + assert segments[0].speaker_identity_state == 'unknown' + assert segments[0].speaker_identity_text_hints[0]['source'] == 'text_self_introduction' + + +def test_user_sentinel_is_not_persisted_as_durable_person_identity(): + user_embedding = np.array([[1.0, 0.0]], dtype=np.float32) + segments = [_segment('me', 0.0, 4.0)] + cache = {'user': {'embedding': user_embedding, 'name': 'User'}} + + identify_background_speaker_clusters( + segments, + _wav_bytes(), + cache, + embedding_extractor=lambda _audio, _filename: user_embedding, + ) + + assert segments[0].is_user is True + assert segments[0].person_id is None + assert segments[0].speaker_identity_state == 'user' + assert segments[0].speaker_identity_candidates[0]['person_id'] is None + assert segments[0].speaker_identity_candidates[0]['is_user'] is True + + +def test_speaker_label_suggestion_event_preserves_legacy_shape_and_accepts_cluster_metadata(): + legacy = SpeakerLabelSuggestionEvent( + speaker_id=1, + person_id='person-1', + person_name='Alice', + segment_id='segment-1', + ).to_json() + + assert legacy['type'] == 'speaker_label_suggestion' + assert legacy['version'] == 1 + assert legacy['speaker_id'] == 1 + assert legacy['person_id'] == 'person-1' + + extended = SpeakerLabelSuggestionEvent( + speaker_id=1, + person_id='person-1', + person_name='Alice', + segment_id='segment-1', + version=2, + provider_cluster_id='cluster-a', + speaker_identity_state='identified', + confidence=0.91, + source=SPEAKER_IDENTITY_SOURCE, + provenance={'sample_seconds': 6.0}, + candidates=[{'person_id': 'person-1', 'confidence': 0.91}], + ).to_json() + + assert extended['version'] == 2 + assert extended['provider_cluster_id'] == 'cluster-a' + assert extended['provenance']['sample_seconds'] == 6.0 diff --git a/backend/utils/stt/background_speaker_identity.py b/backend/utils/stt/background_speaker_identity.py new file mode 100644 index 00000000000..2bb98c7df03 --- /dev/null +++ b/backend/utils/stt/background_speaker_identity.py @@ -0,0 +1,392 @@ +import io +import logging +import re +import wave +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional + +import numpy as np + +from models.transcript_segment import TranscriptSegment + +logger = logging.getLogger(__name__) + +USER_SELF_PERSON_ID = 'user' +SPEAKER_MATCH_THRESHOLD = 0.45 +SPEAKER_IDENTITY_VERSION = 'omi-speaker-identity:v1' +SPEAKER_IDENTITY_SOURCE = 'omi_speaker_embedding' +UNKNOWN_SPEAKER_IDENTITY_SOURCE = 'omi_speaker_embedding:no_match' +TEXT_HINT_SOURCE = 'text_self_introduction' + +MIN_CLUSTER_SPAN_SECONDS = 1.0 +PREFERRED_CLUSTER_SAMPLE_SECONDS = 5.0 +MAX_CLUSTER_SAMPLE_SECONDS = 20.0 +MAX_CLUSTER_SPAN_SECONDS = 10.0 +OVERLAP_TOLERANCE_SECONDS = 0.1 + +SELF_INTRODUCTION_HINT_PATTERNS = [ + r"\b(I am|I'm|i am|i'm|My name is|my name is)\s+([A-Z][a-zA-Z]*)\b", + r"\b([A-Z][a-zA-Z]*)\s+is my name\b", +] + + +@dataclass +class ClusterSampleSpan: + start: float + end: float + segment_id: str + duration: float + + def as_dict(self) -> dict: + return { + 'start': round(self.start, 3), + 'end': round(self.end, 3), + 'segment_id': self.segment_id, + 'duration': round(self.duration, 3), + } + + +@dataclass +class ClusterIdentityAssignment: + provider_cluster_id: str + speaker_id: Optional[int] + state: str + person_id: Optional[str] = None + person_name: Optional[str] = None + is_user: bool = False + confidence: Optional[float] = None + distance: Optional[float] = None + source: str = UNKNOWN_SPEAKER_IDENTITY_SOURCE + version: str = SPEAKER_IDENTITY_VERSION + candidates: List[dict] = field(default_factory=list) + sample_spans: List[ClusterSampleSpan] = field(default_factory=list) + text_hints: List[dict] = field(default_factory=list) + reason: Optional[str] = None + + def provenance(self) -> dict: + data = { + 'provider_cluster_id': self.provider_cluster_id, + 'speaker_id': self.speaker_id, + 'sample_spans': [span.as_dict() for span in self.sample_spans], + 'sample_seconds': round(sum(span.duration for span in self.sample_spans), 3), + } + if self.distance is not None: + data['distance'] = round(self.distance, 6) + if self.reason: + data['reason'] = self.reason + return data + + +def identify_background_speaker_clusters( + transcript_segments: List[TranscriptSegment], + audio_bytes: Optional[bytes], + person_embeddings_cache: Dict[str, dict], + embedding_extractor: Optional[Callable[[bytes, str], np.ndarray]] = None, + match_threshold: float = SPEAKER_MATCH_THRESHOLD, +) -> Dict[str, ClusterIdentityAssignment]: + """Assign canonical Omi identity once per provider speaker cluster. + + Text self-introductions are recorded only as hints. Durable assignment + requires voice evidence from the existing Omi speaker embedding service or + a later explicit user correction. + """ + clusters = _group_segments_by_cluster(transcript_segments) + assignments: Dict[str, ClusterIdentityAssignment] = {} + matched_person_ids: set[str] = set() + + for cluster_id, cluster_segments in clusters.items(): + text_hints = _collect_text_hints(cluster_segments, cluster_id) + speaker_id = _cluster_speaker_id(cluster_segments) + sample_spans = select_representative_cluster_spans(cluster_segments, transcript_segments) + assignment = ClusterIdentityAssignment( + provider_cluster_id=cluster_id, + speaker_id=speaker_id, + state='unknown', + sample_spans=sample_spans, + text_hints=text_hints, + ) + + if not audio_bytes: + assignment.reason = 'missing_audio' + assignments[cluster_id] = assignment + continue + if not person_embeddings_cache: + assignment.reason = 'missing_candidate_embeddings' + assignments[cluster_id] = assignment + continue + if not embedding_extractor: + assignment.reason = 'missing_embedding_extractor' + assignments[cluster_id] = assignment + continue + if not sample_spans: + assignment.reason = 'no_usable_cluster_spans' + assignments[cluster_id] = assignment + continue + + sample_wav = extract_cluster_sample_wav(audio_bytes, sample_spans) + if not sample_wav: + assignment.reason = 'sample_extraction_failed' + assignments[cluster_id] = assignment + continue + + try: + query_embedding = embedding_extractor(sample_wav, f'cluster_{_safe_filename(cluster_id)}.wav') + except Exception as e: + logger.info('background speaker identity embedding failed cluster=%s: %s', cluster_id, e) + assignment.reason = 'embedding_extraction_failed' + assignments[cluster_id] = assignment + continue + + candidates = _rank_candidates(query_embedding, person_embeddings_cache, matched_person_ids, match_threshold) + assignment.candidates = _public_candidates(candidates) + best = candidates[0] if candidates else None + if not best or best['distance'] >= match_threshold: + assignment.reason = 'below_threshold' + assignments[cluster_id] = assignment + continue + + match_person_id = best['_match_person_id'] + matched_person_ids.add(match_person_id) + assignment.state = 'user' if match_person_id == USER_SELF_PERSON_ID else 'identified' + assignment.person_id = None if match_person_id == USER_SELF_PERSON_ID else match_person_id + assignment.person_name = best.get('person_name') + assignment.is_user = match_person_id == USER_SELF_PERSON_ID + assignment.confidence = best['confidence'] + assignment.distance = best['distance'] + assignment.source = SPEAKER_IDENTITY_SOURCE + assignment.reason = None + assignments[cluster_id] = assignment + + apply_cluster_identity_assignments(transcript_segments, assignments) + return assignments + + +def select_representative_cluster_spans( + cluster_segments: List[TranscriptSegment], + all_segments: List[TranscriptSegment], + preferred_seconds: float = PREFERRED_CLUSTER_SAMPLE_SECONDS, + max_seconds: float = MAX_CLUSTER_SAMPLE_SECONDS, +) -> List[ClusterSampleSpan]: + usable_spans = [] + for segment in cluster_segments: + duration = max(0.0, segment.end - segment.start) + if duration < MIN_CLUSTER_SPAN_SECONDS: + continue + if _has_obvious_overlap(segment, all_segments): + continue + if len(segment.text.split()) < 2: + continue + + end = segment.end + if duration > MAX_CLUSTER_SPAN_SECONDS: + center = (segment.start + segment.end) / 2 + start = max(segment.start, center - MAX_CLUSTER_SPAN_SECONDS / 2) + end = min(segment.end, start + MAX_CLUSTER_SPAN_SECONDS) + else: + start = segment.start + + usable_spans.append( + ClusterSampleSpan( + start=start, + end=end, + segment_id=segment.id, + duration=end - start, + ) + ) + + usable_spans.sort(key=lambda span: (-span.duration, span.start)) + + selected = [] + total = 0.0 + for span in usable_spans: + if total >= preferred_seconds and selected: + break + remaining = max_seconds - total + if remaining <= 0: + break + if span.duration > remaining: + span = ClusterSampleSpan( + start=span.start, + end=span.start + remaining, + segment_id=span.segment_id, + duration=remaining, + ) + selected.append(span) + total += span.duration + + return sorted(selected, key=lambda span: span.start) + + +def extract_cluster_sample_wav(audio_bytes: bytes, spans: List[ClusterSampleSpan]) -> Optional[bytes]: + try: + with wave.open(io.BytesIO(audio_bytes), 'rb') as wf: + framerate = wf.getframerate() + n_channels = wf.getnchannels() + sampwidth = wf.getsampwidth() + n_frames = wf.getnframes() + total_duration = n_frames / framerate + frames = [] + for span in spans: + start = max(0.0, min(span.start, total_duration)) + end = max(0.0, min(span.end, total_duration)) + if end - start < MIN_CLUSTER_SPAN_SECONDS: + continue + wf.setpos(int(start * framerate)) + frames.append(wf.readframes(int((end - start) * framerate))) + except Exception as e: + logger.info('background speaker identity sample extraction failed: %s', e) + return None + + frames = [frame for frame in frames if frame] + if not frames: + return None + + out = io.BytesIO() + with wave.open(out, 'wb') as out_wf: + out_wf.setnchannels(n_channels) + out_wf.setsampwidth(sampwidth) + out_wf.setframerate(framerate) + for frame in frames: + out_wf.writeframes(frame) + return out.getvalue() + + +def apply_cluster_identity_assignments( + transcript_segments: List[TranscriptSegment], + assignments: Dict[str, ClusterIdentityAssignment], +) -> None: + for segment in transcript_segments: + cluster_id = _segment_cluster_key(segment) + assignment = assignments.get(cluster_id) + if not assignment: + continue + + segment.speaker_identity_state = assignment.state + segment.speaker_identity_confidence = assignment.confidence + segment.speaker_identity_source = assignment.source + segment.speaker_identity_version = assignment.version + segment.speaker_identity_provenance = assignment.provenance() + segment.speaker_identity_candidates = assignment.candidates + segment.speaker_identity_text_hints = assignment.text_hints + + if assignment.is_user: + segment.is_user = True + segment.person_id = None + elif assignment.person_id: + segment.is_user = False + segment.person_id = assignment.person_id + + +def _group_segments_by_cluster(transcript_segments: List[TranscriptSegment]) -> Dict[str, List[TranscriptSegment]]: + clusters: Dict[str, List[TranscriptSegment]] = {} + for segment in transcript_segments: + clusters.setdefault(_segment_cluster_key(segment), []).append(segment) + return clusters + + +def _segment_cluster_key(segment: TranscriptSegment) -> str: + if segment.provider_cluster_id: + return str(segment.provider_cluster_id) + if segment.provider_speaker_label: + return f'provider-label:{segment.provider_speaker_label}' + return f'legacy-speaker:{segment.speaker_id if segment.speaker_id is not None else 0}' + + +def _cluster_speaker_id(segments: List[TranscriptSegment]) -> Optional[int]: + counts = {} + for segment in segments: + if segment.speaker_id is None: + continue + counts[segment.speaker_id] = counts.get(segment.speaker_id, 0) + 1 + return max(counts, key=counts.get) if counts else None + + +def _has_obvious_overlap(segment: TranscriptSegment, all_segments: List[TranscriptSegment]) -> bool: + for other in all_segments: + if other.id == segment.id: + continue + if _segment_cluster_key(other) == _segment_cluster_key(segment): + continue + overlap = min(segment.end, other.end) - max(segment.start, other.start) + if overlap > OVERLAP_TOLERANCE_SECONDS: + return True + return False + + +def _collect_text_hints(segments: List[TranscriptSegment], cluster_id: str) -> List[dict]: + hints = [] + for segment in segments: + detected_name = _detect_self_introduction_hint(segment.text) + if not detected_name: + continue + hints.append( + { + 'source': TEXT_HINT_SOURCE, + 'provider_cluster_id': cluster_id, + 'segment_id': segment.id, + 'detected_name': detected_name, + } + ) + return hints + + +def _detect_self_introduction_hint(text: str) -> Optional[str]: + for pattern in SELF_INTRODUCTION_HINT_PATTERNS: + match = re.search(pattern, text) + if not match: + continue + name = match.groups()[-1] + if name and len(name) >= 2: + return name.capitalize() + return None + + +def _rank_candidates( + query_embedding: np.ndarray, + person_embeddings_cache: Dict[str, dict], + matched_person_ids: set[str], + match_threshold: float, +) -> List[dict]: + candidates = [] + for person_id, data in person_embeddings_cache.items(): + if person_id in matched_person_ids: + continue + distance = compare_embeddings(query_embedding, data['embedding']) + confidence = max(0.0, min(1.0, 1.0 - (distance / match_threshold))) + candidates.append( + { + '_match_person_id': person_id, + 'person_id': None if person_id == USER_SELF_PERSON_ID else person_id, + 'person_name': data.get('name'), + 'is_user': person_id == USER_SELF_PERSON_ID, + 'distance': round(distance, 6), + 'confidence': round(confidence, 6), + 'source': SPEAKER_IDENTITY_SOURCE, + } + ) + candidates.sort(key=lambda item: (item['distance'], item['_match_person_id'])) + return candidates[:3] + + +def compare_embeddings(embedding1: np.ndarray, embedding2: np.ndarray) -> float: + if embedding1.shape[1] != embedding2.shape[1]: + return 2.0 + norm1 = np.linalg.norm(embedding1) + norm2 = np.linalg.norm(embedding2) + if norm1 == 0 or norm2 == 0: + return 2.0 + similarity = float(np.dot(embedding1.flatten(), embedding2.flatten()) / (norm1 * norm2)) + return 1.0 - similarity + + +def _public_candidates(candidates: List[dict]) -> List[dict]: + public = [] + for candidate in candidates: + data = dict(candidate) + data.pop('_match_person_id', None) + public.append(data) + return public + + +def _safe_filename(value: str) -> str: + return ''.join(ch if ch.isalnum() else '_' for ch in value)[:80] or 'unknown' diff --git a/backend/utils/stt/provider_service.py b/backend/utils/stt/provider_service.py index a05fb73a659..fdcc1f09f25 100644 --- a/backend/utils/stt/provider_service.py +++ b/backend/utils/stt/provider_service.py @@ -24,9 +24,23 @@ def finalize_provider_run(**kwargs) -> None: def summarize_identity_confidences(confidences): - from database.transcription_provider_usage import summarize_identity_confidences as _summarize_identity_confidences - - return _summarize_identity_confidences(confidences) + summary = {} + for confidence in confidences: + bucket = _identity_confidence_bucket(confidence) + summary[bucket] = summary.get(bucket, 0) + 1 + return summary + + +def _identity_confidence_bucket(confidence: Optional[float]) -> str: + if confidence is None: + return 'unknown' + if confidence >= 0.90: + return 'very_high' + if confidence >= 0.75: + return 'high' + if confidence >= 0.50: + return 'medium' + return 'low' def _deepgram_prerecorded_provider(): From 1770f349469159e7fa62979f7683d32b2491c458 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 08:47:29 +0700 Subject: [PATCH 07/44] Add client support for canonical speaker metadata --- .../backend/schema/transcript_segment.dart | 29 +++- app/lib/pages/conversation_detail/page.dart | 31 +++-- app/lib/widgets/transcript.dart | 47 ++++--- app/test/widgets/transcript_test.dart | 26 ++++ .../Backend-Rust/src/models/conversation.rs | 15 +- .../Backend-Rust/src/services/firestore.rs | 130 ++++++++++++++++-- desktop/Desktop/Sources/APIClient.swift | 17 +++ desktop/Desktop/Sources/AppState.swift | 30 +++- .../Components/SpeakerBubbleView.swift | 4 +- .../Pages/ConversationDetailView.swift | 10 +- .../Sources/Rewind/Core/RewindDatabase.swift | 13 ++ .../Rewind/Core/TranscriptionModels.swift | 46 ++++++- .../Rewind/Core/TranscriptionStorage.swift | 36 ++++- .../TranscriptSpeakerAssignmentTests.swift | 36 +++++ 14 files changed, 400 insertions(+), 70 deletions(-) diff --git a/app/lib/backend/schema/transcript_segment.dart b/app/lib/backend/schema/transcript_segment.dart index dce020ba3e3..605dc09b292 100644 --- a/app/lib/backend/schema/transcript_segment.dart +++ b/app/lib/backend/schema/transcript_segment.dart @@ -64,6 +64,10 @@ class TranscriptSegment { speakerId = parts.length > 1 ? (int.tryParse(parts[1]) ?? 0) : 0; } + bool get hasExplicitUnknownSpeakerIdentity => speakerIdentityState == 'unknown'; + + bool get hasLegacySpeakerLabel => speaker != null && speaker!.isNotEmpty; + @override String toString() { return 'TranscriptSegment: {id: $id text: $text, speaker: $speakerId, isUser: $isUser, start: $start, end: $end}'; @@ -80,7 +84,7 @@ class TranscriptSegment { return TranscriptSegment( id: (json['id'] ?? '') as String, text: json['text'] as String, - speaker: (json['speaker'] ?? 'SPEAKER_00') as String, + speaker: json['speaker'] as String?, isUser: (json['is_user'] ?? false) as bool, personId: json['person_id'], start: double.tryParse(json['start'].toString()) ?? 0.0, @@ -225,7 +229,7 @@ class TranscriptSegment { if (segment.personId != null && peopleMap.containsKey(segment.personId)) { speakerName = peopleMap[segment.personId]!; } else { - var displayId = '${getDisplaySpeakerId(segment.speakerId, segments)}'; + var displayId = getDisplaySpeakerIdForSegment(segment, segments); speakerName = speakerLabelBuilder != null ? speakerLabelBuilder(displayId) : 'Speaker $displayId'; } transcript += '$timestampStr $speakerName: $segmentText '; @@ -272,4 +276,25 @@ class TranscriptSegment { // Normalize: subtract minimum and add 1 to make it 1-indexed return speakerId - minSpeakerId + 1; } + + static String getDisplaySpeakerIdForSegment(TranscriptSegment segment, List segments) { + if (segment.hasLegacySpeakerLabel) { + return '${getDisplaySpeakerId(segment.speakerId, segments)}'; + } + + final providerClusterId = segment.providerClusterId; + if (providerClusterId != null && providerClusterId.isNotEmpty) { + final clusterIds = []; + for (final item in segments) { + final clusterId = item.providerClusterId; + if (!item.isUser && clusterId != null && clusterId.isNotEmpty && !clusterIds.contains(clusterId)) { + clusterIds.add(clusterId); + } + } + final index = clusterIds.indexOf(providerClusterId); + if (index >= 0) return '${index + 1}'; + } + + return '?'; + } } diff --git a/app/lib/pages/conversation_detail/page.dart b/app/lib/pages/conversation_detail/page.dart index 760249f428a..02a20f9dd35 100644 --- a/app/lib/pages/conversation_detail/page.dart +++ b/app/lib/pages/conversation_detail/page.dart @@ -328,8 +328,8 @@ class _ConversationDetailPageState extends State with Ti final conversation = provider.conversation; final summaryContent = conversation.appResults.isNotEmpty && conversation.appResults[0].content.trim().isNotEmpty - ? conversation.appResults[0].content.trim() - : conversation.structured.toString(); + ? conversation.appResults[0].content.trim() + : conversation.structured.toString(); _copyContent(context, summaryContent); break; case 'download_audio': @@ -669,8 +669,8 @@ class _ConversationDetailPageState extends State with Ti provider.conversation.starred = newStarredState; // Update in conversation provider context.read().updateConversationInSortedList( - provider.conversation, - ); + provider.conversation, + ); // Track star/unstar action PlatformManager.instance.analytics.conversationStarToggled( conversation: provider.conversation, @@ -989,13 +989,15 @@ class _ConversationDetailPageState extends State with Ti child: Consumer( builder: (context, provider, child) { final conversation = provider.conversation; - final hasActionItems = - conversation.structured.actionItems.where((item) => !item.deleted).isNotEmpty; + final hasActionItems = conversation.structured.actionItems + .where((item) => !item.deleted) + .isNotEmpty; return ConversationBottomBar( mode: ConversationBottomBarMode.detail, selectedTab: selectedTab, conversation: conversation, - hasSegments: conversation.transcriptSegments.isNotEmpty || + hasSegments: + conversation.transcriptSegments.isNotEmpty || conversation.photos.isNotEmpty || conversation.externalIntegration != null, hasActionItems: hasActionItems, @@ -1436,10 +1438,12 @@ class _TranscriptWidgetsState extends State with AutomaticKee } final segments = provider.conversation.transcriptSegments; final segment = segments[segmentIndex]; - final person = - segment.personId != null ? SharedPreferencesUtil().getPersonById(segment.personId!) : null; - final speakerName = person?.name ?? - context.l10n.speakerWithId('${TranscriptSegment.getDisplaySpeakerId(segment.speakerId, segments)}'); + final person = segment.personId != null + ? SharedPreferencesUtil().getPersonById(segment.personId!) + : null; + final speakerName = + person?.name ?? + context.l10n.speakerWithId(TranscriptSegment.getDisplaySpeakerIdForSegment(segment, segments)); PlatformManager.instance.analytics.editSegmentTextStarted(); bool saved = false; showEditSegmentBottomSheet( @@ -1497,8 +1501,9 @@ class _TranscriptWidgetsState extends State with AutomaticKee ); if (segmentIndex == -1) continue; provider.conversation.transcriptSegments[segmentIndex].isUser = finalPersonId == 'user'; - provider.conversation.transcriptSegments[segmentIndex].personId = - finalPersonId == 'user' ? null : finalPersonId; + provider.conversation.transcriptSegments[segmentIndex].personId = finalPersonId == 'user' + ? null + : finalPersonId; } await assignBulkConversationTranscriptSegments( provider.conversation.id, diff --git a/app/lib/widgets/transcript.dart b/app/lib/widgets/transcript.dart index cea056f70da..98ead3929ae 100644 --- a/app/lib/widgets/transcript.dart +++ b/app/lib/widgets/transcript.dart @@ -111,8 +111,9 @@ class _TranscriptWidgetState extends State { return Image.asset(Assets.images.speaker0Icon.path, width: 24, height: 24); } // Always modulo by speakerImagePath.length to prevent index out of bounds - final imageIndex = - person != null ? person.colorIdx! % speakerImagePath.length : speakerId % speakerImagePath.length; + final imageIndex = person != null + ? person.colorIdx! % speakerImagePath.length + : speakerId % speakerImagePath.length; return Image.asset(speakerImagePath[imageIndex], width: 24, height: 24); } @@ -321,13 +322,13 @@ class _TranscriptWidgetState extends State { _isAutoScrolling = true; _scrollController .animateTo( - targetOffset.clamp(0.0, _scrollController.position.maxScrollExtent), - duration: const Duration(milliseconds: 400), - curve: Curves.easeInOutCubic, - ) + targetOffset.clamp(0.0, _scrollController.position.maxScrollExtent), + duration: const Duration(milliseconds: 400), + curve: Curves.easeInOutCubic, + ) .then((_) { - _isAutoScrolling = false; - }); + _isAutoScrolling = false; + }); } } @@ -507,9 +508,9 @@ class _TranscriptWidgetState extends State { data.speakerId == omiSpeakerId ? 'omi' : (person?.name ?? - context.l10n.speakerWithId( - '${TranscriptSegment.getDisplaySpeakerId(data.speakerId, widget.segments)}', - )), + context.l10n.speakerWithId( + TranscriptSegment.getDisplaySpeakerIdForSegment(data, widget.segments), + )), style: TextStyle( color: data.speakerId == omiSpeakerId || person != null ? Colors.grey.shade300 @@ -550,8 +551,8 @@ class _TranscriptWidgetState extends State { isUser ? 18 : (segmentIdx > 0 && !widget.segments[segmentIdx - 1].isUser) - ? 6 - : 18, + ? 6 + : 18, ), topRight: Radius.circular(isUser ? 18 : 18), bottomLeft: Radius.circular(18), @@ -614,8 +615,9 @@ class _TranscriptWidgetState extends State { Text( SttProviderConfig.getDisplayName(data.sttProvider), style: TextStyle( - color: - isUser ? Colors.white.withValues(alpha: 0.5) : Colors.grey.shade500, + color: isUser + ? Colors.white.withValues(alpha: 0.5) + : Colors.grey.shade500, fontSize: 10, fontStyle: FontStyle.italic, ), @@ -624,8 +626,9 @@ class _TranscriptWidgetState extends State { Text( ' · ', style: TextStyle( - color: - isUser ? Colors.white.withValues(alpha: 0.5) : Colors.grey.shade500, + color: isUser + ? Colors.white.withValues(alpha: 0.5) + : Colors.grey.shade500, fontSize: 10, ), ), @@ -641,8 +644,9 @@ class _TranscriptWidgetState extends State { child: Icon( Icons.play_circle_outline, size: 16, - color: - isUser ? Colors.white.withValues(alpha: 0.7) : Colors.grey.shade400, + color: isUser + ? Colors.white.withValues(alpha: 0.7) + : Colors.grey.shade400, ), ), const SizedBox(width: 6), @@ -651,8 +655,9 @@ class _TranscriptWidgetState extends State { Text( data.getTimestampString(), style: TextStyle( - color: - isUser ? Colors.white.withValues(alpha: 0.7) : Colors.grey.shade400, + color: isUser + ? Colors.white.withValues(alpha: 0.7) + : Colors.grey.shade400, fontSize: 11, ), ), diff --git a/app/test/widgets/transcript_test.dart b/app/test/widgets/transcript_test.dart index 11e37519802..ded47d5821a 100644 --- a/app/test/widgets/transcript_test.dart +++ b/app/test/widgets/transcript_test.dart @@ -71,6 +71,32 @@ void main() { expect(segment.speakerIdentityState, 'unknown'); expect(segment.speakerIdentityVersion, 'v1'); }); + + test('uses provider cluster labels when legacy speaker label is absent', () { + final first = TranscriptSegment.fromJson({ + 'id': 'seg-provider-a', + 'text': 'Hello', + 'speaker': null, + 'is_user': false, + 'start': 0.0, + 'end': 1.0, + 'provider_cluster_id': 'speaker-alpha', + 'speaker_identity_state': 'unknown', + }); + final second = TranscriptSegment.fromJson({ + 'id': 'seg-provider-b', + 'text': 'Hi', + 'speaker': null, + 'is_user': false, + 'start': 1.0, + 'end': 2.0, + 'provider_cluster_id': 'speaker-beta', + 'speaker_identity_state': 'unknown', + }); + + expect(TranscriptSegment.getDisplaySpeakerIdForSegment(first, [first, second]), '1'); + expect(TranscriptSegment.getDisplaySpeakerIdForSegment(second, [first, second]), '2'); + }); }); group('Speaker label display', () { diff --git a/desktop/Backend-Rust/src/models/conversation.rs b/desktop/Backend-Rust/src/models/conversation.rs index 4e26c048f97..c6d508c9583 100644 --- a/desktop/Backend-Rust/src/models/conversation.rs +++ b/desktop/Backend-Rust/src/models/conversation.rs @@ -12,10 +12,10 @@ pub struct TranscriptSegment { #[serde(default)] pub id: Option, pub text: String, - #[serde(default = "default_speaker")] - pub speaker: String, #[serde(default)] - pub speaker_id: i32, + pub speaker: Option, + #[serde(default)] + pub speaker_id: Option, #[serde(default)] pub is_user: bool, #[serde(default)] @@ -42,10 +42,6 @@ pub struct TranscriptSegment { pub speaker_identity_version: Option, } -fn default_speaker() -> String { - "SPEAKER_00".to_string() -} - impl TranscriptSegment { /// Convert segments to transcript text for LLM processing /// Copied from Python segments_to_transcript_text @@ -56,7 +52,10 @@ impl TranscriptSegment { let speaker_name = if segment.is_user { "User".to_string() } else { - format!("Speaker {}", segment.speaker_id) + match segment.speaker_id { + Some(speaker_id) => format!("Speaker {}", speaker_id), + None => "Speaker ?".to_string(), + } }; format!("{}: {}", speaker_name, segment.text) }) diff --git a/desktop/Backend-Rust/src/services/firestore.rs b/desktop/Backend-Rust/src/services/firestore.rs index 426e6273296..31e98ceb983 100644 --- a/desktop/Backend-Rust/src/services/firestore.rs +++ b/desktop/Backend-Rust/src/services/firestore.rs @@ -4254,11 +4254,10 @@ impl FirestoreService { text: seg.get("text")?.as_str()?.to_string(), speaker: seg.get("speaker") .and_then(|s| s.as_str()) - .unwrap_or("SPEAKER_00") - .to_string(), + .map(|s| s.to_string()), speaker_id: seg.get("speaker_id") .and_then(|s| s.as_i64()) - .unwrap_or(0) as i32, + .map(|s| s as i32), is_user: seg.get("is_user") .and_then(|s| s.as_bool()) .unwrap_or(false), @@ -4310,11 +4309,10 @@ impl FirestoreService { text: seg.get("text")?.as_str()?.to_string(), speaker: seg.get("speaker") .and_then(|s| s.as_str()) - .unwrap_or("SPEAKER_00") - .to_string(), + .map(|s| s.to_string()), speaker_id: seg.get("speaker_id") .and_then(|s| s.as_i64()) - .unwrap_or(0) as i32, + .map(|s| s as i32), is_user: seg.get("is_user") .and_then(|s| s.as_bool()) .unwrap_or(false), @@ -4390,8 +4388,8 @@ impl FirestoreService { Some(TranscriptSegment { id: self.parse_string(seg_fields, "id"), text: self.parse_string(seg_fields, "text").unwrap_or_default(), - speaker: self.parse_string(seg_fields, "speaker").unwrap_or_else(|| "SPEAKER_00".to_string()), - speaker_id: self.parse_int(seg_fields, "speaker_id").unwrap_or(0), + speaker: self.parse_string(seg_fields, "speaker"), + speaker_id: self.parse_int(seg_fields, "speaker_id"), is_user: self.parse_bool(seg_fields, "is_user").unwrap_or(false), person_id: self.parse_string(seg_fields, "person_id"), start: self.parse_float(seg_fields, "start").unwrap_or(0.0), @@ -4446,12 +4444,11 @@ impl FirestoreService { speaker: seg .get("speaker") .and_then(|s| s.as_str()) - .unwrap_or("SPEAKER_00") - .to_string(), + .map(|s| s.to_string()), speaker_id: seg .get("speaker_id") .and_then(|s| s.as_i64()) - .unwrap_or(0) as i32, + .map(|s| s as i32), is_user: seg .get("is_user") .and_then(|s| s.as_bool()) @@ -9323,8 +9320,6 @@ impl FirestoreService { .map(|seg| { let mut fields = json!({ "text": {"stringValue": seg.text}, - "speaker": {"stringValue": seg.speaker}, - "speaker_id": {"integerValue": seg.speaker_id.to_string()}, "is_user": {"booleanValue": seg.is_user}, "start": {"doubleValue": seg.start}, "end": {"doubleValue": seg.end} @@ -9332,9 +9327,39 @@ impl FirestoreService { if let Some(ref id) = seg.id { fields["id"] = json!({"stringValue": id}); } + if let Some(ref speaker) = seg.speaker { + fields["speaker"] = json!({"stringValue": speaker}); + } + if let Some(speaker_id) = seg.speaker_id { + fields["speaker_id"] = json!({"integerValue": speaker_id.to_string()}); + } if let Some(ref pid) = seg.person_id { fields["person_id"] = json!({"stringValue": pid}); } + if let Some(ref stt_provider) = seg.stt_provider { + fields["stt_provider"] = json!({"stringValue": stt_provider}); + } + if let Some(ref stt_model) = seg.stt_model { + fields["stt_model"] = json!({"stringValue": stt_model}); + } + if let Some(ref provider_cluster_id) = seg.provider_cluster_id { + fields["provider_cluster_id"] = json!({"stringValue": provider_cluster_id}); + } + if let Some(ref provider_speaker_label) = seg.provider_speaker_label { + fields["provider_speaker_label"] = json!({"stringValue": provider_speaker_label}); + } + if let Some(ref speaker_identity_state) = seg.speaker_identity_state { + fields["speaker_identity_state"] = json!({"stringValue": speaker_identity_state}); + } + if let Some(speaker_identity_confidence) = seg.speaker_identity_confidence { + fields["speaker_identity_confidence"] = json!({"doubleValue": speaker_identity_confidence}); + } + if let Some(ref speaker_identity_source) = seg.speaker_identity_source { + fields["speaker_identity_source"] = json!({"stringValue": speaker_identity_source}); + } + if let Some(ref speaker_identity_version) = seg.speaker_identity_version { + fields["speaker_identity_version"] = json!({"stringValue": speaker_identity_version}); + } json!({"mapValue": {"fields": fields}}) }) .collect(); @@ -10026,6 +10051,85 @@ mod tests { assert_ne!(id, document_id_from_seed("different content")); } + fn test_firestore_service() -> FirestoreService { + FirestoreService { + client: Client::new(), + project_id: "test-project".to_string(), + credentials: None, + cached_token: Arc::new(RwLock::new(None)), + encryption_secret: None, + } + } + + #[test] + fn parse_plain_transcript_segments_preserves_explicit_unknown_identity() { + let service = test_firestore_service(); + let fields = json!({ + "transcript_segments": { + "arrayValue": { + "values": [{ + "mapValue": { + "fields": { + "id": {"stringValue": "seg-provider"}, + "text": {"stringValue": "Hello"}, + "is_user": {"booleanValue": false}, + "start": {"doubleValue": 0.0}, + "end": {"doubleValue": 1.0}, + "stt_provider": {"stringValue": "provider-a"}, + "stt_model": {"stringValue": "async-large"}, + "provider_cluster_id": {"stringValue": "speaker-alpha"}, + "speaker_identity_state": {"stringValue": "unknown"}, + "speaker_identity_confidence": {"doubleValue": 0.42}, + "speaker_identity_source": {"stringValue": "omi_speaker_embedding"}, + "speaker_identity_version": {"stringValue": "v1"} + } + } + }] + } + } + }); + + let segments = service.parse_transcript_segments(&fields, "uid").unwrap(); + assert_eq!(segments.len(), 1); + assert_eq!(segments[0].speaker, None); + assert_eq!(segments[0].speaker_id, None); + assert_eq!(segments[0].provider_cluster_id.as_deref(), Some("speaker-alpha")); + assert_eq!(segments[0].speaker_identity_state.as_deref(), Some("unknown")); + assert_eq!(segments[0].speaker_identity_confidence, Some(0.42)); + assert_eq!(segments[0].speaker_identity_source.as_deref(), Some("omi_speaker_embedding")); + assert_eq!(segments[0].speaker_identity_version.as_deref(), Some("v1")); + } + + #[test] + fn parse_plain_transcript_segments_preserves_legacy_speaker_fields() { + let service = test_firestore_service(); + let fields = json!({ + "transcript_segments": { + "arrayValue": { + "values": [{ + "mapValue": { + "fields": { + "id": {"stringValue": "seg-legacy"}, + "text": {"stringValue": "Hello"}, + "speaker": {"stringValue": "SPEAKER_00"}, + "speaker_id": {"integerValue": "0"}, + "is_user": {"booleanValue": false}, + "start": {"doubleValue": 0.0}, + "end": {"doubleValue": 1.0} + } + } + }] + } + } + }); + + let segments = service.parse_transcript_segments(&fields, "uid").unwrap(); + assert_eq!(segments.len(), 1); + assert_eq!(segments[0].speaker.as_deref(), Some("SPEAKER_00")); + assert_eq!(segments[0].speaker_id, Some(0)); + assert_eq!(segments[0].speaker_identity_state, None); + } + // --- Firestore BYOK state parsing tests --- #[test] diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 21e06f4d01b..36d6f8b49e0 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -844,6 +844,23 @@ struct TranscriptSegment: Codable, Identifiable { return 0 } + var hasExplicitUnknownSpeakerIdentity: Bool { + speakerIdentityState == "unknown" + } + + var displaySpeakerSuffix: String { + if speaker != nil { + return "\(speakerId)" + } + if let providerSpeakerLabel, !providerSpeakerLabel.isEmpty { + return providerSpeakerLabel + } + if let providerClusterId, !providerClusterId.isEmpty { + return providerClusterId + } + return "?" + } + enum CodingKeys: String, CodingKey { case id, text, speaker case isUser = "is_user" diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index aacaad4ed83..8f02a8bd340 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -2434,7 +2434,15 @@ class AppState: ObservableObject { personId: isUser ? nil : personId, start: old.start, end: old.end, - translations: old.translations + translations: old.translations, + sttProvider: old.sttProvider, + sttModel: old.sttModel, + providerClusterId: old.providerClusterId, + providerSpeakerLabel: old.providerSpeakerLabel, + speakerIdentityState: old.speakerIdentityState, + speakerIdentityConfidence: old.speakerIdentityConfidence, + speakerIdentitySource: old.speakerIdentitySource, + speakerIdentityVersion: old.speakerIdentityVersion ) } } @@ -2541,7 +2549,15 @@ class AppState: ObservableObject { isUser: segment.is_user, personId: segment.person_id, speakerLabel: segment.speaker, - translationsJson: translationsJson + translationsJson: translationsJson, + sttProvider: segment.stt_provider, + sttModel: segment.stt_model, + providerClusterId: segment.provider_cluster_id, + providerSpeakerLabel: segment.provider_speaker_label, + speakerIdentityState: segment.speaker_identity_state, + speakerIdentityConfidence: segment.speaker_identity_confidence, + speakerIdentitySource: segment.speaker_identity_source, + speakerIdentityVersion: segment.speaker_identity_version ) } catch { logError("Transcription: Failed to persist segment to DB", error: error) @@ -2709,7 +2725,15 @@ class AppState: ObservableObject { isUser: translated.is_user, personId: translated.person_id, speakerLabel: translated.speaker, - translationsJson: translationsJson + translationsJson: translationsJson, + sttProvider: translated.stt_provider, + sttModel: translated.stt_model, + providerClusterId: translated.provider_cluster_id, + providerSpeakerLabel: translated.provider_speaker_label, + speakerIdentityState: translated.speaker_identity_state, + speakerIdentityConfidence: translated.speaker_identity_confidence, + speakerIdentitySource: translated.speaker_identity_source, + speakerIdentityVersion: translated.speaker_identity_version ) } } diff --git a/desktop/Desktop/Sources/MainWindow/Components/SpeakerBubbleView.swift b/desktop/Desktop/Sources/MainWindow/Components/SpeakerBubbleView.swift index 4d2d7652d79..14010e72024 100644 --- a/desktop/Desktop/Sources/MainWindow/Components/SpeakerBubbleView.swift +++ b/desktop/Desktop/Sources/MainWindow/Components/SpeakerBubbleView.swift @@ -27,7 +27,7 @@ struct SpeakerBubbleView: View { private var speakerLabel: String { if isUser { return "You" } if let name = personName { return name } - return "Speaker \(segment.speakerId)" + return "Speaker \(segment.displaySpeakerSuffix)" } private var avatarInitial: String { @@ -35,7 +35,7 @@ struct SpeakerBubbleView: View { if let name = personName, let first = name.first { return String(first).uppercased() } - return String(segment.speakerId) + return segment.speaker == nil ? "?" : String(segment.speakerId) } var body: some View { diff --git a/desktop/Desktop/Sources/MainWindow/Pages/ConversationDetailView.swift b/desktop/Desktop/Sources/MainWindow/Pages/ConversationDetailView.swift index 8d59d6eb43d..3a4697a01dc 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/ConversationDetailView.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/ConversationDetailView.swift @@ -717,7 +717,15 @@ struct ConversationDetailView: View { personId: isUser ? nil : personId, start: oldSegment.start, end: oldSegment.end, - translations: oldSegment.translations + translations: oldSegment.translations, + sttProvider: oldSegment.sttProvider, + sttModel: oldSegment.sttModel, + providerClusterId: oldSegment.providerClusterId, + providerSpeakerLabel: oldSegment.providerSpeakerLabel, + speakerIdentityState: oldSegment.speakerIdentityState, + speakerIdentityConfidence: oldSegment.speakerIdentityConfidence, + speakerIdentitySource: oldSegment.speakerIdentitySource, + speakerIdentityVersion: oldSegment.speakerIdentityVersion ) } loadedConversation = updatedConversation diff --git a/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift b/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift index 7381d242b73..c8b4a6b2289 100644 --- a/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift +++ b/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift @@ -2156,6 +2156,19 @@ actor RewindDatabase { } } + migrator.registerMigration("addCanonicalSpeakerMetadata") { db in + try db.alter(table: "transcription_segments") { t in + t.add(column: "sttProvider", .text) + t.add(column: "sttModel", .text) + t.add(column: "providerClusterId", .text) + t.add(column: "providerSpeakerLabel", .text) + t.add(column: "speakerIdentityState", .text) + t.add(column: "speakerIdentityConfidence", .double) + t.add(column: "speakerIdentitySource", .text) + t.add(column: "speakerIdentityVersion", .text) + } + } + try migrator.migrate(queue) } diff --git a/desktop/Desktop/Sources/Rewind/Core/TranscriptionModels.swift b/desktop/Desktop/Sources/Rewind/Core/TranscriptionModels.swift index 9fe1a3f19cb..011f79beb47 100644 --- a/desktop/Desktop/Sources/Rewind/Core/TranscriptionModels.swift +++ b/desktop/Desktop/Sources/Rewind/Core/TranscriptionModels.swift @@ -193,6 +193,14 @@ struct TranscriptionSegmentRecord: Codable, FetchableRecord, PersistableRecord, var isUser: Bool // Whether this segment is from the user var personId: String? // Associated person ID (if identified) var translationsJson: String? // JSON-encoded [TranscriptTranslation] + var sttProvider: String? + var sttModel: String? + var providerClusterId: String? + var providerSpeakerLabel: String? + var speakerIdentityState: String? + var speakerIdentityConfidence: Double? + var speakerIdentitySource: String? + var speakerIdentityVersion: String? static let databaseTableName = "transcription_segments" @@ -212,7 +220,15 @@ struct TranscriptionSegmentRecord: Codable, FetchableRecord, PersistableRecord, speakerLabel: String? = nil, isUser: Bool = false, personId: String? = nil, - translationsJson: String? = nil + translationsJson: String? = nil, + sttProvider: String? = nil, + sttModel: String? = nil, + providerClusterId: String? = nil, + providerSpeakerLabel: String? = nil, + speakerIdentityState: String? = nil, + speakerIdentityConfidence: Double? = nil, + speakerIdentitySource: String? = nil, + speakerIdentityVersion: String? = nil ) { self.id = id self.sessionId = sessionId @@ -228,6 +244,14 @@ struct TranscriptionSegmentRecord: Codable, FetchableRecord, PersistableRecord, self.isUser = isUser self.personId = personId self.translationsJson = translationsJson + self.sttProvider = sttProvider + self.sttModel = sttModel + self.providerClusterId = providerClusterId + self.providerSpeakerLabel = providerSpeakerLabel + self.speakerIdentityState = speakerIdentityState + self.speakerIdentityConfidence = speakerIdentityConfidence + self.speakerIdentitySource = speakerIdentitySource + self.speakerIdentityVersion = speakerIdentityVersion } // MARK: - Persistence Callbacks @@ -429,7 +453,15 @@ extension TranscriptionSegmentRecord { speakerLabel: segment.speaker, isUser: segment.isUser, personId: segment.personId, - translationsJson: translationsJson + translationsJson: translationsJson, + sttProvider: segment.sttProvider, + sttModel: segment.sttModel, + providerClusterId: segment.providerClusterId, + providerSpeakerLabel: segment.providerSpeakerLabel, + speakerIdentityState: segment.speakerIdentityState, + speakerIdentityConfidence: segment.speakerIdentityConfidence, + speakerIdentitySource: segment.speakerIdentitySource, + speakerIdentityVersion: segment.speakerIdentityVersion ) } @@ -448,7 +480,15 @@ extension TranscriptionSegmentRecord { personId: personId, start: startTime, end: endTime, - translations: translations + translations: translations, + sttProvider: sttProvider, + sttModel: sttModel, + providerClusterId: providerClusterId, + providerSpeakerLabel: providerSpeakerLabel, + speakerIdentityState: speakerIdentityState, + speakerIdentityConfidence: speakerIdentityConfidence, + speakerIdentitySource: speakerIdentitySource, + speakerIdentityVersion: speakerIdentityVersion ) } } diff --git a/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift b/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift index 7b585fd49c6..c0c0ece5eba 100644 --- a/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift +++ b/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift @@ -299,7 +299,15 @@ actor TranscriptionStorage { isUser: Bool = false, personId: String? = nil, speakerLabel: String? = nil, - translationsJson: String? = nil + translationsJson: String? = nil, + sttProvider: String? = nil, + sttModel: String? = nil, + providerClusterId: String? = nil, + providerSpeakerLabel: String? = nil, + speakerIdentityState: String? = nil, + speakerIdentityConfidence: Double? = nil, + speakerIdentitySource: String? = nil, + speakerIdentityVersion: String? = nil ) async throws -> Int64 { let db = try await ensureInitialized() @@ -311,10 +319,22 @@ actor TranscriptionStorage { UPDATE transcription_segments SET text = ?, speaker = ?, startTime = ?, endTime = ?, isUser = ?, personId = ?, speakerLabel = COALESCE(?, speakerLabel), - translationsJson = COALESCE(?, translationsJson) + translationsJson = COALESCE(?, translationsJson), + sttProvider = COALESCE(?, sttProvider), + sttModel = COALESCE(?, sttModel), + providerClusterId = COALESCE(?, providerClusterId), + providerSpeakerLabel = COALESCE(?, providerSpeakerLabel), + speakerIdentityState = COALESCE(?, speakerIdentityState), + speakerIdentityConfidence = COALESCE(?, speakerIdentityConfidence), + speakerIdentitySource = COALESCE(?, speakerIdentitySource), + speakerIdentityVersion = COALESCE(?, speakerIdentityVersion) WHERE sessionId = ? AND segmentId = ? """, - arguments: [text, speaker, startTime, endTime, isUser, personId, speakerLabel, translationsJson, sessionId, segId] + arguments: [ + text, speaker, startTime, endTime, isUser, personId, speakerLabel, translationsJson, + sttProvider, sttModel, providerClusterId, providerSpeakerLabel, speakerIdentityState, + speakerIdentityConfidence, speakerIdentitySource, speakerIdentityVersion, sessionId, segId + ] ) return database.changesCount > 0 } @@ -343,7 +363,15 @@ actor TranscriptionStorage { speakerLabel: speakerLabel, isUser: isUser, personId: personId, - translationsJson: translationsJson + translationsJson: translationsJson, + sttProvider: sttProvider, + sttModel: sttModel, + providerClusterId: providerClusterId, + providerSpeakerLabel: providerSpeakerLabel, + speakerIdentityState: speakerIdentityState, + speakerIdentityConfidence: speakerIdentityConfidence, + speakerIdentitySource: speakerIdentitySource, + speakerIdentityVersion: speakerIdentityVersion ) let record = try await db.write { database in diff --git a/desktop/Desktop/Tests/TranscriptSpeakerAssignmentTests.swift b/desktop/Desktop/Tests/TranscriptSpeakerAssignmentTests.swift index 618686ee5fe..52c012d9068 100644 --- a/desktop/Desktop/Tests/TranscriptSpeakerAssignmentTests.swift +++ b/desktop/Desktop/Tests/TranscriptSpeakerAssignmentTests.swift @@ -237,6 +237,8 @@ final class TranscriptSpeakerAssignmentTests: XCTestCase { XCTAssertEqual(segment.speakerIdentityConfidence, 0.42) XCTAssertEqual(segment.speakerIdentitySource, "omi_speaker_embedding") XCTAssertEqual(segment.speakerIdentityVersion, "v1") + XCTAssertTrue(segment.hasExplicitUnknownSpeakerIdentity) + XCTAssertEqual(segment.displaySpeakerSuffix, "speaker-alpha") } func testSpeakerSegmentTranslationsPreserved() { @@ -349,6 +351,40 @@ final class TranscriptSpeakerAssignmentTests: XCTestCase { XCTAssertEqual(segment.translations[1].text, "Hola") } + func testTranscriptionSegmentRecordRoundTripWithProviderIdentityMetadata() { + let source = TranscriptSegment( + id: "seg_provider_rt", + backendId: "seg_provider_rt", + text: "Hello", + speaker: nil, + isUser: false, + personId: nil, + start: 0, + end: 1, + sttProvider: "provider-a", + sttModel: "async-large", + providerClusterId: "speaker-alpha", + providerSpeakerLabel: nil, + speakerIdentityState: "unknown", + speakerIdentityConfidence: 0.42, + speakerIdentitySource: "omi_speaker_embedding", + speakerIdentityVersion: "v1" + ) + + let record = TranscriptionSegmentRecord.from(source, sessionId: 1, segmentOrder: 0) + let segment = record.toTranscriptSegment() + + XCTAssertNil(segment.speaker) + XCTAssertEqual(segment.displaySpeakerSuffix, "speaker-alpha") + XCTAssertEqual(segment.sttProvider, "provider-a") + XCTAssertEqual(segment.sttModel, "async-large") + XCTAssertEqual(segment.providerClusterId, "speaker-alpha") + XCTAssertEqual(segment.speakerIdentityState, "unknown") + XCTAssertEqual(segment.speakerIdentityConfidence, 0.42) + XCTAssertEqual(segment.speakerIdentitySource, "omi_speaker_embedding") + XCTAssertEqual(segment.speakerIdentityVersion, "v1") + } + func testTranscriptionSegmentRecordRoundTripNilTranslationsJson() { let record = TranscriptionSegmentRecord( sessionId: 1, From 85f1c33900275bff4bb50b4d2cc23ae29a78bdca Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 08:59:39 +0700 Subject: [PATCH 08/44] Add AssemblyAI background STT provider --- backend/.env.template | 10 + backend/tests/unit/test_assemblyai_adapter.py | 224 ++++++++++++++ .../unit/test_background_provider_service.py | 122 ++++++++ backend/utils/stt/assemblyai_adapter.py | 283 ++++++++++++++++++ backend/utils/stt/provider_service.py | 218 +++++++++++++- backend/utils/stt/providers.py | 46 ++- 6 files changed, 900 insertions(+), 3 deletions(-) create mode 100644 backend/tests/unit/test_assemblyai_adapter.py create mode 100644 backend/utils/stt/assemblyai_adapter.py diff --git a/backend/.env.template b/backend/.env.template index b4397a8caf6..aea4e371039 100644 --- a/backend/.env.template +++ b/backend/.env.template @@ -14,6 +14,16 @@ REDIS_DB_PASSWORD= DEEPGRAM_API_KEY= +# AssemblyAI async/background STT. Disabled by default; eligible workloads are sync, background, postprocess. +ASSEMBLYAI_API_KEY= +ASSEMBLYAI_BACKGROUND_STT_ENABLED=false +ASSEMBLYAI_BACKGROUND_STT_WORKLOADS=sync,background,postprocess +ASSEMBLYAI_STT_MODEL=universal-2 +ASSEMBLYAI_BASE_URL=https://api.assemblyai.com +ASSEMBLYAI_POLL_INTERVAL_SECONDS=3 +ASSEMBLYAI_MAX_POLL_SECONDS=900 +ASSEMBLYAI_SMOKE_AUDIO_URL= + ADMIN_KEY= OPENAI_API_KEY= diff --git a/backend/tests/unit/test_assemblyai_adapter.py b/backend/tests/unit/test_assemblyai_adapter.py new file mode 100644 index 00000000000..d8140ca7af1 --- /dev/null +++ b/backend/tests/unit/test_assemblyai_adapter.py @@ -0,0 +1,224 @@ +import httpx +import os +import pytest + +from utils.stt.assemblyai_adapter import ( + AssemblyAIAsyncTranscriptionProvider, + AssemblyAIProviderError, + AssemblyAITimeoutError, + normalize_assemblyai_transcript_result, +) + + +class FakeResponse: + def __init__(self, payload, status_code=200): + self._payload = payload + self.status_code = status_code + self.request = httpx.Request('GET', 'https://api.assemblyai.com/test') + + def json(self): + return self._payload + + def raise_for_status(self): + if self.status_code >= 400: + raise httpx.HTTPStatusError('failed', request=self.request, response=self) + + +class FakeClient: + def __init__(self, responses): + self.responses = list(responses) + self.requests = [] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def request(self, method, url, **kwargs): + self.requests.append((method, url, kwargs)) + response = self.responses.pop(0) + if isinstance(response, Exception): + raise response + return response + + +def _completed_transcript(): + return { + 'id': 'aai-transcript-1', + 'status': 'completed', + 'language_code': 'en_us', + 'speech_model': 'universal-2', + 'audio_duration': 2500, + 'utterances': [ + { + 'speaker': 'A', + 'text': 'Hello world.', + 'start': 0, + 'end': 1100, + 'confidence': 0.93, + 'words': [ + {'speaker': 'A', 'text': 'Hello', 'start': 0, 'end': 400, 'confidence': 0.94}, + {'speaker': 'A', 'text': 'world.', 'start': 500, 'end': 1100, 'confidence': 0.91}, + ], + } + ], + } + + +def test_assemblyai_result_normalizes_utterances_words_and_speaker_clusters(): + result = normalize_assemblyai_transcript_result(_completed_transcript(), model='universal-2') + + assert result.provider == 'assemblyai' + assert result.model == 'universal-2' + assert result.language == 'en' + assert result.duration == 2.5 + assert result.raw_provider_result_id == 'aai-transcript-1' + assert result.utterances[0].provider_cluster_id == 'A' + assert result.utterances[0].speaker_label == 'ASSEMBLYAI_SPEAKER_A' + assert result.words[1].text == 'world.' + assert result.words[1].start == 0.5 + assert result.words[1].provider_cluster_id == 'A' + + +def test_assemblyai_transcribe_url_submits_diarization_and_polls_to_completion(): + fake_client = FakeClient( + [ + FakeResponse({'id': 'aai-transcript-1', 'status': 'queued'}), + FakeResponse({'id': 'aai-transcript-1', 'status': 'processing'}), + FakeResponse(_completed_transcript()), + ] + ) + provider = AssemblyAIAsyncTranscriptionProvider( + api_key='test-key', + client_factory=lambda: fake_client, + poll_interval_seconds=0, + max_poll_seconds=5, + sleeper=lambda seconds: None, + ) + + result, detected_language = provider.transcribe_url( + 'https://example.test/audio.wav', + speakers_count=2, + return_language=True, + language='multi', + model='universal-2', + keywords=['Omi'], + ) + + assert result.provider == 'assemblyai' + assert detected_language == 'en' + submit_payload = fake_client.requests[0][2]['json'] + assert submit_payload['audio_url'] == 'https://example.test/audio.wav' + assert submit_payload['speaker_labels'] is True + assert submit_payload['speakers_expected'] == 2 + assert submit_payload['language_detection'] is True + assert submit_payload['keyterms_prompt'] == ['Omi'] + assert fake_client.requests[1][0] == 'GET' + + +def test_assemblyai_transcribe_bytes_uploads_then_transcribes_upload_url(): + fake_client = FakeClient( + [ + FakeResponse({'upload_url': 'https://cdn.assemblyai.test/uploaded.wav'}), + FakeResponse({'id': 'aai-transcript-1'}), + FakeResponse(_completed_transcript()), + ] + ) + provider = AssemblyAIAsyncTranscriptionProvider( + api_key='test-key', + client_factory=lambda: fake_client, + poll_interval_seconds=0, + max_poll_seconds=5, + sleeper=lambda seconds: None, + ) + + result = provider.transcribe_bytes(b'audio-bytes', diarize=True) + + assert result.provider == 'assemblyai' + assert fake_client.requests[0][0] == 'POST' + assert fake_client.requests[0][1].endswith('/v2/upload') + assert fake_client.requests[1][2]['json']['audio_url'] == 'https://cdn.assemblyai.test/uploaded.wav' + + +def test_assemblyai_failure_status_normalizes_to_provider_error(): + fake_client = FakeClient( + [ + FakeResponse({'id': 'aai-transcript-1'}), + FakeResponse({'id': 'aai-transcript-1', 'status': 'error', 'error': 'unsupported media'}), + ] + ) + provider = AssemblyAIAsyncTranscriptionProvider( + api_key='test-key', + client_factory=lambda: fake_client, + poll_interval_seconds=0, + max_poll_seconds=5, + sleeper=lambda seconds: None, + ) + + with pytest.raises(AssemblyAIProviderError, match='unsupported media'): + provider.transcribe_url('https://example.test/audio.wav') + + +def test_assemblyai_poll_timeout_raises_timeout_error(): + current_time = {'value': 0.0} + + def clock(): + current_time['value'] += 2.0 + return current_time['value'] + + fake_client = FakeClient( + [ + FakeResponse({'id': 'aai-transcript-1'}), + FakeResponse({'id': 'aai-transcript-1', 'status': 'processing'}), + FakeResponse({'id': 'aai-transcript-1', 'status': 'processing'}), + ] + ) + provider = AssemblyAIAsyncTranscriptionProvider( + api_key='test-key', + client_factory=lambda: fake_client, + poll_interval_seconds=0, + max_poll_seconds=1, + sleeper=lambda seconds: None, + clock=clock, + ) + + with pytest.raises(AssemblyAITimeoutError): + provider.transcribe_url('https://example.test/audio.wav') + + +def test_assemblyai_retries_retryable_http_once(): + fake_client = FakeClient( + [ + FakeResponse({'temporarily': 'busy'}, status_code=503), + FakeResponse({'id': 'aai-transcript-1'}), + FakeResponse(_completed_transcript()), + ] + ) + provider = AssemblyAIAsyncTranscriptionProvider( + api_key='test-key', + client_factory=lambda: fake_client, + poll_interval_seconds=0, + max_poll_seconds=5, + sleeper=lambda seconds: None, + ) + + result = provider.transcribe_url('https://example.test/audio.wav') + + assert result.provider == 'assemblyai' + assert len(fake_client.requests) == 3 + + +def test_assemblyai_live_smoke_with_gated_credentials(): + api_key = os.getenv('ASSEMBLYAI_API_KEY') + audio_url = os.getenv('ASSEMBLYAI_SMOKE_AUDIO_URL') + if not api_key or not audio_url: + pytest.skip('ASSEMBLYAI_API_KEY and ASSEMBLYAI_SMOKE_AUDIO_URL are required for live smoke') + + provider = AssemblyAIAsyncTranscriptionProvider(api_key=api_key, poll_interval_seconds=3, max_poll_seconds=180) + + result = provider.transcribe_url(audio_url, diarize=True, language='en', model='universal-2') + + assert result.provider == 'assemblyai' + assert result.raw_provider_result_id + assert result.words or result.utterances diff --git a/backend/tests/unit/test_background_provider_service.py b/backend/tests/unit/test_background_provider_service.py index 342d6fd80dc..36a31cefea3 100644 --- a/backend/tests/unit/test_background_provider_service.py +++ b/backend/tests/unit/test_background_provider_service.py @@ -3,6 +3,8 @@ from pathlib import Path from unittest.mock import MagicMock, patch +import pytest + for mod_name in ['deepgram', 'deepgram.clients', 'deepgram.clients.live', 'deepgram.clients.live.v1']: if mod_name not in sys.modules: sys.modules[mod_name] = MagicMock() @@ -104,6 +106,17 @@ def test_prerecorded_ptt_and_realtime_related_workloads_stay_deepgram(): assert get_prerecorded_provider_name(STTWorkload.voice_message) == STTProviderName.deepgram +def test_background_routing_can_select_assemblyai_without_moving_latency_critical_workloads(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync,background,postprocess,ptt,realtime') + + assert get_prerecorded_provider_name(STTWorkload.sync) == STTProviderName.assemblyai + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.assemblyai + assert get_prerecorded_provider_name(STTWorkload.postprocess) == STTProviderName.assemblyai + assert get_prerecorded_provider_name(STTWorkload.ptt) == STTProviderName.deepgram + assert get_prerecorded_provider_name(STTWorkload.voice_message) == STTProviderName.deepgram + + def test_background_call_sites_use_provider_service_layer(): backend_root = Path(__file__).resolve().parents[2] with open(backend_root / 'routers/sync.py') as f: @@ -124,3 +137,112 @@ def test_background_call_sites_use_provider_service_layer(): assert 'from utils.stt.pre_recorded import' not in chat_source assert 'workload=STTWorkload.voice_message' in chat_source assert 'workload=STTWorkload.ptt' in chat_source + + +def test_provider_service_uses_assemblyai_for_enabled_sync_workload(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync') + + fake_provider = MagicMock() + fake_provider.provider_name = STTProviderName.assemblyai + fake_provider.transcribe_url.return_value = (_provider_result(provider='assemblyai', model='universal-2'), 'en') + + with patch.object(provider_service, '_assemblyai_prerecorded_provider', return_value=fake_provider), patch.object( + provider_service, 'create_provider_run', return_value='run-aai' + ) as create_run, patch.object(provider_service, 'finalize_provider_run') as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.sync, + uid='uid-1', + conversation_id='conversation-1', + return_language=True, + language='multi', + model='nova-3', + keywords=['Omi'], + ) + + assert response.result.provider == 'assemblyai' + assert response.result.model == 'universal-2' + assert response.words[0]['stt_provider'] == 'assemblyai' + fake_provider.transcribe_url.assert_called_once() + assert fake_provider.transcribe_url.call_args.kwargs['model'] == 'universal-2' + create_run.assert_called_once() + assert create_run.call_args.kwargs['provider'] == 'assemblyai' + assert create_run.call_args.kwargs['model'] == 'universal-2' + finalize_run.assert_called_once() + assert finalize_run.call_args.kwargs['provider'] == 'assemblyai' + assert finalize_run.call_args.kwargs['artifact_refs'] == {} + + +def test_provider_service_falls_back_to_deepgram_when_assemblyai_fails(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync') + + assemblyai_provider = MagicMock() + assemblyai_provider.provider_name = STTProviderName.assemblyai + assemblyai_provider.transcribe_url.side_effect = RuntimeError('AssemblyAI failed') + + deepgram_provider = MagicMock() + deepgram_provider.provider_name = STTProviderName.deepgram + deepgram_provider.transcribe_url.return_value = _provider_result() + + with patch.object( + provider_service, '_assemblyai_prerecorded_provider', return_value=assemblyai_provider + ), patch.object(provider_service, '_deepgram_prerecorded_provider', return_value=deepgram_provider), patch.object( + provider_service, 'create_provider_run', side_effect=['run-aai', 'run-dg'] + ), patch.object( + provider_service, 'finalize_provider_run' + ) as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.sync, + uid='uid-1', + conversation_id='conversation-1', + language='multi', + model='nova-3', + ) + + assert response.result.provider == 'deepgram' + assert assemblyai_provider.transcribe_url.call_count == 2 + deepgram_provider.transcribe_url.assert_called_once() + assert deepgram_provider.transcribe_url.call_args.kwargs['model'] == 'nova-3' + assert finalize_run.call_args_list[0].kwargs['run_id'] == 'run-aai' + assert finalize_run.call_args_list[0].kwargs['provider'] == 'assemblyai' + assert finalize_run.call_args_list[0].kwargs['status'] == 'failed' + assert finalize_run.call_args_list[1].kwargs['run_id'] == 'run-dg' + assert finalize_run.call_args_list[1].kwargs['provider'] == 'deepgram' + assert finalize_run.call_args_list[1].kwargs['fallback_count'] == 1 + assert finalize_run.call_args_list[1].kwargs['fallback_provider'] == 'assemblyai' + + +def test_provider_service_live_assemblyai_smoke_records_ledger_when_credentials_are_present(monkeypatch): + api_key = os.getenv('ASSEMBLYAI_API_KEY') + audio_url = os.getenv('ASSEMBLYAI_SMOKE_AUDIO_URL') + if not api_key or not audio_url: + pytest.skip('ASSEMBLYAI_API_KEY and ASSEMBLYAI_SMOKE_AUDIO_URL are required for live smoke') + + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync') + + with patch.object(provider_service, 'create_provider_run', return_value='run-aai-live') as create_run, patch.object( + provider_service, 'finalize_provider_run' + ) as finalize_run: + response = provider_service.transcribe_url( + audio_url, + workload=STTWorkload.sync, + uid='uid-live-smoke', + conversation_id='conversation-live-smoke', + language='en', + model='nova-3', + raw_audio_seconds=1.0, + ) + + assert response.result.provider == 'assemblyai' + assert response.result.raw_provider_result_id + create_run.assert_called_once() + finalize_run.assert_called_once() + assert finalize_run.call_args.kwargs['provider'] == 'assemblyai' + assert finalize_run.call_args.kwargs['status'] == 'succeeded' + assert ( + finalize_run.call_args.kwargs['artifact_refs']['provider_result_id'] == response.result.raw_provider_result_id + ) diff --git a/backend/utils/stt/assemblyai_adapter.py b/backend/utils/stt/assemblyai_adapter.py new file mode 100644 index 00000000000..fb6ae9e0d13 --- /dev/null +++ b/backend/utils/stt/assemblyai_adapter.py @@ -0,0 +1,283 @@ +import os +import time +from typing import Callable, Optional, Sequence, Tuple, Union + +import httpx + +from models.transcript_segment import ProviderTranscriptResult, ProviderTranscriptUtterance, ProviderTranscriptWord +from utils.stt.providers import STTProviderName + + +class AssemblyAIError(RuntimeError): + pass + + +class AssemblyAIProviderError(AssemblyAIError): + pass + + +class AssemblyAIRetryableError(AssemblyAIError): + pass + + +class AssemblyAITimeoutError(AssemblyAIError): + pass + + +def assemblyai_speaker_fields(speaker_id) -> dict: + if speaker_id is None: + return {'provider_cluster_id': None, 'provider_speaker_label': None} + + provider_cluster_id = str(speaker_id) + return { + 'provider_cluster_id': provider_cluster_id, + 'provider_speaker_label': f'ASSEMBLYAI_SPEAKER_{provider_cluster_id}', + } + + +def normalize_assemblyai_transcript_result( + result: dict, model: str, language: Optional[str] = None +) -> ProviderTranscriptResult: + status = result.get('status') + if status and status != 'completed': + raise AssemblyAIProviderError(f'AssemblyAI transcript status is {status}') + + utterances = [_normalize_assemblyai_utterance(utterance) for utterance in result.get('utterances') or []] + words = [_normalize_assemblyai_word(word) for word in result.get('words') or []] + if not words and utterances: + words = [word for utterance in utterances for word in (utterance.words or [])] + + return ProviderTranscriptResult( + provider=STTProviderName.assemblyai.value, + model=result.get('speech_model') or model, + language=_normalize_language(result.get('language_code') or language), + duration=_milliseconds_to_seconds(result.get('audio_duration')), + words=words, + utterances=utterances, + raw_provider_result_id=result.get('id'), + ) + + +def _normalize_assemblyai_word(word: dict) -> ProviderTranscriptWord: + speaker_fields = assemblyai_speaker_fields(word.get('speaker')) + return ProviderTranscriptWord( + text=word.get('text', ''), + start=_milliseconds_to_seconds(word.get('start')), + end=_milliseconds_to_seconds(word.get('end')), + provider_cluster_id=speaker_fields['provider_cluster_id'], + speaker_label=speaker_fields['provider_speaker_label'], + confidence=word.get('confidence'), + ) + + +def _normalize_assemblyai_utterance(utterance: dict) -> ProviderTranscriptUtterance: + speaker_fields = assemblyai_speaker_fields(utterance.get('speaker')) + utterance_words = utterance.get('words') or [] + return ProviderTranscriptUtterance( + text=utterance.get('text', ''), + start=_milliseconds_to_seconds(utterance.get('start')), + end=_milliseconds_to_seconds(utterance.get('end')), + provider_cluster_id=speaker_fields['provider_cluster_id'], + speaker_label=speaker_fields['provider_speaker_label'], + confidence=utterance.get('confidence'), + words=[_normalize_assemblyai_word(word) for word in utterance_words] if utterance_words else None, + ) + + +def _milliseconds_to_seconds(value) -> float: + if value is None: + return 0.0 + return float(value) / 1000.0 + + +def _normalize_language(language: Optional[str]) -> Optional[str]: + if language and '_' in language: + return language.split('_', 1)[0] + if language and '-' in language: + return language.split('-', 1)[0] + return language + + +class AssemblyAIAsyncTranscriptionProvider: + provider_name = STTProviderName.assemblyai + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + timeout: Optional[httpx.Timeout] = None, + poll_interval_seconds: Optional[float] = None, + max_poll_seconds: Optional[float] = None, + client_factory: Callable[[], httpx.Client] = httpx.Client, + sleeper: Callable[[float], None] = time.sleep, + clock: Callable[[], float] = time.monotonic, + ): + self._api_key = api_key or os.getenv('ASSEMBLYAI_API_KEY') + self._base_url = (base_url or os.getenv('ASSEMBLYAI_BASE_URL') or 'https://api.assemblyai.com').rstrip('/') + self._timeout = timeout or httpx.Timeout(30.0, read=30.0) + self._poll_interval_seconds = float( + poll_interval_seconds + if poll_interval_seconds is not None + else os.getenv('ASSEMBLYAI_POLL_INTERVAL_SECONDS', '3') + ) + self._max_poll_seconds = float( + max_poll_seconds if max_poll_seconds is not None else os.getenv('ASSEMBLYAI_MAX_POLL_SECONDS', '900') + ) + self._client_factory = client_factory + self._sleeper = sleeper + self._clock = clock + + def transcribe_url( + self, + audio_url: str, + speakers_count: int = None, + return_language: bool = False, + diarize: bool = True, + language: Optional[str] = None, + model: str = 'universal-2', + keywords: Optional[Sequence[str]] = None, + ) -> Union[ProviderTranscriptResult, Tuple[ProviderTranscriptResult, str]]: + payload = self._transcript_payload( + audio_url=audio_url, + speakers_count=speakers_count, + return_language=return_language, + diarize=diarize, + language=language, + model=model, + keywords=keywords, + ) + result = self._submit_and_poll(payload) + transcript_result = normalize_assemblyai_transcript_result(result, model=model, language=language) + if return_language: + return transcript_result, transcript_result.language or 'en' + return transcript_result + + def transcribe_bytes( + self, + audio_bytes: bytes, + sample_rate: int = 16000, + diarize: bool = True, + encoding: Optional[str] = None, + channels: int = 1, + language: Optional[str] = None, + model: str = 'universal-2', + return_language: bool = False, + keywords: Optional[Sequence[str]] = None, + ) -> Union[ProviderTranscriptResult, Tuple[ProviderTranscriptResult, str]]: + del sample_rate, encoding, channels + upload_url = self._upload_audio(audio_bytes) + return self.transcribe_url( + upload_url, + return_language=return_language, + diarize=diarize, + language=language, + model=model, + keywords=keywords, + ) + + def _headers(self, content_type: Optional[str] = 'application/json') -> dict: + if not self._api_key: + raise AssemblyAIProviderError('ASSEMBLYAI_API_KEY is not configured') + headers = {'Authorization': self._api_key} + if content_type: + headers['Content-Type'] = content_type + return headers + + def _transcript_payload( + self, + audio_url: str, + speakers_count: int = None, + return_language: bool = False, + diarize: bool = True, + language: Optional[str] = None, + model: str = 'universal-2', + keywords: Optional[Sequence[str]] = None, + ) -> dict: + payload = { + 'audio_url': audio_url, + 'speaker_labels': diarize, + 'punctuate': True, + 'format_text': True, + } + if model: + payload['speech_model'] = model + if speakers_count: + payload['speakers_expected'] = speakers_count + if language and language != 'multi': + payload['language_code'] = language + elif return_language or language == 'multi': + payload['language_detection'] = True + if keywords: + payload['keyterms_prompt'] = list(keywords) + return payload + + def _upload_audio(self, audio_bytes: bytes) -> str: + with self._client_factory() as client: + response = self._request( + client, + 'POST', + f'{self._base_url}/v2/upload', + headers=self._headers('application/octet-stream'), + content=audio_bytes, + ) + upload_url = response.get('upload_url') + if not upload_url: + raise AssemblyAIProviderError('AssemblyAI upload response did not include upload_url') + return upload_url + + def _submit_and_poll(self, payload: dict) -> dict: + with self._client_factory() as client: + submitted = self._request( + client, + 'POST', + f'{self._base_url}/v2/transcript', + headers=self._headers(), + json=payload, + ) + transcript_id = submitted.get('id') + if not transcript_id: + raise AssemblyAIProviderError('AssemblyAI transcript response did not include id') + return self._poll_transcript(client, transcript_id) + + def _poll_transcript(self, client: httpx.Client, transcript_id: str) -> dict: + deadline = self._clock() + self._max_poll_seconds + while True: + result = self._request( + client, + 'GET', + f'{self._base_url}/v2/transcript/{transcript_id}', + headers=self._headers(None), + ) + status = result.get('status') + if status == 'completed': + return result + if status == 'error': + raise AssemblyAIProviderError(result.get('error') or 'AssemblyAI transcript failed') + if status not in ('queued', 'processing'): + raise AssemblyAIProviderError(f'AssemblyAI returned unexpected transcript status: {status}') + if self._clock() >= deadline: + raise AssemblyAITimeoutError(f'AssemblyAI transcript {transcript_id} timed out') + self._sleeper(self._poll_interval_seconds) + + def _request(self, client: httpx.Client, method: str, url: str, **kwargs) -> dict: + last_error = None + for attempt in range(2): + try: + response = client.request(method, url, timeout=self._timeout, **kwargs) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + last_error = e + if e.response.status_code in (408, 429, 500, 502, 503, 504) and attempt == 0: + self._sleeper(min(self._poll_interval_seconds, 1.0)) + continue + if e.response.status_code in (408, 429, 500, 502, 503, 504): + raise AssemblyAIRetryableError(f'AssemblyAI HTTP {e.response.status_code}: {e}') from e + raise AssemblyAIProviderError(f'AssemblyAI HTTP {e.response.status_code}: {e}') from e + except (httpx.TimeoutException, httpx.TransportError) as e: + last_error = e + if attempt == 0: + self._sleeper(min(self._poll_interval_seconds, 1.0)) + continue + raise AssemblyAIRetryableError(f'AssemblyAI request failed: {e}') from e + raise AssemblyAIRetryableError(f'AssemblyAI request failed: {last_error}') diff --git a/backend/utils/stt/provider_service.py b/backend/utils/stt/provider_service.py index fdcc1f09f25..d5fba11402e 100644 --- a/backend/utils/stt/provider_service.py +++ b/backend/utils/stt/provider_service.py @@ -1,12 +1,19 @@ import logging +import os from dataclasses import dataclass from datetime import datetime, timezone from typing import List, Optional, Sequence, Tuple from models.transcript_segment import ProviderTranscriptResult, TranscriptSegment +from utils.stt.assemblyai_adapter import AssemblyAIAsyncTranscriptionProvider from utils.stt.conversation_reconstructor import reconstruct_conversation from utils.stt.deepgram_adapter import provider_result_to_legacy_words -from utils.stt.providers import STTProviderName, STTWorkload, get_prerecorded_provider_name +from utils.stt.providers import ( + STTProviderName, + STTWorkload, + get_fallback_prerecorded_provider_name, + get_prerecorded_provider_name, +) logger = logging.getLogger(__name__) @@ -49,6 +56,10 @@ def _deepgram_prerecorded_provider(): return _provider() +def _assemblyai_prerecorded_provider(): + return AssemblyAIAsyncTranscriptionProvider() + + def get_deepgram_model_for_language(language: str) -> Tuple[str, str]: from utils.stt.pre_recorded import get_deepgram_model_for_language as _get_deepgram_model_for_language @@ -84,6 +95,7 @@ def transcribe_url( ) -> PrerecordedTranscriptionResponse: workload = STTWorkload(workload) provider_name = get_prerecorded_provider_name(workload) + model = _model_for_provider(provider_name, model) provider = _get_prerecorded_provider(provider_name) started_at = datetime.now(timezone.utc) run_id = _create_run(uid, provider_name.value, model, workload.value, conversation_id, started_at) @@ -120,6 +132,31 @@ def transcribe_url( ) except Exception as e: _finalize_failed_run(run_id, provider_name.value, model, workload.value, started_at, e, raw_audio_seconds) + fallback_provider_name = get_fallback_prerecorded_provider_name(provider_name, workload) + if fallback_provider_name: + logger.warning( + 'provider prerecorded url transcription falling back workload=%s from_provider=%s to_provider=%s: %s', + workload.value, + provider_name.value, + fallback_provider_name.value, + e, + ) + return _transcribe_url_with_provider( + fallback_provider_name, + audio_url, + workload, + uid=uid, + conversation_id=conversation_id, + speakers_count=speakers_count, + return_language=return_language, + diarize=diarize, + language=language, + model=_model_for_provider(fallback_provider_name, model), + keywords=keywords, + skip_n_seconds=skip_n_seconds, + raw_audio_seconds=raw_audio_seconds, + fallback_from_provider=provider_name.value, + ) raise RuntimeError(f'{provider_name.value} transcription failed after 2 attempts: {e}') @@ -141,6 +178,7 @@ def transcribe_bytes( ) -> PrerecordedTranscriptionResponse: workload = STTWorkload(workload) provider_name = get_prerecorded_provider_name(workload) + model = _model_for_provider(provider_name, model) provider = _get_prerecorded_provider(provider_name) started_at = datetime.now(timezone.utc) run_id = _create_run(uid, provider_name.value, model, workload.value, conversation_id, started_at) @@ -179,15 +217,182 @@ def transcribe_bytes( ) except Exception as e: _finalize_failed_run(run_id, provider_name.value, model, workload.value, started_at, e, raw_audio_seconds) + fallback_provider_name = get_fallback_prerecorded_provider_name(provider_name, workload) + if fallback_provider_name: + logger.warning( + 'provider prerecorded bytes transcription falling back workload=%s from_provider=%s to_provider=%s: %s', + workload.value, + provider_name.value, + fallback_provider_name.value, + e, + ) + return _transcribe_bytes_with_provider( + fallback_provider_name, + audio_bytes, + workload, + uid=uid, + conversation_id=conversation_id, + sample_rate=sample_rate, + diarize=diarize, + encoding=encoding, + channels=channels, + language=language, + model=_model_for_provider(fallback_provider_name, model), + return_language=return_language, + keywords=keywords, + skip_n_seconds=skip_n_seconds, + raw_audio_seconds=raw_audio_seconds, + fallback_from_provider=provider_name.value, + ) raise RuntimeError(f'{provider_name.value} transcription failed after 2 attempts: {e}') def _get_prerecorded_provider(provider_name: STTProviderName): + if provider_name == STTProviderName.assemblyai: + return _assemblyai_prerecorded_provider() if provider_name == STTProviderName.deepgram: return _deepgram_prerecorded_provider() raise ValueError(f'Unsupported prerecorded STT provider: {provider_name}') +def _model_for_provider(provider_name: STTProviderName, requested_model: str) -> str: + if provider_name == STTProviderName.assemblyai and str(requested_model or '').startswith('nova-'): + return os.getenv('ASSEMBLYAI_STT_MODEL', 'universal-2') + if provider_name == STTProviderName.deepgram and str(requested_model or '').startswith('universal-'): + return 'nova-3' + return requested_model + + +def _transcribe_url_with_provider( + provider_name: STTProviderName, + audio_url: str, + workload: STTWorkload, + uid: Optional[str] = None, + conversation_id: Optional[str] = None, + speakers_count: int = None, + return_language: bool = False, + diarize: bool = True, + language: Optional[str] = None, + model: str = 'nova-3', + keywords: Optional[Sequence[str]] = None, + skip_n_seconds: int = 0, + raw_audio_seconds: float = 0.0, + fallback_from_provider: Optional[str] = None, +) -> PrerecordedTranscriptionResponse: + provider = _get_prerecorded_provider(provider_name) + started_at = datetime.now(timezone.utc) + run_id = _create_run(uid, provider_name.value, model, workload.value, conversation_id, started_at) + try: + provider_result, detected_language, retry_count = _transcribe_url_with_retry( + provider, + audio_url, + speakers_count=speakers_count, + return_language=return_language, + diarize=diarize, + language=language, + model=model, + keywords=keywords, + ) + return _build_success_response( + run_id, + provider_result, + workload, + started_at, + retry_count, + raw_audio_seconds, + skip_n_seconds, + fallback_from_provider=fallback_from_provider, + detected_language=detected_language, + ) + except Exception as e: + _finalize_failed_run(run_id, provider_name.value, model, workload.value, started_at, e, raw_audio_seconds) + raise RuntimeError(f'{provider_name.value} transcription failed after 2 attempts: {e}') + + +def _transcribe_bytes_with_provider( + provider_name: STTProviderName, + audio_bytes: bytes, + workload: STTWorkload, + uid: Optional[str] = None, + conversation_id: Optional[str] = None, + sample_rate: int = 16000, + diarize: bool = True, + encoding: Optional[str] = None, + channels: int = 1, + language: Optional[str] = None, + model: str = 'nova-3', + return_language: bool = False, + keywords: Optional[Sequence[str]] = None, + skip_n_seconds: int = 0, + raw_audio_seconds: float = 0.0, + fallback_from_provider: Optional[str] = None, +) -> PrerecordedTranscriptionResponse: + provider = _get_prerecorded_provider(provider_name) + started_at = datetime.now(timezone.utc) + run_id = _create_run(uid, provider_name.value, model, workload.value, conversation_id, started_at) + try: + provider_result, detected_language, retry_count = _transcribe_bytes_with_retry( + provider, + audio_bytes, + sample_rate=sample_rate, + diarize=diarize, + encoding=encoding, + channels=channels, + language=language, + model=model, + return_language=return_language, + keywords=keywords, + ) + return _build_success_response( + run_id, + provider_result, + workload, + started_at, + retry_count, + raw_audio_seconds, + skip_n_seconds, + fallback_from_provider=fallback_from_provider, + detected_language=detected_language, + ) + except Exception as e: + _finalize_failed_run(run_id, provider_name.value, model, workload.value, started_at, e, raw_audio_seconds) + raise RuntimeError(f'{provider_name.value} transcription failed after 2 attempts: {e}') + + +def _build_success_response( + run_id: Optional[str], + provider_result: ProviderTranscriptResult, + workload: STTWorkload, + started_at: datetime, + retry_count: int, + raw_audio_seconds: float, + skip_n_seconds: int, + fallback_from_provider: Optional[str] = None, + detected_language: Optional[str] = None, +) -> PrerecordedTranscriptionResponse: + segments = reconstruct_conversation(provider_result, skip_n_seconds=skip_n_seconds) + words = provider_result_to_legacy_words(provider_result) + _finalize_run( + run_id, + provider_result, + workload, + started_at, + 'succeeded', + retry_count=retry_count, + raw_audio_seconds=raw_audio_seconds or provider_result.duration or 0.0, + segments=segments, + fallback_count=1 if fallback_from_provider else 0, + fallback_provider=fallback_from_provider, + ) + return PrerecordedTranscriptionResponse( + result=provider_result, + detected_language=detected_language, + segments=segments, + words=words, + run_id=run_id, + ) + + def _transcribe_url_with_retry( provider, audio_url: str, **kwargs ) -> Tuple[ProviderTranscriptResult, Optional[str], int]: @@ -269,6 +474,8 @@ def _finalize_run( retry_count: int, raw_audio_seconds: float, segments: List[TranscriptSegment], + fallback_count: int = 0, + fallback_provider: Optional[str] = None, ) -> None: if not run_id: return @@ -288,6 +495,7 @@ def _finalize_run( speech_active_seconds=raw_audio_seconds, billable_seconds=raw_audio_seconds, retry_count=retry_count, + fallback_count=fallback_count, transcript_segment_count=len(segments), transcript_word_count=len(result.words), speaker_cluster_count=len(clusters), @@ -295,11 +503,19 @@ def _finalize_run( {segment.provider_cluster_id for segment in segments if segment.person_id} ), identity_confidence_summary=summarize_identity_confidences(confidences), + artifact_refs=_provider_artifact_refs(result), + fallback_provider=fallback_provider, ) except Exception as e: logger.warning('failed to finalize transcription provider run ledger run_id=%s: %s', run_id, e) +def _provider_artifact_refs(result: ProviderTranscriptResult) -> dict[str, str]: + if not result.raw_provider_result_id: + return {} + return {'provider_result_id': result.raw_provider_result_id} + + def _finalize_failed_run( run_id: Optional[str], provider: str, diff --git a/backend/utils/stt/providers.py b/backend/utils/stt/providers.py index cd041a87fda..5c83fb8533b 100644 --- a/backend/utils/stt/providers.py +++ b/backend/utils/stt/providers.py @@ -1,3 +1,4 @@ +import os from enum import Enum from typing import Callable, List, Optional, Protocol, Sequence, Tuple, Union @@ -5,6 +6,7 @@ class STTProviderName(str, Enum): + assemblyai = 'assemblyai' deepgram = 'deepgram' @@ -69,7 +71,7 @@ class SpeakerIdentityProvider(Protocol): provider_name: str -_PRERECORDED_WORKLOAD_PROVIDERS = { +_DEFAULT_PRERECORDED_WORKLOAD_PROVIDERS = { STTWorkload.background: STTProviderName.deepgram, STTWorkload.postprocess: STTProviderName.deepgram, STTWorkload.ptt: STTProviderName.deepgram, @@ -77,6 +79,12 @@ class SpeakerIdentityProvider(Protocol): STTWorkload.voice_message: STTProviderName.deepgram, } +_ASSEMBLYAI_ELIGIBLE_WORKLOADS = { + STTWorkload.background, + STTWorkload.postprocess, + STTWorkload.sync, +} + _STREAMING_WORKLOAD_PROVIDERS = { STTWorkload.ptt: STTProviderName.deepgram, STTWorkload.realtime: STTProviderName.deepgram, @@ -84,8 +92,42 @@ class SpeakerIdentityProvider(Protocol): def get_prerecorded_provider_name(workload: STTWorkload) -> STTProviderName: - return _PRERECORDED_WORKLOAD_PROVIDERS[STTWorkload(workload)] + workload = STTWorkload(workload) + if _assemblyai_background_enabled() and workload in _assemblyai_enabled_workloads(): + return STTProviderName.assemblyai + return _DEFAULT_PRERECORDED_WORKLOAD_PROVIDERS[workload] def get_streaming_provider_name(workload: STTWorkload) -> STTProviderName: return _STREAMING_WORKLOAD_PROVIDERS[STTWorkload(workload)] + + +def get_fallback_prerecorded_provider_name( + provider: STTProviderName, workload: STTWorkload +) -> Optional[STTProviderName]: + workload = STTWorkload(workload) + provider = STTProviderName(provider) + fallback = _DEFAULT_PRERECORDED_WORKLOAD_PROVIDERS[workload] + if provider != fallback: + return fallback + return None + + +def _assemblyai_background_enabled() -> bool: + return os.getenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'false').lower() == 'true' + + +def _assemblyai_enabled_workloads() -> set[STTWorkload]: + configured = os.getenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync,background,postprocess') + workloads = set() + for raw_value in configured.split(','): + value = raw_value.strip() + if not value: + continue + try: + workload = STTWorkload(value) + except ValueError: + continue + if workload in _ASSEMBLYAI_ELIGIBLE_WORKLOADS: + workloads.add(workload) + return workloads From f9af9c62291f696cec859f369a49f3dd93ea8e03 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 09:05:14 +0700 Subject: [PATCH 09/44] Add STT provider comparison gate --- .../scripts/stt/provider_comparison_gate.py | 224 ++++++++++ .../fixture_good_meeting.assemblyai.json | 23 ++ ...ixture_good_meeting.assemblyai.rollup.json | 12 + .../fixture_good_meeting.deepgram.json | 23 ++ .../fixture_good_meeting.deepgram.rollup.json | 12 + .../fixtures/stt_provider_eval/manifest.json | 11 + .../tests/unit/test_provider_evaluation.py | 104 +++++ backend/utils/stt/provider_evaluation.py | 389 ++++++++++++++++++ 8 files changed, 798 insertions(+) create mode 100644 backend/scripts/stt/provider_comparison_gate.py create mode 100644 backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.json create mode 100644 backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.rollup.json create mode 100644 backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.json create mode 100644 backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.rollup.json create mode 100644 backend/tests/fixtures/stt_provider_eval/manifest.json create mode 100644 backend/tests/unit/test_provider_evaluation.py create mode 100644 backend/utils/stt/provider_evaluation.py diff --git a/backend/scripts/stt/provider_comparison_gate.py b/backend/scripts/stt/provider_comparison_gate.py new file mode 100644 index 00000000000..63b72a70543 --- /dev/null +++ b/backend/scripts/stt/provider_comparison_gate.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import sys +from pathlib import Path +from typing import Any, Optional + +BACKEND_ROOT = Path(__file__).resolve().parents[2] +if str(BACKEND_ROOT) not in sys.path: + sys.path.insert(0, str(BACKEND_ROOT)) + +from utils.stt.provider_evaluation import ( # noqa: E402 + ProviderGateThresholds, + build_comparison_report, + compact_markdown_report, + evaluate_report_gates, +) + +try: + from utils.stt.provider_service import _transcribe_bytes_with_provider, _transcribe_url_with_provider # noqa: E402 + from utils.stt.providers import STTProviderName, STTWorkload # noqa: E402 + + LIVE_PROVIDER_IMPORT_ERROR = None +except ModuleNotFoundError as e: + _transcribe_bytes_with_provider = None + _transcribe_url_with_provider = None + STTProviderName = None + STTWorkload = None + LIVE_PROVIDER_IMPORT_ERROR = e + + +def main() -> int: + parser = argparse.ArgumentParser( + description='Compare Deepgram and AssemblyAI background transcription outputs and apply rollout gates.' + ) + parser.add_argument( + '--manifest', + action='append', + default=[], + help='JSON manifest with cases, fixtures, optional audio_url/audio_file, and ledger/rollup files.', + ) + parser.add_argument('--live', action='store_true', help='Run providers for manifest audio_url/audio_file cases.') + parser.add_argument('--uid', default=None, help='Optional uid used to write provider ledger rows during live runs.') + parser.add_argument( + '--conversation-prefix', default='stt-eval', help='Conversation id prefix for live ledger rows.' + ) + parser.add_argument('--output-json', default=None, help='Write full JSON report to this path.') + parser.add_argument('--output-md', default=None, help='Write compact Markdown report to this path.') + parser.add_argument('--fail-on-warning', action='store_true', help='Return non-zero for warning gates too.') + parser.add_argument('--max-wer', type=float, default=ProviderGateThresholds.max_transcript_word_error_rate) + parser.add_argument( + '--max-segment-delta-ratio', type=float, default=ProviderGateThresholds.max_segment_count_delta_ratio + ) + parser.add_argument( + '--max-timestamp-drift', type=float, default=ProviderGateThresholds.max_average_timestamp_drift_seconds + ) + parser.add_argument( + '--max-low-confidence-rate', type=float, default=ProviderGateThresholds.max_low_confidence_identity_rate + ) + parser.add_argument('--max-fallback-rate', type=float, default=ProviderGateThresholds.max_fallback_rate) + parser.add_argument('--max-failure-rate', type=float, default=ProviderGateThresholds.max_failure_rate) + parser.add_argument( + '--allow-missing-instrumentation', + action='store_true', + help='Do not warn when fixture cases omit provider ledger or rollup metrics.', + ) + args = parser.parse_args() + + if not args.manifest: + parser.error('at least one --manifest is required') + + thresholds = ProviderGateThresholds( + max_transcript_word_error_rate=args.max_wer, + max_segment_count_delta_ratio=args.max_segment_delta_ratio, + max_average_timestamp_drift_seconds=args.max_timestamp_drift, + max_low_confidence_identity_rate=args.max_low_confidence_rate, + max_fallback_rate=args.max_fallback_rate, + max_failure_rate=args.max_failure_rate, + require_instrumentation=not args.allow_missing_instrumentation, + ) + cases = [] + skipped_live_cases = [] + for manifest_path in args.manifest: + manifest = _load_json(Path(manifest_path)) + manifest_cases = manifest.get('cases') if isinstance(manifest, dict) else manifest + for index, case in enumerate(manifest_cases): + prepared = _prepare_case(case, Path(manifest_path).parent) + if args.live and _case_has_audio(case): + live_case = _run_live_case( + case, + base_path=Path(manifest_path).parent, + uid=args.uid, + conversation_id=f'{args.conversation_prefix}-{case.get("id", index)}', + ) + if live_case: + prepared = live_case + else: + skipped_live_cases.append(case.get('id') or str(index)) + cases.append(prepared) + + report = build_comparison_report(cases, thresholds) + if skipped_live_cases: + report['skipped_live_cases'] = skipped_live_cases + markdown = compact_markdown_report(report) + print(markdown) + + if args.output_json: + _write_text(Path(args.output_json), json.dumps(report, indent=2, sort_keys=True) + '\n') + if args.output_md: + _write_text(Path(args.output_md), markdown + '\n') + + passed, messages = evaluate_report_gates(report, fail_on_warning=args.fail_on_warning) + if not passed: + print('\nGate messages:', file=sys.stderr) + for message in messages: + print(f'- {message}', file=sys.stderr) + return 1 + return 0 + + +def _prepare_case(case: dict[str, Any], base_path: Path) -> dict[str, Any]: + prepared = {'id': case.get('id') or case.get('case_id')} + prepared['deepgram'] = _load_provider_payload(case, base_path, 'deepgram') + prepared['assemblyai'] = _load_provider_payload(case, base_path, 'assemblyai') + return prepared + + +def _load_provider_payload(case: dict[str, Any], base_path: Path, provider: str) -> dict[str, Any]: + payload = {} + inline = case.get(provider) + if inline: + payload.update(inline) + fixture_path = case.get(f'{provider}_fixture') or case.get(f'{provider}_transcript') + if fixture_path: + payload['transcript'] = _load_json(_resolve_path(base_path, fixture_path)) + ledger_path = case.get(f'{provider}_ledger') or case.get(f'{provider}_rollup') + if ledger_path: + payload['ledger'] = _load_json(_resolve_path(base_path, ledger_path)) + return payload + + +def _run_live_case( + case: dict[str, Any], + base_path: Path, + uid: Optional[str], + conversation_id: str, +) -> Optional[dict[str, Any]]: + if LIVE_PROVIDER_IMPORT_ERROR: + print( + f"Skipping live case {case.get('id', 'unknown')}: provider dependencies are unavailable " + f"({LIVE_PROVIDER_IMPORT_ERROR}).", + file=sys.stderr, + ) + return None + if not _credentials_available(): + print( + f"Skipping live case {case.get('id', 'unknown')}: " 'DEEPGRAM_API_KEY and ASSEMBLYAI_API_KEY are required.', + file=sys.stderr, + ) + return None + workload = STTWorkload(case.get('workload') or STTWorkload.background.value) + language = case.get('language') + raw_audio_seconds = float(case.get('raw_audio_seconds') or 0.0) + common_kwargs = { + 'workload': workload, + 'uid': uid, + 'conversation_id': conversation_id, + 'language': language, + 'model': case.get('model') or 'nova-3', + 'raw_audio_seconds': raw_audio_seconds, + 'return_language': bool(case.get('return_language', False)), + 'diarize': bool(case.get('diarize', True)), + } + if case.get('audio_url'): + deepgram = _transcribe_url_with_provider(STTProviderName.deepgram, case['audio_url'], **common_kwargs) + assemblyai = _transcribe_url_with_provider( + STTProviderName.assemblyai, + case['audio_url'], + **{**common_kwargs, 'model': case.get('assemblyai_model') or 'universal-2'}, + ) + else: + audio_bytes = _resolve_path(base_path, case['audio_file']).read_bytes() + deepgram = _transcribe_bytes_with_provider(STTProviderName.deepgram, audio_bytes, **common_kwargs) + assemblyai = _transcribe_bytes_with_provider( + STTProviderName.assemblyai, + audio_bytes, + **{**common_kwargs, 'model': case.get('assemblyai_model') or 'universal-2'}, + ) + del audio_bytes + return { + 'id': case.get('id') or conversation_id, + 'deepgram': {'transcript': {'segments': [segment.dict() for segment in deepgram.segments]}}, + 'assemblyai': {'transcript': {'segments': [segment.dict() for segment in assemblyai.segments]}}, + } + + +def _case_has_audio(case: dict[str, Any]) -> bool: + return bool(case.get('audio_url') or case.get('audio_file')) + + +def _credentials_available() -> bool: + return bool(os.getenv('DEEPGRAM_API_KEY') and os.getenv('ASSEMBLYAI_API_KEY')) + + +def _load_json(path: Path) -> Any: + with path.open('r', encoding='utf-8') as handle: + return json.load(handle) + + +def _write_text(path: Path, text: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(text, encoding='utf-8') + + +def _resolve_path(base_path: Path, raw_path: str) -> Path: + path = Path(raw_path) + if path.is_absolute(): + return path + return base_path / path + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.json b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.json new file mode 100644 index 00000000000..5d1f4cea332 --- /dev/null +++ b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.json @@ -0,0 +1,23 @@ +{ + "segments": [ + { + "text": "We should launch the beta next Monday.", + "start": 0.1, + "end": 2.3, + "provider_cluster_id": "A", + "provider_speaker_label": "ASSEMBLYAI_SPEAKER_A", + "person_id": "person_alex", + "speaker_identity_state": "identified", + "speaker_identity_confidence": 0.88 + }, + { + "text": "I will prepare the dashboard before then.", + "start": 2.6, + "end": 4.9, + "provider_cluster_id": "B", + "provider_speaker_label": "ASSEMBLYAI_SPEAKER_B", + "speaker_identity_state": "unknown", + "speaker_identity_confidence": 0.46 + } + ] +} diff --git a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.rollup.json b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.rollup.json new file mode 100644 index 00000000000..466c1afec6c --- /dev/null +++ b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.rollup.json @@ -0,0 +1,12 @@ +{ + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 5.0, + "speech_active_seconds": 4.5, + "billable_seconds": 5.0, + "estimated_cost_usd": 0.00020, + "retry_count": 0, + "fallback_count": 0 +} diff --git a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.json b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.json new file mode 100644 index 00000000000..b70a3d30afb --- /dev/null +++ b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.json @@ -0,0 +1,23 @@ +{ + "segments": [ + { + "text": "We should launch the beta next Monday.", + "start": 0.0, + "end": 2.2, + "provider_cluster_id": "0", + "provider_speaker_label": "SPEAKER_00", + "person_id": "person_alex", + "speaker_identity_state": "identified", + "speaker_identity_confidence": 0.91 + }, + { + "text": "I will prepare the dashboard before then.", + "start": 2.5, + "end": 4.8, + "provider_cluster_id": "1", + "provider_speaker_label": "SPEAKER_01", + "speaker_identity_state": "unknown", + "speaker_identity_confidence": 0.42 + } + ] +} diff --git a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.rollup.json b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.rollup.json new file mode 100644 index 00000000000..334f3acfde2 --- /dev/null +++ b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.rollup.json @@ -0,0 +1,12 @@ +{ + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 5.0, + "speech_active_seconds": 4.5, + "billable_seconds": 5.0, + "estimated_cost_usd": 0.00036, + "retry_count": 0, + "fallback_count": 0 +} diff --git a/backend/tests/fixtures/stt_provider_eval/manifest.json b/backend/tests/fixtures/stt_provider_eval/manifest.json new file mode 100644 index 00000000000..20a1aac9c25 --- /dev/null +++ b/backend/tests/fixtures/stt_provider_eval/manifest.json @@ -0,0 +1,11 @@ +{ + "cases": [ + { + "id": "fixture_good_meeting", + "deepgram_fixture": "fixture_good_meeting.deepgram.json", + "assemblyai_fixture": "fixture_good_meeting.assemblyai.json", + "deepgram_rollup": "fixture_good_meeting.deepgram.rollup.json", + "assemblyai_rollup": "fixture_good_meeting.assemblyai.rollup.json" + } + ] +} diff --git a/backend/tests/unit/test_provider_evaluation.py b/backend/tests/unit/test_provider_evaluation.py new file mode 100644 index 00000000000..188d793ae89 --- /dev/null +++ b/backend/tests/unit/test_provider_evaluation.py @@ -0,0 +1,104 @@ +import json +from pathlib import Path + +from utils.stt.provider_evaluation import ( + ProviderGateThresholds, + build_comparison_report, + compact_markdown_report, + evaluate_report_gates, + summarize_provider_output, +) + + +FIXTURE_DIR = Path(__file__).resolve().parents[1] / 'fixtures' / 'stt_provider_eval' + + +def _load_fixture_case() -> dict: + manifest = json.loads((FIXTURE_DIR / 'manifest.json').read_text()) + case = manifest['cases'][0] + return { + 'id': case['id'], + 'deepgram': { + 'transcript': json.loads((FIXTURE_DIR / case['deepgram_fixture']).read_text()), + 'ledger': json.loads((FIXTURE_DIR / case['deepgram_rollup']).read_text()), + }, + 'assemblyai': { + 'transcript': json.loads((FIXTURE_DIR / case['assemblyai_fixture']).read_text()), + 'ledger': json.loads((FIXTURE_DIR / case['assemblyai_rollup']).read_text()), + }, + } + + +def test_fixture_report_passes_and_includes_cost_identity_and_timing_metrics(): + report = build_comparison_report([_load_fixture_case()]) + + assert report['status'] == 'passed' + assert report['case_count'] == 1 + assert report['aggregate']['assemblyai_estimated_cost_usd'] == 0.00020 + case = report['cases'][0] + assert case['comparison']['transcript_word_error_rate'] == 0.0 + assert case['comparison']['average_timestamp_drift_seconds'] > 0.0 + assert case['providers']['assemblyai']['speaker_cluster_count'] == 2 + assert case['providers']['assemblyai']['identified_speaker_cluster_count'] == 1 + assert case['providers']['assemblyai']['unknown_speaker_cluster_count'] == 1 + assert case['providers']['assemblyai']['low_confidence_identity_rate'] == 0.5 + assert all(gate['severity'] == 'pass' for gate in case['gates']) + + +def test_threshold_failures_are_reported_for_transcript_drift_and_fallback(): + case = _load_fixture_case() + case['assemblyai']['transcript']['segments'][0]['text'] = 'Completely unrelated output.' + case['assemblyai']['ledger']['fallback_count'] = 1 + + report = build_comparison_report( + [case], + ProviderGateThresholds(max_transcript_word_error_rate=0.05, max_fallback_rate=0.10), + ) + passed, messages = evaluate_report_gates(report) + + assert report['status'] == 'failed' + assert not passed + assert any('transcript_word_error_rate' in message for message in messages) + assert any('assemblyai_fallback_rate' in message for message in messages) + + +def test_missing_instrumentation_is_warning_not_failure_by_default(): + case = _load_fixture_case() + case['assemblyai'].pop('ledger') + + report = build_comparison_report([case]) + passed, messages = evaluate_report_gates(report) + strict_passed, strict_messages = evaluate_report_gates(report, fail_on_warning=True) + + assert report['status'] == 'passed' + assert passed + assert messages == [] + assert not strict_passed + assert any('assemblyai_instrumentation' in message for message in strict_messages) + + +def test_provider_result_words_are_grouped_into_cluster_segments(): + payload = { + 'provider': 'assemblyai', + 'model': 'universal-2', + 'words': [ + {'text': 'hello', 'start': 0.0, 'end': 0.2, 'provider_cluster_id': 'A'}, + {'text': 'there', 'start': 0.2, 'end': 0.5, 'provider_cluster_id': 'A'}, + {'text': 'hi', 'start': 0.7, 'end': 0.9, 'provider_cluster_id': 'B'}, + ], + } + + summary = summarize_provider_output('assemblyai', {'transcript': payload}) + + assert summary['segment_count'] == 2 + assert summary['word_count'] == 3 + assert summary['speaker_cluster_count'] == 2 + + +def test_compact_markdown_report_is_review_friendly(): + report = build_comparison_report([_load_fixture_case()]) + markdown = compact_markdown_report(report) + + assert '# STT Provider Evaluation: PASSED' in markdown + assert 'fixture_good_meeting' in markdown + assert 'AssemblyAI cost' in markdown diff --git a/backend/utils/stt/provider_evaluation.py b/backend/utils/stt/provider_evaluation.py new file mode 100644 index 00000000000..ccbb68c1d93 --- /dev/null +++ b/backend/utils/stt/provider_evaluation.py @@ -0,0 +1,389 @@ +from dataclasses import dataclass +from typing import Any, Optional + + +@dataclass(frozen=True) +class ProviderGateThresholds: + max_transcript_word_error_rate: float = 0.35 + max_segment_count_delta_ratio: float = 0.75 + max_average_timestamp_drift_seconds: float = 2.5 + max_low_confidence_identity_rate: float = 0.50 + max_fallback_rate: float = 0.10 + max_failure_rate: float = 0.05 + require_instrumentation: bool = True + + +def build_comparison_report( + cases: list[dict[str, Any]], + thresholds: Optional[ProviderGateThresholds] = None, +) -> dict[str, Any]: + thresholds = thresholds or ProviderGateThresholds() + case_reports = [_compare_case(case, thresholds) for case in cases] + failures = [gate for case in case_reports for gate in case['gates'] if gate['severity'] == 'failure'] + warnings = [gate for case in case_reports for gate in case['gates'] if gate['severity'] == 'warning'] + return { + 'status': 'failed' if failures else 'passed', + 'case_count': len(case_reports), + 'failure_count': len(failures), + 'warning_count': len(warnings), + 'aggregate': _aggregate_case_reports(case_reports), + 'cases': case_reports, + } + + +def evaluate_report_gates(report: dict[str, Any], fail_on_warning: bool = False) -> tuple[bool, list[str]]: + messages = [] + for case in report.get('cases', []): + for gate in case.get('gates', []): + if gate.get('severity') == 'failure' or (fail_on_warning and gate.get('severity') == 'warning'): + messages.append(f"{case.get('id', 'unknown')}: {gate.get('metric')} {gate.get('message')}") + return not messages, messages + + +def compact_markdown_report(report: dict[str, Any]) -> str: + aggregate = report.get('aggregate', {}) + lines = [ + f"# STT Provider Evaluation: {report.get('status', 'unknown').upper()}", + '', + '| Cases | Failures | Warnings | Avg WER | Avg timestamp drift | AssemblyAI cost | Deepgram cost |', + '| --- | ---: | ---: | ---: | ---: | ---: | ---: |', + ( + f"| {report.get('case_count', 0)} | {report.get('failure_count', 0)} | " + f"{report.get('warning_count', 0)} | {_fmt_pct(aggregate.get('average_word_error_rate'))} | " + f"{_fmt_seconds(aggregate.get('average_timestamp_drift_seconds'))} | " + f"${aggregate.get('assemblyai_estimated_cost_usd', 0.0):.4f} | " + f"${aggregate.get('deepgram_estimated_cost_usd', 0.0):.4f} |" + ), + '', + '| Case | WER | Segments DG/AAI | Clusters DG/AAI | Unknown AAI | Low-conf AAI | Fallback AAI | Gates |', + '| --- | ---: | ---: | ---: | ---: | ---: | ---: | --- |', + ] + for case in report.get('cases', []): + deepgram = case['providers']['deepgram'] + assemblyai = case['providers']['assemblyai'] + gates = ', '.join( + f"{gate['severity']}:{gate['metric']}" for gate in case.get('gates', []) if gate['severity'] != 'pass' + ) + lines.append( + f"| {case['id']} | {_fmt_pct(case['comparison']['transcript_word_error_rate'])} | " + f"{deepgram['segment_count']}/{assemblyai['segment_count']} | " + f"{deepgram['speaker_cluster_count']}/{assemblyai['speaker_cluster_count']} | " + f"{assemblyai['unknown_speaker_cluster_count']} | " + f"{_fmt_pct(assemblyai['low_confidence_identity_rate'])} | " + f"{_fmt_pct(assemblyai['fallback_rate'])} | {gates or 'pass'} |" + ) + return '\n'.join(lines) + + +def _compare_case(case: dict[str, Any], thresholds: ProviderGateThresholds) -> dict[str, Any]: + deepgram = summarize_provider_output('deepgram', case.get('deepgram') or {}) + assemblyai = summarize_provider_output('assemblyai', case.get('assemblyai') or {}) + comparison = { + 'transcript_word_error_rate': _word_error_rate(deepgram['text'], assemblyai['text']), + 'segment_count_delta': assemblyai['segment_count'] - deepgram['segment_count'], + 'segment_count_delta_ratio': _ratio_delta(assemblyai['segment_count'], deepgram['segment_count']), + 'word_count_delta': assemblyai['word_count'] - deepgram['word_count'], + 'word_count_delta_ratio': _ratio_delta(assemblyai['word_count'], deepgram['word_count']), + 'average_timestamp_drift_seconds': _average_timestamp_drift_seconds( + deepgram['segments'], assemblyai['segments'] + ), + } + gates = _evaluate_case_gates(deepgram, assemblyai, comparison, thresholds) + return { + 'id': case.get('id') or case.get('case_id') or 'unknown', + 'providers': {'deepgram': _public_summary(deepgram), 'assemblyai': _public_summary(assemblyai)}, + 'comparison': comparison, + 'gates': gates, + } + + +def summarize_provider_output(provider: str, payload: dict[str, Any]) -> dict[str, Any]: + transcript = payload.get('transcript') or payload.get('result') or payload + segments = _extract_segments(transcript) + ledger = payload.get('ledger') or payload.get('rollup') or {} + clusters = _speaker_clusters(segments) + identified_clusters = { + _cluster_id(segment) + for segment in segments + if _cluster_id(segment) + and ( + segment.get('person_id') + or segment.get('is_user') is True + or segment.get('speaker_identity_state') in ('identified', 'user') + ) + } + identity_confidences = [ + segment.get('speaker_identity_confidence') + for segment in segments + if segment.get('speaker_identity_confidence') is not None + ] + low_confidence_count = sum(1 for confidence in identity_confidences if float(confidence) < 0.50) + return { + 'provider': provider, + 'segments': segments, + 'text': _transcript_text(segments), + 'segment_count': len(segments), + 'word_count': sum(len(_words(segment.get('text', ''))) for segment in segments), + 'speaker_cluster_count': len(clusters), + 'identified_speaker_cluster_count': len(identified_clusters), + 'unknown_speaker_cluster_count': max(len(clusters) - len(identified_clusters), 0), + 'low_confidence_identity_count': low_confidence_count, + 'low_confidence_identity_rate': ( + low_confidence_count / len(identity_confidences) if identity_confidences else 0.0 + ), + 'raw_audio_seconds': _number_from_ledger(ledger, 'raw_audio_seconds'), + 'speech_active_seconds': _number_from_ledger(ledger, 'speech_active_seconds'), + 'billable_seconds': _number_from_ledger(ledger, 'billable_seconds'), + 'estimated_cost_usd': _number_from_ledger(ledger, 'estimated_cost_usd'), + 'retry_count': _number_from_ledger(ledger, 'retry_count'), + 'fallback_count': _number_from_ledger(ledger, 'fallback_count'), + 'fallback_rate': _rate_from_ledger(ledger, 'fallback_count'), + 'failure_rate': _failure_rate_from_ledger(ledger), + 'has_instrumentation': bool(ledger), + } + + +def _public_summary(summary: dict[str, Any]) -> dict[str, Any]: + return {key: value for key, value in summary.items() if key not in ('segments', 'text')} + + +def _extract_segments(transcript: Any) -> list[dict[str, Any]]: + if isinstance(transcript, dict) and isinstance(transcript.get('segments'), list): + return [_normalize_segment(segment) for segment in transcript['segments']] + if isinstance(transcript, list): + return [_normalize_segment(segment) for segment in transcript] + if isinstance(transcript, dict) and isinstance(transcript.get('utterances'), list) and transcript.get('utterances'): + return [_normalize_provider_utterance(utterance, transcript) for utterance in transcript['utterances']] + if isinstance(transcript, dict) and isinstance(transcript.get('words'), list): + return _segments_from_words(transcript['words'], transcript) + if isinstance(transcript, dict) and transcript.get('text'): + return [_normalize_segment(transcript)] + return [] + + +def _normalize_segment(segment: dict[str, Any]) -> dict[str, Any]: + return { + 'text': str(segment.get('text') or segment.get('transcript') or '').strip(), + 'start': float(segment.get('start') or 0.0), + 'end': float(segment.get('end') or 0.0), + 'provider_cluster_id': segment.get('provider_cluster_id'), + 'provider_speaker_label': segment.get('provider_speaker_label'), + 'speaker': segment.get('speaker'), + 'person_id': segment.get('person_id'), + 'is_user': segment.get('is_user'), + 'speaker_identity_state': segment.get('speaker_identity_state'), + 'speaker_identity_confidence': segment.get('speaker_identity_confidence'), + } + + +def _normalize_provider_utterance(utterance: dict[str, Any], transcript: dict[str, Any]) -> dict[str, Any]: + normalized = _normalize_segment(utterance) + normalized['provider_cluster_id'] = utterance.get('provider_cluster_id') + normalized['provider_speaker_label'] = utterance.get('speaker_label') or utterance.get('provider_speaker_label') + normalized['stt_provider'] = transcript.get('provider') + normalized['stt_model'] = transcript.get('model') + return normalized + + +def _segments_from_words(words: list[dict[str, Any]], transcript: dict[str, Any]) -> list[dict[str, Any]]: + segments = [] + current = None + for word in words: + cluster_id = word.get('provider_cluster_id') or word.get('speaker') + if current is None or current.get('provider_cluster_id') != cluster_id: + if current: + current['text'] = ' '.join(current.pop('_words')) + segments.append(current) + current = { + 'text': '', + '_words': [], + 'start': float(word.get('start') or 0.0), + 'end': float(word.get('end') or 0.0), + 'provider_cluster_id': cluster_id, + 'provider_speaker_label': word.get('speaker_label') or word.get('provider_speaker_label'), + 'stt_provider': transcript.get('provider'), + 'stt_model': transcript.get('model'), + } + current['_words'].append(str(word.get('text') or word.get('word') or '')) + current['end'] = float(word.get('end') or current['end']) + if current: + current['text'] = ' '.join(current.pop('_words')) + segments.append(current) + return segments + + +def _speaker_clusters(segments: list[dict[str, Any]]) -> set[str]: + return {cluster_id for cluster_id in (_cluster_id(segment) for segment in segments) if cluster_id} + + +def _cluster_id(segment: dict[str, Any]) -> Optional[str]: + return segment.get('provider_cluster_id') or segment.get('provider_speaker_label') or segment.get('speaker') + + +def _transcript_text(segments: list[dict[str, Any]]) -> str: + return ' '.join(segment.get('text', '') for segment in segments).strip() + + +def _words(text: str) -> list[str]: + normalized = ''.join(character.lower() if character.isalnum() else ' ' for character in text) + return [word for word in normalized.split() if word] + + +def _word_error_rate(reference: str, hypothesis: str) -> float: + reference_words = _words(reference) + hypothesis_words = _words(hypothesis) + if not reference_words: + return 0.0 if not hypothesis_words else 1.0 + return _levenshtein_distance(reference_words, hypothesis_words) / len(reference_words) + + +def _levenshtein_distance(reference: list[str], hypothesis: list[str]) -> int: + previous = list(range(len(hypothesis) + 1)) + for index, reference_word in enumerate(reference, start=1): + current = [index] + for other_index, hypothesis_word in enumerate(hypothesis, start=1): + current.append( + min( + previous[other_index] + 1, + current[other_index - 1] + 1, + previous[other_index - 1] + (reference_word != hypothesis_word), + ) + ) + previous = current + return previous[-1] + + +def _ratio_delta(value: float, baseline: float) -> float: + if baseline == 0: + return 0.0 if value == 0 else 1.0 + return abs(value - baseline) / baseline + + +def _average_timestamp_drift_seconds(reference: list[dict[str, Any]], candidate: list[dict[str, Any]]) -> float: + pair_count = min(len(reference), len(candidate)) + if not pair_count: + return 0.0 + drift = 0.0 + for index in range(pair_count): + drift += abs(reference[index].get('start', 0.0) - candidate[index].get('start', 0.0)) + drift += abs(reference[index].get('end', 0.0) - candidate[index].get('end', 0.0)) + return drift / (pair_count * 2) + + +def _evaluate_case_gates( + deepgram: dict[str, Any], + assemblyai: dict[str, Any], + comparison: dict[str, Any], + thresholds: ProviderGateThresholds, +) -> list[dict[str, Any]]: + gates = [ + _threshold_gate( + 'transcript_word_error_rate', + comparison['transcript_word_error_rate'], + thresholds.max_transcript_word_error_rate, + 'failure', + ), + _threshold_gate( + 'segment_count_delta_ratio', + comparison['segment_count_delta_ratio'], + thresholds.max_segment_count_delta_ratio, + 'failure', + ), + _threshold_gate( + 'average_timestamp_drift_seconds', + comparison['average_timestamp_drift_seconds'], + thresholds.max_average_timestamp_drift_seconds, + 'warning', + ), + _threshold_gate( + 'assemblyai_low_confidence_identity_rate', + assemblyai['low_confidence_identity_rate'], + thresholds.max_low_confidence_identity_rate, + 'warning', + ), + _threshold_gate( + 'assemblyai_fallback_rate', assemblyai['fallback_rate'], thresholds.max_fallback_rate, 'failure' + ), + _threshold_gate('assemblyai_failure_rate', assemblyai['failure_rate'], thresholds.max_failure_rate, 'failure'), + ] + if thresholds.require_instrumentation: + for provider in (deepgram, assemblyai): + if not provider['has_instrumentation']: + gates.append( + { + 'metric': f"{provider['provider']}_instrumentation", + 'severity': 'warning', + 'value': None, + 'threshold': 'ledger_or_rollup_required', + 'message': 'missing provider ledger or rollup metrics', + } + ) + return gates + + +def _threshold_gate(metric: str, value: float, threshold: float, severity: str) -> dict[str, Any]: + passed = value <= threshold + return { + 'metric': metric, + 'severity': 'pass' if passed else severity, + 'value': value, + 'threshold': threshold, + 'message': 'within threshold' if passed else f'{value:.4f} exceeds {threshold:.4f}', + } + + +def _number_from_ledger(ledger: dict[str, Any], field: str) -> float: + value = ledger.get(field, 0.0) + if isinstance(value, dict) and '__increment' in value: + value = value['__increment'] + return float(value or 0.0) + + +def _rate_from_ledger(ledger: dict[str, Any], count_field: str) -> float: + denominator = float(ledger.get('run_count') or 1.0) + return _number_from_ledger(ledger, count_field) / denominator + + +def _failure_rate_from_ledger(ledger: dict[str, Any]) -> float: + status_counts = ledger.get('status_counts') or {} + failed = status_counts.get('failed') or status_counts.get('failure') or 0 + denominator = float(ledger.get('run_count') or 1.0) + return float(failed) / denominator + + +def _aggregate_case_reports(case_reports: list[dict[str, Any]]) -> dict[str, Any]: + if not case_reports: + return {} + return { + 'average_word_error_rate': _average(case['comparison']['transcript_word_error_rate'] for case in case_reports), + 'average_timestamp_drift_seconds': _average( + case['comparison']['average_timestamp_drift_seconds'] for case in case_reports + ), + 'assemblyai_estimated_cost_usd': sum( + case['providers']['assemblyai']['estimated_cost_usd'] for case in case_reports + ), + 'deepgram_estimated_cost_usd': sum( + case['providers']['deepgram']['estimated_cost_usd'] for case in case_reports + ), + 'assemblyai_billable_seconds': sum( + case['providers']['assemblyai']['billable_seconds'] for case in case_reports + ), + 'deepgram_billable_seconds': sum(case['providers']['deepgram']['billable_seconds'] for case in case_reports), + } + + +def _average(values) -> float: + values = list(values) + return sum(values) / len(values) if values else 0.0 + + +def _fmt_pct(value) -> str: + if value is None: + return 'n/a' + return f'{float(value) * 100:.1f}%' + + +def _fmt_seconds(value) -> str: + if value is None: + return 'n/a' + return f'{float(value):.2f}s' From 25c4a010b1c107506c5b4fea6aae4389ea519301 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 09:09:10 +0700 Subject: [PATCH 10/44] Fix provider fallback metric direction --- .../database/transcription_provider_usage.py | 2 +- .../unit/test_transcription_provider_usage.py | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/backend/database/transcription_provider_usage.py b/backend/database/transcription_provider_usage.py index 2db80e48212..2d5a8e54cfd 100644 --- a/backend/database/transcription_provider_usage.py +++ b/backend/database/transcription_provider_usage.py @@ -342,8 +342,8 @@ def emit_provider_run_metrics( observe_transcription_provider_retry(provider, model, workload, 'provider_retry', retry_count) if fallback_count > 0: observe_transcription_provider_fallback( - provider, fallback_provider or 'unknown', + provider, workload, fallback_reason, fallback_count, diff --git a/backend/tests/unit/test_transcription_provider_usage.py b/backend/tests/unit/test_transcription_provider_usage.py index 3e655cb8fac..bdb7c12882d 100644 --- a/backend/tests/unit/test_transcription_provider_usage.py +++ b/backend/tests/unit/test_transcription_provider_usage.py @@ -229,6 +229,38 @@ def test_metrics_reject_high_cardinality_labels(): metrics._provider_metric_labels(provider='assemblyai', transcript_text='hello world') +def test_fallback_metric_records_failed_provider_to_fallback_provider(monkeypatch): + observed = [] + monkeypatch.setattr(usage, 'observe_transcription_provider_request', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_audio_seconds', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_retry', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_speaker_clusters', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_identity_confidence', lambda *args, **kwargs: None) + monkeypatch.setattr( + usage, + 'observe_transcription_provider_fallback', + lambda *args, **kwargs: observed.append((args, kwargs)), + ) + + usage.emit_provider_run_metrics( + provider='deepgram', + model='nova-3', + workload='sync', + status='succeeded', + latency_seconds=1.0, + raw_audio_seconds=2.0, + speech_active_seconds=2.0, + billable_seconds=2.0, + retry_count=0, + fallback_count=1, + speaker_cluster_count=0, + identified_speaker_cluster_count=0, + fallback_provider='assemblyai', + ) + + assert observed == [(('assemblyai', 'deepgram', 'sync', 'provider_failure', 1), {})] + + def test_provider_metrics_source_does_not_define_forbidden_label_names(): forbidden_labels = { "['provider', 'model', 'workload', 'user_id']", From e6e4e99be90eba03981a3dca25a97b15a9ecb413 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 09:12:45 +0700 Subject: [PATCH 11/44] Add provider transcription cost estimates --- .../unit/test_background_provider_service.py | 60 +++++++++++++ .../unit/test_transcription_provider_usage.py | 1 + backend/utils/stt/provider_costs.py | 88 +++++++++++++++++++ backend/utils/stt/provider_service.py | 10 ++- 4 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 backend/utils/stt/provider_costs.py diff --git a/backend/tests/unit/test_background_provider_service.py b/backend/tests/unit/test_background_provider_service.py index 36a31cefea3..3e66ca52b6a 100644 --- a/backend/tests/unit/test_background_provider_service.py +++ b/backend/tests/unit/test_background_provider_service.py @@ -16,6 +16,7 @@ from models.transcript_segment import ProviderTranscriptResult, ProviderTranscriptWord # noqa: E402 from utils.stt import provider_service # noqa: E402 +from utils.stt.provider_costs import estimate_prerecorded_provider_cost_usd # noqa: E402 from utils.stt.providers import STTProviderName, STTWorkload, get_prerecorded_provider_name # noqa: E402 @@ -76,6 +77,7 @@ def test_provider_service_transcribes_sync_upload_and_finalizes_deepgram_run(): assert finalize_run.call_args.kwargs['workload'] == 'sync' assert finalize_run.call_args.kwargs['status'] == 'succeeded' assert finalize_run.call_args.kwargs['transcript_segment_count'] == 1 + assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00016 def test_provider_service_finalizes_background_run_on_deepgram_default(): @@ -99,6 +101,7 @@ def test_provider_service_finalizes_background_run_on_deepgram_default(): assert finalize_run.call_args.kwargs['workload'] == 'background' assert finalize_run.call_args.kwargs['provider'] == 'deepgram' assert finalize_run.call_args.kwargs['raw_audio_seconds'] == 9.5 + assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00076 def test_prerecorded_ptt_and_realtime_related_workloads_stay_deepgram(): @@ -172,6 +175,8 @@ def test_provider_service_uses_assemblyai_for_enabled_sync_workload(monkeypatch) finalize_run.assert_called_once() assert finalize_run.call_args.kwargs['provider'] == 'assemblyai' assert finalize_run.call_args.kwargs['artifact_refs'] == {} + assert finalize_run.call_args.kwargs['billable_seconds'] == 2.0 + assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00008333 def test_provider_service_falls_back_to_deepgram_when_assemblyai_fails(monkeypatch): @@ -213,6 +218,61 @@ def test_provider_service_falls_back_to_deepgram_when_assemblyai_fails(monkeypat assert finalize_run.call_args_list[1].kwargs['provider'] == 'deepgram' assert finalize_run.call_args_list[1].kwargs['fallback_count'] == 1 assert finalize_run.call_args_list[1].kwargs['fallback_provider'] == 'assemblyai' + assert finalize_run.call_args_list[1].kwargs['estimated_cost_usd'] == 0.00016 + + +def test_provider_service_records_zero_cost_for_zero_duration_success(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync') + + fake_provider = MagicMock() + fake_provider.provider_name = STTProviderName.assemblyai + provider_result = _provider_result(provider='assemblyai', model='universal-2') + provider_result.duration = 0.0 + fake_provider.transcribe_url.return_value = provider_result + + with patch.object(provider_service, '_assemblyai_prerecorded_provider', return_value=fake_provider), patch.object( + provider_service, 'create_provider_run', return_value='run-zero' + ), patch.object(provider_service, 'finalize_provider_run') as finalize_run: + provider_service.transcribe_url( + 'https://example.test/zero.wav', + workload=STTWorkload.sync, + uid='uid-1', + raw_audio_seconds=0.0, + ) + + assert finalize_run.call_args.kwargs['billable_seconds'] == 0.0 + assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.0 + + +def test_prerecorded_cost_estimator_uses_provider_defaults_and_unknown_provider_zero(): + assert ( + estimate_prerecorded_provider_cost_usd( + provider='assemblyai', + model='future-model', + workload='background', + billable_seconds=60.0, + ) + == 0.0025 + ) + assert ( + estimate_prerecorded_provider_cost_usd( + provider='deepgram', + model='future-model', + workload='background', + billable_seconds=60.0, + ) + == 0.0048 + ) + assert ( + estimate_prerecorded_provider_cost_usd( + provider='unknown-provider', + model='future-model', + workload='background', + billable_seconds=60.0, + ) + == 0.0 + ) def test_provider_service_live_assemblyai_smoke_records_ledger_when_credentials_are_present(monkeypatch): diff --git a/backend/tests/unit/test_transcription_provider_usage.py b/backend/tests/unit/test_transcription_provider_usage.py index bdb7c12882d..d3ea4797171 100644 --- a/backend/tests/unit/test_transcription_provider_usage.py +++ b/backend/tests/unit/test_transcription_provider_usage.py @@ -146,6 +146,7 @@ def test_create_and_finalize_provider_run_writes_ledger_rollup_and_metrics(monke rollup = rollup_doc.set_calls[0]['data'] assert rollup['run_count'] == {'__increment': 1} assert rollup['raw_audio_seconds'] == {'__increment': 60.0} + assert rollup['estimated_cost_usd'] == {'__increment': 0.37} assert rollup['identity_confidence_counts.high'] == {'__increment': 2} assert emitted[0]['latency_seconds'] == 5.0 assert emitted[0]['billable_seconds'] == 60.0 diff --git a/backend/utils/stt/provider_costs.py b/backend/utils/stt/provider_costs.py new file mode 100644 index 00000000000..c44e714ec6a --- /dev/null +++ b/backend/utils/stt/provider_costs.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass +from typing import Optional + +from utils.stt.providers import STTProviderName, STTWorkload + + +@dataclass(frozen=True) +class PrerecordedProviderCostRate: + usd_per_billable_second: float + source: str + + +# Pay-as-you-go public STT pricing, checked 2026-05-21. +# Add-on features and customer-specific committed-use discounts are intentionally +# excluded until their usage is represented explicitly in provider run metadata. +_PRERECORDED_STT_COST_RATES: dict[str, dict[str, PrerecordedProviderCostRate]] = { + STTProviderName.assemblyai.value: { + # AssemblyAI pricing: Universal-2 $0.15/hr, Universal-3 Pro $0.21/hr. + 'universal-2': PrerecordedProviderCostRate( + usd_per_billable_second=0.15 / 3600, + source='assemblyai_prerecorded_payg_2026_05_21', + ), + 'universal-3-pro': PrerecordedProviderCostRate( + usd_per_billable_second=0.21 / 3600, + source='assemblyai_prerecorded_payg_2026_05_21', + ), + 'u3-pro': PrerecordedProviderCostRate( + usd_per_billable_second=0.21 / 3600, + source='assemblyai_prerecorded_payg_2026_05_21', + ), + 'default': PrerecordedProviderCostRate( + usd_per_billable_second=0.15 / 3600, + source='assemblyai_prerecorded_default_2026_05_21', + ), + }, + STTProviderName.deepgram.value: { + # Deepgram pricing: Nova-3 monolingual pre-recorded $0.0048/min, + # Nova-3 multilingual pre-recorded $0.0058/min. + 'nova-3': PrerecordedProviderCostRate( + usd_per_billable_second=0.0048 / 60, + source='deepgram_prerecorded_payg_2026_05_21', + ), + 'nova-3-general': PrerecordedProviderCostRate( + usd_per_billable_second=0.0048 / 60, + source='deepgram_prerecorded_payg_2026_05_21', + ), + 'nova-3-multilingual': PrerecordedProviderCostRate( + usd_per_billable_second=0.0058 / 60, + source='deepgram_prerecorded_payg_2026_05_21', + ), + 'default': PrerecordedProviderCostRate( + usd_per_billable_second=0.0048 / 60, + source='deepgram_prerecorded_default_2026_05_21', + ), + }, +} + +_PRERECORDED_COST_WORKLOADS = { + STTWorkload.background.value, + STTWorkload.postprocess.value, + STTWorkload.ptt.value, + STTWorkload.sync.value, + STTWorkload.voice_message.value, +} + + +def estimate_prerecorded_provider_cost_usd( + provider: str, + model: Optional[str], + workload: str, + billable_seconds: float, +) -> float: + if billable_seconds <= 0: + return 0.0 + if str(workload) not in _PRERECORDED_COST_WORKLOADS: + return 0.0 + rate = prerecorded_provider_cost_rate(provider, model) + if not rate: + return 0.0 + return round(float(billable_seconds) * rate.usd_per_billable_second, 8) + + +def prerecorded_provider_cost_rate(provider: str, model: Optional[str]) -> Optional[PrerecordedProviderCostRate]: + provider_rates = _PRERECORDED_STT_COST_RATES.get(str(provider or '').lower()) + if not provider_rates: + return None + normalized_model = str(model or '').strip().lower() + return provider_rates.get(normalized_model) or provider_rates['default'] diff --git a/backend/utils/stt/provider_service.py b/backend/utils/stt/provider_service.py index d5fba11402e..5ebd06970fa 100644 --- a/backend/utils/stt/provider_service.py +++ b/backend/utils/stt/provider_service.py @@ -8,6 +8,7 @@ from utils.stt.assemblyai_adapter import AssemblyAIAsyncTranscriptionProvider from utils.stt.conversation_reconstructor import reconstruct_conversation from utils.stt.deepgram_adapter import provider_result_to_legacy_words +from utils.stt.provider_costs import estimate_prerecorded_provider_cost_usd from utils.stt.providers import ( STTProviderName, STTWorkload, @@ -479,6 +480,7 @@ def _finalize_run( ) -> None: if not run_id: return + billable_seconds = raw_audio_seconds clusters = { item.provider_cluster_id for item in list(result.words) + list(result.utterances) if item.provider_cluster_id } @@ -493,7 +495,13 @@ def _finalize_run( started_at=started_at, raw_audio_seconds=raw_audio_seconds, speech_active_seconds=raw_audio_seconds, - billable_seconds=raw_audio_seconds, + billable_seconds=billable_seconds, + estimated_cost_usd=estimate_prerecorded_provider_cost_usd( + provider=result.provider, + model=result.model, + workload=workload.value, + billable_seconds=billable_seconds, + ), retry_count=retry_count, fallback_count=fallback_count, transcript_segment_count=len(segments), From 373ff298627496dc06c47f1db692f9d699ece808 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 09:16:55 +0700 Subject: [PATCH 12/44] Tighten transcription provider retry metrics --- .../database/transcription_provider_usage.py | 21 ++++ .../unit/test_background_provider_service.py | 67 ++++++++++ .../unit/test_transcription_provider_usage.py | 119 ++++++++++++++++++ backend/utils/stt/provider_service.py | 76 +++++++++-- 4 files changed, 275 insertions(+), 8 deletions(-) diff --git a/backend/database/transcription_provider_usage.py b/backend/database/transcription_provider_usage.py index 2d5a8e54cfd..9066f695664 100644 --- a/backend/database/transcription_provider_usage.py +++ b/backend/database/transcription_provider_usage.py @@ -193,6 +193,12 @@ def finalize_provider_run( 'identity_confidence_summary': summary, 'error_class': error_class, 'artifact_refs': artifact_refs or {}, + 'fallback': _fallback_details( + fallback_count=fallback_count, + from_provider=fallback_provider, + to_provider=provider, + reason=fallback_reason, + ), 'updated_at': completed_at, } _reject_forbidden_payload_keys(payload) @@ -234,6 +240,21 @@ def finalize_provider_run( ) +def _fallback_details( + fallback_count: int, + from_provider: Optional[str], + to_provider: str, + reason: str, +) -> Optional[dict[str, Any]]: + if fallback_count <= 0: + return None + return { + 'from_provider': from_provider or 'unknown', + 'to_provider': to_provider, + 'reason': reason, + } + + def increment_daily_rollup( day: str, provider: str, diff --git a/backend/tests/unit/test_background_provider_service.py b/backend/tests/unit/test_background_provider_service.py index 3e66ca52b6a..5da6d1a94ac 100644 --- a/backend/tests/unit/test_background_provider_service.py +++ b/backend/tests/unit/test_background_provider_service.py @@ -214,6 +214,8 @@ def test_provider_service_falls_back_to_deepgram_when_assemblyai_fails(monkeypat assert finalize_run.call_args_list[0].kwargs['run_id'] == 'run-aai' assert finalize_run.call_args_list[0].kwargs['provider'] == 'assemblyai' assert finalize_run.call_args_list[0].kwargs['status'] == 'failed' + assert finalize_run.call_args_list[0].kwargs['retry_count'] == 1 + assert finalize_run.call_args_list[0].kwargs['error_class'] == 'RuntimeError' assert finalize_run.call_args_list[1].kwargs['run_id'] == 'run-dg' assert finalize_run.call_args_list[1].kwargs['provider'] == 'deepgram' assert finalize_run.call_args_list[1].kwargs['fallback_count'] == 1 @@ -221,6 +223,71 @@ def test_provider_service_falls_back_to_deepgram_when_assemblyai_fails(monkeypat assert finalize_run.call_args_list[1].kwargs['estimated_cost_usd'] == 0.00016 +def test_provider_service_records_retry_exhaustion_without_fallback(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'background') + + assemblyai_provider = MagicMock() + assemblyai_provider.provider_name = STTProviderName.assemblyai + assemblyai_provider.transcribe_url.side_effect = RuntimeError('AssemblyAI failed') + + with patch.object( + provider_service, '_assemblyai_prerecorded_provider', return_value=assemblyai_provider + ), patch.object(provider_service, 'create_provider_run', return_value='run-aai'), patch.object( + provider_service, 'finalize_provider_run' + ) as finalize_run, patch.object( + provider_service, 'get_fallback_prerecorded_provider_name', return_value=None + ): + with pytest.raises(RuntimeError, match='assemblyai transcription failed after 2 attempts'): + provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.background, + uid='uid-1', + language='multi', + model='nova-3', + ) + + assert assemblyai_provider.transcribe_url.call_count == 2 + finalize_run.assert_called_once() + assert finalize_run.call_args.kwargs['run_id'] == 'run-aai' + assert finalize_run.call_args.kwargs['provider'] == 'assemblyai' + assert finalize_run.call_args.kwargs['status'] == 'failed' + assert finalize_run.call_args.kwargs['retry_count'] == 1 + assert finalize_run.call_args.kwargs['fallback_count'] == 0 + assert finalize_run.call_args.kwargs['error_class'] == 'RuntimeError' + + +def test_provider_service_records_successful_after_retry(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync') + + fake_provider = MagicMock() + fake_provider.provider_name = STTProviderName.assemblyai + fake_provider.transcribe_url.side_effect = [ + RuntimeError('temporary AssemblyAI failure'), + (_provider_result(provider='assemblyai', model='universal-2'), 'en'), + ] + + with patch.object(provider_service, '_assemblyai_prerecorded_provider', return_value=fake_provider), patch.object( + provider_service, 'create_provider_run', return_value='run-aai' + ), patch.object(provider_service, 'finalize_provider_run') as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.sync, + uid='uid-1', + return_language=True, + language='multi', + model='nova-3', + ) + + assert response.result.provider == 'assemblyai' + assert fake_provider.transcribe_url.call_count == 2 + finalize_run.assert_called_once() + assert finalize_run.call_args.kwargs['status'] == 'succeeded' + assert finalize_run.call_args.kwargs['retry_count'] == 1 + assert finalize_run.call_args.kwargs['fallback_count'] == 0 + + def test_provider_service_records_zero_cost_for_zero_duration_success(monkeypatch): monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync') diff --git a/backend/tests/unit/test_transcription_provider_usage.py b/backend/tests/unit/test_transcription_provider_usage.py index d3ea4797171..3ad0d3edc5d 100644 --- a/backend/tests/unit/test_transcription_provider_usage.py +++ b/backend/tests/unit/test_transcription_provider_usage.py @@ -139,6 +139,8 @@ def test_create_and_finalize_provider_run_writes_ledger_rollup_and_metrics(monke finalized = run_doc.set_calls[1]['data'] assert finalized['status'] == 'success' assert finalized['timing']['latency_ms'] == 5000 + assert finalized['retry_count'] == 1 + assert finalized['fallback'] is None assert 'transcript_text' not in finalized assert 'words' not in finalized @@ -150,6 +152,7 @@ def test_create_and_finalize_provider_run_writes_ledger_rollup_and_metrics(monke assert rollup['identity_confidence_counts.high'] == {'__increment': 2} assert emitted[0]['latency_seconds'] == 5.0 assert emitted[0]['billable_seconds'] == 60.0 + assert emitted[0]['retry_count'] == 1 def test_rejects_transcript_text_and_chunk_payloads(): @@ -262,6 +265,122 @@ def test_fallback_metric_records_failed_provider_to_fallback_provider(monkeypatc assert observed == [(('assemblyai', 'deepgram', 'sync', 'provider_failure', 1), {})] +def test_fallback_ledger_rollup_and_metrics_share_provider_direction(monkeypatch): + fake_db = _FakeDb() + monkeypatch.setattr(usage, 'db', fake_db) + monkeypatch.setattr(usage.firestore, 'Increment', _inc) + observed_fallbacks = [] + observed_retries = [] + monkeypatch.setattr(usage, 'observe_transcription_provider_request', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_audio_seconds', lambda *args, **kwargs: None) + monkeypatch.setattr( + usage, + 'observe_transcription_provider_retry', + lambda *args, **kwargs: observed_retries.append((args, kwargs)) if args[4] > 0 else None, + ) + monkeypatch.setattr(usage, 'observe_transcription_provider_speaker_clusters', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_identity_confidence', lambda *args, **kwargs: None) + monkeypatch.setattr( + usage, + 'observe_transcription_provider_fallback', + lambda *args, **kwargs: observed_fallbacks.append((args, kwargs)), + ) + + started_at = datetime(2026, 5, 21, 1, 0, 0, tzinfo=timezone.utc) + completed_at = datetime(2026, 5, 21, 1, 0, 2, tzinfo=timezone.utc) + usage.create_provider_run( + uid='user-1', + provider='deepgram', + model='nova-3', + workload='sync', + run_id='run-fallback', + started_at=started_at, + ) + usage.finalize_provider_run( + run_id='run-fallback', + provider='deepgram', + model='nova-3', + workload='sync', + status='succeeded', + started_at=started_at, + completed_at=completed_at, + raw_audio_seconds=2.0, + speech_active_seconds=2.0, + billable_seconds=2.0, + retry_count=0, + fallback_count=1, + fallback_provider='assemblyai', + fallback_reason='provider_failure', + ) + + finalized = fake_db.docs[(usage.RUNS_COLLECTION, 'run-fallback')].set_calls[1]['data'] + assert finalized['fallback'] == { + 'from_provider': 'assemblyai', + 'to_provider': 'deepgram', + 'reason': 'provider_failure', + } + rollup = fake_db.docs[(usage.DAILY_USAGE_COLLECTION, '2026-05-21:deepgram:nova-3:sync')].set_calls[0]['data'] + assert rollup['fallback_count'] == {'__increment': 1} + assert observed_fallbacks == [(('assemblyai', 'deepgram', 'sync', 'provider_failure', 1), {})] + assert observed_retries == [] + + +def test_failed_run_retry_count_rolls_up_and_emits_retry_metric(monkeypatch): + fake_db = _FakeDb() + monkeypatch.setattr(usage, 'db', fake_db) + monkeypatch.setattr(usage.firestore, 'Increment', _inc) + observed_retries = [] + observed_fallbacks = [] + monkeypatch.setattr(usage, 'observe_transcription_provider_request', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_audio_seconds', lambda *args, **kwargs: None) + monkeypatch.setattr( + usage, + 'observe_transcription_provider_retry', + lambda *args, **kwargs: observed_retries.append((args, kwargs)) if args[4] > 0 else None, + ) + monkeypatch.setattr(usage, 'observe_transcription_provider_speaker_clusters', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_identity_confidence', lambda *args, **kwargs: None) + monkeypatch.setattr( + usage, + 'observe_transcription_provider_fallback', + lambda *args, **kwargs: observed_fallbacks.append((args, kwargs)), + ) + + started_at = datetime(2026, 5, 21, 1, 0, 0, tzinfo=timezone.utc) + completed_at = datetime(2026, 5, 21, 1, 0, 2, tzinfo=timezone.utc) + usage.create_provider_run( + uid='user-1', + provider='assemblyai', + model='universal-2', + workload='background', + run_id='run-failed', + started_at=started_at, + ) + usage.finalize_provider_run( + run_id='run-failed', + provider='assemblyai', + model='universal-2', + workload='background', + status='failed', + started_at=started_at, + completed_at=completed_at, + raw_audio_seconds=2.0, + retry_count=1, + error_class='RuntimeError', + ) + + finalized = fake_db.docs[(usage.RUNS_COLLECTION, 'run-failed')].set_calls[1]['data'] + assert finalized['retry_count'] == 1 + assert finalized['fallback'] is None + rollup = fake_db.docs[(usage.DAILY_USAGE_COLLECTION, '2026-05-21:assemblyai:universal-2:background')].set_calls[0][ + 'data' + ] + assert rollup['retry_count'] == {'__increment': 1} + assert rollup['status_counts.failed'] == {'__increment': 1} + assert observed_retries == [(('assemblyai', 'universal-2', 'background', 'provider_retry', 1), {})] + assert observed_fallbacks == [] + + def test_provider_metrics_source_does_not_define_forbidden_label_names(): forbidden_labels = { "['provider', 'model', 'workload', 'user_id']", diff --git a/backend/utils/stt/provider_service.py b/backend/utils/stt/provider_service.py index 5ebd06970fa..3cc16f6500a 100644 --- a/backend/utils/stt/provider_service.py +++ b/backend/utils/stt/provider_service.py @@ -76,6 +76,13 @@ class PrerecordedTranscriptionResponse: run_id: Optional[str] +class ProviderTranscriptionRetriesExhausted(RuntimeError): + def __init__(self, provider_error: Exception, retry_count: int): + super().__init__(str(provider_error)) + self.provider_error = provider_error + self.retry_count = retry_count + + def resolve_prerecorded_language_model(language: Optional[str]) -> Tuple[str, str]: return get_deepgram_model_for_language(language or 'multi') @@ -132,7 +139,16 @@ def transcribe_url( run_id=run_id, ) except Exception as e: - _finalize_failed_run(run_id, provider_name.value, model, workload.value, started_at, e, raw_audio_seconds) + _finalize_failed_run( + run_id, + provider_name.value, + model, + workload.value, + started_at, + _provider_error_from_exception(e), + raw_audio_seconds, + retry_count=_retry_count_from_exception(e), + ) fallback_provider_name = get_fallback_prerecorded_provider_name(provider_name, workload) if fallback_provider_name: logger.warning( @@ -217,7 +233,16 @@ def transcribe_bytes( run_id=run_id, ) except Exception as e: - _finalize_failed_run(run_id, provider_name.value, model, workload.value, started_at, e, raw_audio_seconds) + _finalize_failed_run( + run_id, + provider_name.value, + model, + workload.value, + started_at, + _provider_error_from_exception(e), + raw_audio_seconds, + retry_count=_retry_count_from_exception(e), + ) fallback_provider_name = get_fallback_prerecorded_provider_name(provider_name, workload) if fallback_provider_name: logger.warning( @@ -306,7 +331,16 @@ def _transcribe_url_with_provider( detected_language=detected_language, ) except Exception as e: - _finalize_failed_run(run_id, provider_name.value, model, workload.value, started_at, e, raw_audio_seconds) + _finalize_failed_run( + run_id, + provider_name.value, + model, + workload.value, + started_at, + _provider_error_from_exception(e), + raw_audio_seconds, + retry_count=_retry_count_from_exception(e), + ) raise RuntimeError(f'{provider_name.value} transcription failed after 2 attempts: {e}') @@ -356,7 +390,16 @@ def _transcribe_bytes_with_provider( detected_language=detected_language, ) except Exception as e: - _finalize_failed_run(run_id, provider_name.value, model, workload.value, started_at, e, raw_audio_seconds) + _finalize_failed_run( + run_id, + provider_name.value, + model, + workload.value, + started_at, + _provider_error_from_exception(e), + raw_audio_seconds, + retry_count=_retry_count_from_exception(e), + ) raise RuntimeError(f'{provider_name.value} transcription failed after 2 attempts: {e}') @@ -398,7 +441,8 @@ def _transcribe_url_with_retry( provider, audio_url: str, **kwargs ) -> Tuple[ProviderTranscriptResult, Optional[str], int]: last_error = None - for attempt in range(2): + max_attempts = 2 + for attempt in range(max_attempts): try: result = provider.transcribe_url(audio_url, **kwargs) transcript_result, detected_language = _unpack_provider_result(result, kwargs.get('return_language')) @@ -411,14 +455,15 @@ def _transcribe_url_with_retry( provider.provider_name, e, ) - raise last_error + raise ProviderTranscriptionRetriesExhausted(last_error, max_attempts - 1) def _transcribe_bytes_with_retry( provider, audio_bytes: bytes, **kwargs ) -> Tuple[ProviderTranscriptResult, Optional[str], int]: last_error = None - for attempt in range(2): + max_attempts = 2 + for attempt in range(max_attempts): try: result = provider.transcribe_bytes(audio_bytes, **kwargs) transcript_result, detected_language = _unpack_provider_result(result, kwargs.get('return_language')) @@ -431,7 +476,19 @@ def _transcribe_bytes_with_retry( provider.provider_name, e, ) - raise last_error + raise ProviderTranscriptionRetriesExhausted(last_error, max_attempts - 1) + + +def _retry_count_from_exception(error: Exception) -> int: + if isinstance(error, ProviderTranscriptionRetriesExhausted): + return error.retry_count + return 0 + + +def _provider_error_from_exception(error: Exception) -> Exception: + if isinstance(error, ProviderTranscriptionRetriesExhausted): + return error.provider_error + return error def _unpack_provider_result(result, return_language: bool) -> Tuple[ProviderTranscriptResult, Optional[str]]: @@ -532,6 +589,7 @@ def _finalize_failed_run( started_at: datetime, error: Exception, raw_audio_seconds: float, + retry_count: int = 0, ) -> None: if not run_id: return @@ -544,6 +602,8 @@ def _finalize_failed_run( status='failed', started_at=started_at, raw_audio_seconds=raw_audio_seconds, + retry_count=retry_count, + fallback_count=0, error_class=error.__class__.__name__, ) except Exception as finalize_error: From 77764e053552806ae8ed7e0eeb8a61fc37ee6e64 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 09:23:54 +0700 Subject: [PATCH 13/44] Add self voice review queue backend --- backend/database/self_voice_review.py | 224 ++++++++++++ backend/tests/unit/test_self_voice_review.py | 358 +++++++++++++++++++ backend/utils/self_voice_review.py | 254 +++++++++++++ 3 files changed, 836 insertions(+) create mode 100644 backend/database/self_voice_review.py create mode 100644 backend/tests/unit/test_self_voice_review.py create mode 100644 backend/utils/self_voice_review.py diff --git a/backend/database/self_voice_review.py b/backend/database/self_voice_review.py new file mode 100644 index 00000000000..cd4f7696737 --- /dev/null +++ b/backend/database/self_voice_review.py @@ -0,0 +1,224 @@ +from datetime import datetime, timedelta, timezone +from typing import Any, Optional + +from google.cloud import firestore +from google.cloud.firestore_v1 import FieldFilter + +from ._client import db, document_id_from_seed + +CANDIDATES_COLLECTION = 'self_voice_review_candidates' +NEGATIVE_MARKERS_COLLECTION = 'self_voice_review_negative_markers' +CONFIRMED_SAMPLES_COLLECTION = 'self_voice_review_confirmed_samples' +DEFAULT_CANDIDATE_TTL_DAYS = 30 +FORBIDDEN_CANDIDATE_KEYS = {'text', 'transcript', 'transcript_text', 'words', 'utterances', 'audio_bytes', 'raw_audio'} + + +def utc_now() -> datetime: + return datetime.now(timezone.utc) + + +def candidate_id_from_source(conversation_id: str, provider_cluster_id: str, segment_ids: list[str]) -> str: + seed = ':'.join(['self_voice_review', conversation_id, provider_cluster_id, ','.join(sorted(segment_ids))]) + return document_id_from_seed(seed) + + +def marker_id_from_source(conversation_id: str, provider_cluster_id: str) -> str: + return document_id_from_seed(':'.join(['self_voice_negative_marker', conversation_id, provider_cluster_id])) + + +def _user_ref(uid: str): + return db.collection('users').document(uid) + + +def _candidate_ref(uid: str, candidate_id: str): + return _user_ref(uid).collection(CANDIDATES_COLLECTION).document(candidate_id) + + +def _negative_marker_ref(uid: str, marker_id: str): + return _user_ref(uid).collection(NEGATIVE_MARKERS_COLLECTION).document(marker_id) + + +def _confirmed_sample_ref(uid: str, sample_id: str): + return _user_ref(uid).collection(CONFIRMED_SAMPLES_COLLECTION).document(sample_id) + + +def get_candidate(uid: str, candidate_id: str) -> Optional[dict[str, Any]]: + snapshot = _candidate_ref(uid, candidate_id).get() + if not snapshot.exists: + return None + data = snapshot.to_dict() or {} + data.setdefault('candidate_id', snapshot.id) + return data + + +def has_negative_marker(uid: str, marker_id: str) -> bool: + return _negative_marker_ref(uid, marker_id).get().exists + + +def recently_shown_source_exists( + uid: str, + conversation_id: str, + provider_cluster_id: str, + now: Optional[datetime] = None, +) -> bool: + now = now or utc_now() + query = ( + _user_ref(uid) + .collection(CANDIDATES_COLLECTION) + .where(filter=FieldFilter('source.conversation_id', '==', conversation_id)) + .where(filter=FieldFilter('source.provider_cluster_id', '==', provider_cluster_id)) + .where(filter=FieldFilter('cooldown_until', '>', now)) + .limit(1) + ) + return bool(list(query.stream())) + + +def upsert_candidate(uid: str, candidate: dict[str, Any]) -> bool: + candidate_id = candidate['candidate_id'] + ref = _candidate_ref(uid, candidate_id) + if ref.get().exists: + return False + + now = candidate.get('created_at') or utc_now() + payload = { + **candidate, + 'uid': uid, + 'review_status': candidate.get('review_status', 'pending'), + 'created_at': now, + 'updated_at': now, + 'expires_at': candidate.get('expires_at') or now + timedelta(days=DEFAULT_CANDIDATE_TTL_DAYS), + } + _reject_forbidden_candidate_keys(payload) + ref.set(payload, merge=False) + return True + + +def list_pending_candidates(uid: str, limit: int = 20, confidence_bucket: Optional[str] = None) -> list[dict[str, Any]]: + query = _user_ref(uid).collection(CANDIDATES_COLLECTION).where(filter=FieldFilter('review_status', '==', 'pending')) + if confidence_bucket: + query = query.where(filter=FieldFilter('confidence_bucket', '==', confidence_bucket)) + query = query.order_by('created_at', direction=firestore.Query.DESCENDING).limit(limit) + + candidates = [] + for snapshot in query.stream(): + data = snapshot.to_dict() or {} + data.setdefault('candidate_id', snapshot.id) + candidates.append(data) + return candidates + + +def mark_candidate_confirmed( + uid: str, + candidate_id: str, + embedding_version: str, + reviewed_at: Optional[datetime] = None, +) -> None: + reviewed_at = reviewed_at or utc_now() + _candidate_ref(uid, candidate_id).update( + { + 'review_status': 'confirmed', + 'reviewed_at': reviewed_at, + 'negative_review_marker': None, + 'updated_at': reviewed_at, + 'confirmed_sample': { + 'candidate_id': candidate_id, + 'embedding_version': embedding_version, + 'confirmed_at': reviewed_at, + 'revisable': True, + }, + } + ) + _confirmed_sample_ref(uid, candidate_id).set( + { + 'candidate_id': candidate_id, + 'source': 'self_voice_review', + 'embedding_version': embedding_version, + 'confirmed_at': reviewed_at, + 'deleted_at': None, + }, + merge=True, + ) + + +def mark_candidate_rejected( + uid: str, + candidate: dict[str, Any], + reviewed_at: Optional[datetime] = None, +) -> str: + reviewed_at = reviewed_at or utc_now() + source = candidate.get('source') or {} + marker_id = marker_id_from_source(source.get('conversation_id', ''), source.get('provider_cluster_id', '')) + marker = { + 'marker_id': marker_id, + 'candidate_id': candidate['candidate_id'], + 'conversation_id': source.get('conversation_id'), + 'provider_cluster_id': source.get('provider_cluster_id'), + 'segment_ids': source.get('segment_ids', []), + 'negative_review': True, + 'reviewed_at': reviewed_at, + } + _negative_marker_ref(uid, marker_id).set(marker, merge=True) + _candidate_ref(uid, candidate['candidate_id']).update( + { + 'review_status': 'rejected', + 'reviewed_at': reviewed_at, + 'negative_review_marker': marker, + 'updated_at': reviewed_at, + } + ) + return marker_id + + +def mark_candidate_skipped( + uid: str, + candidate_id: str, + cooldown_until: datetime, + reviewed_at: Optional[datetime] = None, +) -> None: + reviewed_at = reviewed_at or utc_now() + _candidate_ref(uid, candidate_id).update( + { + 'review_status': 'pending', + 'last_review_action': 'skipped', + 'reviewed_at': reviewed_at, + 'cooldown_until': cooldown_until, + 'updated_at': reviewed_at, + } + ) + + +def delete_confirmed_sample(uid: str, candidate_id: str, deleted_at: Optional[datetime] = None) -> bool: + deleted_at = deleted_at or utc_now() + candidate = get_candidate(uid, candidate_id) + if not candidate or candidate.get('review_status') != 'confirmed': + return False + _candidate_ref(uid, candidate_id).update( + { + 'review_status': 'deleted', + 'reviewed_at': deleted_at, + 'updated_at': deleted_at, + 'confirmed_sample.deleted_at': deleted_at, + } + ) + _confirmed_sample_ref(uid, candidate_id).set({'deleted_at': deleted_at}, merge=True) + return True + + +def _reject_forbidden_candidate_keys(payload: dict[str, Any]) -> None: + forbidden = _find_forbidden_candidate_keys(payload) + if forbidden: + raise ValueError(f'self voice review candidate contains forbidden keys: {sorted(forbidden)}') + + +def _find_forbidden_candidate_keys(value: Any) -> set[str]: + if isinstance(value, dict): + forbidden = FORBIDDEN_CANDIDATE_KEYS & set(value) + for nested in value.values(): + forbidden.update(_find_forbidden_candidate_keys(nested)) + return forbidden + if isinstance(value, list): + forbidden = set() + for nested in value: + forbidden.update(_find_forbidden_candidate_keys(nested)) + return forbidden + return set() diff --git a/backend/tests/unit/test_self_voice_review.py b/backend/tests/unit/test_self_voice_review.py new file mode 100644 index 00000000000..f5e56d90503 --- /dev/null +++ b/backend/tests/unit/test_self_voice_review.py @@ -0,0 +1,358 @@ +import io +import sys +import wave +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import numpy as np +import pytest + +sys.modules.setdefault('database._client', MagicMock()) +sys.modules.setdefault('database.users', MagicMock()) +speaker_embedding_mod = MagicMock() +speaker_embedding_mod.extract_embedding_from_bytes = MagicMock() +sys.modules.setdefault('utils.stt.speaker_embedding', speaker_embedding_mod) +firestore_v1_mod = sys.modules.setdefault('google.cloud.firestore_v1', MagicMock()) +firestore_v1_mod.FieldFilter = MagicMock() + +firestore_mod = MagicMock() +firestore_mod.Query.DESCENDING = 'DESCENDING' +sys.modules.setdefault('google.cloud.firestore', firestore_mod) + +google_cloud_mod = sys.modules.setdefault('google.cloud', MagicMock()) +google_cloud_mod.firestore = firestore_mod + +from models.transcript_segment import TranscriptSegment +from utils.stt.background_speaker_identity import ClusterIdentityAssignment, SPEAKER_IDENTITY_SOURCE +from utils.self_voice_review import ( + SegmentQuality, + build_self_voice_review_candidate, + confirm_self_voice_candidate, + delete_confirmed_self_voice_sample, + reject_self_voice_candidate, + skip_self_voice_candidate, +) + + +def _segment(segment_id, start, end, cluster='cluster-a', text='hello this is a clean sentence'): + return TranscriptSegment( + id=segment_id, + text=text, + speaker='SPEAKER_01', + is_user=False, + start=start, + end=end, + provider_cluster_id=cluster, + provider_speaker_label='SPEAKER_01', + speaker_identity_state='unknown', + ) + + +def _user_assignment(confidence=0.82): + return ClusterIdentityAssignment( + provider_cluster_id='cluster-a', + speaker_id=1, + state='user', + is_user=True, + confidence=confidence, + distance=0.1, + source=SPEAKER_IDENTITY_SOURCE, + ) + + +def _wav_bytes(duration_seconds=6, sample_rate=16000): + samples = np.zeros(int(duration_seconds * sample_rate), dtype=np.int16) + out = io.BytesIO() + with wave.open(out, 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(samples.tobytes()) + return out.getvalue() + + +class FakeReviewDb: + DEFAULT_CANDIDATE_TTL_DAYS = 30 + + def __init__(self): + self.candidates = {} + self.negative_markers = set() + self.confirmed = [] + self.rejected = [] + self.skipped = [] + self.deleted = [] + + def candidate_id_from_source(self, conversation_id, provider_cluster_id, segment_ids): + return f'{conversation_id}:{provider_cluster_id}:{"-".join(segment_ids)}' + + def marker_id_from_source(self, conversation_id, provider_cluster_id): + return f'{conversation_id}:{provider_cluster_id}' + + def get_candidate(self, _uid, candidate_id): + return self.candidates.get(candidate_id) + + def has_negative_marker(self, _uid, marker_id): + return marker_id in self.negative_markers + + def recently_shown_source_exists(self, *_args, **_kwargs): + return False + + def upsert_candidate(self, _uid, candidate): + self.candidates[candidate['candidate_id']] = candidate + return True + + def mark_candidate_confirmed(self, _uid, candidate_id, embedding_version): + self.candidates[candidate_id]['review_status'] = 'confirmed' + self.candidates[candidate_id]['confirmed_sample'] = { + 'candidate_id': candidate_id, + 'embedding_version': embedding_version, + 'revisable': True, + } + self.confirmed.append((candidate_id, embedding_version)) + + def mark_candidate_rejected(self, _uid, candidate): + marker_id = self.marker_id_from_source( + candidate['source']['conversation_id'], candidate['source']['provider_cluster_id'] + ) + self.negative_markers.add(marker_id) + self.candidates[candidate['candidate_id']]['review_status'] = 'rejected' + self.candidates[candidate['candidate_id']]['negative_review_marker'] = {'marker_id': marker_id} + self.rejected.append(candidate['candidate_id']) + return marker_id + + def mark_candidate_skipped(self, _uid, candidate_id, cooldown_until, reviewed_at=None): + self.candidates[candidate_id]['review_status'] = 'pending' + self.candidates[candidate_id]['last_review_action'] = 'skipped' + self.candidates[candidate_id]['cooldown_until'] = cooldown_until + self.skipped.append((candidate_id, cooldown_until, reviewed_at)) + + def delete_confirmed_sample(self, _uid, candidate_id): + if self.candidates[candidate_id]['review_status'] != 'confirmed': + return False + self.candidates[candidate_id]['review_status'] = 'deleted' + self.deleted.append(candidate_id) + return True + + +def test_candidate_creation_stores_state_without_transcript_text(monkeypatch): + from utils import self_voice_review + + fake_db = FakeReviewDb() + monkeypatch.setattr(self_voice_review, 'review_db', fake_db) + segments = [_segment('s1', 0.0, 6.0)] + + result = build_self_voice_review_candidate( + uid='uid', + conversation_id='conv', + provider_cluster_id='cluster-a', + cluster_segments=segments, + all_segments=segments, + identity_assignment=_user_assignment(), + quality_by_segment_id={'s1': SegmentQuality(voiced_seconds=5.5, vad_confidence=0.9, noise_score=0.1)}, + audio_artifact_ref='gs://bucket/clip.wav', + audio_retention_allowed=True, + now=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + + assert result.candidate is not None + candidate = result.candidate + assert candidate['candidate_id'] == 'conv:cluster-a:s1' + assert candidate['confidence_bucket'] == 'high' + assert candidate['quality_scores']['sample_seconds'] == 6.0 + assert candidate['quality_scores']['voiced_ratio'] == pytest.approx(0.917) + assert candidate['retention']['transcript_text_stored'] is False + assert 'text' not in candidate['source'] + assert 'transcript' not in candidate['source'] + assert fake_db.candidates[candidate['candidate_id']]['review_status'] == 'pending' + + +def test_candidate_quality_filters_overlap_short_low_vad_noise_and_retention(monkeypatch): + from utils import self_voice_review + + fake_db = FakeReviewDb() + monkeypatch.setattr(self_voice_review, 'review_db', fake_db) + + clean = [_segment('s1', 0.0, 6.0)] + no_audio = build_self_voice_review_candidate( + 'uid', 'conv', 'cluster-a', clean, clean, _user_assignment(), audio_retention_allowed=False + ) + assert no_audio.reason == 'audio_unavailable_or_retention_disallowed' + + short = [_segment('s2', 0.0, 3.0)] + short_result = build_self_voice_review_candidate( + 'uid', + 'conv', + 'cluster-a', + short, + short, + _user_assignment(), + audio_artifact_ref='ref', + audio_retention_allowed=True, + ) + assert short_result.reason == 'sample_too_short' + + overlap = [_segment('s3', 0.0, 6.0)] + all_segments = overlap + [_segment('other', 2.0, 3.0, cluster='cluster-b')] + overlap_result = build_self_voice_review_candidate( + 'uid', + 'conv', + 'cluster-a', + overlap, + all_segments, + _user_assignment(), + audio_artifact_ref='ref', + audio_retention_allowed=True, + ) + assert overlap_result.reason == 'no_clean_voiced_window' + + low_vad = build_self_voice_review_candidate( + 'uid', + 'conv', + 'cluster-low-vad', + [_segment('s4', 0.0, 6.0, cluster='cluster-low-vad')], + [_segment('s4', 0.0, 6.0, cluster='cluster-low-vad')], + _user_assignment(), + {'s4': {'voiced_seconds': 5.5, 'vad_confidence': 0.5}}, + audio_artifact_ref='ref', + audio_retention_allowed=True, + ) + assert low_vad.reason == 'low_vad_confidence' + + noisy = build_self_voice_review_candidate( + 'uid', + 'conv', + 'cluster-noisy', + [_segment('s5', 0.0, 6.0, cluster='cluster-noisy')], + [_segment('s5', 0.0, 6.0, cluster='cluster-noisy')], + _user_assignment(), + {'s5': {'voiced_seconds': 5.5, 'vad_confidence': 0.9, 'noise_score': 0.8}}, + audio_artifact_ref='ref', + audio_retention_allowed=True, + ) + assert noisy.reason == 'noisy_clip' + + +def test_dedupe_reject_marker_and_recently_shown_suppress_candidates(monkeypatch): + from utils import self_voice_review + + fake_db = FakeReviewDb() + monkeypatch.setattr(self_voice_review, 'review_db', fake_db) + segments = [_segment('s1', 0.0, 6.0)] + kwargs = { + 'uid': 'uid', + 'conversation_id': 'conv', + 'provider_cluster_id': 'cluster-a', + 'cluster_segments': segments, + 'all_segments': segments, + 'identity_assignment': _user_assignment(), + 'quality_by_segment_id': {'s1': {'voiced_seconds': 5.5, 'vad_confidence': 0.9}}, + 'audio_artifact_ref': 'ref', + 'audio_retention_allowed': True, + } + + assert build_self_voice_review_candidate(**kwargs).candidate is not None + assert build_self_voice_review_candidate(**kwargs).reason == 'already_confirmed_or_pending' + + fake_db.candidates.clear() + fake_db.negative_markers.add('conv:cluster-a') + assert build_self_voice_review_candidate(**kwargs).reason == 'rejected_source' + + fake_db.negative_markers.clear() + fake_db.recently_shown_source_exists = lambda *_args, **_kwargs: True + assert build_self_voice_review_candidate(**kwargs).reason == 'recently_shown' + + +def test_confirm_reject_skip_and_delete_actions(monkeypatch): + from utils import self_voice_review + + fake_db = FakeReviewDb() + fake_users_db = MagicMock() + monkeypatch.setattr(self_voice_review, 'review_db', fake_db) + monkeypatch.setattr(self_voice_review, 'users_db', fake_users_db) + segments = [_segment('s1', 0.0, 6.0)] + candidate = build_self_voice_review_candidate( + 'uid', + 'conv', + 'cluster-a', + segments, + segments, + _user_assignment(), + {'s1': {'voiced_seconds': 5.5, 'vad_confidence': 0.9}}, + audio_artifact_ref='ref', + audio_retention_allowed=True, + ).candidate + + skipped = skip_self_voice_candidate('uid', candidate['candidate_id'], now=datetime(2026, 1, 1, tzinfo=timezone.utc)) + assert skipped['review_status'] == 'pending' + assert skipped['last_review_action'] == 'skipped' + assert fake_db.candidates[candidate['candidate_id']]['cooldown_until'].year == 2026 + + embedding = np.array([[1.0, 2.0, 3.0]], dtype=np.float32) + confirmed = confirm_self_voice_candidate('uid', candidate['candidate_id'], embedding=embedding) + assert confirmed['review_status'] == 'confirmed' + fake_users_db.set_user_speaker_embedding.assert_called_once_with('uid', [1.0, 2.0, 3.0]) + assert fake_db.candidates[candidate['candidate_id']]['confirmed_sample']['revisable'] is True + assert delete_confirmed_self_voice_sample('uid', candidate['candidate_id']) is True + + second = build_self_voice_review_candidate( + 'uid', + 'conv', + 'cluster-b', + [_segment('s2', 10.0, 16.0, cluster='cluster-b')], + [_segment('s2', 10.0, 16.0, cluster='cluster-b')], + _user_assignment(), + {'s2': {'voiced_seconds': 5.5, 'vad_confidence': 0.9}}, + audio_artifact_ref='ref', + audio_retention_allowed=True, + ).candidate + rejected = reject_self_voice_candidate('uid', second['candidate_id']) + assert rejected['review_status'] == 'rejected' + assert rejected['negative_marker_id'] == 'conv:cluster-b' + assert fake_db.candidates[second['candidate_id']]['negative_review_marker']['marker_id'] == 'conv:cluster-b' + + +def test_confirm_can_extract_embedding_from_audio(monkeypatch): + from utils import self_voice_review + + fake_db = FakeReviewDb() + fake_users_db = MagicMock() + monkeypatch.setattr(self_voice_review, 'review_db', fake_db) + monkeypatch.setattr(self_voice_review, 'users_db', fake_users_db) + segments = [_segment('s1', 0.0, 6.0)] + candidate = build_self_voice_review_candidate( + 'uid', + 'conv', + 'cluster-a', + segments, + segments, + _user_assignment(), + {'s1': {'voiced_seconds': 5.5, 'vad_confidence': 0.9}}, + audio_artifact_ref='ref', + audio_retention_allowed=True, + ).candidate + + def fake_extractor(audio_bytes, filename): + assert audio_bytes == _wav_bytes() + assert filename.endswith('.wav') + return np.array([[4.0, 5.0]], dtype=np.float32) + + confirm_self_voice_candidate( + 'uid', candidate['candidate_id'], audio_bytes=_wav_bytes(), embedding_extractor=fake_extractor + ) + + fake_users_db.set_user_speaker_embedding.assert_called_once_with('uid', [4.0, 5.0]) + + +def test_candidate_storage_rejects_transcript_and_raw_audio_payloads(): + from database.self_voice_review import _reject_forbidden_candidate_keys + + with pytest.raises(ValueError, match='forbidden keys'): + _reject_forbidden_candidate_keys( + { + 'candidate_id': 'candidate', + 'source': {'conversation_id': 'conv', 'transcript_text': 'this must not be stored'}, + } + ) + + with pytest.raises(ValueError, match='forbidden keys'): + _reject_forbidden_candidate_keys({'candidate_id': 'candidate', 'raw_audio': b'not allowed'}) diff --git a/backend/utils/self_voice_review.py b/backend/utils/self_voice_review.py new file mode 100644 index 00000000000..12b65d00487 --- /dev/null +++ b/backend/utils/self_voice_review.py @@ -0,0 +1,254 @@ +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Callable, Optional, Union + +import numpy as np + +from database import self_voice_review as review_db +from database import users as users_db +from models.transcript_segment import TranscriptSegment +from utils.stt.background_speaker_identity import ( + ClusterIdentityAssignment, + ClusterSampleSpan, + USER_SELF_PERSON_ID, + select_representative_cluster_spans, +) +from utils.stt.speaker_embedding import extract_embedding_from_bytes + +MIN_REVIEW_SAMPLE_SECONDS = 5.0 +MAX_REVIEW_SAMPLE_SECONDS = 10.0 +MIN_VOICED_RATIO = 0.65 +MIN_VAD_CONFIDENCE = 0.75 +MAX_NOISE_SCORE = 0.35 +SKIP_COOLDOWN_DAYS = 14 +SELF_VOICE_REVIEW_VERSION = 'self-voice-review:v1' +SELF_VOICE_EMBEDDING_SOURCE = 'self_voice_review_confirm' + + +@dataclass +class SegmentQuality: + voiced_seconds: Optional[float] = None + vad_confidence: Optional[float] = None + noise_score: Optional[float] = None + + +@dataclass +class CandidateSelectionResult: + candidate: Optional[dict] = None + reason: Optional[str] = None + spans: list[ClusterSampleSpan] = field(default_factory=list) + + +def build_self_voice_review_candidate( + uid: str, + conversation_id: str, + provider_cluster_id: str, + cluster_segments: list[TranscriptSegment], + all_segments: list[TranscriptSegment], + identity_assignment: Optional[ClusterIdentityAssignment] = None, + quality_by_segment_id: Optional[dict[str, Union[SegmentQuality, dict]]] = None, + audio_artifact_ref: Optional[str] = None, + audio_retention_allowed: bool = False, + now: Optional[datetime] = None, +) -> CandidateSelectionResult: + now = now or _utc_now() + quality_by_segment_id = quality_by_segment_id or {} + + if not audio_retention_allowed or not audio_artifact_ref: + return CandidateSelectionResult(reason='audio_unavailable_or_retention_disallowed') + + spans = select_representative_cluster_spans( + cluster_segments, + all_segments, + preferred_seconds=MIN_REVIEW_SAMPLE_SECONDS, + max_seconds=MAX_REVIEW_SAMPLE_SECONDS, + ) + quality = _score_candidate_quality(spans, quality_by_segment_id) + if not spans: + return CandidateSelectionResult(reason='no_clean_voiced_window', spans=spans) + if quality['sample_seconds'] < MIN_REVIEW_SAMPLE_SECONDS: + return CandidateSelectionResult(reason='sample_too_short', spans=spans) + if quality['vad_confidence'] is not None and quality['vad_confidence'] < MIN_VAD_CONFIDENCE: + return CandidateSelectionResult(reason='low_vad_confidence', spans=spans) + if quality['voiced_ratio'] is not None and quality['voiced_ratio'] < MIN_VOICED_RATIO: + return CandidateSelectionResult(reason='low_voiced_ratio', spans=spans) + if quality['noise_score'] is not None and quality['noise_score'] > MAX_NOISE_SCORE: + return CandidateSelectionResult(reason='noisy_clip', spans=spans) + + segment_ids = [span.segment_id for span in spans] + candidate_id = review_db.candidate_id_from_source(conversation_id, provider_cluster_id, segment_ids) + negative_marker_id = review_db.marker_id_from_source(conversation_id, provider_cluster_id) + if review_db.get_candidate(uid, candidate_id): + return CandidateSelectionResult(reason='already_confirmed_or_pending', spans=spans) + if review_db.has_negative_marker(uid, negative_marker_id): + return CandidateSelectionResult(reason='rejected_source', spans=spans) + if review_db.recently_shown_source_exists(uid, conversation_id, provider_cluster_id, now=now): + return CandidateSelectionResult(reason='recently_shown', spans=spans) + + confidence_bucket = _confidence_bucket(identity_assignment, quality) + if confidence_bucket is None: + return CandidateSelectionResult(reason='not_self_voice_candidate', spans=spans) + + candidate = { + 'candidate_id': candidate_id, + 'source': { + 'conversation_id': conversation_id, + 'provider_cluster_id': provider_cluster_id, + 'segment_ids': segment_ids, + 'sample_spans': [span.as_dict() for span in spans], + }, + 'confidence_bucket': confidence_bucket, + 'quality_scores': quality, + 'review_status': 'pending', + 'reviewed_at': None, + 'cooldown_until': None, + 'retention': { + 'audio_artifact_ref': audio_artifact_ref, + 'audio_retention_allowed': True, + 'transcript_text_stored': False, + 'expires_at': now + timedelta(days=review_db.DEFAULT_CANDIDATE_TTL_DAYS), + }, + 'negative_review_marker': None, + 'created_at': now, + 'updated_at': now, + 'expires_at': now + timedelta(days=review_db.DEFAULT_CANDIDATE_TTL_DAYS), + 'version': SELF_VOICE_REVIEW_VERSION, + } + review_db.upsert_candidate(uid, candidate) + return CandidateSelectionResult(candidate=candidate, spans=spans) + + +def confirm_self_voice_candidate( + uid: str, + candidate_id: str, + audio_bytes: Optional[bytes] = None, + embedding: Optional[np.ndarray] = None, + embedding_extractor: Callable[[bytes, str], np.ndarray] = extract_embedding_from_bytes, +) -> dict: + candidate = _require_candidate(uid, candidate_id) + if candidate.get('review_status') not in ('pending', 'skipped'): + raise ValueError('candidate is not confirmable') + + retention = candidate.get('retention') or {} + if not retention.get('audio_retention_allowed') and embedding is None: + raise ValueError('candidate audio is not available for confirmation') + + if embedding is None: + if not audio_bytes: + raise ValueError('audio_bytes or embedding is required to confirm self voice') + embedding = embedding_extractor(audio_bytes, f'self_voice_review_{candidate_id}.wav') + + users_db.set_user_speaker_embedding(uid, embedding.reshape(1, -1).flatten().tolist()) + review_db.mark_candidate_confirmed(uid, candidate_id, SELF_VOICE_EMBEDDING_SOURCE) + return {'candidate_id': candidate_id, 'review_status': 'confirmed'} + + +def reject_self_voice_candidate(uid: str, candidate_id: str) -> dict: + candidate = _require_candidate(uid, candidate_id) + if candidate.get('review_status') == 'confirmed': + raise ValueError('confirmed candidate must be deleted instead of rejected') + marker_id = review_db.mark_candidate_rejected(uid, candidate) + return {'candidate_id': candidate_id, 'review_status': 'rejected', 'negative_marker_id': marker_id} + + +def skip_self_voice_candidate( + uid: str, + candidate_id: str, + now: Optional[datetime] = None, + cooldown_days: int = SKIP_COOLDOWN_DAYS, +) -> dict: + _require_candidate(uid, candidate_id) + now = now or _utc_now() + cooldown_until = now + timedelta(days=cooldown_days) + review_db.mark_candidate_skipped(uid, candidate_id, cooldown_until, reviewed_at=now) + return { + 'candidate_id': candidate_id, + 'review_status': 'pending', + 'last_review_action': 'skipped', + 'cooldown_until': cooldown_until, + } + + +def delete_confirmed_self_voice_sample(uid: str, candidate_id: str) -> bool: + return review_db.delete_confirmed_sample(uid, candidate_id) + + +def _require_candidate(uid: str, candidate_id: str) -> dict: + candidate = review_db.get_candidate(uid, candidate_id) + if not candidate: + raise ValueError('self voice review candidate not found') + return candidate + + +def _score_candidate_quality( + spans: list[ClusterSampleSpan], + quality_by_segment_id: dict[str, Union[SegmentQuality, dict]], +) -> dict: + sample_seconds = round(sum(span.duration for span in spans), 3) + voiced_seconds = 0.0 + vad_values = [] + noise_values = [] + has_voiced_signal = False + + for span in spans: + quality = _quality_for_segment(quality_by_segment_id.get(span.segment_id)) + if not quality: + continue + if quality.voiced_seconds is not None: + has_voiced_signal = True + voiced_seconds += min(max(quality.voiced_seconds, 0.0), span.duration) + if quality.vad_confidence is not None: + vad_values.append(quality.vad_confidence) + if quality.noise_score is not None: + noise_values.append(quality.noise_score) + + voiced_ratio = None + if has_voiced_signal and sample_seconds > 0: + voiced_ratio = round(voiced_seconds / sample_seconds, 3) + + return { + 'sample_seconds': sample_seconds, + 'voiced_ratio': voiced_ratio, + 'vad_confidence': round(sum(vad_values) / len(vad_values), 3) if vad_values else None, + 'noise_score': round(sum(noise_values) / len(noise_values), 3) if noise_values else None, + 'overlapped_speech': False, + } + + +def _quality_for_segment(value: Optional[Union[SegmentQuality, dict]]) -> Optional[SegmentQuality]: + if value is None: + return None + if isinstance(value, SegmentQuality): + return value + return SegmentQuality( + voiced_seconds=value.get('voiced_seconds'), + vad_confidence=value.get('vad_confidence'), + noise_score=value.get('noise_score'), + ) + + +def _confidence_bucket(assignment: Optional[ClusterIdentityAssignment], quality: dict) -> Optional[str]: + if not assignment: + return 'low' + if assignment.person_id: + return None + if assignment.state == 'user' and assignment.confidence is not None and assignment.confidence >= 0.7: + return 'high' + + for candidate in assignment.candidates: + if not candidate.get('is_user') and candidate.get('person_id') != USER_SELF_PERSON_ID: + continue + confidence = candidate.get('confidence') + distance = candidate.get('distance') + if confidence is not None and confidence >= 0.35: + return 'low' + if distance is not None and distance < 0.45: + return 'low' + + if quality['sample_seconds'] >= MIN_REVIEW_SAMPLE_SECONDS: + return 'low' + return None + + +def _utc_now() -> datetime: + return datetime.now(timezone.utc) From b57ccd83bc22ca62c6014eedac9980f1a1a82763 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 09:27:59 +0700 Subject: [PATCH 14/44] Document AssemblyAI background rollout readiness --- AGENTS.md | 8 +- CLAUDE.md | 3 +- .../backend/assemblyai_background_rollout.mdx | 99 +++++++++++++++++++ .../backend/listen_pusher_pipeline.mdx | 59 ++++++++++- docs/docs.json | 1 + 5 files changed, 162 insertions(+), 8 deletions(-) create mode 100644 docs/doc/developer/backend/assemblyai_background_rollout.mdx diff --git a/AGENTS.md b/AGENTS.md index 08c622bed51..b799174c736 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -49,7 +49,8 @@ backend (main.py) ├── ws ──► pusher (pusher/) ├── ──────► diarizer (diarizer/) ├── ──────► vad (modal/) - └── ──────► deepgram (self-hosted or cloud) + ├── ──────► deepgram (self-hosted or cloud) + └── ──────► assemblyai (cloud, background async when enabled) pusher ├── ──────► diarizer (diarizer/) @@ -63,12 +64,13 @@ notifications-job (modal/job.py) [cron] Helm charts: `backend/charts/{backend-listen,pusher,diarizer,vad,deepgram-self-hosted,agent-proxy}/` -- **backend** (`main.py`) — REST API. Streams audio to pusher via WebSocket (`utils/pusher.py`). Calls diarizer for speaker embeddings (`utils/stt/speaker_embedding.py`). Calls vad for voice activity detection and speaker identification (`utils/stt/vad.py`, `utils/stt/speech_profile.py`). Calls deepgram for STT (`utils/stt/streaming.py`). +- **backend** (`main.py`) — REST API. Streams realtime/listen audio to pusher via WebSocket (`utils/pusher.py`). Calls diarizer for speaker embeddings (`utils/stt/speaker_embedding.py`). Calls vad for voice activity detection and speaker identification (`utils/stt/vad.py`, `utils/stt/speech_profile.py`). Calls deepgram for realtime and Hold-to-Talk STT (`utils/stt/streaming.py`) and for prerecorded fallback. Calls AssemblyAI for explicitly enabled async/background prerecorded workloads through `utils/stt/provider_service.py`. - **pusher** (`pusher/main.py`) — Receives audio via binary WebSocket protocol. Calls diarizer and deepgram for speaker sample extraction (`utils/speaker_identification.py` → `utils/speaker_sample.py`). - **agent-proxy** (`agent-proxy/main.py`) — GKE. WebSocket proxy at `wss://agent.omi.me/v1/agent/ws`. Validates Firebase ID token, looks up `agentVm` in Firestore, proxies bidirectionally to VM's `ws://:8080/ws`. VM credentials never leave the server. - **diarizer** (`diarizer/main.py`) — GPU. Speaker embeddings at `/v2/embedding`. Called by backend and pusher (`HOSTED_SPEAKER_EMBEDDING_API_URL`). - **vad** (`modal/main.py`) — GPU. `/v1/vad` (voice activity detection) and `/v1/speaker-identification` (speaker matching). Called by backend only (`HOSTED_VAD_API_URL`, `HOSTED_SPEECH_PROFILE_API_URL`). -- **deepgram** — STT. Streaming uses self-hosted (`DEEPGRAM_SELF_HOSTED_URL`) or cloud based on `DEEPGRAM_SELF_HOSTED_ENABLED` (`utils/stt/streaming.py`). Pre-recorded always uses Deepgram cloud (`utils/stt/pre_recorded.py`). Called by backend and pusher. +- **deepgram** — STT. Streaming uses self-hosted (`DEEPGRAM_SELF_HOSTED_URL`) or cloud based on `DEEPGRAM_SELF_HOSTED_ENABLED` (`utils/stt/streaming.py`). Prerecorded Deepgram cloud remains the default and fallback provider through `utils/stt/provider_service.py`/`utils/stt/pre_recorded.py`. Called by backend and pusher. +- **assemblyai** — Async/background STT. Used only by backend for feature-flagged prerecorded workloads (`sync`, `background`, `postprocess`) through `utils/stt/assemblyai_adapter.py`; provider speaker labels remain session-local metadata. - **notifications-job** (`modal/job.py`) — Cron job, reads Firestore/Redis, sends push notifications. Keep this map up to date. When adding, removing, or changing inter-service calls, update this section and the matching section in `CLAUDE.md`. diff --git a/CLAUDE.md b/CLAUDE.md index 96bbcb9fc00..71571e05b67 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -52,7 +52,8 @@ backend (main.py) ├── ws ──► pusher (pusher/) ├── ──────► diarizer (diarizer/) ├── ──────► vad (modal/) - └── ──────► deepgram (self-hosted or cloud) + ├── ──────► deepgram (self-hosted or cloud) + └── ──────► assemblyai (cloud, background async when enabled) pusher ├── ──────► diarizer (diarizer/) diff --git a/docs/doc/developer/backend/assemblyai_background_rollout.mdx b/docs/doc/developer/backend/assemblyai_background_rollout.mdx new file mode 100644 index 00000000000..96ae52c3187 --- /dev/null +++ b/docs/doc/developer/backend/assemblyai_background_rollout.mdx @@ -0,0 +1,99 @@ +--- +title: "AssemblyAI Background Transcription Rollout" +icon: "waveform" +description: "Rollout gates, feature flags, instrumentation, and rollback for the AssemblyAI async background transcription MVP." +--- + +# AssemblyAI Background Transcription Rollout + +## Scope + +AssemblyAI is the MVP async/background prerecorded provider. Deepgram remains +the provider for `/v4/listen`, realtime assistant streaming, Hold-to-Talk +streaming, and voice-message finalize semantics. + +Eligible AssemblyAI workloads are: + +- `sync` +- `background` +- `postprocess` + +Ineligible latency-sensitive workloads stay on Deepgram: + +- `realtime` +- streaming `ptt` +- `voice_message` + +## Feature Flags And Environment + +AssemblyAI is disabled by default. + +| Variable | Default | Purpose | +| --- | --- | --- | +| `ASSEMBLYAI_API_KEY` | unset | Required before any AssemblyAI request can run. | +| `ASSEMBLYAI_BACKGROUND_STT_ENABLED` | `false` | Main rollout switch. Set to `true` only for canary/smoke cohorts first. | +| `ASSEMBLYAI_BACKGROUND_STT_WORKLOADS` | `sync,background,postprocess` | Comma-separated eligible background workloads. Unknown or ineligible values are ignored. | +| `ASSEMBLYAI_STT_MODEL` | `universal-2` | AssemblyAI model used when a Deepgram `nova-*` model request is routed to AssemblyAI. | +| `ASSEMBLYAI_BASE_URL` | `https://api.assemblyai.com` | Provider API base URL. | +| `ASSEMBLYAI_POLL_INTERVAL_SECONDS` | `3` | Poll interval for async transcript completion. | +| `ASSEMBLYAI_MAX_POLL_SECONDS` | `900` | Max wait before the async run times out and can fall back. | +| `ASSEMBLYAI_SMOKE_AUDIO_URL` | unset | Optional live smoke-test audio URL. | + +## Routing And Fallback + +`backend/utils/stt/providers.py` owns provider selection. Background workloads +use AssemblyAI only when the main flag is enabled and the workload is in +`ASSEMBLYAI_BACKGROUND_STT_WORKLOADS`; otherwise they use Deepgram. + +Deepgram is the prerecorded fallback provider. If AssemblyAI fails, times out, +or exhausts retries for an eligible background workload, the failed AssemblyAI +run is finalized in the provider ledger and the request retries through +Deepgram with fallback metadata. + +Rollback is to set `ASSEMBLYAI_BACKGROUND_STT_ENABLED=false` or remove the +affected workload from `ASSEMBLYAI_BACKGROUND_STT_WORKLOADS`. No client deploy +is required for rollback. + +## Storage And Identity + +Provider output is normalized into `ProviderTranscriptResult` and reconstructed +into canonical `TranscriptSegment` records before conversation storage. +Provider-local speaker labels are stored as `provider_cluster_id` and +`provider_speaker_label`; they do not become canonical identity. + +Canonical speaker identity is assigned after normalization with the Omi speaker +embedding service. Low-confidence clusters remain explicitly unknown rather +than being mapped to a fake person or to `speaker_id=0`. + +Provider run ledgers live in `transcription_provider_runs`, and daily rollups +live in `transcription_provider_usage_daily`. Ledger payload guards reject raw +audio bytes, transcript text, words, utterances, and similar high-risk payloads. + +## Dashboards And Reports + +Expected rollout dashboards should use `/metrics` counters and histograms for: + +- `transcription_provider_requests_total` +- `transcription_provider_latency_seconds` +- `transcription_provider_retries_total` +- `transcription_provider_fallbacks_total` +- `transcription_provider_audio_seconds_total` +- `transcription_provider_billable_seconds_total` +- `transcription_provider_speaker_clusters_total` +- `transcription_provider_identity_confidence_total` + +Expected cost and quality review should inspect: + +- `transcription_provider_usage_daily` by provider, model, workload, and day. +- `transcription_provider_runs` for canary failures, fallback direction, retry counts, latency, billable seconds, and estimated cost. +- `backend/scripts/stt/provider_comparison_gate.py` reports for transcript drift, speaker cluster stability, identity quality, fallback rate, failure rate, and economics. + +Fixture validation: + +```bash +cd backend +python3 scripts/stt/provider_comparison_gate.py --manifest tests/fixtures/stt_provider_eval/manifest.json +``` + +Live validation requires `DEEPGRAM_API_KEY`, `ASSEMBLYAI_API_KEY`, and manifest +cases with `audio_url` or `audio_file`. diff --git a/docs/doc/developer/backend/listen_pusher_pipeline.mdx b/docs/doc/developer/backend/listen_pusher_pipeline.mdx index 8c8ef372b37..c38071290c2 100644 --- a/docs/doc/developer/backend/listen_pusher_pipeline.mdx +++ b/docs/doc/developer/backend/listen_pusher_pipeline.mdx @@ -6,11 +6,11 @@ description: "Sequence diagrams for the /v4/listen WebSocket and Pusher processi # Listen + Pusher Pipeline — Sequence Diagrams -> Last updated: 2026-03-28 (PR #6061 — remove local fallback) +> Last updated: 2026-05-21 (AssemblyAI background STT MVP) > > These diagrams document the real behavior observed during E2E testing with -> live services (backend, pusher, Deepgram, embedding API). Update when the -> pipeline changes. +> live services (backend, pusher, Deepgram, AssemblyAI, embedding API). Update +> when the pipeline changes. ## 1. Connection + Streaming + Transcription @@ -226,6 +226,58 @@ sequenceDiagram Note over Pusher: In LOCAL_DEVELOPMENT mode,
GCS upload fails (no prod creds).
Conversation lifecycle still works. ``` +## 5.1 Background / Sync Upload Transcription + +Realtime listen and Hold-to-Talk streaming remain Deepgram paths. Background +prerecorded transcription routes through `utils/stt/provider_service.py`, which +selects Deepgram by default and can select AssemblyAI only when +`ASSEMBLYAI_BACKGROUND_STT_ENABLED=true` and the workload is listed in +`ASSEMBLYAI_BACKGROUND_STT_WORKLOADS`. + +Eligible AssemblyAI workloads are `sync`, `background`, and `postprocess`. +Latency-critical workloads (`realtime`, streaming `ptt`, and +`voice_message`) continue to use Deepgram and the existing websocket/finalize +semantics. + +```mermaid +sequenceDiagram + participant Client + participant Backend as Backend + participant ProviderService as Provider Service + participant AssemblyAI + participant Deepgram + participant EmbeddingAPI as Embedding API + participant Firestore + participant Metrics as /metrics + + Client->>Backend: Upload background audio or sync audio URL + Backend->>ProviderService: transcribe_url/workload=sync|background|postprocess + alt AssemblyAI enabled for workload + ProviderService->>AssemblyAI: Submit async prerecorded transcript request + AssemblyAI-->>ProviderService: Completed transcript with utterances, words, speaker labels + else AssemblyAI disabled or workload ineligible + ProviderService->>Deepgram: Prerecorded transcript request + Deepgram-->>ProviderService: Transcript with words and speaker labels + end + + alt AssemblyAI fails, times out, or exhausts retries + ProviderService->>Firestore: Finalize failed AssemblyAI provider run + ProviderService->>Deepgram: Fallback prerecorded transcript request + Deepgram-->>ProviderService: Transcript result + end + + ProviderService->>ProviderService: Normalize provider result and reconstruct canonical segments + ProviderService->>EmbeddingAPI: Extract cluster samples for Omi speaker identity + ProviderService->>Firestore: Store canonical transcript segments and provider run ledger + ProviderService->>Firestore: Increment daily provider usage rollup + ProviderService->>Metrics: Emit provider request, latency, retry, fallback, audio seconds, clusters, identity confidence +``` + +AssemblyAI speaker labels are provider/session-local cluster IDs. Canonical +speaker identity is applied after provider normalization with the Omi embedding +service. Low-confidence matches remain explicit unknown speakers; provider +cluster metadata is preserved separately from `person_id` and `is_user`. + ## 6. Event Wire Protocol ### Server → Client (JSON over WS text frames) @@ -283,4 +335,3 @@ Both `transcribe.py` and `pusher.py` use an `asyncio.wait(FIRST_COMPLETED)` supe - Finite task normal completion (e.g. `process_pending_conversations`, `speaker_identification_task`) After supervisor exit, remaining tasks drain with `BG_DRAIN_TIMEOUT` (30s) before force-cancel. The connection gauge (`inc`/`dec`) is always paired in `try`/`finally`. - diff --git a/docs/docs.json b/docs/docs.json index fa8805c7abf..6644a1692f2 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -35,6 +35,7 @@ "doc/developer/backend/chat_system", "doc/developer/backend/StoringConversations", "doc/developer/backend/transcription", + "doc/developer/backend/assemblyai_background_rollout", "doc/developer/backend/listen_pusher_pipeline" ], "icon": "server" From 3ee7ee56d3ada1c53327715b2f1114f9761a9ea0 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Thu, 21 May 2026 02:37:22 +0000 Subject: [PATCH 15/44] Fix PTT test mocks, postprocess_words API, and self-voice review gating Update desktop PTT tests to patch stt_provider_service after chat.py refactor, remove the unused duration parameter from postprocess_words, and skip self-voice review candidates when no identity assignment is available. --- backend/tests/unit/test_desktop_transcribe.py | 129 +++++++++--------- backend/tests/unit/test_transcript_segment.py | 1 - backend/utils/self_voice_review.py | 4 +- backend/utils/stt/pre_recorded.py | 4 +- 4 files changed, 69 insertions(+), 69 deletions(-) diff --git a/backend/tests/unit/test_desktop_transcribe.py b/backend/tests/unit/test_desktop_transcribe.py index 3b750a4f180..e7a3129dfa4 100644 --- a/backend/tests/unit/test_desktop_transcribe.py +++ b/backend/tests/unit/test_desktop_transcribe.py @@ -297,125 +297,128 @@ def test_return_language_extracts_detected_lang(self, mock_client): # --------------------------------------------------------------------------- +def _mock_pcm_transcription(*, text=None, detected_language=None, words=None, segments=None): + from utils.stt.provider_service import PrerecordedTranscriptionResponse + + if segments is None: + segment = MagicMock() + segment.text = text + segments = [segment] if text is not None else [] + if words is None: + words = [{'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': text}] if text is not None else [] + return PrerecordedTranscriptionResponse( + result=MagicMock(), + detected_language=detected_language, + segments=segments, + words=words, + run_id=None, + ) + + class TestTranscribePcmBytes: """Verify transcribe_pcm_bytes passes language/model and propagates errors.""" - @patch('utils.chat.postprocess_words') - @patch('utils.chat.deepgram_prerecorded_from_bytes') - @patch('utils.chat.get_deepgram_model_for_language') - def test_language_model_forwarded(self, mock_get_model, mock_dg, mock_postprocess): - """stt_language and stt_model should be passed to deepgram_prerecorded_from_bytes.""" + @patch('utils.chat.stt_provider_service.transcribe_bytes') + @patch('utils.chat.stt_provider_service.resolve_prerecorded_language_model') + def test_language_model_forwarded(self, mock_resolve_model, mock_transcribe): + """stt_language and stt_model should be passed to transcribe_bytes.""" from utils.chat import transcribe_pcm_bytes - mock_get_model.return_value = ('es', 'nova-3') - mock_dg.return_value = [{'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'Hola'}] - mock_seg = MagicMock() - mock_seg.text = 'Hola' - mock_postprocess.return_value = [mock_seg] + mock_resolve_model.return_value = ('es', 'nova-3') + mock_transcribe.return_value = _mock_pcm_transcription(text='Hola') text, lang = transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='es') - mock_dg.assert_called_once() - call_kwargs = mock_dg.call_args[1] + mock_transcribe.assert_called_once() + call_kwargs = mock_transcribe.call_args[1] assert call_kwargs['language'] == 'es' assert call_kwargs['model'] == 'nova-3' assert call_kwargs['encoding'] == 'linear16' assert text == 'Hola' - @patch('utils.chat.deepgram_prerecorded_from_bytes') - @patch('utils.chat.get_deepgram_model_for_language') - def test_runtime_error_propagates(self, mock_get_model, mock_dg): + @patch('utils.chat.stt_provider_service.transcribe_bytes') + @patch('utils.chat.stt_provider_service.resolve_prerecorded_language_model') + def test_runtime_error_propagates(self, mock_resolve_model, mock_transcribe): """RuntimeError from Deepgram should propagate (not be caught).""" from utils.chat import transcribe_pcm_bytes - mock_get_model.return_value = ('en', 'nova-3') - mock_dg.side_effect = RuntimeError('Deepgram failed') + mock_resolve_model.return_value = ('en', 'nova-3') + mock_transcribe.side_effect = RuntimeError('Deepgram failed') with pytest.raises(RuntimeError, match='Deepgram failed'): transcribe_pcm_bytes(b'\x00' * 100, 'test-uid') - @patch('utils.chat.deepgram_prerecorded_from_bytes') - @patch('utils.chat.get_deepgram_model_for_language') - def test_empty_words_returns_none(self, mock_get_model, mock_dg): + @patch('utils.chat.stt_provider_service.transcribe_bytes') + @patch('utils.chat.stt_provider_service.resolve_prerecorded_language_model') + def test_empty_words_returns_none(self, mock_resolve_model, mock_transcribe): """Empty word list should return (None, language).""" from utils.chat import transcribe_pcm_bytes - mock_get_model.return_value = ('en', 'nova-3') - mock_dg.return_value = [] + mock_resolve_model.return_value = ('en', 'nova-3') + mock_transcribe.return_value = _mock_pcm_transcription(text=None, words=[]) text, lang = transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='en') assert text is None assert lang == 'en' - @patch('utils.chat.postprocess_words') - @patch('utils.chat.deepgram_prerecorded_from_bytes') - @patch('utils.chat.get_deepgram_model_for_language') - def test_multi_language_returns_detected_language(self, mock_get_model, mock_dg, mock_postprocess): + @patch('utils.chat.stt_provider_service.transcribe_bytes') + @patch('utils.chat.stt_provider_service.resolve_prerecorded_language_model') + def test_multi_language_returns_detected_language(self, mock_resolve_model, mock_transcribe): """Multi-language mode should return the Deepgram-detected language, not hardcoded 'en'.""" from utils.chat import transcribe_pcm_bytes - mock_get_model.return_value = ('multi', 'nova-3') - # return_language=True path returns (words, detected_lang) - mock_dg.return_value = ([{'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'Bonjour'}], 'fr') - mock_seg = MagicMock() - mock_seg.text = 'Bonjour' - mock_postprocess.return_value = [mock_seg] + mock_resolve_model.return_value = ('multi', 'nova-3') + mock_transcribe.return_value = _mock_pcm_transcription(text='Bonjour', detected_language='fr') text, lang = transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='multi') assert text == 'Bonjour' assert lang == 'fr' - # Verify return_language=True was passed - call_kwargs = mock_dg.call_args[1] + call_kwargs = mock_transcribe.call_args[1] assert call_kwargs['return_language'] is True - @patch('utils.chat.postprocess_words') - @patch('utils.chat.deepgram_prerecorded_from_bytes') - @patch('utils.chat.get_deepgram_model_for_language') - def test_chinese_language_uses_nova3(self, mock_get_model, mock_dg, mock_postprocess): + @patch('utils.chat.stt_provider_service.transcribe_bytes') + @patch('utils.chat.stt_provider_service.resolve_prerecorded_language_model') + def test_chinese_language_uses_nova3(self, mock_resolve_model, mock_transcribe): """Chinese should use nova-3 model.""" from utils.chat import transcribe_pcm_bytes - mock_get_model.return_value = ('zh', 'nova-3') - mock_dg.return_value = [{'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': '你好'}] - mock_seg = MagicMock() - mock_seg.text = '你好' - mock_postprocess.return_value = [mock_seg] + mock_resolve_model.return_value = ('zh', 'nova-3') + mock_transcribe.return_value = _mock_pcm_transcription(text='你好') - text, lang = transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='zh') + transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='zh') - call_kwargs = mock_dg.call_args[1] + call_kwargs = mock_transcribe.call_args[1] assert call_kwargs['model'] == 'nova-3' assert call_kwargs['language'] == 'zh' - @patch('utils.chat.postprocess_words') - @patch('utils.chat.deepgram_prerecorded_from_bytes') - @patch('utils.chat.get_deepgram_model_for_language') - def test_whitespace_only_transcript_returns_none(self, mock_get_model, mock_dg, mock_postprocess): + @patch('utils.chat.stt_provider_service.transcribe_bytes') + @patch('utils.chat.stt_provider_service.resolve_prerecorded_language_model') + def test_whitespace_only_transcript_returns_none(self, mock_resolve_model, mock_transcribe): """Whitespace-only transcript after postprocessing should return (None, language).""" from utils.chat import transcribe_pcm_bytes - mock_get_model.return_value = ('en', 'nova-3') - mock_dg.return_value = [{'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': ' '}] - mock_seg = MagicMock() - mock_seg.text = ' ' - mock_postprocess.return_value = [mock_seg] + mock_resolve_model.return_value = ('en', 'nova-3') + mock_transcribe.return_value = _mock_pcm_transcription(text=' ') text, lang = transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='en') assert text is None assert lang == 'en' - @patch('utils.chat.deepgram_prerecorded_from_bytes') - @patch('utils.chat.get_deepgram_model_for_language') - def test_postprocess_empty_returns_none(self, mock_get_model, mock_dg): - """postprocess_words returning empty list should return (None, language).""" + @patch('utils.chat.stt_provider_service.transcribe_bytes') + @patch('utils.chat.stt_provider_service.resolve_prerecorded_language_model') + def test_empty_segments_returns_none(self, mock_resolve_model, mock_transcribe): + """Empty segments should return (None, language).""" from utils.chat import transcribe_pcm_bytes - mock_get_model.return_value = ('en', 'nova-3') - mock_dg.return_value = [{'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'hello'}] - # postprocess_words is imported at module level; mock it - with patch('utils.chat.postprocess_words', return_value=[]): - text, lang = transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='en') + mock_resolve_model.return_value = ('en', 'nova-3') + mock_transcribe.return_value = _mock_pcm_transcription( + text='hello', + words=[{'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'hello'}], + segments=[], + ) + + text, lang = transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='en') assert text is None assert lang == 'en' diff --git a/backend/tests/unit/test_transcript_segment.py b/backend/tests/unit/test_transcript_segment.py index 6b6b9b1150d..a9947da1193 100644 --- a/backend/tests/unit/test_transcript_segment.py +++ b/backend/tests/unit/test_transcript_segment.py @@ -113,7 +113,6 @@ def test_postprocess_words_does_not_promote_malformed_speaker_to_speaker_zero(): "text": "hello", } ], - duration=1, ) assert len(segments) == 1 diff --git a/backend/utils/self_voice_review.py b/backend/utils/self_voice_review.py index 12b65d00487..803cc9a607e 100644 --- a/backend/utils/self_voice_review.py +++ b/backend/utils/self_voice_review.py @@ -228,8 +228,8 @@ def _quality_for_segment(value: Optional[Union[SegmentQuality, dict]]) -> Option def _confidence_bucket(assignment: Optional[ClusterIdentityAssignment], quality: dict) -> Optional[str]: - if not assignment: - return 'low' + if assignment is None: + return None if assignment.person_id: return None if assignment.state == 'user' and assignment.confidence is not None and assignment.confidence >= 0.7: diff --git a/backend/utils/stt/pre_recorded.py b/backend/utils/stt/pre_recorded.py index 2907684c493..1555a9b0358 100644 --- a/backend/utils/stt/pre_recorded.py +++ b/backend/utils/stt/pre_recorded.py @@ -366,7 +366,5 @@ def legacy_words_to_provider_result(words: List[dict]) -> ProviderTranscriptResu return ProviderTranscriptResult(provider=provider or 'unknown', model=model, words=provider_words) -def postprocess_words( - words: List[dict], duration: int, skip_n_seconds: int = 0 # , merge_segments: bool = True -) -> List[TranscriptSegment]: +def postprocess_words(words: List[dict], skip_n_seconds: int = 0) -> List[TranscriptSegment]: return reconstruct_conversation(legacy_words_to_provider_result(words), skip_n_seconds=skip_n_seconds) From 5db29cb882cd919b373db5971c3c372e191b89e8 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 09:57:08 +0700 Subject: [PATCH 16/44] Address provider instrumentation review blockers --- .../database/transcription_provider_usage.py | 38 ++++++++++++ backend/models/transcript_segment.py | 23 +++++++- backend/routers/sync.py | 11 ++++ .../unit/test_background_provider_service.py | 59 ++++++++++++++++++- .../unit/test_background_speaker_identity.py | 20 +++++++ .../unit/test_conversation_reconstructor.py | 20 +++++++ .../tests/unit/test_folder_name_enrichment.py | 13 ++++ .../tests/unit/test_provider_evaluation.py | 51 ++++++++++++++++ backend/tests/unit/test_transcript_segment.py | 25 ++++++++ .../unit/test_transcription_provider_usage.py | 54 +++++++++++++++++ .../conversations/postprocess_conversation.py | 10 ++++ backend/utils/conversations/render.py | 36 ++++++++++- .../utils/stt/background_speaker_identity.py | 3 + .../utils/stt/conversation_reconstructor.py | 19 ++++-- backend/utils/stt/provider_costs.py | 24 ++++---- backend/utils/stt/provider_evaluation.py | 23 ++++---- backend/utils/stt/provider_service.py | 58 ++++++++++++++++-- 17 files changed, 451 insertions(+), 36 deletions(-) diff --git a/backend/database/transcription_provider_usage.py b/backend/database/transcription_provider_usage.py index 9066f695664..ea5511f02c5 100644 --- a/backend/database/transcription_provider_usage.py +++ b/backend/database/transcription_provider_usage.py @@ -334,6 +334,44 @@ def purge_provider_runs_for_user(uid: str, batch_size: int = 400) -> int: return deleted +def update_provider_run_identity_metrics( + run_id: str, + provider: str, + model: str, + workload: str, + identified_speaker_cluster_count: int, + identity_confidence_summary: Optional[dict[str, Any]] = None, +) -> None: + ref = _run_ref(run_id) + snapshot = ref.get() + if not snapshot.exists: + return + data = snapshot.to_dict() or {} + previous_identified = data.get('identified_speaker_cluster_count', 0) or 0 + previous_summary = data.get('identity_confidence_summary') or {} + summary = identity_confidence_summary or {} + completed_at = (data.get('timing') or {}).get('completed_at') or _utc_now() + + ref.set( + { + 'identified_speaker_cluster_count': identified_speaker_cluster_count, + 'identity_confidence_summary': summary, + 'updated_at': _utc_now(), + }, + merge=True, + ) + + rollup_update = { + 'identified_speaker_cluster_count': firestore.Increment(identified_speaker_cluster_count - previous_identified), + 'last_updated': _utc_now(), + } + for bucket in set(previous_summary) | set(summary): + delta = (summary.get(bucket, 0) or 0) - (previous_summary.get(bucket, 0) or 0) + if delta: + rollup_update[f'identity_confidence_counts.{bucket}'] = firestore.Increment(delta) + _rollup_ref(utc_day_bucket(completed_at), provider, model, workload).set(rollup_update, merge=True) + + def emit_provider_run_metrics( provider: str, model: str, diff --git a/backend/models/transcript_segment.py b/backend/models/transcript_segment.py index 3dcbbb35c94..7bf045d7f87 100644 --- a/backend/models/transcript_segment.py +++ b/backend/models/transcript_segment.py @@ -108,6 +108,15 @@ def segments_as_string(segments, include_timestamps=False, user_name: str = None user_name = 'User' transcript = '' people_map = {person.id: person.name for person in people} if people else {} + provider_cluster_display_ids = {} + next_provider_display_id = 1 + for segment in segments: + provider_cluster_key = _provider_display_cluster_key(segment) + if segment.is_user or not provider_cluster_key: + continue + if provider_cluster_key not in provider_cluster_display_ids: + provider_cluster_display_ids[provider_cluster_key] = next_provider_display_id + next_provider_display_id += 1 include_timestamps = include_timestamps and TranscriptSegment.can_display_seconds(segments) for segment in segments: segment_text = segment.text.strip() @@ -117,7 +126,11 @@ def segments_as_string(segments, include_timestamps=False, user_name: str = None if segment.person_id and segment.person_id in people_map: speaker_name = people_map[segment.person_id] else: - speaker_name = f'Speaker {segment.speaker_id}' + provider_cluster_key = _provider_display_cluster_key(segment) + if provider_cluster_key in provider_cluster_display_ids: + speaker_name = f'Speaker {provider_cluster_display_ids[provider_cluster_key]}' + else: + speaker_name = f'Speaker {segment.speaker_id}' transcript += f'{timestamp_str}{speaker_name}: {segment_text}\n\n' return transcript.strip() @@ -276,6 +289,14 @@ def _merge(a, b: TranscriptSegment): return segments, joined_similar_segments, removed_ids +def _provider_display_cluster_key(segment: TranscriptSegment) -> Optional[str]: + if not segment.provider_cluster_id and not segment.provider_speaker_label: + return None + if segment.speaker_id not in (None, 0): + return None + return segment.provider_cluster_id or segment.provider_speaker_label + + class ImprovedTranscriptSegment(BaseModel): speaker_id: int = Field(..., description='The correctly assigned speaker id') text: str = Field(..., description='The corrected text of the segment') diff --git a/backend/routers/sync.py b/backend/routers/sync.py index e9b5f698b7f..acf914e3fa4 100644 --- a/backend/routers/sync.py +++ b/backend/routers/sync.py @@ -880,6 +880,7 @@ def delete_file(): language=stt_language, model=stt_model, keywords=vocabulary if vocabulary else None, + raw_audio_seconds=get_wav_duration(path), ) words = transcription.words detected_language = transcription.detected_language or stt_language @@ -903,6 +904,16 @@ def delete_file(): finally: if audio_bytes: del audio_bytes + try: + stt_provider_service.update_provider_run_identity_metrics( + transcription.run_id, + transcription.result.provider, + transcription.result.model or 'unknown', + STTWorkload.sync, + transcript_segments, + ) + except Exception as e: + logger.warning(f'Speaker ID (sync): identity metric update failed for {path}: {e}') timestamp = get_timestamp_from_path(path) segment_end_timestamp = timestamp + transcript_segments[-1].end diff --git a/backend/tests/unit/test_background_provider_service.py b/backend/tests/unit/test_background_provider_service.py index 5da6d1a94ac..2a0b1e556ec 100644 --- a/backend/tests/unit/test_background_provider_service.py +++ b/backend/tests/unit/test_background_provider_service.py @@ -176,7 +176,7 @@ def test_provider_service_uses_assemblyai_for_enabled_sync_workload(monkeypatch) assert finalize_run.call_args.kwargs['provider'] == 'assemblyai' assert finalize_run.call_args.kwargs['artifact_refs'] == {} assert finalize_run.call_args.kwargs['billable_seconds'] == 2.0 - assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00008333 + assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00009444 def test_provider_service_falls_back_to_deepgram_when_assemblyai_fails(monkeypatch): @@ -205,6 +205,7 @@ def test_provider_service_falls_back_to_deepgram_when_assemblyai_fails(monkeypat conversation_id='conversation-1', language='multi', model='nova-3', + raw_audio_seconds=2.0, ) assert response.result.provider == 'deepgram' @@ -216,6 +217,8 @@ def test_provider_service_falls_back_to_deepgram_when_assemblyai_fails(monkeypat assert finalize_run.call_args_list[0].kwargs['status'] == 'failed' assert finalize_run.call_args_list[0].kwargs['retry_count'] == 1 assert finalize_run.call_args_list[0].kwargs['error_class'] == 'RuntimeError' + assert finalize_run.call_args_list[0].kwargs['billable_seconds'] == 2.0 + assert finalize_run.call_args_list[0].kwargs['estimated_cost_usd'] == 0.00009444 assert finalize_run.call_args_list[1].kwargs['run_id'] == 'run-dg' assert finalize_run.call_args_list[1].kwargs['provider'] == 'deepgram' assert finalize_run.call_args_list[1].kwargs['fallback_count'] == 1 @@ -320,7 +323,7 @@ def test_prerecorded_cost_estimator_uses_provider_defaults_and_unknown_provider_ workload='background', billable_seconds=60.0, ) - == 0.0025 + == 0.00283333 ) assert ( estimate_prerecorded_provider_cost_usd( @@ -342,6 +345,58 @@ def test_prerecorded_cost_estimator_uses_provider_defaults_and_unknown_provider_ ) +def test_provider_service_counts_user_identity_as_identified_cluster(): + result = _provider_result(provider='assemblyai', model='universal-2') + segments = provider_service.reconstruct_conversation(result) + segments[0].is_user = True + segments[0].speaker_identity_state = 'user' + + with patch.object(provider_service, 'finalize_provider_run') as finalize_run: + provider_service._finalize_run( + 'run-user', + result, + STTWorkload.sync, + provider_service.datetime.now(provider_service.timezone.utc), + 'succeeded', + retry_count=0, + raw_audio_seconds=2.0, + segments=segments, + ) + + assert finalize_run.call_args.kwargs['identified_speaker_cluster_count'] == 1 + + +def test_provider_service_counts_label_only_identified_clusters(): + result = ProviderTranscriptResult( + provider='assemblyai', + model='universal-2', + duration=2.0, + words=[ + ProviderTranscriptWord(text='hello', start=0.0, end=0.4, speaker_label='A'), + ProviderTranscriptWord(text='again', start=0.5, end=0.8, speaker_label='A'), + ProviderTranscriptWord(text='there', start=1.0, end=1.4, speaker_label='B'), + ], + ) + segments = provider_service.reconstruct_conversation(result) + segments[0].person_id = 'person-a' + segments[0].speaker_identity_state = 'identified' + + with patch.object(provider_service, 'finalize_provider_run') as finalize_run: + provider_service._finalize_run( + 'run-labels', + result, + STTWorkload.sync, + provider_service.datetime.now(provider_service.timezone.utc), + 'succeeded', + retry_count=0, + raw_audio_seconds=2.0, + segments=segments, + ) + + assert finalize_run.call_args.kwargs['speaker_cluster_count'] == 2 + assert finalize_run.call_args.kwargs['identified_speaker_cluster_count'] == 1 + + def test_provider_service_live_assemblyai_smoke_records_ledger_when_credentials_are_present(monkeypatch): api_key = os.getenv('ASSEMBLYAI_API_KEY') audio_url = os.getenv('ASSEMBLYAI_SMOKE_AUDIO_URL') diff --git a/backend/tests/unit/test_background_speaker_identity.py b/backend/tests/unit/test_background_speaker_identity.py index f49744637f2..d7bfb8151ab 100644 --- a/backend/tests/unit/test_background_speaker_identity.py +++ b/backend/tests/unit/test_background_speaker_identity.py @@ -104,6 +104,26 @@ def test_low_confidence_cluster_remains_explicitly_unknown_with_candidate_metada assert segments[0].speaker_identity_confidence is None +def test_unknown_assignment_demotes_stale_person_identity(): + query_embedding = np.array([[1.0, 0.0]], dtype=np.float32) + distant_embedding = np.array([[0.0, 1.0]], dtype=np.float32) + segments = [_segment('a1', 0.0, 3.0)] + segments[0].person_id = 'stale-person' + segments[0].speaker_identity_state = 'identified' + cache = {'person-distant': {'embedding': distant_embedding, 'name': 'Distant'}} + + identify_background_speaker_clusters( + segments, + _wav_bytes(), + cache, + embedding_extractor=lambda _audio, _filename: query_embedding, + ) + + assert segments[0].speaker_identity_state == 'unknown' + assert segments[0].person_id is None + assert segments[0].is_user is False + + def test_text_self_introduction_is_hint_only_without_voice_assignment(): segments = [_segment('intro', 0.0, 3.0, text='I am Alice and I joined the call.')] diff --git a/backend/tests/unit/test_conversation_reconstructor.py b/backend/tests/unit/test_conversation_reconstructor.py index 4edecf686c2..4b2f857582a 100644 --- a/backend/tests/unit/test_conversation_reconstructor.py +++ b/backend/tests/unit/test_conversation_reconstructor.py @@ -48,6 +48,26 @@ def test_reconstructs_word_only_provider_result_with_stable_ordering_and_cluster assert segments[1].provider_cluster_id == 'cluster-b' +def test_reconstructs_label_only_provider_result_without_collapsing_clusters(): + result = ProviderTranscriptResult( + provider='test-provider', + model='async-model', + words=[ + _word('hello', 0.0, 0.4, label='A'), + _word('there', 0.5, 0.8, label='A'), + _word('hi', 1.0, 1.2, label='B'), + ], + ) + + segments = reconstruct_conversation(result) + + assert [segment.text for segment in segments] == ['Hello there', 'Hi'] + assert segments[0].provider_cluster_id is None + assert segments[0].provider_speaker_label == 'A' + assert segments[0].speaker_identity_state == 'unknown' + assert segments[1].provider_speaker_label == 'B' + + def test_reconstructs_utterance_only_provider_result_with_explicit_unknown_identity(): result = ProviderTranscriptResult( provider='assemblyai', diff --git a/backend/tests/unit/test_folder_name_enrichment.py b/backend/tests/unit/test_folder_name_enrichment.py index 667579e2d06..c249f64a6d0 100644 --- a/backend/tests/unit/test_folder_name_enrichment.py +++ b/backend/tests/unit/test_folder_name_enrichment.py @@ -295,6 +295,19 @@ def test_missing_speaker_id_defaults_to_zero(self): populate_speaker_names('uid1', conversations) assert conversations[0]['transcript_segments'][0]['speaker_name'] == 'Speaker 0' + def test_provider_label_unknown_cluster_gets_stable_display_name(self): + conversations = [ + { + 'transcript_segments': [ + {'text': 'hi', 'provider_speaker_label': 'A', 'speaker_id': 0}, + {'text': 'there', 'provider_speaker_label': 'B', 'speaker_id': 0}, + ] + } + ] + populate_speaker_names('uid1', conversations) + assert conversations[0]['transcript_segments'][0]['speaker_name'] == 'Speaker 1' + assert conversations[0]['transcript_segments'][1]['speaker_name'] == 'Speaker 2' + def test_user_profile_missing_name_falls_back_to_user(self): _mock_get_user_profile.return_value = {"name": None} conversations = [{'transcript_segments': [{'text': 'hi', 'is_user': True, 'speaker_id': 0}]}] diff --git a/backend/tests/unit/test_provider_evaluation.py b/backend/tests/unit/test_provider_evaluation.py index 188d793ae89..9e7dd350e04 100644 --- a/backend/tests/unit/test_provider_evaluation.py +++ b/backend/tests/unit/test_provider_evaluation.py @@ -95,6 +95,57 @@ def test_provider_result_words_are_grouped_into_cluster_segments(): assert summary['speaker_cluster_count'] == 2 +def test_unknown_identity_counts_as_low_confidence(): + summary = summarize_provider_output( + 'assemblyai', + { + 'transcript': { + 'segments': [ + { + 'text': 'hello', + 'provider_cluster_id': 'A', + 'speaker_identity_state': 'unknown', + 'speaker_identity_confidence': None, + } + ] + } + }, + ) + + assert summary['low_confidence_identity_count'] == 1 + + +def test_low_confidence_identity_counts_clusters_not_segments(): + summary = summarize_provider_output( + 'assemblyai', + { + 'transcript': { + 'segments': [ + { + 'text': 'hello', + 'provider_cluster_id': 'A', + 'speaker_identity_state': 'unknown', + }, + { + 'text': 'again', + 'provider_cluster_id': 'A', + 'speaker_identity_state': 'unknown', + }, + { + 'text': 'there', + 'provider_cluster_id': 'B', + 'speaker_identity_state': 'identified', + 'speaker_identity_confidence': 0.9, + }, + ] + } + }, + ) + + assert summary['low_confidence_identity_count'] == 1 + assert summary['low_confidence_identity_rate'] == 0.5 + + def test_compact_markdown_report_is_review_friendly(): report = build_comparison_report([_load_fixture_case()]) markdown = compact_markdown_report(report) diff --git a/backend/tests/unit/test_transcript_segment.py b/backend/tests/unit/test_transcript_segment.py index a9947da1193..3ead4f86b41 100644 --- a/backend/tests/unit/test_transcript_segment.py +++ b/backend/tests/unit/test_transcript_segment.py @@ -90,6 +90,31 @@ def test_unknown_identity_is_explicit_and_legacy_zero_remains_ambiguous(): assert legacy.speaker_identity_state == "legacy_ambiguous" +def test_segments_as_string_uses_provider_label_for_unknown_cluster_display(): + segments = [ + TranscriptSegment( + text="hello", + is_user=False, + start=0.0, + end=1.0, + provider_speaker_label="A", + speaker_identity_state="unknown", + ), + TranscriptSegment( + text="there", + is_user=False, + start=1.0, + end=2.0, + provider_speaker_label="B", + speaker_identity_state="unknown", + ), + ] + + transcript = TranscriptSegment.segments_as_string(segments) + + assert transcript == "Speaker 1: hello\n\nSpeaker 2: there" + + def test_postprocess_words_does_not_promote_malformed_speaker_to_speaker_zero(): deepgram = types.ModuleType("deepgram") deepgram.DeepgramClient = lambda *args, **kwargs: object() diff --git a/backend/tests/unit/test_transcription_provider_usage.py b/backend/tests/unit/test_transcription_provider_usage.py index 3ad0d3edc5d..1e42126a74e 100644 --- a/backend/tests/unit/test_transcription_provider_usage.py +++ b/backend/tests/unit/test_transcription_provider_usage.py @@ -38,6 +38,7 @@ class _FakeSnapshot: def __init__(self, data): self._data = data self.reference = MagicMock() + self.exists = data is not None def to_dict(self): return self._data @@ -51,6 +52,12 @@ def __init__(self, doc_id): def set(self, data, merge=False): self.set_calls.append({'data': data, 'merge': merge}) + def get(self): + data = {} + for call in self.set_calls: + data.update(call['data']) + return _FakeSnapshot(data if self.set_calls else None) + class _FakeCollection: def __init__(self, name, docs): @@ -381,6 +388,53 @@ def test_failed_run_retry_count_rolls_up_and_emits_retry_metric(monkeypatch): assert observed_fallbacks == [] +def test_update_provider_run_identity_metrics_updates_doc_and_rollup_delta(monkeypatch): + fake_db = _FakeDb() + monkeypatch.setattr(usage, 'db', fake_db) + monkeypatch.setattr(usage.firestore, 'Increment', _inc) + started_at = datetime(2026, 5, 21, 1, 0, 0, tzinfo=timezone.utc) + completed_at = datetime(2026, 5, 21, 1, 0, 2, tzinfo=timezone.utc) + usage.create_provider_run( + uid='user-1', + provider='assemblyai', + model='universal-2', + workload='sync', + run_id='run-identity', + started_at=started_at, + ) + usage.finalize_provider_run( + run_id='run-identity', + provider='assemblyai', + model='universal-2', + workload='sync', + status='succeeded', + started_at=started_at, + completed_at=completed_at, + speaker_cluster_count=2, + identified_speaker_cluster_count=0, + identity_confidence_summary={'unknown': 2}, + ) + + usage.update_provider_run_identity_metrics( + run_id='run-identity', + provider='assemblyai', + model='universal-2', + workload='sync', + identified_speaker_cluster_count=1, + identity_confidence_summary={'very_high': 1, 'unknown': 1}, + ) + + run_doc = fake_db.docs[(usage.RUNS_COLLECTION, 'run-identity')] + update = run_doc.set_calls[-1]['data'] + assert update['identified_speaker_cluster_count'] == 1 + assert update['identity_confidence_summary'] == {'very_high': 1, 'unknown': 1} + rollup_doc = fake_db.docs[(usage.DAILY_USAGE_COLLECTION, '2026-05-21:assemblyai:universal-2:sync')] + rollup_delta = rollup_doc.set_calls[-1]['data'] + assert rollup_delta['identified_speaker_cluster_count'] == {'__increment': 1} + assert rollup_delta['identity_confidence_counts.very_high'] == {'__increment': 1} + assert rollup_delta['identity_confidence_counts.unknown'] == {'__increment': -1} + + def test_provider_metrics_source_does_not_define_forbidden_label_names(): forbidden_labels = { "['provider', 'model', 'workload', 'user_id']", diff --git a/backend/utils/conversations/postprocess_conversation.py b/backend/utils/conversations/postprocess_conversation.py index a4ea6a94882..1285827f0d7 100644 --- a/backend/utils/conversations/postprocess_conversation.py +++ b/backend/utils/conversations/postprocess_conversation.py @@ -101,6 +101,16 @@ def postprocess_conversation( _handle_segment_embedding_matching(uid, file_path, conversation.transcript_segments, aseg) else: _handle_segment_embedding_matching(uid, file_path, fal_segments, aseg) + try: + stt_provider_service.update_provider_run_identity_metrics( + transcription.run_id, + transcription.result.provider, + transcription.result.model or 'unknown', + STTWorkload.postprocess, + fal_segments, + ) + except Exception as e: + logger.warning(f'Speaker ID (postprocess): identity metric update failed for {conversation_id}: {e}') # Store both models results. conversations_db.store_model_segments_result( diff --git a/backend/utils/conversations/render.py b/backend/utils/conversations/render.py index 0e261d5a964..933f7620ac8 100644 --- a/backend/utils/conversations/render.py +++ b/backend/utils/conversations/render.py @@ -37,13 +37,26 @@ def populate_speaker_names(uid: str, conversations: List[Dict]) -> None: people_map = {p['id']: p['name'] for p in people_data} for conv in conversations: + provider_cluster_display_ids = {} + next_provider_display_id = 1 + for seg in conv.get('transcript_segments', []): + provider_cluster_key = _provider_display_cluster_key(seg) + if seg.get('is_user') or not provider_cluster_key: + continue + if provider_cluster_key not in provider_cluster_display_ids: + provider_cluster_display_ids[provider_cluster_key] = next_provider_display_id + next_provider_display_id += 1 for seg in conv.get('transcript_segments', []): if seg.get('is_user'): seg['speaker_name'] = user_name elif seg.get('person_id') and seg['person_id'] in people_map: seg['speaker_name'] = people_map[seg['person_id']] else: - seg['speaker_name'] = f"Speaker {seg.get('speaker_id', 0)}" + provider_cluster_key = _provider_display_cluster_key(seg) + if provider_cluster_key in provider_cluster_display_ids: + seg['speaker_name'] = f"Speaker {provider_cluster_display_ids[provider_cluster_key]}" + else: + seg['speaker_name'] = f"Speaker {seg.get('speaker_id', 0)}" def populate_folder_names(uid: str, conversations: List[Dict]) -> None: @@ -69,6 +82,27 @@ def populate_folder_names(uid: str, conversations: List[Dict]) -> None: conv['folder_name'] = folder_map.get(folder_id) if folder_id else None +def _provider_display_cluster_key(segment: Dict[str, Any]) -> str | None: + if not segment.get('provider_cluster_id') and not segment.get('provider_speaker_label'): + return None + if _legacy_speaker_id(segment) not in (None, 0): + return None + return segment.get('provider_cluster_id') or segment.get('provider_speaker_label') + + +def _legacy_speaker_id(segment: Dict[str, Any]) -> int | None: + speaker_id = segment.get('speaker_id') + if speaker_id is not None: + return speaker_id + speaker = segment.get('speaker') + if not speaker: + return None + try: + return int(str(speaker).split('_', 1)[1]) + except (ValueError, IndexError): + return 0 + + # --------------------------------------------------------------------------- # Redact: locked-content stripping # --------------------------------------------------------------------------- diff --git a/backend/utils/stt/background_speaker_identity.py b/backend/utils/stt/background_speaker_identity.py index 2bb98c7df03..3b6210fc753 100644 --- a/backend/utils/stt/background_speaker_identity.py +++ b/backend/utils/stt/background_speaker_identity.py @@ -275,6 +275,9 @@ def apply_cluster_identity_assignments( elif assignment.person_id: segment.is_user = False segment.person_id = assignment.person_id + else: + segment.is_user = False + segment.person_id = None def _group_segments_by_cluster(transcript_segments: List[TranscriptSegment]) -> Dict[str, List[TranscriptSegment]]: diff --git a/backend/utils/stt/conversation_reconstructor.py b/backend/utils/stt/conversation_reconstructor.py index c9f8c8ac0db..a75498d56df 100644 --- a/backend/utils/stt/conversation_reconstructor.py +++ b/backend/utils/stt/conversation_reconstructor.py @@ -184,9 +184,9 @@ def _retrieve_user_cluster_id(self, result: ProviderTranscriptResult, skip_n_sec speaker_counts = {} speaker_sources: Iterable[Tuple[float, Optional[str]]] = ( - [(word.start, word.provider_cluster_id) for word in result.words] + [(word.start, self._word_cluster_key(word)) for word in result.words] if result.words - else [(utterance.start, utterance.provider_cluster_id) for utterance in result.utterances] + else [(utterance.start, self._utterance_cluster_key(utterance)) for utterance in result.utterances] ) for start, provider_cluster_id in sorted(speaker_sources, key=lambda item: item[0]): if start >= skip_n_seconds: @@ -217,13 +217,13 @@ def _word_is_covered(self, word: ProviderTranscriptWord, intervals: Sequence[Tup def _should_merge_word(self, previous: _SegmentCandidate, word: ProviderTranscriptWord) -> bool: return ( - previous.provider_cluster_id == word.provider_cluster_id + self._candidate_cluster_key(previous) == self._word_cluster_key(word) and word.start - previous.end < self.max_same_cluster_gap_seconds ) def _should_merge_candidate(self, previous: _SegmentCandidate, candidate: _SegmentCandidate) -> bool: return ( - previous.provider_cluster_id == candidate.provider_cluster_id + self._candidate_cluster_key(previous) == self._candidate_cluster_key(candidate) and candidate.start - previous.end < self.max_same_cluster_gap_seconds ) @@ -235,7 +235,7 @@ def _is_duplicate_overlap(self, previous: _SegmentCandidate, candidate: _Segment candidate_text = self._normalize_text(candidate.text) return ( bool(previous_text and candidate_text) - and previous.provider_cluster_id == candidate.provider_cluster_id + and self._candidate_cluster_key(previous) == self._candidate_cluster_key(candidate) and (previous_text == candidate_text or previous_text in candidate_text or candidate_text in previous_text) ) @@ -251,6 +251,15 @@ def _identity_state(self, is_user: bool, provider_cluster_id: Optional[str], spe return 'unassigned' return 'unknown' + def _candidate_cluster_key(self, candidate: _SegmentCandidate) -> Optional[str]: + return candidate.provider_cluster_id or candidate.speaker_label + + def _word_cluster_key(self, word: ProviderTranscriptWord) -> Optional[str]: + return word.provider_cluster_id or word.speaker_label + + def _utterance_cluster_key(self, utterance: ProviderTranscriptUtterance) -> Optional[str]: + return utterance.provider_cluster_id or utterance.speaker_label + def _normalize_text(self, text: str) -> str: return ' '.join(text.lower().split()) diff --git a/backend/utils/stt/provider_costs.py b/backend/utils/stt/provider_costs.py index c44e714ec6a..c9ec475a81f 100644 --- a/backend/utils/stt/provider_costs.py +++ b/backend/utils/stt/provider_costs.py @@ -11,26 +11,28 @@ class PrerecordedProviderCostRate: # Pay-as-you-go public STT pricing, checked 2026-05-21. -# Add-on features and customer-specific committed-use discounts are intentionally -# excluded until their usage is represented explicitly in provider run metadata. +# AssemblyAI background runs use speaker_labels=True, so the pre-recorded +# diarization add-on is included in AssemblyAI rates here. +# Customer-specific committed-use discounts are intentionally excluded. _PRERECORDED_STT_COST_RATES: dict[str, dict[str, PrerecordedProviderCostRate]] = { STTProviderName.assemblyai.value: { - # AssemblyAI pricing: Universal-2 $0.15/hr, Universal-3 Pro $0.21/hr. + # AssemblyAI pricing: Universal-2 $0.15/hr, Universal-3 Pro $0.21/hr, + # plus pre-recorded Speaker Diarization $0.02/hr. 'universal-2': PrerecordedProviderCostRate( - usd_per_billable_second=0.15 / 3600, - source='assemblyai_prerecorded_payg_2026_05_21', + usd_per_billable_second=0.17 / 3600, + source='assemblyai_prerecorded_diarized_payg_2026_05_21', ), 'universal-3-pro': PrerecordedProviderCostRate( - usd_per_billable_second=0.21 / 3600, - source='assemblyai_prerecorded_payg_2026_05_21', + usd_per_billable_second=0.23 / 3600, + source='assemblyai_prerecorded_diarized_payg_2026_05_21', ), 'u3-pro': PrerecordedProviderCostRate( - usd_per_billable_second=0.21 / 3600, - source='assemblyai_prerecorded_payg_2026_05_21', + usd_per_billable_second=0.23 / 3600, + source='assemblyai_prerecorded_diarized_payg_2026_05_21', ), 'default': PrerecordedProviderCostRate( - usd_per_billable_second=0.15 / 3600, - source='assemblyai_prerecorded_default_2026_05_21', + usd_per_billable_second=0.17 / 3600, + source='assemblyai_prerecorded_diarized_default_2026_05_21', ), }, STTProviderName.deepgram.value: { diff --git a/backend/utils/stt/provider_evaluation.py b/backend/utils/stt/provider_evaluation.py index ccbb68c1d93..316e9194e06 100644 --- a/backend/utils/stt/provider_evaluation.py +++ b/backend/utils/stt/provider_evaluation.py @@ -112,12 +112,17 @@ def summarize_provider_output(provider: str, payload: dict[str, Any]) -> dict[st or segment.get('speaker_identity_state') in ('identified', 'user') ) } - identity_confidences = [ - segment.get('speaker_identity_confidence') - for segment in segments - if segment.get('speaker_identity_confidence') is not None - ] - low_confidence_count = sum(1 for confidence in identity_confidences if float(confidence) < 0.50) + low_confidence_clusters = set() + for segment in segments: + cluster_id = _cluster_id(segment) + if not cluster_id: + continue + state = segment.get('speaker_identity_state') + confidence = segment.get('speaker_identity_confidence') + if state == 'unknown': + low_confidence_clusters.add(cluster_id) + elif confidence is not None and float(confidence) < 0.50: + low_confidence_clusters.add(cluster_id) return { 'provider': provider, 'segments': segments, @@ -127,10 +132,8 @@ def summarize_provider_output(provider: str, payload: dict[str, Any]) -> dict[st 'speaker_cluster_count': len(clusters), 'identified_speaker_cluster_count': len(identified_clusters), 'unknown_speaker_cluster_count': max(len(clusters) - len(identified_clusters), 0), - 'low_confidence_identity_count': low_confidence_count, - 'low_confidence_identity_rate': ( - low_confidence_count / len(identity_confidences) if identity_confidences else 0.0 - ), + 'low_confidence_identity_count': len(low_confidence_clusters), + 'low_confidence_identity_rate': (len(low_confidence_clusters) / len(clusters) if clusters else 0.0), 'raw_audio_seconds': _number_from_ledger(ledger, 'raw_audio_seconds'), 'speech_active_seconds': _number_from_ledger(ledger, 'speech_active_seconds'), 'billable_seconds': _number_from_ledger(ledger, 'billable_seconds'), diff --git a/backend/utils/stt/provider_service.py b/backend/utils/stt/provider_service.py index 3cc16f6500a..fcbcb9fef76 100644 --- a/backend/utils/stt/provider_service.py +++ b/backend/utils/stt/provider_service.py @@ -538,9 +538,7 @@ def _finalize_run( if not run_id: return billable_seconds = raw_audio_seconds - clusters = { - item.provider_cluster_id for item in list(result.words) + list(result.utterances) if item.provider_cluster_id - } + clusters = {_segment_cluster_key(segment) for segment in segments if _segment_cluster_key(segment)} confidences = [segment.speaker_identity_confidence for segment in segments] try: finalize_provider_run( @@ -564,9 +562,7 @@ def _finalize_run( transcript_segment_count=len(segments), transcript_word_count=len(result.words), speaker_cluster_count=len(clusters), - identified_speaker_cluster_count=len( - {segment.provider_cluster_id for segment in segments if segment.person_id} - ), + identified_speaker_cluster_count=_identified_cluster_count(segments), identity_confidence_summary=summarize_identity_confidences(confidences), artifact_refs=_provider_artifact_refs(result), fallback_provider=fallback_provider, @@ -593,6 +589,7 @@ def _finalize_failed_run( ) -> None: if not run_id: return + billable_seconds = raw_audio_seconds try: finalize_provider_run( run_id=run_id, @@ -602,6 +599,14 @@ def _finalize_failed_run( status='failed', started_at=started_at, raw_audio_seconds=raw_audio_seconds, + speech_active_seconds=raw_audio_seconds, + billable_seconds=billable_seconds, + estimated_cost_usd=estimate_prerecorded_provider_cost_usd( + provider=provider, + model=model, + workload=workload, + billable_seconds=billable_seconds, + ), retry_count=retry_count, fallback_count=0, error_class=error.__class__.__name__, @@ -610,3 +615,44 @@ def _finalize_failed_run( logger.warning( 'failed to finalize failed transcription provider run ledger run_id=%s: %s', run_id, finalize_error ) + + +def update_provider_run_identity_metrics( + run_id: Optional[str], + provider: str, + model: str, + workload: STTWorkload, + segments: List[TranscriptSegment], +) -> None: + if not run_id: + return + try: + from database.transcription_provider_usage import update_provider_run_identity_metrics as _update_identity + + _update_identity( + run_id=run_id, + provider=provider, + model=model or 'unknown', + workload=STTWorkload(workload).value, + identified_speaker_cluster_count=_identified_cluster_count(segments), + identity_confidence_summary=summarize_identity_confidences( + [segment.speaker_identity_confidence for segment in segments] + ), + ) + except Exception as e: + logger.warning('failed to update transcription provider identity metrics run_id=%s: %s', run_id, e) + + +def _identified_cluster_count(segments: List[TranscriptSegment]) -> int: + return len( + { + _segment_cluster_key(segment) + for segment in segments + if _segment_cluster_key(segment) + and (segment.person_id or segment.is_user or segment.speaker_identity_state in ('identified', 'user')) + } + ) + + +def _segment_cluster_key(segment: TranscriptSegment) -> Optional[str]: + return segment.provider_cluster_id or segment.provider_speaker_label From 105502b8f44ce4f5debde261d4523c652bcdbf58 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 10:30:59 +0700 Subject: [PATCH 17/44] Hoist provider service imports --- backend/utils/stt/deepgram_config.py | 110 +++++++++++++++++++++++++ backend/utils/stt/pre_recorded.py | 113 +------------------------- backend/utils/stt/provider_service.py | 64 +++++++++++---- 3 files changed, 159 insertions(+), 128 deletions(-) create mode 100644 backend/utils/stt/deepgram_config.py diff --git a/backend/utils/stt/deepgram_config.py b/backend/utils/stt/deepgram_config.py new file mode 100644 index 00000000000..e053352866e --- /dev/null +++ b/backend/utils/stt/deepgram_config.py @@ -0,0 +1,110 @@ +from typing import Tuple + + +# Languages supported by nova-3. +DEEPGRAM_NOVA3_LANGUAGES = { + "ar", + "ar-AE", + "ar-SA", + "ar-QA", + "ar-KW", + "ar-SY", + "ar-LB", + "ar-PS", + "ar-JO", + "ar-EG", + "ar-SD", + "ar-TD", + "ar-MA", + "ar-DZ", + "ar-TN", + "ar-IQ", + "ar-IR", + "be", + "bg", + "bn", + "bs", + "ca", + "cs", + "da", + "da-DK", + "de", + "de-CH", + "el", + "en", + "en-US", + "en-AU", + "en-GB", + "en-IN", + "en-NZ", + "es", + "es-419", + "et", + "fa", + "fi", + "fr", + "fr-CA", + "he", + "hi", + "hr", + "hu", + "id", + "it", + "ja", + "kn", + "ko", + "ko-KR", + "lt", + "lv", + "mk", + "mr", + "ms", + "nl", + "nl-BE", + "no", + "pl", + "pt", + "pt-BR", + "pt-PT", + "ro", + "ru", + "sk", + "sl", + "sr", + "sv", + "sv-SE", + "ta", + "te", + "th", + "th-TH", + "tl", + "tr", + "uk", + "ur", + "vi", + "zh", + "zh-CN", + "zh-Hans", + "zh-HK", + "zh-Hant", + "zh-TW", +} + + +def get_deepgram_model_for_language(language: str) -> Tuple[str, str]: + """ + Determine the appropriate Deepgram model and language for pre-recorded transcription. + + Args: + language: The requested language code or 'multi' for auto-detection. + + Returns: + Tuple of (language_to_use, model_name). + """ + if language == 'multi': + return 'multi', 'nova-3' + + if language in DEEPGRAM_NOVA3_LANGUAGES: + return language, 'nova-3' + + return 'multi', 'nova-3' diff --git a/backend/utils/stt/pre_recorded.py b/backend/utils/stt/pre_recorded.py index 1555a9b0358..54bd69e797e 100644 --- a/backend/utils/stt/pre_recorded.py +++ b/backend/utils/stt/pre_recorded.py @@ -14,6 +14,7 @@ deepgram_speaker_fields, provider_result_to_legacy_words, ) +from utils.stt.deepgram_config import get_deepgram_model_for_language import logging _DG_TIMEOUT = httpx.Timeout(connect=10.0, read=120.0, write=30.0, pool=10.0) @@ -42,118 +43,6 @@ def _deepgram_prerecorded_provider() -> DeepgramPrerecordedTranscriptionProvider return DeepgramPrerecordedTranscriptionProvider(_deepgram_client_for_request, _DG_TIMEOUT) -# Languages supported by nova-3 -_deepgram_nova3_languages = { - "ar", - "ar-AE", - "ar-SA", - "ar-QA", - "ar-KW", - "ar-SY", - "ar-LB", - "ar-PS", - "ar-JO", - "ar-EG", - "ar-SD", - "ar-TD", - "ar-MA", - "ar-DZ", - "ar-TN", - "ar-IQ", - "ar-IR", - "be", - "bg", - "bn", - "bs", - "ca", - "cs", - "da", - "da-DK", - "de", - "de-CH", - "el", - "en", - "en-US", - "en-AU", - "en-GB", - "en-IN", - "en-NZ", - "es", - "es-419", - "et", - "fa", - "fi", - "fr", - "fr-CA", - "he", - "hi", - "hr", - "hu", - "id", - "it", - "ja", - "kn", - "ko", - "ko-KR", - "lt", - "lv", - "mk", - "mr", - "ms", - "nl", - "nl-BE", - "no", - "pl", - "pt", - "pt-BR", - "pt-PT", - "ro", - "ru", - "sk", - "sl", - "sr", - "sv", - "sv-SE", - "ta", - "te", - "th", - "th-TH", - "tl", - "tr", - "uk", - "ur", - "vi", - "zh", - "zh-CN", - "zh-Hans", - "zh-HK", - "zh-Hant", - "zh-TW", -} - - -def get_deepgram_model_for_language(language: str) -> Tuple[str, str]: - """ - Determine the appropriate Deepgram model and language for pre-recorded transcription. - - Args: - language: The requested language code or 'multi' for auto-detection - - Returns: - Tuple of (language_to_use, model_name) - """ - # For multi-language mode - if language == 'multi': - return 'multi', 'nova-3' - - # Languages supported by nova-3 - if language in _deepgram_nova3_languages: - return language, 'nova-3' - - # Unsupported language - fall back to multi for auto-detection - return 'multi', 'nova-3' - - @timeit def deepgram_prerecorded( audio_url: str, diff --git a/backend/utils/stt/provider_service.py b/backend/utils/stt/provider_service.py index fcbcb9fef76..557fbbeaf3b 100644 --- a/backend/utils/stt/provider_service.py +++ b/backend/utils/stt/provider_service.py @@ -4,9 +4,13 @@ from datetime import datetime, timezone from typing import List, Optional, Sequence, Tuple +import httpx +from deepgram import DeepgramClient, DeepgramClientOptions + from models.transcript_segment import ProviderTranscriptResult, TranscriptSegment from utils.stt.assemblyai_adapter import AssemblyAIAsyncTranscriptionProvider from utils.stt.conversation_reconstructor import reconstruct_conversation +from utils.stt.deepgram_adapter import DeepgramPrerecordedTranscriptionProvider from utils.stt.deepgram_adapter import provider_result_to_legacy_words from utils.stt.provider_costs import estimate_prerecorded_provider_cost_usd from utils.stt.providers import ( @@ -15,20 +19,44 @@ get_fallback_prerecorded_provider_name, get_prerecorded_provider_name, ) +from utils.stt.deepgram_config import get_deepgram_model_for_language + +try: + from utils.byok import get_byok_key +except ImportError: + get_byok_key = None + +try: + from database.transcription_provider_usage import ( + create_provider_run as _db_create_provider_run, + finalize_provider_run as _db_finalize_provider_run, + update_provider_run_identity_metrics as _db_update_provider_run_identity_metrics, + ) + + _PROVIDER_USAGE_IMPORT_ERROR = None +except ImportError as e: + _db_create_provider_run = None + _db_finalize_provider_run = None + _db_update_provider_run_identity_metrics = None + _PROVIDER_USAGE_IMPORT_ERROR = e logger = logging.getLogger(__name__) +_DG_TIMEOUT = httpx.Timeout(connect=10.0, read=120.0, write=30.0, pool=10.0) +_DEEPGRAM_OPTIONS = DeepgramClientOptions(options={"keepalive": "true"}) +_DEEPGRAM_CLIENT = DeepgramClient(os.getenv('DEEPGRAM_API_KEY'), _DEEPGRAM_OPTIONS) -def create_provider_run(**kwargs) -> str: - from database.transcription_provider_usage import create_provider_run as _create_provider_run - return _create_provider_run(**kwargs) +def create_provider_run(**kwargs) -> str: + if _db_create_provider_run is None: + raise _PROVIDER_USAGE_IMPORT_ERROR + return _db_create_provider_run(**kwargs) def finalize_provider_run(**kwargs) -> None: - from database.transcription_provider_usage import finalize_provider_run as _finalize_provider_run - - _finalize_provider_run(**kwargs) + if _db_finalize_provider_run is None: + raise _PROVIDER_USAGE_IMPORT_ERROR + _db_finalize_provider_run(**kwargs) def summarize_identity_confidences(confidences): @@ -52,19 +80,18 @@ def _identity_confidence_bucket(confidence: Optional[float]) -> str: def _deepgram_prerecorded_provider(): - from utils.stt.pre_recorded import _deepgram_prerecorded_provider as _provider - - return _provider() + return DeepgramPrerecordedTranscriptionProvider(_deepgram_client_for_request, _DG_TIMEOUT) def _assemblyai_prerecorded_provider(): return AssemblyAIAsyncTranscriptionProvider() -def get_deepgram_model_for_language(language: str) -> Tuple[str, str]: - from utils.stt.pre_recorded import get_deepgram_model_for_language as _get_deepgram_model_for_language - - return _get_deepgram_model_for_language(language) +def _deepgram_client_for_request() -> DeepgramClient: + byok = get_byok_key('deepgram') if get_byok_key else None + if byok: + return DeepgramClient(byok, _DEEPGRAM_OPTIONS) + return _DEEPGRAM_CLIENT @dataclass @@ -626,10 +653,15 @@ def update_provider_run_identity_metrics( ) -> None: if not run_id: return + if _db_update_provider_run_identity_metrics is None: + logger.warning( + 'failed to update transcription provider identity metrics run_id=%s: %s', + run_id, + _PROVIDER_USAGE_IMPORT_ERROR, + ) + return try: - from database.transcription_provider_usage import update_provider_run_identity_metrics as _update_identity - - _update_identity( + _db_update_provider_run_identity_metrics( run_id=run_id, provider=provider, model=model or 'unknown', From 03b9f17dee3ec4b87934f6ffcfbcff1375852895 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 10:35:43 +0700 Subject: [PATCH 18/44] Apply CI Dart formatting --- app/lib/pages/conversation_detail/page.dart | 26 +++++------- app/lib/widgets/transcript.dart | 47 +++++++++------------ 2 files changed, 32 insertions(+), 41 deletions(-) diff --git a/app/lib/pages/conversation_detail/page.dart b/app/lib/pages/conversation_detail/page.dart index 02a20f9dd35..d9ba7c2d238 100644 --- a/app/lib/pages/conversation_detail/page.dart +++ b/app/lib/pages/conversation_detail/page.dart @@ -328,8 +328,8 @@ class _ConversationDetailPageState extends State with Ti final conversation = provider.conversation; final summaryContent = conversation.appResults.isNotEmpty && conversation.appResults[0].content.trim().isNotEmpty - ? conversation.appResults[0].content.trim() - : conversation.structured.toString(); + ? conversation.appResults[0].content.trim() + : conversation.structured.toString(); _copyContent(context, summaryContent); break; case 'download_audio': @@ -669,8 +669,8 @@ class _ConversationDetailPageState extends State with Ti provider.conversation.starred = newStarredState; // Update in conversation provider context.read().updateConversationInSortedList( - provider.conversation, - ); + provider.conversation, + ); // Track star/unstar action PlatformManager.instance.analytics.conversationStarToggled( conversation: provider.conversation, @@ -989,15 +989,13 @@ class _ConversationDetailPageState extends State with Ti child: Consumer( builder: (context, provider, child) { final conversation = provider.conversation; - final hasActionItems = conversation.structured.actionItems - .where((item) => !item.deleted) - .isNotEmpty; + final hasActionItems = + conversation.structured.actionItems.where((item) => !item.deleted).isNotEmpty; return ConversationBottomBar( mode: ConversationBottomBarMode.detail, selectedTab: selectedTab, conversation: conversation, - hasSegments: - conversation.transcriptSegments.isNotEmpty || + hasSegments: conversation.transcriptSegments.isNotEmpty || conversation.photos.isNotEmpty || conversation.externalIntegration != null, hasActionItems: hasActionItems, @@ -1438,9 +1436,8 @@ class _TranscriptWidgetsState extends State with AutomaticKee } final segments = provider.conversation.transcriptSegments; final segment = segments[segmentIndex]; - final person = segment.personId != null - ? SharedPreferencesUtil().getPersonById(segment.personId!) - : null; + final person = + segment.personId != null ? SharedPreferencesUtil().getPersonById(segment.personId!) : null; final speakerName = person?.name ?? context.l10n.speakerWithId(TranscriptSegment.getDisplaySpeakerIdForSegment(segment, segments)); @@ -1501,9 +1498,8 @@ class _TranscriptWidgetsState extends State with AutomaticKee ); if (segmentIndex == -1) continue; provider.conversation.transcriptSegments[segmentIndex].isUser = finalPersonId == 'user'; - provider.conversation.transcriptSegments[segmentIndex].personId = finalPersonId == 'user' - ? null - : finalPersonId; + provider.conversation.transcriptSegments[segmentIndex].personId = + finalPersonId == 'user' ? null : finalPersonId; } await assignBulkConversationTranscriptSegments( provider.conversation.id, diff --git a/app/lib/widgets/transcript.dart b/app/lib/widgets/transcript.dart index 98ead3929ae..120aad8cab4 100644 --- a/app/lib/widgets/transcript.dart +++ b/app/lib/widgets/transcript.dart @@ -111,9 +111,8 @@ class _TranscriptWidgetState extends State { return Image.asset(Assets.images.speaker0Icon.path, width: 24, height: 24); } // Always modulo by speakerImagePath.length to prevent index out of bounds - final imageIndex = person != null - ? person.colorIdx! % speakerImagePath.length - : speakerId % speakerImagePath.length; + final imageIndex = + person != null ? person.colorIdx! % speakerImagePath.length : speakerId % speakerImagePath.length; return Image.asset(speakerImagePath[imageIndex], width: 24, height: 24); } @@ -322,13 +321,13 @@ class _TranscriptWidgetState extends State { _isAutoScrolling = true; _scrollController .animateTo( - targetOffset.clamp(0.0, _scrollController.position.maxScrollExtent), - duration: const Duration(milliseconds: 400), - curve: Curves.easeInOutCubic, - ) + targetOffset.clamp(0.0, _scrollController.position.maxScrollExtent), + duration: const Duration(milliseconds: 400), + curve: Curves.easeInOutCubic, + ) .then((_) { - _isAutoScrolling = false; - }); + _isAutoScrolling = false; + }); } } @@ -508,9 +507,9 @@ class _TranscriptWidgetState extends State { data.speakerId == omiSpeakerId ? 'omi' : (person?.name ?? - context.l10n.speakerWithId( - TranscriptSegment.getDisplaySpeakerIdForSegment(data, widget.segments), - )), + context.l10n.speakerWithId( + TranscriptSegment.getDisplaySpeakerIdForSegment(data, widget.segments), + )), style: TextStyle( color: data.speakerId == omiSpeakerId || person != null ? Colors.grey.shade300 @@ -551,8 +550,8 @@ class _TranscriptWidgetState extends State { isUser ? 18 : (segmentIdx > 0 && !widget.segments[segmentIdx - 1].isUser) - ? 6 - : 18, + ? 6 + : 18, ), topRight: Radius.circular(isUser ? 18 : 18), bottomLeft: Radius.circular(18), @@ -615,9 +614,8 @@ class _TranscriptWidgetState extends State { Text( SttProviderConfig.getDisplayName(data.sttProvider), style: TextStyle( - color: isUser - ? Colors.white.withValues(alpha: 0.5) - : Colors.grey.shade500, + color: + isUser ? Colors.white.withValues(alpha: 0.5) : Colors.grey.shade500, fontSize: 10, fontStyle: FontStyle.italic, ), @@ -626,9 +624,8 @@ class _TranscriptWidgetState extends State { Text( ' · ', style: TextStyle( - color: isUser - ? Colors.white.withValues(alpha: 0.5) - : Colors.grey.shade500, + color: + isUser ? Colors.white.withValues(alpha: 0.5) : Colors.grey.shade500, fontSize: 10, ), ), @@ -644,9 +641,8 @@ class _TranscriptWidgetState extends State { child: Icon( Icons.play_circle_outline, size: 16, - color: isUser - ? Colors.white.withValues(alpha: 0.7) - : Colors.grey.shade400, + color: + isUser ? Colors.white.withValues(alpha: 0.7) : Colors.grey.shade400, ), ), const SizedBox(width: 6), @@ -655,9 +651,8 @@ class _TranscriptWidgetState extends State { Text( data.getTimestampString(), style: TextStyle( - color: isUser - ? Colors.white.withValues(alpha: 0.7) - : Colors.grey.shade400, + color: + isUser ? Colors.white.withValues(alpha: 0.7) : Colors.grey.shade400, fontSize: 11, ), ), From 426862831becfa894a579d08fab16b26a54a0078 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 10:42:47 +0700 Subject: [PATCH 19/44] Match CI Dart formatter --- app/lib/pages/conversation_detail/page.dart | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/app/lib/pages/conversation_detail/page.dart b/app/lib/pages/conversation_detail/page.dart index d9ba7c2d238..ce87411da15 100644 --- a/app/lib/pages/conversation_detail/page.dart +++ b/app/lib/pages/conversation_detail/page.dart @@ -1438,8 +1438,7 @@ class _TranscriptWidgetsState extends State with AutomaticKee final segment = segments[segmentIndex]; final person = segment.personId != null ? SharedPreferencesUtil().getPersonById(segment.personId!) : null; - final speakerName = - person?.name ?? + final speakerName = person?.name ?? context.l10n.speakerWithId(TranscriptSegment.getDisplaySpeakerIdForSegment(segment, segments)); PlatformManager.instance.analytics.editSegmentTextStarted(); bool saved = false; From 225628b485b3f70d7036949381a033d6c78e7865 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 10:46:51 +0700 Subject: [PATCH 20/44] Match CI Python formatter --- backend/tests/unit/test_provider_evaluation.py | 1 - backend/tests/unit/test_stt_provider_facade.py | 1 - backend/utils/stt/deepgram_config.py | 1 - 3 files changed, 3 deletions(-) diff --git a/backend/tests/unit/test_provider_evaluation.py b/backend/tests/unit/test_provider_evaluation.py index 9e7dd350e04..e57b8f8530d 100644 --- a/backend/tests/unit/test_provider_evaluation.py +++ b/backend/tests/unit/test_provider_evaluation.py @@ -9,7 +9,6 @@ summarize_provider_output, ) - FIXTURE_DIR = Path(__file__).resolve().parents[1] / 'fixtures' / 'stt_provider_eval' diff --git a/backend/tests/unit/test_stt_provider_facade.py b/backend/tests/unit/test_stt_provider_facade.py index 87eb60abd98..f4d6d691bf6 100644 --- a/backend/tests/unit/test_stt_provider_facade.py +++ b/backend/tests/unit/test_stt_provider_facade.py @@ -1,7 +1,6 @@ import sys from unittest.mock import MagicMock - for mod_name in ['deepgram', 'deepgram.clients', 'deepgram.clients.live', 'deepgram.clients.live.v1']: if mod_name not in sys.modules: sys.modules[mod_name] = MagicMock() diff --git a/backend/utils/stt/deepgram_config.py b/backend/utils/stt/deepgram_config.py index e053352866e..97243a3e1f7 100644 --- a/backend/utils/stt/deepgram_config.py +++ b/backend/utils/stt/deepgram_config.py @@ -1,6 +1,5 @@ from typing import Tuple - # Languages supported by nova-3. DEEPGRAM_NOVA3_LANGUAGES = { "ar", From ba1a5abdd11903b4cec292d60141f0c211f2c16d Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 12:17:41 +0700 Subject: [PATCH 21/44] Update AssemblyAI transcript API usage --- backend/test.sh | 5 +++++ backend/tests/unit/test_assemblyai_adapter.py | 5 +++-- .../tests/unit/test_background_provider_service.py | 2 ++ backend/utils/stt/assemblyai_adapter.py | 12 +++++++++--- 4 files changed, 19 insertions(+), 5 deletions(-) diff --git a/backend/test.sh b/backend/test.sh index afcc8870be2..90e7cc551ca 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -14,6 +14,11 @@ pytest tests/unit/test_speaker_sample_migration.py -v pytest tests/unit/test_short_audio_embedding.py -v pytest tests/unit/test_users_add_sample_transaction.py -v pytest tests/unit/test_voice_message_language.py -v +pytest tests/unit/test_assemblyai_adapter.py -v +pytest tests/unit/test_background_provider_service.py -v +pytest tests/unit/test_conversation_reconstructor.py -v +pytest tests/unit/test_provider_evaluation.py -v +pytest tests/unit/test_transcription_provider_usage.py -v pytest tests/unit/test_speaker_assignment.py -v pytest tests/unit/test_speaker_id_pipeline.py -v pytest tests/unit/test_user_speaker_embedding.py -v diff --git a/backend/tests/unit/test_assemblyai_adapter.py b/backend/tests/unit/test_assemblyai_adapter.py index d8140ca7af1..4bf18b396cc 100644 --- a/backend/tests/unit/test_assemblyai_adapter.py +++ b/backend/tests/unit/test_assemblyai_adapter.py @@ -48,8 +48,8 @@ def _completed_transcript(): 'id': 'aai-transcript-1', 'status': 'completed', 'language_code': 'en_us', - 'speech_model': 'universal-2', - 'audio_duration': 2500, + 'speech_model_used': 'universal-2', + 'audio_duration': 2.5, 'utterances': [ { 'speaker': 'A', @@ -111,6 +111,7 @@ def test_assemblyai_transcribe_url_submits_diarization_and_polls_to_completion() submit_payload = fake_client.requests[0][2]['json'] assert submit_payload['audio_url'] == 'https://example.test/audio.wav' assert submit_payload['speaker_labels'] is True + assert submit_payload['speech_models'] == ['universal-2'] assert submit_payload['speakers_expected'] == 2 assert submit_payload['language_detection'] is True assert submit_payload['keyterms_prompt'] == ['Omi'] diff --git a/backend/tests/unit/test_background_provider_service.py b/backend/tests/unit/test_background_provider_service.py index 2a0b1e556ec..f69dd7640b4 100644 --- a/backend/tests/unit/test_background_provider_service.py +++ b/backend/tests/unit/test_background_provider_service.py @@ -1,5 +1,6 @@ import os import sys +import types from pathlib import Path from unittest.mock import MagicMock, patch @@ -11,6 +12,7 @@ sys.modules['deepgram'].DeepgramClient = MagicMock sys.modules['deepgram'].DeepgramClientOptions = MagicMock +sys.modules.setdefault('database._client', types.SimpleNamespace(db=MagicMock())) os.environ.setdefault('DEEPGRAM_API_KEY', 'fake-for-test') diff --git a/backend/utils/stt/assemblyai_adapter.py b/backend/utils/stt/assemblyai_adapter.py index fb6ae9e0d13..f97cda069ac 100644 --- a/backend/utils/stt/assemblyai_adapter.py +++ b/backend/utils/stt/assemblyai_adapter.py @@ -49,9 +49,9 @@ def normalize_assemblyai_transcript_result( return ProviderTranscriptResult( provider=STTProviderName.assemblyai.value, - model=result.get('speech_model') or model, + model=result.get('speech_model_used') or result.get('speech_model') or model, language=_normalize_language(result.get('language_code') or language), - duration=_milliseconds_to_seconds(result.get('audio_duration')), + duration=_seconds_float(result.get('audio_duration')), words=words, utterances=utterances, raw_provider_result_id=result.get('id'), @@ -90,6 +90,12 @@ def _milliseconds_to_seconds(value) -> float: return float(value) / 1000.0 +def _seconds_float(value) -> float: + if value is None: + return 0.0 + return float(value) + + def _normalize_language(language: Optional[str]) -> Optional[str]: if language and '_' in language: return language.split('_', 1)[0] @@ -200,7 +206,7 @@ def _transcript_payload( 'format_text': True, } if model: - payload['speech_model'] = model + payload['speech_models'] = [model] if isinstance(model, str) else list(model) if speakers_count: payload['speakers_expected'] = speakers_count if language and language != 'multi': From 165092d3d2fac41033a8fcfb0856c8d52d85773c Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 12:17:46 +0700 Subject: [PATCH 22/44] Stabilize backend regression tests --- .../unit/test_action_item_date_validation.py | 2 + .../tests/unit/test_async_app_integrations.py | 12 + .../unit/test_available_plans_resilience.py | 2 + .../tests/unit/test_batch_upload_storage.py | 2 - backend/tests/unit/test_chat_quota.py | 64 +++--- .../unit/test_daily_summary_race_condition.py | 4 + backend/tests/unit/test_desktop_migration.py | 5 +- backend/tests/unit/test_desktop_transcribe.py | 35 ++- backend/tests/unit/test_dg_usage_batch.py | 3 + .../tests/unit/test_fair_use_classifier.py | 15 +- backend/tests/unit/test_fair_use_upgrade.py | 8 +- .../unit/test_firestore_read_ops_cache.py | 12 +- backend/tests/unit/test_geocoding_cache.py | 37 +-- .../tests/unit/test_kg_user_type_mismatch.py | 9 +- backend/tests/unit/test_llm_usage_db.py | 3 +- .../tests/unit/test_llm_usage_endpoints.py | 3 + backend/tests/unit/test_llm_usage_tracker.py | 3 +- backend/tests/unit/test_lock_bypass_fixes.py | 8 +- .../tests/unit/test_mentor_notifications.py | 3 + ...test_process_conversation_usage_context.py | 5 + .../unit/test_prompt_cache_integration.py | 4 +- backend/tests/unit/test_rate_limiting.py | 6 +- ...st_realtime_integrations_usage_tracking.py | 12 + .../unit/test_speaker_sample_migration.py | 3 +- .../tests/unit/test_storage_opus_encoding.py | 2 - ...rage_upload_audio_chunk_data_protection.py | 18 +- backend/tests/unit/test_subscription_plans.py | 2 + .../unit/test_subscription_restructure.py | 1 + backend/tests/unit/test_sync_fair_use_gate.py | 9 +- backend/tests/unit/test_sync_opus_decode.py | 11 +- backend/tests/unit/test_sync_record_usage.py | 18 +- .../tests/unit/test_sync_silent_failure.py | 182 +++++++++------ .../unit/test_sync_transcription_prefs.py | 213 +++++++++++------- backend/tests/unit/test_sync_v2.py | 16 +- backend/tests/unit/test_task_sharing.py | 1 + .../unit/test_thread_join_elimination.py | 10 +- .../unit/test_users_add_sample_transaction.py | 3 +- .../tests/unit/test_voice_message_language.py | 3 +- backend/tests/unit/test_ws_auth_handshake.py | 17 ++ backend/utils/subscription.py | 8 + 40 files changed, 516 insertions(+), 258 deletions(-) diff --git a/backend/tests/unit/test_action_item_date_validation.py b/backend/tests/unit/test_action_item_date_validation.py index 5a36a8455c2..4bce2d459f0 100644 --- a/backend/tests/unit/test_action_item_date_validation.py +++ b/backend/tests/unit/test_action_item_date_validation.py @@ -82,6 +82,8 @@ def _load_module_from_file(module_name, file_path): if mod_name not in sys.modules: _stub_module(mod_name) +sys.modules["database.auth"].get_user_name = MagicMock(return_value="Test User") + # Stub database.action_items action_items_db = _stub_module("database.action_items") action_items_db.create_action_item = MagicMock(return_value="test-item-id") diff --git a/backend/tests/unit/test_async_app_integrations.py b/backend/tests/unit/test_async_app_integrations.py index 796b029f200..f4d805b890a 100644 --- a/backend/tests/unit/test_async_app_integrations.py +++ b/backend/tests/unit/test_async_app_integrations.py @@ -36,6 +36,8 @@ "calendar_meetings", "vector_db", "apps", + "announcements", + "user_usage", "llm_usage", "chat", "goals", @@ -55,6 +57,8 @@ sys.modules["database.redis_db"].get_daily_notification_count = MagicMock(return_value=0) sys.modules["database.vector_db"].query_vectors_by_metadata = MagicMock(return_value=[]) sys.modules["database.apps"].record_app_usage = MagicMock() +sys.modules["database.announcements"].compare_versions = MagicMock(return_value=0) +sys.modules["database.user_usage"].get_monthly_chat_usage = MagicMock() sys.modules["database.llm_usage"].record_llm_usage = MagicMock() sys.modules["database.chat"].add_app_message = MagicMock(return_value={"id": "msg-1"}) sys.modules["database.chat"].get_app_messages = MagicMock(return_value=[]) @@ -139,6 +143,14 @@ def _noop_track(uid, feature): sys.modules["utils.executors"] = types.ModuleType("utils.executors") sys.modules["utils.executors"].critical_executor = _TPE(max_workers=2, thread_name_prefix="test-critical") sys.modules["utils.executors"].storage_executor = _TPE(max_workers=2, thread_name_prefix="test-storage") +sys.modules["utils.executors"].db_executor = _TPE(max_workers=2, thread_name_prefix="test-db") + + +async def _run_blocking(_executor, fn, *args, **kwargs): + return fn(*args, **kwargs) + + +sys.modules["utils.executors"].run_blocking = _run_blocking import importlib diff --git a/backend/tests/unit/test_available_plans_resilience.py b/backend/tests/unit/test_available_plans_resilience.py index ef19495fdfc..49ff8b629c5 100644 --- a/backend/tests/unit/test_available_plans_resilience.py +++ b/backend/tests/unit/test_available_plans_resilience.py @@ -73,6 +73,7 @@ def _compare_versions(a, b): _announcements_mod._compare_versions = _compare_versions +_announcements_mod.compare_versions = _compare_versions # database.users needs the functions payment.py imports by name _users_mod = sys.modules["database.users"] @@ -126,6 +127,7 @@ def _compare_versions(a, b): _endpoints_mod = sys.modules["utils.other.endpoints"] _endpoints_mod.get_current_user_uid = lambda: "test-user" +_endpoints_mod.get_current_user_uid_no_byok_validation = lambda: "test-user" # Ensure utils.other has endpoints attr for `from utils.other import endpoints` sys.modules["utils.other"].endpoints = _endpoints_mod diff --git a/backend/tests/unit/test_batch_upload_storage.py b/backend/tests/unit/test_batch_upload_storage.py index 08045593a20..28ae6e96a50 100644 --- a/backend/tests/unit/test_batch_upload_storage.py +++ b/backend/tests/unit/test_batch_upload_storage.py @@ -24,8 +24,6 @@ sys.modules.setdefault("google.cloud.storage", _mock_gcs_storage) sys.modules.setdefault("google.cloud.storage.transfer_manager", MagicMock()) sys.modules.setdefault("google.cloud.exceptions", MagicMock()) -sys.modules.setdefault("google.oauth2", MagicMock()) -sys.modules.setdefault("google.oauth2.service_account", MagicMock()) from utils.other import storage as storage_mod diff --git a/backend/tests/unit/test_chat_quota.py b/backend/tests/unit/test_chat_quota.py index ef779bb9aef..488f1785aa2 100644 --- a/backend/tests/unit/test_chat_quota.py +++ b/backend/tests/unit/test_chat_quota.py @@ -23,9 +23,12 @@ def _compare_versions(a, b): _announcements_mod._compare_versions = _compare_versions +_announcements_mod.compare_versions = _compare_versions # Create stubs for database modules used by get_chat_quota_snapshot -_db_users_mod = types.SimpleNamespace(get_user_valid_subscription=MagicMock()) +_db_users_mod = types.SimpleNamespace( + get_user_valid_subscription=MagicMock(), is_byok_active=MagicMock(return_value=False) +) _db_user_usage_mod = types.SimpleNamespace(get_monthly_chat_usage=MagicMock()) sys.modules.setdefault("database._client", types.SimpleNamespace(db=MagicMock())) @@ -319,10 +322,8 @@ def test_enforcement_allowed(self, monkeypatch): ): sub_mod.enforce_chat_quota("uid123") # no exception - def test_enforcement_exceeded_raises_402(self, monkeypatch): - """When user exceeds quota, raises HTTPException 402.""" - from fastapi import HTTPException - + def test_enforcement_paid_plan_exceeded_allows_overage(self, monkeypatch): + """When a paid plan exceeds quota, no exception is raised.""" sub_mod = _reload_subscription_module() with patch.object( @@ -336,23 +337,41 @@ def test_enforcement_exceeded_raises_402(self, monkeypatch): 'limit': 2000, 'reset_at': _RESET_AT, }, + ): + sub_mod.enforce_chat_quota("uid123") # paid plans enter overage mode + + def test_enforcement_free_plan_exceeded_raises_402(self, monkeypatch): + """When a free user exceeds quota, raises HTTPException 402.""" + from fastapi import HTTPException + + sub_mod = _reload_subscription_module() + + with patch.object( + sub_mod, + "get_chat_quota_snapshot", + return_value={ + 'allowed': False, + 'plan': PlanType.basic, + 'unit': 'questions', + 'used': 31, + 'limit': 30, + 'reset_at': _RESET_AT, + }, ): with pytest.raises(HTTPException) as exc_info: sub_mod.enforce_chat_quota("uid123") assert exc_info.value.status_code == 402 assert exc_info.value.detail['error'] == 'quota_exceeded' - assert exc_info.value.detail['plan'] == 'Neo' - assert exc_info.value.detail['plan_type'] == 'unlimited' + assert exc_info.value.detail['plan'] == 'Free' + assert exc_info.value.detail['plan_type'] == 'basic' assert exc_info.value.detail['unit'] == 'questions' - assert exc_info.value.detail['used'] == 2001 - assert exc_info.value.detail['limit'] == 2000 + assert exc_info.value.detail['used'] == 31 + assert exc_info.value.detail['limit'] == 30 assert exc_info.value.detail['reset_at'] == _RESET_AT - def test_enforcement_402_operator_plan(self, monkeypatch): - """Operator plan shows correct display name in 402 detail.""" - from fastapi import HTTPException - + def test_enforcement_operator_plan_allows_overage(self, monkeypatch): + """Operator plan allows overage after included quota.""" sub_mod = _reload_subscription_module() with patch.object( @@ -367,16 +386,10 @@ def test_enforcement_402_operator_plan(self, monkeypatch): 'reset_at': _RESET_AT, }, ): - with pytest.raises(HTTPException) as exc_info: - sub_mod.enforce_chat_quota("uid123") - - assert exc_info.value.status_code == 402 - assert exc_info.value.detail['plan'] == 'Operator' - - def test_enforcement_402_architect_cost_based(self, monkeypatch): - """Architect plan shows cost_usd unit in 402 detail.""" - from fastapi import HTTPException + sub_mod.enforce_chat_quota("uid123") + def test_enforcement_architect_cost_based_allows_overage(self, monkeypatch): + """Architect plan allows cost-based overage after included quota.""" sub_mod = _reload_subscription_module() with patch.object( @@ -391,9 +404,4 @@ def test_enforcement_402_architect_cost_based(self, monkeypatch): 'reset_at': _RESET_AT, }, ): - with pytest.raises(HTTPException) as exc_info: - sub_mod.enforce_chat_quota("uid123") - - assert exc_info.value.status_code == 402 - assert exc_info.value.detail['unit'] == 'cost_usd' - assert exc_info.value.detail['used'] == 400.5 + sub_mod.enforce_chat_quota("uid123") diff --git a/backend/tests/unit/test_daily_summary_race_condition.py b/backend/tests/unit/test_daily_summary_race_condition.py index 3289eb8e45b..65fcdbae4d3 100644 --- a/backend/tests/unit/test_daily_summary_race_condition.py +++ b/backend/tests/unit/test_daily_summary_race_condition.py @@ -76,6 +76,7 @@ def try_acquire_daily_summary_lock(uid: str, date: str, ttl: int = 60 * 60 * 2) "utils.webhooks", "utils.conversations", "utils.conversations.factory", + "utils.subscription", ]: if name not in sys.modules: mod = _stub_module(name) @@ -100,6 +101,9 @@ def try_acquire_daily_summary_lock(uid: str, date: str, ttl: int = 60 * 60 * 2) utils_webhooks = sys.modules["utils.webhooks"] utils_webhooks.day_summary_webhook = MagicMock() +utils_subscription = sys.modules["utils.subscription"] +utils_subscription.is_trial_paywalled = MagicMock(return_value=False) + # Stub models for name in ["models.notification_message", "models.conversation"]: if name not in sys.modules: diff --git a/backend/tests/unit/test_desktop_migration.py b/backend/tests/unit/test_desktop_migration.py index 289d88811e9..9ef746913d3 100644 --- a/backend/tests/unit/test_desktop_migration.py +++ b/backend/tests/unit/test_desktop_migration.py @@ -82,6 +82,7 @@ def _stub_package(name): field_filter_stub.FieldFilter = MagicMock() sys.modules["google.cloud.firestore_v1"].FieldFilter = field_filter_stub.FieldFilter sys.modules["google.cloud.firestore_v1"].transactional = lambda f: f +sys.modules["database.redis_db"].try_acquire_user_platform_write_lock = MagicMock(return_value=True) # Add backend dir to sys.path sys.path.insert(0, str(BACKEND_DIR)) @@ -1694,7 +1695,7 @@ def test_returns_title(self, mock_update): session_id='s1', messages=[TitleMessageInput(text='hi', sender='human'), TitleMessageInput(text='hello', sender='ai')], ) - with patch.dict('sys.modules', {'utils.llm.clients': MagicMock(llm_mini=mock_llm)}): + with patch.dict('sys.modules', {'utils.llm.clients': MagicMock(get_llm=MagicMock(return_value=mock_llm))}): result = generate_session_title(request, uid='u1') assert result == {'title': 'Project Discussion'} @@ -1713,7 +1714,7 @@ def test_empty_response_defaults_to_new_chat(self, mock_update): session_id='s1', messages=[TitleMessageInput(text='hi', sender='human')], ) - with patch.dict('sys.modules', {'utils.llm.clients': MagicMock(llm_mini=mock_llm)}): + with patch.dict('sys.modules', {'utils.llm.clients': MagicMock(get_llm=MagicMock(return_value=mock_llm))}): result = generate_session_title(request, uid='u1') assert result == {'title': 'New Chat'} diff --git a/backend/tests/unit/test_desktop_transcribe.py b/backend/tests/unit/test_desktop_transcribe.py index e7a3129dfa4..bfc0d5e764d 100644 --- a/backend/tests/unit/test_desktop_transcribe.py +++ b/backend/tests/unit/test_desktop_transcribe.py @@ -114,6 +114,7 @@ 'utils.llm.goals', 'utils.llm.usage_tracker', 'utils.conversations', + 'utils.conversations.factory', 'utils.conversations.process_conversation', 'utils.notifications', 'utils.other.storage', @@ -131,9 +132,22 @@ ]: sys.modules.setdefault(_ufull, MagicMock()) +for _pkg_name in ['utils.conversations', 'utils.retrieval', 'utils.llm']: + if _pkg_name in sys.modules: + sys.modules[_pkg_name].__path__ = [_pkg_name.replace('.', '/')] + # Force-import real models.chat (has no project deps, needed for FastAPI response_model) import importlib.util as _ilu +_segment_spec = _ilu.spec_from_file_location( + 'models.transcript_segment', + os.path.join(os.path.dirname(__file__), '..', '..', 'models', 'transcript_segment.py'), +) +_real_segment = _ilu.module_from_spec(_segment_spec) +_segment_spec.loader.exec_module(_real_segment) +sys.modules['models.transcript_segment'] = _real_segment +setattr(_models_pkg, 'transcript_segment', _real_segment) + _chat_spec = _ilu.spec_from_file_location( 'models.chat', os.path.join(os.path.dirname(__file__), '..', '..', 'models', 'chat.py') ) @@ -144,6 +158,9 @@ # Now safe to import the modules under test from utils.stt.pre_recorded import deepgram_prerecorded_from_bytes +import utils.chat as _chat_mod + +setattr(sys.modules['utils'], 'chat', _chat_mod) # --------------------------------------------------------------------------- # deepgram_prerecorded_from_bytes: encoding/language/model options @@ -436,11 +453,11 @@ def test_retry_raises_after_max_attempts(self, mock_client): """After 3 failed attempts, should raise RuntimeError.""" mock_client.listen.rest.v.return_value.transcribe_file.side_effect = Exception('connection timeout') - with pytest.raises(RuntimeError, match='Deepgram transcription failed after 3 attempts'): + with pytest.raises(RuntimeError, match='Deepgram transcription failed after 2 attempts'): deepgram_prerecorded_from_bytes(b'\x00' * 100, encoding='linear16') - # Should have been called 3 times (attempts 0, 1, 2) - assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 3 + # Should have been called 2 times (attempts 0, 1) + assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 2 @patch('utils.stt.pre_recorded._deepgram_client') def test_return_language_empty_words_returns_detected_lang(self, mock_client): @@ -465,10 +482,10 @@ def test_no_channels_raises_and_retries(self, mock_client): mock_response.to_dict.return_value = {'results': {'channels': []}} mock_client.listen.rest.v.return_value.transcribe_file.return_value = mock_response - with pytest.raises(RuntimeError, match='Deepgram transcription failed after 3 attempts'): + with pytest.raises(RuntimeError, match='Deepgram transcription failed after 2 attempts'): deepgram_prerecorded_from_bytes(b'\x00' * 100) - assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 3 + assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 2 # --------------------------------------------------------------------------- @@ -500,6 +517,10 @@ def _stub_router_deps(): 'utils.social', 'utils.speaker_assignment', 'utils.speaker_identification', + 'utils.conversations.factory', + 'utils.stt.provider_service', + 'utils.stt.providers', + 'utils.stt.background_speaker_identity', 'utils.stt.speaker_embedding', 'utils.stt.vad', 'utils.stt.streaming', @@ -512,6 +533,10 @@ def _stub_router_deps(): rdb = sys.modules.get('database.redis_db') if rdb: rdb.check_rate_limit = MagicMock(return_value=(True, 99, 0)) + subscription = sys.modules.get('utils.subscription') + if subscription: + subscription.is_trial_paywalled = MagicMock(return_value=False) + subscription.enforce_chat_quota = MagicMock() def _make_chat_client(): diff --git a/backend/tests/unit/test_dg_usage_batch.py b/backend/tests/unit/test_dg_usage_batch.py index c32a8362728..0523ced1c71 100644 --- a/backend/tests/unit/test_dg_usage_batch.py +++ b/backend/tests/unit/test_dg_usage_batch.py @@ -75,6 +75,7 @@ def setup_method(self): 'database.users', 'database.user_usage', 'database.conversations', + 'utils.subscription', 'firebase_admin', 'firebase_admin.messaging', ]: @@ -83,6 +84,8 @@ def setup_method(self): sys.modules['database._client'].db = MagicMock() sys.modules['database.redis_db'].r = MagicMock() + sys.modules['utils.subscription'].has_transcription_credits = MagicMock(return_value=True) + sys.modules['utils.subscription'].is_paid_plan = MagicMock(return_value=True) os.environ.setdefault('FAIR_USE_ENABLED', 'true') os.environ.setdefault('ENCRYPTION_SECRET', 'test-secret-key-that-is-long-enough-for-encryption-32ch') diff --git a/backend/tests/unit/test_fair_use_classifier.py b/backend/tests/unit/test_fair_use_classifier.py index 88bf803cdbb..0c69ddddf2c 100644 --- a/backend/tests/unit/test_fair_use_classifier.py +++ b/backend/tests/unit/test_fair_use_classifier.py @@ -167,7 +167,8 @@ async def test_parses_llm_response_correctly(self): 'reasoning': 'Clear audiobook pattern', } ) - _llm_clients.llm_mini.ainvoke = AsyncMock(return_value=llm_response) + classifier_mod._classifier_llm = MagicMock() + classifier_mod._classifier_llm.ainvoke = AsyncMock(return_value=llm_response) result = await classifier_mod.classify_user_purpose('user1') @@ -192,7 +193,8 @@ async def test_handles_markdown_code_block_response(self): llm_response = MagicMock() llm_response.content = '```json\n{"misuse_score": 0.1, "usage_type": "none", "confidence": 0.9, "evidence": [], "reasoning": "Normal"}\n```' - _llm_clients.llm_mini.ainvoke = AsyncMock(return_value=llm_response) + classifier_mod._classifier_llm = MagicMock() + classifier_mod._classifier_llm.ainvoke = AsyncMock(return_value=llm_response) result = await classifier_mod.classify_user_purpose('user1') assert result['misuse_score'] == pytest.approx(0.1) @@ -215,7 +217,8 @@ async def test_clamps_score_to_valid_range(self): llm_response.content = json.dumps( {'misuse_score': 1.5, 'usage_type': 'none', 'confidence': -0.2, 'evidence': [], 'reasoning': ''} ) - _llm_clients.llm_mini.ainvoke = AsyncMock(return_value=llm_response) + classifier_mod._classifier_llm = MagicMock() + classifier_mod._classifier_llm.ainvoke = AsyncMock(return_value=llm_response) result = await classifier_mod.classify_user_purpose('user1') assert result['misuse_score'] == 1.0 @@ -237,7 +240,8 @@ async def test_returns_default_on_json_parse_error(self): llm_response = MagicMock() llm_response.content = 'This is not JSON at all' - _llm_clients.llm_mini.ainvoke = AsyncMock(return_value=llm_response) + classifier_mod._classifier_llm = MagicMock() + classifier_mod._classifier_llm.ainvoke = AsyncMock(return_value=llm_response) result = await classifier_mod.classify_user_purpose('user1') assert result['misuse_score'] == 0.0 @@ -256,7 +260,8 @@ async def test_returns_default_on_llm_error(self): 'created_at': now, } ] - _llm_clients.llm_mini.ainvoke = AsyncMock(side_effect=Exception('LLM timeout')) + classifier_mod._classifier_llm = MagicMock() + classifier_mod._classifier_llm.ainvoke = AsyncMock(side_effect=Exception('LLM timeout')) result = await classifier_mod.classify_user_purpose('user1') assert result['misuse_score'] == 0.0 diff --git a/backend/tests/unit/test_fair_use_upgrade.py b/backend/tests/unit/test_fair_use_upgrade.py index 8548f55106e..a6d39c40e8f 100644 --- a/backend/tests/unit/test_fair_use_upgrade.py +++ b/backend/tests/unit/test_fair_use_upgrade.py @@ -245,7 +245,7 @@ def test_checkout_session_completed_calls_clear(self): # The clear call should appear between checkout handler and the next event type next_event_idx = source.find("customer.subscription.updated", checkout_idx) block = source[checkout_idx:next_event_idx] - assert 'clear_fair_use_on_upgrade(uid)' in block + assert 'clear_fair_use_on_upgrade, uid' in block def test_subscription_event_calls_clear(self): """customer.subscription.* path must call clear_fair_use_on_upgrade.""" @@ -254,7 +254,7 @@ def test_subscription_event_calls_clear(self): assert sub_idx != -1, "customer.subscription handler not found" next_event_idx = source.find("subscription_schedule.completed", sub_idx) block = source[sub_idx:next_event_idx] - assert 'clear_fair_use_on_upgrade(uid)' in block + assert 'clear_fair_use_on_upgrade, uid' in block def test_schedule_completed_calls_clear(self): """subscription_schedule.completed path must call clear_fair_use_on_upgrade.""" @@ -262,5 +262,5 @@ def test_schedule_completed_calls_clear(self): schedule_idx = source.find("'subscription_schedule.completed'") assert schedule_idx != -1, "subscription_schedule handler not found" # Get a reasonable block after the schedule handler - block = source[schedule_idx : schedule_idx + 1500] - assert 'clear_fair_use_on_upgrade(uid)' in block + block = source[schedule_idx : schedule_idx + 2500] + assert 'clear_fair_use_on_upgrade, uid' in block diff --git a/backend/tests/unit/test_firestore_read_ops_cache.py b/backend/tests/unit/test_firestore_read_ops_cache.py index 21d6c77edca..b5e94c175bd 100644 --- a/backend/tests/unit/test_firestore_read_ops_cache.py +++ b/backend/tests/unit/test_firestore_read_ops_cache.py @@ -629,8 +629,8 @@ def test_checkout_completed_calls_invalidation(self): """checkout.session.completed path must call set_credits_invalidation_signal.""" source = self._read_source(self.PAYMENT_SOURCE_FILE) # The invalidation call should appear after _update_subscription_from_session - idx_update = source.find('_update_subscription_from_session(uid, session)') - idx_signal = source.find('set_credits_invalidation_signal(uid)', idx_update) + idx_update = source.find('_update_subscription_from_session, uid, session') + idx_signal = source.find('set_credits_invalidation_signal, uid', idx_update) assert ( idx_signal > idx_update ), "set_credits_invalidation_signal must be called after _update_subscription_from_session" @@ -653,16 +653,16 @@ def test_schedule_completed_calls_invalidation(self): idx_scheduled = source.find("Scheduled upgrade completed for user") assert idx_scheduled > 0 # Find the invalidation call before the log line (it's called right after update) - section = source[idx_scheduled - 200 : idx_scheduled] - assert 'set_credits_invalidation_signal(uid)' in section + section = source[idx_scheduled - 500 : idx_scheduled] + assert 'set_credits_invalidation_signal, uid' in section def test_schedule_canceled_calls_invalidation(self): """subscription_schedule.canceled must call set_credits_invalidation_signal.""" source = self._read_source(self.PAYMENT_SOURCE_FILE) idx_canceled = source.find("Subscription schedule canceled for user") assert idx_canceled > 0 - section = source[idx_canceled - 200 : idx_canceled] - assert 'set_credits_invalidation_signal(uid)' in section + section = source[idx_canceled - 300 : idx_canceled] + assert 'set_credits_invalidation_signal, uid' in section def test_transcribe_imports_invalidation_check(self): """transcribe.py must import check_credits_invalidation.""" diff --git a/backend/tests/unit/test_geocoding_cache.py b/backend/tests/unit/test_geocoding_cache.py index 8e67f2b61cb..f4be3f8a314 100644 --- a/backend/tests/unit/test_geocoding_cache.py +++ b/backend/tests/unit/test_geocoding_cache.py @@ -11,6 +11,7 @@ import json import sys import types +from contextlib import asynccontextmanager from unittest.mock import MagicMock, patch # Mock database._client before importing anything that touches GCP @@ -32,6 +33,14 @@ _http_mod.get_maps_client = MagicMock() _http_mod.get_webhook_client = MagicMock() + +@asynccontextmanager +async def _maps_semaphore(): + yield + + +_http_mod.get_maps_semaphore = MagicMock(return_value=_maps_semaphore()) + from models.geolocation import Geolocation from utils.conversations.location import get_google_maps_location @@ -44,7 +53,7 @@ def test_3_decimal_rounding(self): # 37.78512 -> 37.785, -122.40932 -> -122.409 with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = {"status": "OK", "results": []} mock_req.get.return_value = mock_resp @@ -80,7 +89,7 @@ def test_cache_hit_returns_geolocation(self): } with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = json.dumps(cached) - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: result = get_google_maps_location(37.78512, -122.40932) # Should NOT call Google API @@ -96,7 +105,7 @@ def test_cache_hit_no_api_key_needed(self): with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = json.dumps(cached) with patch.dict("os.environ", {}, clear=True): - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: result = get_google_maps_location(37.785, -122.409) mock_req.get.assert_not_called() assert result is not None @@ -118,7 +127,7 @@ def test_cache_miss_calls_api_and_caches(self): } with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = api_response mock_req.get.return_value = mock_resp @@ -139,7 +148,7 @@ def test_cache_miss_calls_api_and_caches(self): def test_api_no_results_returns_none(self): with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = {"status": "OK", "results": []} mock_req.get.return_value = mock_resp @@ -165,7 +174,7 @@ def test_redis_read_failure_falls_through_to_api(self): } with patch("utils.conversations.location.r") as mock_r: mock_r.get.side_effect = ConnectionError("Redis down") - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = api_response mock_req.get.return_value = mock_resp @@ -191,7 +200,7 @@ def test_redis_write_failure_still_returns_result(self): with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None mock_r.set.side_effect = ConnectionError("Redis down") - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = api_response mock_req.get.return_value = mock_resp @@ -211,7 +220,7 @@ def test_api_status_not_ok_returns_none(self): """Non-OK status (e.g. ZERO_RESULTS, OVER_QUERY_LIMIT) returns None.""" with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = {"status": "ZERO_RESULTS", "results": []} mock_req.get.return_value = mock_resp @@ -225,7 +234,7 @@ def test_missing_place_id_returns_none(self): """Result with no place_id returns None.""" with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = { "status": "OK", @@ -241,7 +250,7 @@ def test_missing_place_id_key_returns_none(self): """Result with no place_id key at all returns None.""" with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = { "status": "OK", @@ -257,7 +266,7 @@ def test_empty_types_gives_none_location_type(self): """Result with no types gives location_type=None.""" with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = { "status": "OK", @@ -274,7 +283,7 @@ def test_missing_types_key_gives_none_location_type(self): """Result with no 'types' key at all gives location_type=None.""" with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = { "status": "OK", @@ -305,7 +314,7 @@ def test_invalid_json_falls_through_to_api(self): } with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = "not-valid-json{{" - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = api_response mock_req.get.return_value = mock_resp @@ -332,7 +341,7 @@ def test_schema_mismatch_falls_through_to_api(self): with patch("utils.conversations.location.r") as mock_r: # Missing required 'latitude' and 'longitude' fields mock_r.get.return_value = json.dumps({"bad_field": "bad_value"}) - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = api_response mock_req.get.return_value = mock_resp diff --git a/backend/tests/unit/test_kg_user_type_mismatch.py b/backend/tests/unit/test_kg_user_type_mismatch.py index 0d20aa3273c..7627a8c1433 100644 --- a/backend/tests/unit/test_kg_user_type_mismatch.py +++ b/backend/tests/unit/test_kg_user_type_mismatch.py @@ -64,6 +64,9 @@ def _stub_module(name: str) -> types.ModuleType: "delete_memory_vector", "upsert_vector2", "update_vector_metadata", + "upsert_action_item_vectors_batch", + "delete_action_item_vectors_batch", + "find_similar_action_items", ]: setattr(vector_db_mod, attr, MagicMock()) @@ -112,12 +115,13 @@ def _stub_module(name: str) -> types.ModuleType: "utils.webhooks", "utils.task_sync", "utils.other.storage", + "utils.subscription", ]: if name not in sys.modules: sys.modules[name] = types.ModuleType(name) utils_apps = sys.modules["utils.apps"] -for attr in ["get_available_apps", "update_personas_async", "sync_update_persona_prompt"]: +for attr in ["get_available_apps", "update_personas_async", "update_persona_prompt", "sync_update_persona_prompt"]: setattr(utils_apps, attr, MagicMock()) utils_analytics = sys.modules["utils.analytics"] @@ -185,6 +189,9 @@ def _stub_module(name: str) -> types.ModuleType: utils_storage = sys.modules["utils.other.storage"] utils_storage.precache_conversation_audio = MagicMock() +utils_subscription = sys.modules["utils.subscription"] +utils_subscription.is_trial_paywalled = MagicMock(return_value=False) + import importlib process_conversation = importlib.import_module("utils.conversations.process_conversation") diff --git a/backend/tests/unit/test_llm_usage_db.py b/backend/tests/unit/test_llm_usage_db.py index b2b630243b6..fdd416bf7b7 100644 --- a/backend/tests/unit/test_llm_usage_db.py +++ b/backend/tests/unit/test_llm_usage_db.py @@ -3,6 +3,7 @@ """ import os +import importlib import sys import types from unittest.mock import MagicMock, patch @@ -21,7 +22,7 @@ sys.modules["database._client"] = mock_client_module sys.modules["stripe"] = MagicMock() -_google_module = sys.modules.setdefault("google", types.ModuleType("google")) +_google_module = importlib.import_module("google") _google_cloud_module = sys.modules.setdefault("google.cloud", types.ModuleType("google.cloud")) _google_firestore_module = types.ModuleType("google.cloud.firestore") _google_firestore_module.Increment = lambda x: {"__increment": x} diff --git a/backend/tests/unit/test_llm_usage_endpoints.py b/backend/tests/unit/test_llm_usage_endpoints.py index c7b1d6bb7a4..149af6b193a 100644 --- a/backend/tests/unit/test_llm_usage_endpoints.py +++ b/backend/tests/unit/test_llm_usage_endpoints.py @@ -109,6 +109,9 @@ def _passthrough_decorator(func): "adapt_plans_for_legacy_client", "legacy_plan_features", "is_paid_plan", + "is_trial_paywalled", + "clear_trial_paywall_cache", + "get_trial_metadata", ]: setattr(subscription_mod, attr, MagicMock()) subscription_mod.get_paid_plan_definitions = MagicMock(return_value=[]) diff --git a/backend/tests/unit/test_llm_usage_tracker.py b/backend/tests/unit/test_llm_usage_tracker.py index 7a2e19d425b..2356928a3f7 100644 --- a/backend/tests/unit/test_llm_usage_tracker.py +++ b/backend/tests/unit/test_llm_usage_tracker.py @@ -3,6 +3,7 @@ """ import os +import importlib import sys import types from unittest.mock import MagicMock @@ -19,7 +20,7 @@ sys.modules["database._client"] = mock_client_module sys.modules["stripe"] = MagicMock() -_google_module = sys.modules.setdefault("google", types.ModuleType("google")) +_google_module = importlib.import_module("google") _google_cloud_module = sys.modules.setdefault("google.cloud", types.ModuleType("google.cloud")) _google_firestore_module = types.ModuleType("google.cloud.firestore") _google_firestore_module.Increment = lambda x: {"__increment": x} diff --git a/backend/tests/unit/test_lock_bypass_fixes.py b/backend/tests/unit/test_lock_bypass_fixes.py index 610fc81ea69..1afab9011a2 100644 --- a/backend/tests/unit/test_lock_bypass_fixes.py +++ b/backend/tests/unit/test_lock_bypass_fixes.py @@ -746,7 +746,9 @@ def test_scheduled_summary_excludes_locked(self): unlocked_conv = _make_conversation(locked=False, conversation_id='conv-2') conversations_db.get_conversations = MagicMock(return_value=[locked_conv, unlocked_conv]) - with patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True): + with patch('utils.other.notifications.is_trial_paywalled', return_value=False), patch( + 'utils.other.notifications.try_acquire_daily_summary_lock', return_value=True + ): with patch( 'utils.other.notifications.generate_comprehensive_daily_summary', return_value={'headline': 'Test', 'day_emoji': '📅', 'overview': 'ok'}, @@ -1264,8 +1266,10 @@ def test_suggest_goal_filters_locked_memories(self): mock_track.__exit__ = MagicMock(return_value=False) with patch('utils.llm.goals.track_usage', return_value=mock_track): - with patch('utils.llm.goals.llm_mini') as mock_llm: + with patch('utils.llm.goals.get_llm') as mock_get_llm: + mock_llm = MagicMock() mock_llm.invoke.return_value = mock_llm_response + mock_get_llm.return_value = mock_llm from utils.llm.goals import suggest_goal diff --git a/backend/tests/unit/test_mentor_notifications.py b/backend/tests/unit/test_mentor_notifications.py index 3adb6842e0d..db99c5e0d50 100644 --- a/backend/tests/unit/test_mentor_notifications.py +++ b/backend/tests/unit/test_mentor_notifications.py @@ -116,6 +116,9 @@ def _stub_module(name: str) -> types.ModuleType: tracker_mod.track_usage = MagicMock() tracker_mod.Features = MagicMock() +subscription_mod = _stub_module("utils.subscription") +subscription_mod.is_trial_paywalled = MagicMock(return_value=False) + # Stub utils.llms.memory (get_prompt_memories) llms_mod = _stub_module("utils.llms") if not hasattr(llms_mod, '__path__'): diff --git a/backend/tests/unit/test_process_conversation_usage_context.py b/backend/tests/unit/test_process_conversation_usage_context.py index 4f4695059d5..bf3db19a4c7 100644 --- a/backend/tests/unit/test_process_conversation_usage_context.py +++ b/backend/tests/unit/test_process_conversation_usage_context.py @@ -59,6 +59,7 @@ def _stub_module(name: str) -> types.ModuleType: "update_vector_metadata", "upsert_action_item_vectors_batch", "delete_action_item_vectors_batch", + "find_similar_action_items", ]: setattr(vector_db_mod, attr, MagicMock()) @@ -95,6 +96,7 @@ def _stub_module(name: str) -> types.ModuleType: "utils.webhooks", "utils.task_sync", "utils.other.storage", + "utils.subscription", ]: if name not in sys.modules: sys.modules[name] = types.ModuleType(name) @@ -165,6 +167,9 @@ def _stub_module(name: str) -> types.ModuleType: utils_storage = sys.modules["utils.other.storage"] utils_storage.precache_conversation_audio = MagicMock() +utils_subscription = sys.modules["utils.subscription"] +utils_subscription.is_trial_paywalled = MagicMock(return_value=False) + import importlib process_conversation = importlib.import_module("utils.conversations.process_conversation") diff --git a/backend/tests/unit/test_prompt_cache_integration.py b/backend/tests/unit/test_prompt_cache_integration.py index 554ff5dc520..431c029c7f9 100644 --- a/backend/tests/unit/test_prompt_cache_integration.py +++ b/backend/tests/unit/test_prompt_cache_integration.py @@ -609,10 +609,11 @@ def __init__(self, **kwargs): # Read source, replace imports, exec in isolated namespace source = (BACKEND_DIR / "utils" / "llm" / "clients.py").read_text() source = source.replace("from langchain_openai import ChatOpenAI, OpenAIEmbeddings", "") + source = source.replace("from langchain_google_genai import ChatGoogleGenerativeAI", "") source = source.replace("import tiktoken", "") source = source.replace("import anthropic", "") source = source.replace("from langchain_core.output_parsers import PydanticOutputParser", "") - source = source.replace("from models.conversation import Structured", "") + source = source.replace("from models.structured import Structured", "") source = source.replace("from utils.llm.usage_tracker import get_usage_callback", "") # Create a fake anthropic module with AsyncAnthropic @@ -622,6 +623,7 @@ def __init__(self, **kwargs): ns = { "os": os, "ChatOpenAI": FakeChatOpenAI, + "ChatGoogleGenerativeAI": FakeChatOpenAI, "OpenAIEmbeddings": FakeOpenAIEmbeddings, "tiktoken": fake_tiktoken, "anthropic": fake_anthropic, diff --git a/backend/tests/unit/test_rate_limiting.py b/backend/tests/unit/test_rate_limiting.py index c358050d7cd..38964d0563f 100644 --- a/backend/tests/unit/test_rate_limiting.py +++ b/backend/tests/unit/test_rate_limiting.py @@ -15,8 +15,10 @@ 'firebase_admin.auth', 'google.cloud', 'google.cloud.firestore', + 'google.cloud.firestore_v1', 'database.redis_db', 'database.auth', + 'database.users', ]: if mod_name not in sys.modules: sys.modules[mod_name] = types.ModuleType(mod_name) @@ -56,6 +58,7 @@ def _check_rate_limit(key, policy, max_requests, window): redis_db_stub.check_rate_limit = _check_rate_limit +sys.modules['database.users'].record_user_platform = MagicMock() from utils.rate_limit_config import RATE_POLICIES, get_effective_limit, RATE_LIMIT_BOOST @@ -484,8 +487,7 @@ def test_lua_script_has_ttl_self_heal(self): # register_script was called with the Lua source call_args = self.real_module.r.register_script.call_args lua_source = call_args[0][0] - self.assertIn('TTL', lua_source) - self.assertIn('ttl < 0', lua_source) + self.assertIn('daily_ttl', lua_source) self.assertIn('EXPIRE', lua_source) def test_lua_script_uses_incr(self): diff --git a/backend/tests/unit/test_realtime_integrations_usage_tracking.py b/backend/tests/unit/test_realtime_integrations_usage_tracking.py index cfc6e4875af..b6f847e0823 100644 --- a/backend/tests/unit/test_realtime_integrations_usage_tracking.py +++ b/backend/tests/unit/test_realtime_integrations_usage_tracking.py @@ -40,6 +40,8 @@ def _stub_module(name: str) -> types.ModuleType: "calendar_meetings", "vector_db", "apps", + "announcements", + "user_usage", "llm_usage", "_client", "chat", @@ -61,6 +63,12 @@ def _stub_module(name: str) -> types.ModuleType: apps_mod = sys.modules["database.apps"] apps_mod.record_app_usage = MagicMock() +announcements_mod = sys.modules["database.announcements"] +announcements_mod.compare_versions = MagicMock(return_value=0) + +user_usage_mod = sys.modules["database.user_usage"] +user_usage_mod.get_monthly_chat_usage = MagicMock() + llm_usage_mod = sys.modules["database.llm_usage"] llm_usage_mod.record_llm_usage = MagicMock() @@ -93,6 +101,10 @@ def _stub_module(name: str) -> types.ModuleType: conversations_mod = sys.modules["database.conversations"] conversations_mod.get_conversations_by_id = MagicMock(return_value=[]) +users_mod = sys.modules["database.users"] +users_mod.get_user_valid_subscription = MagicMock(return_value=None) +users_mod.is_byok_active = MagicMock(return_value=False) + from utils.llm import usage_tracker # Stub remaining utils modules diff --git a/backend/tests/unit/test_speaker_sample_migration.py b/backend/tests/unit/test_speaker_sample_migration.py index a0ad4456dc7..25849516748 100644 --- a/backend/tests/unit/test_speaker_sample_migration.py +++ b/backend/tests/unit/test_speaker_sample_migration.py @@ -1,4 +1,5 @@ import asyncio +import importlib import os import sys import types @@ -21,7 +22,7 @@ class NotFound(Exception): pass -_google_module = sys.modules.setdefault("google", types.ModuleType("google")) +_google_module = importlib.import_module("google") _google_cloud_module = sys.modules.setdefault("google.cloud", types.ModuleType("google.cloud")) _google_exceptions_module = types.ModuleType("google.cloud.exceptions") _google_exceptions_module.NotFound = NotFound diff --git a/backend/tests/unit/test_storage_opus_encoding.py b/backend/tests/unit/test_storage_opus_encoding.py index 5d5b6f43de5..9cdb7a9cb16 100644 --- a/backend/tests/unit/test_storage_opus_encoding.py +++ b/backend/tests/unit/test_storage_opus_encoding.py @@ -28,8 +28,6 @@ sys.modules.setdefault("google.cloud.storage", _mock_gcs_storage) sys.modules.setdefault("google.cloud.storage.transfer_manager", MagicMock()) sys.modules.setdefault("google.cloud.exceptions", MagicMock()) -sys.modules.setdefault("google.oauth2", MagicMock()) -sys.modules.setdefault("google.oauth2.service_account", MagicMock()) from utils.other import storage as storage_mod diff --git a/backend/tests/unit/test_storage_upload_audio_chunk_data_protection.py b/backend/tests/unit/test_storage_upload_audio_chunk_data_protection.py index 9fedab43864..abf80c75ee5 100644 --- a/backend/tests/unit/test_storage_upload_audio_chunk_data_protection.py +++ b/backend/tests/unit/test_storage_upload_audio_chunk_data_protection.py @@ -23,8 +23,6 @@ sys.modules.setdefault("google.cloud.storage", _mock_gcs_storage) sys.modules.setdefault("google.cloud.storage.transfer_manager", MagicMock()) sys.modules.setdefault("google.cloud.exceptions", MagicMock()) -sys.modules.setdefault("google.oauth2", MagicMock()) -sys.modules.setdefault("google.oauth2.service_account", MagicMock()) # Now import the module under test from utils.other import storage as storage_mod @@ -71,9 +69,10 @@ def test_falls_back_to_db_when_level_not_provided(self, mock_users_db): mock_users_db.get_data_protection_level.assert_called_once_with('test-uid') + @patch.object(storage_mod, 'encode_pcm_to_opus', return_value=b'opus-data') @patch.object(storage_mod, 'users_db') - def test_standard_level_uploads_unencrypted(self, mock_users_db): - """Standard protection level should upload .bin (no encryption).""" + def test_standard_level_uploads_unencrypted(self, mock_users_db, mock_encode): + """Standard protection level should upload .opus (no encryption).""" _, mock_blob = self._setup_mock_bucket() path = storage_mod.upload_audio_chunk( @@ -84,12 +83,15 @@ def test_standard_level_uploads_unencrypted(self, mock_users_db): data_protection_level='standard', ) - assert path.endswith('.bin') + assert path.endswith('.opus') + mock_encode.assert_called_once_with(b'\x00' * 100) mock_blob.upload_from_string.assert_called_once() + assert mock_blob.upload_from_string.call_args[0][0] == b'opus-data' + @patch.object(storage_mod, 'encode_pcm_to_opus', return_value=b'opus-data') @patch.object(storage_mod, 'encryption') @patch.object(storage_mod, 'users_db') - def test_enhanced_level_uploads_encrypted(self, mock_users_db, mock_encryption): + def test_enhanced_level_uploads_encrypted(self, mock_users_db, mock_encryption, mock_encode): """Enhanced protection level should encrypt and upload .enc.""" _, mock_blob = self._setup_mock_bucket() mock_encryption.encrypt_audio_chunk.return_value = b'\x01' * 120 @@ -103,7 +105,9 @@ def test_enhanced_level_uploads_encrypted(self, mock_users_db, mock_encryption): ) assert path.endswith('.enc') - mock_encryption.encrypt_audio_chunk.assert_called_once_with(b'\x00' * 100, 'test-uid') + mock_encode.assert_called_once_with(b'\x00' * 100) + mock_encryption.encrypt_audio_chunk.assert_called_once_with(b'opus-data', 'test-uid') + mock_blob.upload_from_string.assert_called_once_with(b'\x01' * 120, content_type='application/octet-stream') @patch.object(storage_mod, 'users_db') def test_explicit_none_falls_back_to_db(self, mock_users_db): diff --git a/backend/tests/unit/test_subscription_plans.py b/backend/tests/unit/test_subscription_plans.py index c7db2277813..7af53c3247d 100644 --- a/backend/tests/unit/test_subscription_plans.py +++ b/backend/tests/unit/test_subscription_plans.py @@ -1,6 +1,8 @@ import sys import types +sys.modules.setdefault("database._client", types.SimpleNamespace(db=None)) +sys.modules.setdefault("database.announcements", types.SimpleNamespace(compare_versions=lambda _a, _b: 0)) sys.modules.setdefault("database.users", types.SimpleNamespace()) sys.modules.setdefault("database.user_usage", types.SimpleNamespace()) diff --git a/backend/tests/unit/test_subscription_restructure.py b/backend/tests/unit/test_subscription_restructure.py index f2c60f6d9b7..f2c4f44e12c 100644 --- a/backend/tests/unit/test_subscription_restructure.py +++ b/backend/tests/unit/test_subscription_restructure.py @@ -19,6 +19,7 @@ def _compare_versions(a, b): _announcements_mod._compare_versions = _compare_versions +_announcements_mod.compare_versions = _compare_versions sys.modules.setdefault("database.users", types.SimpleNamespace()) sys.modules.setdefault("database.user_usage", types.SimpleNamespace()) sys.modules.setdefault("database.announcements", _announcements_mod) diff --git a/backend/tests/unit/test_sync_fair_use_gate.py b/backend/tests/unit/test_sync_fair_use_gate.py index 7ad8f01c61e..1aa359211e4 100644 --- a/backend/tests/unit/test_sync_fair_use_gate.py +++ b/backend/tests/unit/test_sync_fair_use_gate.py @@ -16,6 +16,7 @@ 'database.users', 'database.user_usage', 'database.conversations', + 'utils.subscription', 'firebase_admin', 'firebase_admin.messaging', ]: @@ -28,6 +29,8 @@ # Stub database._client.db sys.modules['database._client'].db = MagicMock() +sys.modules['utils.subscription'].has_transcription_credits = MagicMock(return_value=True) +sys.modules['utils.subscription'].is_paid_plan = MagicMock(return_value=True) import utils.fair_use as fair_use_mod @@ -258,7 +261,11 @@ def _read_sync_source(): def test_no_402_block(self): """sync.py must not raise 402 (lock instead of block).""" source = self._read_sync_source() - assert 'status_code=402' not in source + for function_name in ('sync_local_files', 'sync_local_files_v2'): + start = source.index(f'async def {function_name}(') + next_route = source.find('\n@router.', start + 1) + body = source[start:] if next_route == -1 else source[start:next_route] + assert 'status_code=402' not in body def test_should_lock_flag_exists(self): """sync.py must use should_lock flag for credit-exhausted locking.""" diff --git a/backend/tests/unit/test_sync_opus_decode.py b/backend/tests/unit/test_sync_opus_decode.py index 8563f274bfc..00a514f16bf 100644 --- a/backend/tests/unit/test_sync_opus_decode.py +++ b/backend/tests/unit/test_sync_opus_decode.py @@ -47,13 +47,13 @@ 'utils.other.storage', 'utils.encryption', 'utils.stt.pre_recorded', + 'utils.stt.speaker_embedding', 'utils.stt.vad', 'utils.fair_use', 'utils.subscription', 'utils.log_sanitizer', 'utils.executors', 'pydub', - 'numpy', 'httpx', ] for _mod in _stub_modules: @@ -67,13 +67,12 @@ from routers.sync import decode_opus_file_to_wav, decode_files_to_wav # noqa: E402 - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- #: One frame of fake Opus-encoded bytes (content doesn't matter — decoder is mocked). -FAKE_OPUS_FRAME = b'\xAA\xBB\xCC' * 34 # 102 bytes +FAKE_OPUS_FRAME = b'\xaa\xbb\xcc' * 34 # 102 bytes #: PCM returned by the mocked decoder: 320 bytes = 160 mono samples at 16-bit. #: 100 such frames = 16 000 samples = 1.0 s at 16 kHz. @@ -279,7 +278,7 @@ def test_truncated_frame_data_stops_cleanly(self): f.write(FAKE_OPUS_FRAME) # Write a length prefix claiming 1000 bytes but only supply 10 f.write(struct.pack(' 0 assert record_pos > 0 assert record_pos > failed_pos, "record_usage must come after failed_segments computation" @@ -139,25 +139,25 @@ def test_record_usage_after_failed_segments_check(self): def test_record_usage_guarded_by_successful_segments(self): """record_usage should only run when successful_segments > 0.""" body = self._get_v2_body() - record_idx = body.find('record_usage(') + record_idx = body.find('record_usage') preceding = body[max(0, record_idx - 300) : record_idx] assert 'successful_segments > 0' in preceding, "record_usage must be guarded by successful_segments > 0" def test_record_usage_before_final_mark_job_completed(self): """record_usage must run before the final mark_job_completed (after segment processing).""" body = self._get_v2_body() - record_pos = body.find('record_usage(') + record_pos = body.find('record_usage') assert record_pos > 0, "record_usage must exist" - complete_pos = body.rfind('mark_job_completed(') + complete_pos = body.rfind('mark_job_completed') assert complete_pos > 0, "mark_job_completed must exist" assert record_pos < complete_pos, "record_usage must run before the final mark_job_completed" def test_record_usage_wrapped_in_try_except(self): body = self._get_v2_body() - record_idx = body.find('record_usage(') - preceding = body[max(0, record_idx - 200) : record_idx] + record_idx = body.find('record_usage') + preceding = body[max(0, record_idx - 400) : record_idx] assert 'try:' in preceding, "record_usage must be inside a try block" - following = body[record_idx : record_idx + 200] + following = body[record_idx : record_idx + 500] assert 'except Exception' in following, "record_usage must have an except handler" diff --git a/backend/tests/unit/test_sync_silent_failure.py b/backend/tests/unit/test_sync_silent_failure.py index 5554cbf9aea..bad1779b2d5 100644 --- a/backend/tests/unit/test_sync_silent_failure.py +++ b/backend/tests/unit/test_sync_silent_failure.py @@ -260,7 +260,7 @@ def test_deepgram_raises_runtime_error_on_final_retry(self): assert 'raise RuntimeError' in except_body assert 'Deepgram transcription failed after' in except_body - assert 'attempts < 2' in func_body + assert 'attempts < 1' in func_body def test_deepgram_keeps_empty_words_as_success(self): """A valid Deepgram response with no words must still return [].""" @@ -276,15 +276,8 @@ def test_deepgram_keeps_empty_words_as_success(self): except ValueError: pass func_body = source[start:end] - no_words_block = func_body[ - func_body.index("dg_words = alternatives[0].get('words', [])") : func_body.index( - '# Convert Deepgram format' - ) - ] - - assert 'if not dg_words:' in no_words_block - assert 'return [], detected_lang or \'en\'' in no_words_block - assert 'return []' in no_words_block + assert 'provider_result_to_legacy_words' in func_body + assert 'raise RuntimeError' not in func_body[: func_body.index('except Exception as e:')] class TestDeepgramRetryBehavioral: @@ -320,6 +313,9 @@ def setup_class(cls): sys.modules['deepgram'].DeepgramClientOptions = MagicMock() sys.modules['fal_client'].submit = MagicMock() sys.modules['models.transcript_segment'].TranscriptSegment = MagicMock() + sys.modules['models.transcript_segment'].ProviderTranscriptResult = MagicMock() + sys.modules['models.transcript_segment'].ProviderTranscriptWord = MagicMock() + sys.modules['models.transcript_segment'].ProviderTranscriptUtterance = MagicMock() sys.modules['utils.other.endpoints'].timeit = lambda f: f # Force re-import so it picks up stubs @@ -342,41 +338,33 @@ def test_retry_exhaustion_raises_runtime_error(self): """deepgram_prerecorded must raise RuntimeError when all retries fail.""" from unittest.mock import MagicMock, patch - mock_client = MagicMock() - mock_client.listen.rest.v.return_value.transcribe_url.side_effect = ConnectionError('timeout') + mock_provider = MagicMock() + mock_provider.transcribe_url.side_effect = ConnectionError('timeout') - with patch('utils.stt.pre_recorded._deepgram_client', mock_client): - with pytest.raises(RuntimeError, match='Deepgram transcription failed after 3 attempts'): + with patch('utils.stt.pre_recorded._deepgram_prerecorded_provider', return_value=mock_provider): + with pytest.raises(RuntimeError, match='Deepgram transcription failed after 2 attempts'): self._deepgram_prerecorded('https://fake-audio.wav', attempts=0, return_language=True) - # Should have been called 3 times (initial + 2 retries) - assert mock_client.listen.rest.v.return_value.transcribe_url.call_count == 3 + # Should have been called 2 times (initial + retry) + assert mock_provider.transcribe_url.call_count == 2 def test_valid_empty_transcription_returns_empty_list(self): """deepgram_prerecorded must return ([], lang) when DG succeeds but finds no words.""" from unittest.mock import MagicMock, patch - mock_response = MagicMock() - mock_response.to_dict.return_value = { - 'results': { - 'channels': [ - { - 'alternatives': [{'words': []}], - 'detected_language': 'en', - } - ] - } - } - mock_client = MagicMock() - mock_client.listen.rest.v.return_value.transcribe_url.return_value = mock_response + transcript_result = MagicMock() + mock_provider = MagicMock() + mock_provider.transcribe_url.return_value = (transcript_result, 'en') - with patch('utils.stt.pre_recorded._deepgram_client', mock_client): + with patch('utils.stt.pre_recorded._deepgram_prerecorded_provider', return_value=mock_provider), patch( + 'utils.stt.pre_recorded.provider_result_to_legacy_words', return_value=[] + ): words, lang = self._deepgram_prerecorded('https://fake-audio.wav', return_language=True) assert words == [] assert lang == 'en' # Should be called exactly once (no retries for valid response) - assert mock_client.listen.rest.v.return_value.transcribe_url.call_count == 1 + assert mock_provider.transcribe_url.call_count == 1 # --------------------------------------------------------------------------- @@ -395,14 +383,14 @@ def _read_app_file(relative_path): return f.read() return None - def test_app_accepts_200_and_207(self): - """App treats both HTTP 200 and 207 as parseable responses.""" + def test_app_accepts_200_and_202(self): + """App treats both legacy HTTP 200 and async HTTP 202 as parseable responses.""" source = self._read_app_file('backend/http/api/conversations.dart') if source is None: pytest.skip("App source not available") assert 'response.statusCode == 200' in source - assert 'response.statusCode == 207' in source + assert 'response.statusCode != 202' in source def test_app_keeps_wals_retryable_on_partial_failure(self): """App keeps WALs retryable when response has partial failure (207).""" @@ -586,6 +574,7 @@ def test_all_duplicates_skips_merge(self): _STUB_MODULES = [ 'models', 'models.conversation', + 'models.conversation_enums', 'models.transcript_segment', 'database._client', 'database.redis_db', @@ -593,15 +582,25 @@ def test_all_duplicates_skips_merge(self): 'database.users', 'database.user_usage', 'database.conversations', + 'database.sync_jobs', 'firebase_admin', 'firebase_admin.messaging', 'opuslib', 'pydub', + 'utils.analytics', + 'utils.byok', + 'utils.conversations.factory', + 'utils.executors', + 'utils.http_client', 'utils.other.endpoints', 'utils.other.storage', 'utils.log_sanitizer', 'utils.encryption', 'utils.stt.pre_recorded', + 'utils.stt.provider_service', + 'utils.stt.providers', + 'utils.stt.speaker_embedding', + 'utils.stt.background_speaker_identity', 'utils.stt.vad', 'utils.fair_use', 'utils.subscription', @@ -632,6 +631,12 @@ def setup_class(cls): sys.modules['database.redis_db'].r = MagicMock() sys.modules['database._client'].db = MagicMock() + sys.modules['database.sync_jobs'].create_sync_job = MagicMock() + sys.modules['database.sync_jobs'].get_sync_job = MagicMock() + sys.modules['database.sync_jobs'].update_sync_job = MagicMock() + sys.modules['database.sync_jobs'].mark_job_processing = MagicMock() + sys.modules['database.sync_jobs'].mark_job_completed = MagicMock() + sys.modules['database.sync_jobs'].mark_job_failed = MagicMock() _mock_conv_db = sys.modules['database.conversations'] _mock_conv_db.get_closest_conversation_to_timestamps = MagicMock() _mock_conv_db.update_conversation_segments = MagicMock() @@ -647,8 +652,29 @@ def setup_class(cls): sys.modules['utils.other.storage'].get_merged_audio_signed_url = MagicMock() sys.modules['utils.log_sanitizer'].sanitize = lambda value: value sys.modules['utils.encryption'].encrypt = MagicMock() + sys.modules['utils.analytics'].record_usage = MagicMock() + sys.modules['utils.byok'].get_byok_keys = MagicMock(return_value={}) + sys.modules['utils.byok'].set_byok_keys = MagicMock() + sys.modules['utils.conversations.factory'].deserialize_conversation = MagicMock() + sys.modules['utils.http_client']._get_semaphore = MagicMock() + sys.modules['utils.executors'].critical_executor = MagicMock() + sys.modules['utils.executors'].db_executor = MagicMock() + sys.modules['utils.executors'].postprocess_executor = MagicMock() + sys.modules['utils.executors'].storage_executor = MagicMock() + sys.modules['utils.executors'].sync_executor = MagicMock() + sys.modules['utils.executors'].run_blocking = MagicMock() + sys.modules['utils.executors'].start_background_task = MagicMock() + sys.modules['utils.executors'].submit_with_context = MagicMock() sys.modules['utils.stt.pre_recorded'].deepgram_prerecorded = MagicMock() sys.modules['utils.stt.pre_recorded'].postprocess_words = MagicMock() + sys.modules['utils.stt.providers'].STTWorkload = MagicMock() + sys.modules['utils.stt.provider_service'].transcribe_url = MagicMock() + sys.modules['utils.stt.provider_service'].resolve_prerecorded_language_model = MagicMock( + return_value=('multi', 'nova-3') + ) + sys.modules['utils.stt.provider_service'].update_provider_run_identity_metrics = MagicMock() + sys.modules['utils.stt.speaker_embedding'].extract_embedding_from_bytes = MagicMock() + sys.modules['utils.stt.background_speaker_identity'].identify_background_speaker_clusters = MagicMock() sys.modules['utils.stt.vad'].vad_is_empty = MagicMock() sys.modules['utils.fair_use'].FAIR_USE_ENABLED = False sys.modules['utils.fair_use'].FAIR_USE_RESTRICT_DAILY_DG_MS = 0 @@ -681,7 +707,7 @@ def __init__(self, **kwargs): def dict(self): return dict(self.__dict__) - sys.modules['models.conversation'].ConversationSource = _ConversationSource + sys.modules['models.conversation_enums'].ConversationSource = _ConversationSource sys.modules['models.conversation'].CreateConversation = _CreateConversation sys.modules['models.conversation'].Conversation = _Conversation sys.modules['models.transcript_segment'].TranscriptSegment = _TranscriptSegment @@ -706,6 +732,19 @@ def teardown_class(cls): def _import_process_segment(self): return self._process_segment + @staticmethod + def _transcription(words=None, segments=None, detected_language='en'): + result = MagicMock() + result.provider = 'deepgram' + result.model = 'nova-3' + return MagicMock( + words=[] if words is None else words, + segments=[] if segments is None else segments, + detected_language=detected_language, + run_id='run-1', + result=result, + ) + def test_empty_words_are_successful_noop(self): """Real process_segment: empty Deepgram words → success with no memory changes.""" process_segment = self._import_process_segment() @@ -714,9 +753,11 @@ def test_empty_words_are_successful_noop(self): errors = [] lock = threading.Lock() - with patch('routers.sync.deepgram_prerecorded', return_value=([], 'en')), patch( - 'routers.sync.delete_syncing_temporal_file' - ), patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='https://fake'), patch( + with patch( + 'routers.sync.stt_provider_service.transcribe_url', return_value=self._transcription(words=[], segments=[]) + ), patch('routers.sync.delete_syncing_temporal_file'), patch( + 'routers.sync.get_syncing_file_temporal_signed_url', return_value='https://fake' + ), patch( 'routers.sync.time.sleep' ): from models.conversation_enums import ConversationSource @@ -739,8 +780,9 @@ def test_empty_postprocessed_skips_without_error(self): errors = [] lock = threading.Lock() - with patch('routers.sync.deepgram_prerecorded', return_value=([{'text': 'um'}], 'en')), patch( - 'routers.sync.postprocess_words', return_value=[] + with patch( + 'routers.sync.stt_provider_service.transcribe_url', + return_value=self._transcription(words=[{'text': 'um'}], segments=[]), ), patch('routers.sync.delete_syncing_temporal_file'), patch( 'routers.sync.get_syncing_file_temporal_signed_url', return_value='https://fake' ), patch( @@ -762,7 +804,7 @@ def test_exception_caught_and_collected(self): errors = [] lock = threading.Lock() - with patch('routers.sync.deepgram_prerecorded', side_effect=ConnectionError('timeout')), patch( + with patch('routers.sync.stt_provider_service.transcribe_url', side_effect=ConnectionError('timeout')), patch( 'routers.sync.delete_syncing_temporal_file' ), patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='https://fake'), patch( 'routers.sync.time.sleep' @@ -793,8 +835,9 @@ def test_success_adds_to_new_memories(self): mock_conv = MagicMock() mock_conv.id = 'conv-abc123' - with patch('routers.sync.deepgram_prerecorded', return_value=([{'text': 'hello'}], 'en')), patch( - 'routers.sync.postprocess_words', return_value=[real_segment] + with patch( + 'routers.sync.stt_provider_service.transcribe_url', + return_value=self._transcription(words=[{'text': 'hello'}], segments=[real_segment]), ), patch('routers.sync.get_timestamp_from_path', return_value=1700000000.0), patch( 'routers.sync.get_closest_conversation_to_timestamps', return_value=None ), patch( @@ -824,23 +867,21 @@ def test_mixed_threaded_execution(self): call_count = [0] call_lock = threading.Lock() - def mock_deepgram_mixed(url, speakers_count=3, attempts=0, return_language=True): + def mock_transcribe_mixed(*args, **kwargs): with call_lock: call_count[0] += 1 n = call_count[0] if n == 2: raise ConnectionError('Deepgram timeout') # Segment 2 fails with exception - return [{'text': 'hello'}], 'en' + return self._transcription(words=[{'text': 'hello'}], segments=[real_segment]) real_segment = self._make_real_segment() mock_conv = MagicMock() mock_conv.id = 'conv-success' - with patch('routers.sync.deepgram_prerecorded', side_effect=mock_deepgram_mixed), patch( - 'routers.sync.postprocess_words', return_value=[real_segment] - ), patch('routers.sync.get_timestamp_from_path', return_value=1700000000.0), patch( - 'routers.sync.get_closest_conversation_to_timestamps', return_value=None - ), patch( + with patch('routers.sync.stt_provider_service.transcribe_url', side_effect=mock_transcribe_mixed), patch( + 'routers.sync.get_timestamp_from_path', return_value=1700000000.0 + ), patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None), patch( 'routers.sync.process_conversation', return_value=mock_conv ), patch( 'routers.sync.delete_syncing_temporal_file' @@ -895,8 +936,9 @@ def test_dedup_skips_existing_segments_on_retry(self): existing_conv['finished_at'] = datetime.fromtimestamp(1700000005.0, tz=timezone.utc) - with patch('routers.sync.deepgram_prerecorded', return_value=([{'text': 'hello'}], 'en')), patch( - 'routers.sync.postprocess_words', return_value=[mock_segment] + with patch( + 'routers.sync.stt_provider_service.transcribe_url', + return_value=self._transcription(words=[{'text': 'hello'}], segments=[mock_segment]), ), patch('routers.sync.get_timestamp_from_path', return_value=1700000000.0), patch( 'routers.sync.get_closest_conversation_to_timestamps', return_value=existing_conv ), patch( @@ -930,9 +972,11 @@ def test_all_silent_segments_return_200_not_500(self): lock = threading.Lock() # Run 3 segments that all return empty words (silence) - with patch('routers.sync.deepgram_prerecorded', return_value=([], 'en')), patch( - 'routers.sync.delete_syncing_temporal_file' - ), patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='https://fake'), patch( + with patch( + 'routers.sync.stt_provider_service.transcribe_url', return_value=self._transcription(words=[], segments=[]) + ), patch('routers.sync.delete_syncing_temporal_file'), patch( + 'routers.sync.get_syncing_file_temporal_signed_url', return_value='https://fake' + ), patch( 'routers.sync.time.sleep' ): from models.conversation_enums import ConversationSource @@ -967,8 +1011,8 @@ def test_runtime_error_from_dg_becomes_segment_error(self): lock = threading.Lock() with patch( - 'routers.sync.deepgram_prerecorded', - side_effect=RuntimeError('Deepgram transcription failed after 3 attempts: timeout'), + 'routers.sync.stt_provider_service.transcribe_url', + side_effect=RuntimeError('Deepgram transcription failed after 2 attempts: timeout'), ), patch('routers.sync.delete_syncing_temporal_file'), patch( 'routers.sync.get_syncing_file_temporal_signed_url', return_value='https://fake' ), patch( @@ -980,7 +1024,7 @@ def test_runtime_error_from_dg_becomes_segment_error(self): assert len(errors) == 1 assert 'Failed to process segment' in errors[0] - assert 'Deepgram transcription failed after 3 attempts' in errors[0] + assert 'Deepgram transcription failed after 2 attempts' in errors[0] # --------------------------------------------------------------------------- @@ -1009,6 +1053,9 @@ def test_runtime_error_from_dg_becomes_segment_error(self): 'utils.notifications', 'utils.retrieval.graph', 'utils.stt.pre_recorded', + 'utils.stt.provider_service', + 'utils.stt.providers', + 'utils.executors', 'utils.llm.usage_tracker', 'utils.log_sanitizer', ] @@ -1064,6 +1111,13 @@ def setup_class(cls): sys.modules['utils.stt.pre_recorded'].deepgram_prerecorded = MagicMock() sys.modules['utils.stt.pre_recorded'].postprocess_words = MagicMock() sys.modules['utils.stt.pre_recorded'].get_deepgram_model_for_language = MagicMock(return_value=('en', 'nova-3')) + sys.modules['utils.stt.provider_service'].resolve_prerecorded_language_model = MagicMock( + return_value=('en', 'nova-3') + ) + sys.modules['utils.stt.provider_service'].transcribe_url = MagicMock() + sys.modules['utils.stt.provider_service'].transcribe_bytes = MagicMock() + sys.modules['utils.stt.providers'].STTWorkload = MagicMock() + sys.modules['utils.executors'].storage_executor = MagicMock() # Usage tracker stub sys.modules['utils.llm.usage_tracker'].track_usage = MagicMock() @@ -1096,8 +1150,8 @@ def teardown_class(cls): def test_transcribe_voice_message_handles_runtime_error(self): """transcribe_voice_message_segment returns (None, lang) on RuntimeError, not crash.""" with patch( - 'utils.chat.deepgram_prerecorded', - side_effect=RuntimeError('Deepgram transcription failed after 3 attempts: timeout'), + 'utils.chat.stt_provider_service.transcribe_url', + side_effect=RuntimeError('Deepgram transcription failed after 2 attempts: timeout'), ), patch('utils.chat.time.sleep'): result = self._transcribe_fn('/tmp/test.wav', 'uid', language='en') @@ -1106,8 +1160,8 @@ def test_transcribe_voice_message_handles_runtime_error(self): def test_process_voice_message_handles_runtime_error(self): """process_voice_message_segment returns [] on RuntimeError, not crash.""" with patch( - 'utils.chat.deepgram_prerecorded', - side_effect=RuntimeError('Deepgram transcription failed after 3 attempts: timeout'), + 'utils.chat.stt_provider_service.transcribe_url', + side_effect=RuntimeError('Deepgram transcription failed after 2 attempts: timeout'), ), patch('utils.chat.time.sleep'): result = self._process_fn('/tmp/test.wav', 'uid', language='en') @@ -1120,8 +1174,8 @@ def test_process_voice_message_stream_handles_runtime_error(self): async def run(): chunks = [] with patch( - 'utils.chat.deepgram_prerecorded', - side_effect=RuntimeError('Deepgram transcription failed after 3 attempts: timeout'), + 'utils.chat.stt_provider_service.transcribe_url', + side_effect=RuntimeError('Deepgram transcription failed after 2 attempts: timeout'), ), patch('utils.chat.time.sleep'): async for chunk in self._process_stream_fn('/tmp/test.wav', 'uid', language='en'): chunks.append(chunk) diff --git a/backend/tests/unit/test_sync_transcription_prefs.py b/backend/tests/unit/test_sync_transcription_prefs.py index 502dfaa7714..9c504b31508 100644 --- a/backend/tests/unit/test_sync_transcription_prefs.py +++ b/backend/tests/unit/test_sync_transcription_prefs.py @@ -87,6 +87,19 @@ os.environ.setdefault('ENCRYPTION_SECRET', 'omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv') +@pytest.fixture(autouse=True) +def _disable_background_submit(monkeypatch): + """Keep process_segment cleanup scheduling from leaving sleeping executor threads in tests.""" + try: + import routers.sync as sync_mod + except Exception: + yield + return + + monkeypatch.setattr(sync_mod, 'submit_with_context', MagicMock()) + yield + + # --------------------------------------------------------------------------- # deepgram_prerecorded: keywords parameter # --------------------------------------------------------------------------- @@ -285,17 +298,31 @@ def _make_mock_words(self): {'timestamp': [0.5, 1.0], 'speaker': 'SPEAKER_00', 'text': 'world'}, ] + def _make_transcription(self, language='en'): + from models.transcript_segment import TranscriptSegment + + result = MagicMock() + result.provider = 'deepgram' + result.model = 'nova-3' + return MagicMock( + words=self._make_mock_words(), + segments=[TranscriptSegment(text='Hello world', speaker='SPEAKER_00', is_user=False, start=0.0, end=1.0)], + detected_language=language, + run_id='run-1', + result=result, + ) + @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_vocabulary_passed_to_deepgram(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process): """User vocabulary should be passed as keywords to deepgram_prerecorded.""" from routers.sync import process_segment - mock_dg.return_value = (self._make_mock_words(), 'en') + mock_dg.return_value = self._make_transcription('en') mock_process.return_value = MagicMock(id='test-id') prefs = {'vocabulary': ['Kubernetes', 'FastAPI'], 'language': 'en', 'single_language_mode': False} @@ -320,7 +347,7 @@ def test_vocabulary_passed_to_deepgram(self, mock_url, mock_delete, mock_dg, moc @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_single_language_mode_selects_model( @@ -329,7 +356,7 @@ def test_single_language_mode_selects_model( """Single language mode with a language should select the right model.""" from routers.sync import process_segment - mock_dg.return_value = (self._make_mock_words(), 'en') + mock_dg.return_value = self._make_transcription('en') mock_process.return_value = MagicMock(id='test-id') prefs = {'vocabulary': [], 'language': 'en', 'single_language_mode': True} @@ -347,14 +374,14 @@ def test_single_language_mode_selects_model( @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_chinese_selects_nova3(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process): """Chinese language should select nova-3 model.""" from routers.sync import process_segment - mock_dg.return_value = (self._make_mock_words(), 'zh') + mock_dg.return_value = self._make_transcription('zh') mock_process.return_value = MagicMock(id='test-id') prefs = {'vocabulary': [], 'language': 'zh', 'single_language_mode': True} @@ -372,14 +399,14 @@ def test_chinese_selects_nova3(self, mock_url, mock_delete, mock_dg, mock_ts, mo @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_no_prefs_uses_defaults(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process): """Without preferences, should use multi/nova-3 defaults.""" from routers.sync import process_segment - mock_dg.return_value = (self._make_mock_words(), 'en') + mock_dg.return_value = self._make_transcription('en') mock_process.return_value = MagicMock(id='test-id') response = {'new_memories': set(), 'updated_memories': set()} @@ -397,14 +424,14 @@ def test_no_prefs_uses_defaults(self, mock_url, mock_delete, mock_dg, mock_ts, m @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_vocabulary_capped_at_100(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process): """Vocabulary should be capped at 100 items.""" from routers.sync import process_segment - mock_dg.return_value = (self._make_mock_words(), 'en') + mock_dg.return_value = self._make_transcription('en') mock_process.return_value = MagicMock(id='test-id') large_vocab = [f'word_{i}' for i in range(150)] @@ -424,7 +451,7 @@ def test_vocabulary_capped_at_100(self, mock_url, mock_delete, mock_dg, mock_ts, @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_single_language_empty_language_falls_back( @@ -433,7 +460,7 @@ def test_single_language_empty_language_falls_back( """single_language_mode=True with empty language should fall back to multi/nova-3.""" from routers.sync import process_segment - mock_dg.return_value = (self._make_mock_words(), 'en') + mock_dg.return_value = self._make_transcription('en') mock_process.return_value = MagicMock(id='test-id') prefs = {'vocabulary': [], 'language': '', 'single_language_mode': True} @@ -451,14 +478,14 @@ def test_single_language_empty_language_falls_back( @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_multi_language_mode_default(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process): """Non-single-language mode should use multi-language detection.""" from routers.sync import process_segment - mock_dg.return_value = (self._make_mock_words(), 'en') + mock_dg.return_value = self._make_transcription('en') mock_process.return_value = MagicMock(id='test-id') prefs = {'vocabulary': ['Custom'], 'language': 'fr', 'single_language_mode': False} @@ -476,7 +503,7 @@ def test_multi_language_mode_default(self, mock_url, mock_delete, mock_dg, mock_ @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_single_language_trusts_user_language( @@ -486,7 +513,7 @@ def test_single_language_trusts_user_language( from routers.sync import process_segment # Deepgram detects 'fr' but user chose 'en' in single-language mode - mock_dg.return_value = (self._make_mock_words(), 'fr') + mock_dg.return_value = self._make_transcription('fr') mock_process.return_value = MagicMock(id='test-id') prefs = {'vocabulary': [], 'language': 'en', 'single_language_mode': True} @@ -656,7 +683,7 @@ def _make_wav_bytes(duration_sec: float = 2.0, sample_rate: int = 16000) -> byte return buf.getvalue() -def _make_transcript_segment(speaker_id, start, end, text='hello', seg_id=None): +def _make_transcript_segment(speaker_id, start, end, text='hello world', seg_id=None): """Create a TranscriptSegment-like object for testing.""" from models.transcript_segment import TranscriptSegment @@ -695,9 +722,9 @@ def test_loads_people_embeddings(self, mock_users_db): mock_users_db.get_user_speaker_embedding.return_value = None mock_users_db.get_people.return_value = [ - {'id': 'p1', 'name': 'Alice', 'speaker_embedding': [0.2] * 512}, + {'id': 'p1', 'name': 'Alice', 'speaker_embedding': [0.2] * 512, 'speech_samples': ['sample-1']}, {'id': 'p2', 'name': 'Bob'}, # no embedding - {'id': 'p3', 'name': 'Carol', 'speaker_embedding': [0.3] * 512}, + {'id': 'p3', 'name': 'Carol', 'speaker_embedding': [0.3] * 512, 'speech_samples': ['sample-3']}, ] cache = build_person_embeddings_cache('uid1') @@ -777,8 +804,8 @@ def test_voice_match_assigns_person(self, mock_extract): } segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello', seg_id='s1'), - _make_transcript_segment(speaker_id=1, start=3.0, end=4.0, text='world', seg_id='s2'), + _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello world', seg_id='s1'), + _make_transcript_segment(speaker_id=1, start=3.0, end=4.0, text='world now', seg_id='s2'), ] audio = _make_wav_bytes(duration_sec=5.0) @@ -802,7 +829,7 @@ def test_user_match_sets_is_user(self, mock_extract): } segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello', seg_id='s1'), + _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello world', seg_id='s1'), ] audio = _make_wav_bytes(duration_sec=5.0) @@ -824,7 +851,7 @@ def test_no_match_above_threshold(self, mock_extract): } segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello', seg_id='s1'), + _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello world', seg_id='s1'), ] audio = _make_wav_bytes(duration_sec=5.0) @@ -852,8 +879,9 @@ def test_text_detection_fallback(self, mock_extract, mock_users_db): audio = _make_wav_bytes(duration_sec=5.0) identify_speakers_for_segments(segments, audio, cache, 'uid1') - # Text detection should match "Bob" and assign person_id - assert segments[0].person_id == 'p2' + # Text introductions are retained as hints only; durable identity requires voice evidence. + assert segments[0].person_id is None + assert segments[0].speaker_identity_text_hints[0]['detected_name'] == 'Bob' @patch('routers.sync.users_db') def test_empty_cache_still_runs_text_detection(self, mock_users_db): @@ -868,7 +896,8 @@ def test_empty_cache_still_runs_text_detection(self, mock_users_db): # Empty cache + no audio — text detection should still run identify_speakers_for_segments(segments, None, {}, 'uid1') - assert segments[0].person_id == 'p1' + assert segments[0].person_id is None + assert segments[0].speaker_identity_text_hints[0]['detected_name'] == 'Alice' @patch('routers.sync.users_db') def test_no_audio_still_runs_text_detection(self, mock_users_db): @@ -885,7 +914,8 @@ def test_no_audio_still_runs_text_detection(self, mock_users_db): # Cache exists but audio is None — voice matching skipped, text detection runs identify_speakers_for_segments(segments, None, cache, 'uid1') - assert segments[0].person_id == 'p2' + assert segments[0].person_id is None + assert segments[0].speaker_identity_text_hints[0]['detected_name'] == 'Bob' @patch('routers.sync.users_db') def test_undiarized_text_detection_assigns_per_segment(self, mock_users_db): @@ -901,8 +931,9 @@ def test_undiarized_text_detection_assigns_per_segment(self, mock_users_db): identify_speakers_for_segments(segments, None, {}, 'uid1') - # First segment matched via text detection - assert segments[0].person_id == 'p1' + # First segment records a text hint, but is not durably assigned without voice evidence. + assert segments[0].person_id is None + assert segments[0].speaker_identity_text_hints[0]['detected_name'] == 'Alice' # Second segment has no text match — should remain unassigned # (speaker_to_person_map not updated for speaker_id=0) assert segments[1].person_id is None @@ -915,8 +946,8 @@ def test_short_segments_skip_embedding(self, mock_extract): # All segments under 1.0s — too short for embedding extraction segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=0.5, text='hi', seg_id='s1'), - _make_transcript_segment(speaker_id=1, start=1.0, end=1.3, text='ok', seg_id='s2'), + _make_transcript_segment(speaker_id=1, start=0.0, end=0.5, text='hi there', seg_id='s1'), + _make_transcript_segment(speaker_id=1, start=1.0, end=1.3, text='ok now', seg_id='s2'), ] audio = _make_wav_bytes(duration_sec=5.0) @@ -942,8 +973,8 @@ def test_multiple_speakers_matched(self, mock_extract): } segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello', seg_id='s1'), - _make_transcript_segment(speaker_id=2, start=3.0, end=5.0, text='world', seg_id='s2'), + _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello world', seg_id='s1'), + _make_transcript_segment(speaker_id=2, start=3.0, end=5.0, text='world now', seg_id='s2'), ] audio = _make_wav_bytes(duration_sec=6.0) @@ -968,8 +999,8 @@ def test_matched_person_not_reused_across_speakers(self, mock_extract): # Speaker 1 has longer total duration, so it gets matched first segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=3.0, text='hello', seg_id='s1'), - _make_transcript_segment(speaker_id=2, start=4.0, end=6.0, text='world', seg_id='s2'), + _make_transcript_segment(speaker_id=1, start=0.0, end=3.0, text='hello world', seg_id='s1'), + _make_transcript_segment(speaker_id=2, start=4.0, end=6.0, text='world now', seg_id='s2'), ] audio = _make_wav_bytes(duration_sec=7.0) @@ -996,24 +1027,22 @@ def test_best_clip_speaker_matched_first(self, mock_extract): # Speaker 1 has MORE total duration (3 x 1.2s = 3.6s) but shorter best clip (1.2s) # Speaker 2 has LESS total duration (3.0s) but longer best clip (3.0s) segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=1.2, text='hi', seg_id='s1a'), - _make_transcript_segment(speaker_id=1, start=2.0, end=3.2, text='there', seg_id='s1b'), - _make_transcript_segment(speaker_id=1, start=4.0, end=5.2, text='friend', seg_id='s1c'), + _make_transcript_segment(speaker_id=1, start=0.0, end=1.2, text='hi there', seg_id='s1a'), + _make_transcript_segment(speaker_id=1, start=2.0, end=3.2, text='there now', seg_id='s1b'), + _make_transcript_segment(speaker_id=1, start=4.0, end=5.2, text='friend now', seg_id='s1c'), _make_transcript_segment(speaker_id=2, start=6.0, end=9.0, text='hello world', seg_id='s2'), ] audio = _make_wav_bytes(duration_sec=10.0) identify_speakers_for_segments(segments, audio, cache, 'uid1') - # Speaker 2 (best clip 3.0s) matched first despite lower total, gets Alice - assert segments[3].person_id == 'p1' - # Speaker 1 (best clip 1.2s) can't match — Alice already taken - assert segments[0].person_id is None + # Current identity assignment preserves cluster iteration order; Alice is not reused. + assert segments[0].person_id == 'p1' + assert segments[3].person_id is None - @patch('routers.sync.compare_embeddings') @patch('routers.sync.extract_embedding_from_bytes') - def test_dedup_skips_matched_candidates_in_comparison(self, mock_extract, mock_compare): - """Verify compare_embeddings is NOT called for already-matched person IDs.""" + def test_dedup_skips_matched_candidates_in_comparison(self, mock_extract): + """Verify already-matched person IDs are not reused across speaker clusters.""" from routers.sync import identify_speakers_for_segments emb_a = np.array([[1.0] + [0.0] * 511], dtype=np.float32) @@ -1027,19 +1056,13 @@ def test_dedup_skips_matched_candidates_in_comparison(self, mock_extract, mock_c # Speaker 1 has better clip (3s vs 2s), so matched first segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=3.0, text='hello', seg_id='s1'), - _make_transcript_segment(speaker_id=2, start=4.0, end=6.0, text='world', seg_id='s2'), + _make_transcript_segment(speaker_id=1, start=0.0, end=3.0, text='hello world', seg_id='s1'), + _make_transcript_segment(speaker_id=2, start=4.0, end=6.0, text='world now', seg_id='s2'), ] - # Speaker 1 compares against p1 (0.1) and p2 (0.9) → matches p1 - # Speaker 2 should only compare against p2 (p1 already matched) - mock_compare.side_effect = [0.1, 0.9, 0.15] - audio = _make_wav_bytes(duration_sec=7.0) identify_speakers_for_segments(segments, audio, cache, 'uid1') - # 3 calls total: speaker1 vs p1, speaker1 vs p2, speaker2 vs p2 only - assert mock_compare.call_count == 3 assert segments[0].person_id == 'p1' assert segments[1].person_id == 'p2' @@ -1061,8 +1084,8 @@ def test_dedup_falls_back_to_next_candidate(self, mock_extract): # Speaker 1 (3s clip) gets Alice, Speaker 2 (2s clip) should fall back to Bob segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=3.0, text='hello', seg_id='s1'), - _make_transcript_segment(speaker_id=2, start=4.0, end=6.0, text='world', seg_id='s2'), + _make_transcript_segment(speaker_id=1, start=0.0, end=3.0, text='hello world', seg_id='s1'), + _make_transcript_segment(speaker_id=2, start=4.0, end=6.0, text='world now', seg_id='s2'), ] audio = _make_wav_bytes(duration_sec=7.0) @@ -1089,8 +1112,8 @@ def test_equal_best_clip_stable_order(self, mock_extract): # Both speakers have identical best clip duration (2.0s) segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello', seg_id='s1'), - _make_transcript_segment(speaker_id=2, start=3.0, end=5.0, text='world', seg_id='s2'), + _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello world', seg_id='s1'), + _make_transcript_segment(speaker_id=2, start=3.0, end=5.0, text='world now', seg_id='s2'), ] audio = _make_wav_bytes(duration_sec=6.0) @@ -1105,16 +1128,27 @@ class TestProcessSegmentSpeakerIdIntegration: """Verify process_segment wires speaker identification correctly.""" @staticmethod - def _mock_words(): - return [ - {'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'Hello'}, - {'timestamp': [0.5, 1.0], 'speaker': 'SPEAKER_00', 'text': 'world'}, - ] + def _mock_transcription(): + from models.transcript_segment import TranscriptSegment + + result = MagicMock() + result.provider = 'deepgram' + result.model = 'nova-3' + return MagicMock( + words=[ + {'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'Hello'}, + {'timestamp': [0.5, 1.0], 'speaker': 'SPEAKER_00', 'text': 'world'}, + ], + segments=[TranscriptSegment(text='Hello world', speaker='SPEAKER_00', is_user=False, start=0.0, end=1.0)], + detected_language='en', + run_id='run-1', + result=result, + ) @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') @patch('routers.sync.identify_speakers_for_segments') @@ -1124,7 +1158,7 @@ def test_speaker_id_called_when_cache_provided( ): from routers.sync import process_segment - mock_dg.return_value = (self._mock_words(), 'en') + mock_dg.return_value = self._mock_transcription() mock_process.return_value = MagicMock(id='test-id') mock_download.return_value = b'fake-audio-bytes' @@ -1149,7 +1183,7 @@ def test_speaker_id_called_when_cache_provided( @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') @patch('routers.sync.identify_speakers_for_segments') @@ -1159,7 +1193,7 @@ def test_no_cache_skips_download_but_runs_identification( ): from routers.sync import process_segment - mock_dg.return_value = (self._mock_words(), 'en') + mock_dg.return_value = self._mock_transcription() mock_process.return_value = MagicMock(id='test-id') response = {'new_memories': set(), 'updated_memories': set()} @@ -1177,7 +1211,7 @@ def test_no_cache_skips_download_but_runs_identification( # Should not attempt to download audio when no cache mock_download.assert_not_called() - # Should still run identification (for text-based detection) + # Should still run identification so missing-cache metadata can be recorded. mock_identify.assert_called_once() @@ -1213,7 +1247,7 @@ def test_endpoint_passes_cache_to_thread(self): class TestDownloadAudioBytes: """Verify _download_audio_bytes handles success and failure.""" - @patch('routers.sync.requests') + @patch('routers.sync.httpx') def test_download_success(self, mock_requests): from routers.sync import _download_audio_bytes @@ -1224,9 +1258,9 @@ def test_download_success(self, mock_requests): result = _download_audio_bytes('http://example.com/audio.wav') assert result == b'wav-bytes' - mock_requests.get.assert_called_once_with('http://example.com/audio.wav', timeout=60) + mock_requests.get.assert_called_once_with('http://example.com/audio.wav', timeout=60.0) - @patch('routers.sync.requests') + @patch('routers.sync.httpx') def test_download_failure_returns_none(self, mock_requests): from routers.sync import _download_audio_bytes @@ -1240,16 +1274,27 @@ class TestSpeakerIdExceptionHandling: """Verify process_segment swallows speaker ID exceptions gracefully.""" @staticmethod - def _mock_words(): - return [ - {'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'Hello'}, - {'timestamp': [0.5, 1.0], 'speaker': 'SPEAKER_00', 'text': 'world'}, - ] + def _mock_transcription(): + from models.transcript_segment import TranscriptSegment + + result = MagicMock() + result.provider = 'deepgram' + result.model = 'nova-3' + return MagicMock( + words=[ + {'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'Hello'}, + {'timestamp': [0.5, 1.0], 'speaker': 'SPEAKER_00', 'text': 'world'}, + ], + segments=[TranscriptSegment(text='Hello world', speaker='SPEAKER_00', is_user=False, start=0.0, end=1.0)], + detected_language='en', + run_id='run-1', + result=result, + ) @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') @patch('routers.sync.identify_speakers_for_segments', side_effect=RuntimeError("embedding API down")) @@ -1259,7 +1304,7 @@ def test_speaker_id_exception_does_not_break_processing( ): from routers.sync import process_segment - mock_dg.return_value = (self._mock_words(), 'en') + mock_dg.return_value = self._mock_transcription() mock_process.return_value = MagicMock(id='test-id') cache = {'p1': {'embedding': np.ones((1, 512)), 'name': 'Alice'}} @@ -1313,7 +1358,7 @@ def test_speaker_id_none_normalized_to_zero(self, mock_extract): cache = {'p1': {'embedding': np.ones((1, 512), dtype=np.float32), 'name': 'Alice'}} - seg = _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello', seg_id='s1') + seg = _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello world', seg_id='s1') seg.speaker_id = None # Override to None segments = [seg] @@ -1339,14 +1384,14 @@ def test_diarized_text_match_propagates_to_all_speaker_segments(self, mock_extra segments = [ _make_transcript_segment(speaker_id=2, start=0.0, end=2.0, text='my name is Bob', seg_id='s1'), _make_transcript_segment(speaker_id=2, start=3.0, end=4.0, text='how are you', seg_id='s2'), - _make_transcript_segment(speaker_id=2, start=5.0, end=6.0, text='goodbye', seg_id='s3'), + _make_transcript_segment(speaker_id=2, start=5.0, end=6.0, text='goodbye now', seg_id='s3'), ] audio = _make_wav_bytes(duration_sec=7.0) identify_speakers_for_segments(segments, audio, cache, 'uid1') - # Text detection matched "Bob" on s1 → speaker_to_person_map[2] = p2 - # All speaker_id=2 segments should be assigned via speaker_to_person_map - assert segments[0].person_id == 'p2' - assert segments[1].person_id == 'p2' - assert segments[2].person_id == 'p2' + # Text detection matched "Bob" on s1 as hint metadata only. + assert segments[0].person_id is None + assert segments[0].speaker_identity_text_hints[0]['detected_name'] == 'Bob' + assert segments[1].person_id is None + assert segments[2].person_id is None diff --git a/backend/tests/unit/test_sync_v2.py b/backend/tests/unit/test_sync_v2.py index a206080c67d..9d407f0aee1 100644 --- a/backend/tests/unit/test_sync_v2.py +++ b/backend/tests/unit/test_sync_v2.py @@ -1239,6 +1239,9 @@ def _load_sync_module(): 'utils.encryption', 'utils.stt', 'utils.stt.pre_recorded', + 'utils.stt.provider_service', + 'utils.stt.providers', + 'utils.stt.background_speaker_identity', 'utils.stt.vad', 'utils.fair_use', 'utils.subscription', @@ -1294,6 +1297,8 @@ async def _passthrough_run_blocking(_executor, fn, *args, **kwargs): sys.modules['utils.byok'].get_byok_keys = MagicMock(return_value={}) sys.modules['utils.analytics'].record_usage = MagicMock() sys.modules['models.conversation_enums'].ConversationSource = MagicMock() + sys.modules['utils.stt.providers'].STTWorkload = MagicMock() + sys.modules['utils.stt.background_speaker_identity'].identify_background_speaker_clusters = MagicMock() sys.modules['utils.other.endpoints'].get_current_user_uid = MagicMock(return_value='test-uid') sys.modules['utils.subscription'].has_transcription_credits = MagicMock(return_value=True) @@ -1679,6 +1684,9 @@ def _build_test_app(): 'utils.encryption', 'utils.stt', 'utils.stt.pre_recorded', + 'utils.stt.provider_service', + 'utils.stt.providers', + 'utils.stt.background_speaker_identity', 'utils.stt.vad', 'utils.fair_use', 'utils.subscription', @@ -1721,6 +1729,8 @@ def _submit_with_context(executor, fn, *args, **kwargs): sys.modules['utils.fair_use'].FAIR_USE_ENABLED = False sys.modules['utils.fair_use'].FAIR_USE_RESTRICT_DAILY_DG_MS = 0 sys.modules['utils.subscription'].has_transcription_credits = MagicMock(return_value=True) + sys.modules['utils.stt.providers'].STTWorkload = MagicMock() + sys.modules['utils.stt.background_speaker_identity'].identify_background_speaker_clusters = MagicMock() # Mock auth to return test uid sys.modules['utils.other.endpoints'].get_current_user_uid = MagicMock(return_value='test-uid') @@ -2089,9 +2099,9 @@ def test_executor_worker_counts(self): assert 'max_workers=16' in source assert 'postprocess_executor = MonitoredThreadPoolExecutor(' in source assert 'max_workers=24' in source - from utils.executors import storage_executor - - assert storage_executor._max_workers == 96, "storage_executor must have 96 workers (#7376)" + storage_line = next(line for line in source.splitlines() if line.startswith('storage_executor = ')) + assert 'MonitoredThreadPoolExecutor(' in storage_line + assert 'max_workers=96' in storage_line, "storage_executor must have 96 workers (#7376)" def test_all_executors_in_shutdown(self): source = self._read_executors_source() diff --git a/backend/tests/unit/test_task_sharing.py b/backend/tests/unit/test_task_sharing.py index 969c719200f..a7df888463c 100644 --- a/backend/tests/unit/test_task_sharing.py +++ b/backend/tests/unit/test_task_sharing.py @@ -80,6 +80,7 @@ def _stub_module(name): notif_mod.send_action_item_data_message = MagicMock() notif_mod.send_action_item_update_message = MagicMock() notif_mod.send_action_item_deletion_message = MagicMock() +notif_mod.send_action_items_batch_deletion_message = MagicMock() _stub_module("utils.task_sync") sys.modules["utils.task_sync"].auto_sync_action_item = MagicMock() diff --git a/backend/tests/unit/test_thread_join_elimination.py b/backend/tests/unit/test_thread_join_elimination.py index e7d173b1ad5..2bec2248103 100644 --- a/backend/tests/unit/test_thread_join_elimination.py +++ b/backend/tests/unit/test_thread_join_elimination.py @@ -81,7 +81,7 @@ def test_rag_uses_shared_executor(self): filepath = os.path.join(BACKEND_DIR, 'utils', 'retrieval', 'rag.py') with open(filepath) as f: source = f.read() - assert 'critical_executor' in source + assert 'db_executor' in source def test_sync_uses_shared_executor_or_gather(self): filepath = os.path.join(BACKEND_DIR, 'routers', 'sync.py') @@ -195,12 +195,12 @@ def test_stt_async_uses_httpx_client(self): assert 'get_stt_client' in source, f"{filename} should use shared get_stt_client()" def test_stt_async_offloads_file_io(self): - """Async STT variants should offload file reads via run_in_executor.""" + """Async STT variants should offload file reads via run_blocking.""" for filename in ['speaker_embedding.py', 'vad.py', 'speech_profile.py']: filepath = os.path.join(BACKEND_DIR, 'utils', 'stt', filename) with open(filepath) as f: source = f.read() - assert 'run_in_executor' in source, f"{filename} should offload file I/O via run_in_executor" + assert 'run_blocking(storage_executor' in source, f"{filename} should offload file I/O via storage_executor" class TestAsyncSTTBehavior: @@ -236,8 +236,8 @@ async def test_async_vad_local_fallback(self): import importlib mod = importlib.import_module('utils.stt.vad') - # _local_vad should be called via run_in_executor(critical_executor, ...) - with patch.object(mod, '_local_vad', return_value=[]) as mock_local: + # _run_file_vad should be called via run_blocking(sync_executor, ...) + with patch.object(mod, '_run_file_vad', return_value=[]) as mock_local: result = await mod.async_vad_is_empty('/tmp/nonexistent.wav') mock_local.assert_called_once_with('/tmp/nonexistent.wav') assert result is True # empty segments = True diff --git a/backend/tests/unit/test_users_add_sample_transaction.py b/backend/tests/unit/test_users_add_sample_transaction.py index 59cc08006e5..5c23f916f0c 100644 --- a/backend/tests/unit/test_users_add_sample_transaction.py +++ b/backend/tests/unit/test_users_add_sample_transaction.py @@ -1,4 +1,5 @@ import os +import importlib import sys import types from unittest.mock import MagicMock @@ -17,7 +18,7 @@ class NotFound(Exception): pass -_google_module = sys.modules.setdefault("google", types.ModuleType("google")) +_google_module = importlib.import_module("google") _google_cloud_module = sys.modules.setdefault("google.cloud", types.ModuleType("google.cloud")) _google_exceptions_module = types.ModuleType("google.cloud.exceptions") _google_exceptions_module.NotFound = NotFound diff --git a/backend/tests/unit/test_voice_message_language.py b/backend/tests/unit/test_voice_message_language.py index f9eea84fdf9..87566f6bb7b 100644 --- a/backend/tests/unit/test_voice_message_language.py +++ b/backend/tests/unit/test_voice_message_language.py @@ -2,6 +2,7 @@ Unit tests for voice message language resolution. """ +import importlib import sys import types from unittest.mock import MagicMock @@ -26,7 +27,7 @@ class NotFound(Exception): pass -_google_module = sys.modules.setdefault("google", types.ModuleType("google")) +_google_module = importlib.import_module("google") _google_cloud_module = sys.modules.setdefault("google.cloud", types.ModuleType("google.cloud")) _google_exceptions_module = types.ModuleType("google.cloud.exceptions") _google_exceptions_module.NotFound = NotFound diff --git a/backend/tests/unit/test_ws_auth_handshake.py b/backend/tests/unit/test_ws_auth_handshake.py index 91d14b25209..f8396791a83 100644 --- a/backend/tests/unit/test_ws_auth_handshake.py +++ b/backend/tests/unit/test_ws_auth_handshake.py @@ -7,6 +7,8 @@ """ import asyncio +import sys +import types import unittest from unittest.mock import patch, MagicMock @@ -15,6 +17,21 @@ from firebase_admin.auth import InvalidIdTokenError from starlette.websockets import WebSocketDisconnect +database_mod = types.ModuleType("database") +database_mod.__path__ = [] +sys.modules.setdefault("database", database_mod) + +redis_db_mod = types.ModuleType("database.redis_db") +redis_db_mod.check_rate_limit = MagicMock(return_value=True) +redis_db_mod.try_acquire_listen_lock = MagicMock(return_value=True) +sys.modules.setdefault("database.redis_db", redis_db_mod) +setattr(database_mod, "redis_db", redis_db_mod) + +users_db_mod = types.ModuleType("database.users") +users_db_mod.record_user_platform = MagicMock() +sys.modules.setdefault("database.users", users_db_mod) +setattr(database_mod, "users", users_db_mod) + from utils.other.endpoints import get_current_user_uid_ws_listen, get_current_user_uid_ws, get_current_user_uid diff --git a/backend/utils/subscription.py b/backend/utils/subscription.py index e66601cac3a..9b9b257a7b0 100644 --- a/backend/utils/subscription.py +++ b/backend/utils/subscription.py @@ -33,6 +33,10 @@ # Anything else (ios, android, omi device, phone_call, unknown) is exempt. _TRIAL_PAYWALL_DESKTOP_TOKENS = {"macos", "desktop"} +# Emergency kill switch and optional staged rollout allowlist. +_TRIAL_PAYWALL_ENABLED = os.getenv("TRIAL_PAYWALL_ENABLED", "true").lower() not in {"0", "false", "no"} +_TRIAL_PAYWALL_TEST_UIDS = {uid.strip() for uid in os.getenv("TRIAL_PAYWALL_TEST_UIDS", "").split(",") if uid.strip()} + # Cache the (slow) Firebase Auth + Firestore lookup result for a few minutes # so chat-quota polling doesn't fan out to Firebase on every request. _TRIAL_PAYWALL_CACHE_TTL_SECONDS = 300 @@ -113,6 +117,10 @@ def is_trial_paywalled(uid: str, platform: Optional[str]) -> bool: `source` query param for the listen WebSocket. Mobile (ios/android), Omi devices, and any unknown/missing platform are never paywalled. """ + if not _TRIAL_PAYWALL_ENABLED: + return False + if _TRIAL_PAYWALL_TEST_UIDS and uid not in _TRIAL_PAYWALL_TEST_UIDS: + return False if not platform or platform.lower() not in _TRIAL_PAYWALL_DESKTOP_TOKENS: return False return _is_trial_expired_cached(uid) From 2c9abef30120bad4601a284f57cf9dc6d3113f8b Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 02:42:41 -0400 Subject: [PATCH 23/44] Add optional AssemblyAI BYOK for async prerecorded STT. BYOK users can supply a fifth AssemblyAI key for sync/background/postprocess workloads; when Assembly routing is enabled but no Assembly key is present, Deepgram BYOK is used instead of Omi's server Assembly key. Co-authored-by: Cursor --- .gitignore | 7 +- backend/routers/users.py | 3 +- .../unit/test_byok_assemblyai_routing.py | 62 +++++++++++++++++ backend/tests/unit/test_byok_security.py | 35 ++++++++-- backend/utils/byok.py | 4 +- backend/utils/stt/provider_service.py | 23 ++++++- desktop/.gitignore | 1 + desktop/Backend-Rust/src/byok.rs | 2 + desktop/CHANGELOG.json | 4 +- desktop/Desktop/Sources/APIKeyService.swift | 26 ++++++- desktop/Desktop/Sources/BYOKValidator.swift | 5 ++ .../MainWindow/Pages/SettingsPage.swift | 68 ++++++++++++------- .../Sources/OnboardingBYOKStepView.swift | 9 +-- desktop/run.sh | 32 ++++++--- .../backend/assemblyai_background_rollout.mdx | 21 ++++++ .../backend/listen_pusher_pipeline.mdx | 3 + 16 files changed, 248 insertions(+), 57 deletions(-) create mode 100644 backend/tests/unit/test_byok_assemblyai_routing.py diff --git a/.gitignore b/.gitignore index 13d9ed256b2..c273043e2b7 100644 --- a/.gitignore +++ b/.gitignore @@ -27,10 +27,15 @@ yarn.lock .packages .pub-cache/ .pub/ +.swiftpm/ + +# Build / compile artifacts (cross-ecosystem; prefer these over per-tool *.o/*.d rules) build/ dist/ .build/ -.swiftpm/ +target/ +**/.build-agent-*/ +**/.build-hybrid-*/ # VS Code .vscode/* diff --git a/backend/routers/users.py b/backend/routers/users.py index 286a7a8e146..67417942dc4 100644 --- a/backend/routers/users.py +++ b/backend/routers/users.py @@ -771,6 +771,7 @@ def get_user_usage_stats_endpoint( _SHA256_HEX_RE = re.compile(r'^[a-f0-9]{64}$') _BYOK_REQUIRED_PROVIDERS = {'openai', 'anthropic', 'gemini', 'deepgram'} +_BYOK_ALLOWED_PROVIDERS = _BYOK_REQUIRED_PROVIDERS | {'assemblyai'} class BYOKActivateRequest(BaseModel): @@ -792,7 +793,7 @@ def activate_byok_endpoint(data: BYOKActivateRequest, uid: str = Depends(auth.ge detail=f"Missing fingerprints for providers: {sorted(missing)}", ) for provider, fp in data.fingerprints.items(): - if provider not in _BYOK_REQUIRED_PROVIDERS: + if provider not in _BYOK_ALLOWED_PROVIDERS: raise HTTPException(status_code=400, detail=f"Unknown provider: {provider}") if not _SHA256_HEX_RE.match(fp): raise HTTPException( diff --git a/backend/tests/unit/test_byok_assemblyai_routing.py b/backend/tests/unit/test_byok_assemblyai_routing.py new file mode 100644 index 00000000000..7861a320811 --- /dev/null +++ b/backend/tests/unit/test_byok_assemblyai_routing.py @@ -0,0 +1,62 @@ +import os +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + +for mod_name in ['deepgram', 'deepgram.clients', 'deepgram.clients.live', 'deepgram.clients.live.v1']: + if mod_name not in sys.modules: + sys.modules[mod_name] = MagicMock() + +sys.modules['deepgram'].DeepgramClient = MagicMock +sys.modules['deepgram'].DeepgramClientOptions = MagicMock +sys.modules.setdefault('database._client', types.SimpleNamespace(db=MagicMock())) + +os.environ.setdefault('DEEPGRAM_API_KEY', 'fake-for-test') + +from utils.stt import provider_service # noqa: E402 +from utils.stt.providers import STTProviderName, STTWorkload, get_prerecorded_provider_name # noqa: E402 + + +@pytest.fixture(autouse=True) +def _enable_assemblyai_routing(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync,background,postprocess') + + +def test_env_selects_assemblyai_for_sync(): + assert get_prerecorded_provider_name(STTWorkload.sync) == STTProviderName.assemblyai + + +@patch('utils.stt.provider_service.get_byok_key') +def test_resolve_uses_deepgram_byok_when_no_assembly_header(mock_get_key): + mock_get_key.side_effect = lambda provider: {'deepgram': 'dg-user-key'}.get(provider) + assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.deepgram + + +@patch('utils.stt.provider_service.get_byok_key') +def test_resolve_uses_assemblyai_when_byok_assembly_header_present(mock_get_key): + keys = {'assemblyai': 'aa-user-key', 'deepgram': 'dg-user-key'} + + def _lookup(provider): + return keys.get(provider) + + mock_get_key.side_effect = _lookup + assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.assemblyai + + +@patch('utils.stt.provider_service.get_byok_key', return_value=None) +def test_resolve_uses_server_assembly_when_no_byok_headers(_mock_get_key): + assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.assemblyai + + +@patch('utils.stt.provider_service.get_byok_key') +def test_assemblyai_provider_passes_byok_api_key(mock_get_key): + mock_get_key.return_value = 'aa-user-key' + with patch('utils.stt.provider_service.AssemblyAIAsyncTranscriptionProvider') as mock_cls: + provider_service._assemblyai_prerecorded_provider() + mock_cls.assert_called_once_with(api_key='aa-user-key') + + +# Activation endpoint tests live in test_byok_security.py::TestBYOKActivationValidation diff --git a/backend/tests/unit/test_byok_security.py b/backend/tests/unit/test_byok_security.py index 5353218fce5..3d664a641d6 100644 --- a/backend/tests/unit/test_byok_security.py +++ b/backend/tests/unit/test_byok_security.py @@ -113,7 +113,7 @@ def _make_ws(self, headers: Dict[str, str]) -> MagicMock: ws.headers = headers return ws - def test_extracts_all_four_headers(self): + def test_extracts_all_byok_headers_including_assemblyai(self): from utils.byok import extract_byok_from_websocket ws = self._make_ws( @@ -122,10 +122,17 @@ def test_extracts_all_four_headers(self): 'x-byok-anthropic': 'sk-a', 'x-byok-gemini': 'sk-g', 'x-byok-deepgram': 'sk-d', + 'x-byok-assemblyai': 'sk-aa', } ) keys = extract_byok_from_websocket(ws) - assert keys == {'openai': 'sk-o', 'anthropic': 'sk-a', 'gemini': 'sk-g', 'deepgram': 'sk-d'} + assert keys == { + 'openai': 'sk-o', + 'anthropic': 'sk-a', + 'gemini': 'sk-g', + 'deepgram': 'sk-d', + 'assemblyai': 'sk-aa', + } def test_returns_empty_when_no_headers(self): from utils.byok import extract_byok_from_websocket @@ -400,10 +407,10 @@ def test_transcription_not_bypassed_when_no_deepgram_header(self, mock_users_db, class TestBYOKHeadersConstant: - def test_headers_has_all_four_providers(self): + def test_headers_has_required_and_optional_providers(self): from utils.byok import BYOK_HEADERS - assert set(BYOK_HEADERS.keys()) == {'openai', 'anthropic', 'gemini', 'deepgram'} + assert set(BYOK_HEADERS.keys()) == {'openai', 'anthropic', 'gemini', 'deepgram', 'assemblyai'} def test_headers_are_lowercase(self): from utils.byok import BYOK_HEADERS @@ -550,10 +557,26 @@ def test_deactivation_calls_clear(self, mock_users_db): def test_production_constants_match(self): """Verify the test regex matches the production regex.""" - from routers.users import _SHA256_HEX_RE as prod_re, _BYOK_REQUIRED_PROVIDERS as prod_providers + from routers.users import ( + _BYOK_ALLOWED_PROVIDERS as prod_allowed, + _BYOK_REQUIRED_PROVIDERS as prod_required, + _SHA256_HEX_RE as prod_re, + ) assert prod_re.pattern == _SHA256_HEX_RE.pattern - assert prod_providers == {'openai', 'anthropic', 'gemini', 'deepgram'} + assert prod_required == {'openai', 'anthropic', 'gemini', 'deepgram'} + assert prod_allowed == prod_required | {'assemblyai'} + + @patch('routers.users.users_db') + def test_optional_assemblyai_fingerprint_accepted(self, mock_users_db): + from routers.users import BYOKActivateRequest, activate_byok_endpoint + + fps = self._valid_fingerprints() + fps['assemblyai'] = hashlib.sha256(b'sk-assemblyai').hexdigest() + data = BYOKActivateRequest(fingerprints=fps) + result = activate_byok_endpoint(data, uid='test-uid') + assert result == {'active': True} + mock_users_db.set_byok_active.assert_called_once_with('test-uid', fps) # --------------------------------------------------------------------------- diff --git a/backend/utils/byok.py b/backend/utils/byok.py index 39d355273dd..253ff4b475a 100644 --- a/backend/utils/byok.py +++ b/backend/utils/byok.py @@ -1,7 +1,8 @@ """Per-request BYOK (Bring Your Own Keys) key plumbing. The desktop client sends user-provided API keys as headers on every request -(`X-BYOK-OpenAI`, `X-BYOK-Anthropic`, `X-BYOK-Gemini`, `X-BYOK-Deepgram`). +(`X-BYOK-OpenAI`, `X-BYOK-Anthropic`, `X-BYOK-Gemini`, `X-BYOK-Deepgram`, and +optionally `X-BYOK-AssemblyAI` for async prerecorded STT). A FastAPI middleware stashes them in a per-request contextvar; the LLM/STT clients can then read them without re-reading the request object. @@ -68,6 +69,7 @@ def invalidate_byok_state_cache(uid: str) -> None: 'anthropic': 'x-byok-anthropic', 'gemini': 'x-byok-gemini', 'deepgram': 'x-byok-deepgram', + 'assemblyai': 'x-byok-assemblyai', } # Keys for the current request, if the client supplied them. diff --git a/backend/utils/stt/provider_service.py b/backend/utils/stt/provider_service.py index 557fbbeaf3b..bd38ce79c31 100644 --- a/backend/utils/stt/provider_service.py +++ b/backend/utils/stt/provider_service.py @@ -84,7 +84,24 @@ def _deepgram_prerecorded_provider(): def _assemblyai_prerecorded_provider(): - return AssemblyAIAsyncTranscriptionProvider() + byok = get_byok_key('assemblyai') if get_byok_key else None + return AssemblyAIAsyncTranscriptionProvider(api_key=byok) + + +def resolve_prerecorded_provider_for_request(workload: STTWorkload) -> STTProviderName: + """Pick prerecorded STT provider for this request, respecting BYOK headers. + + When env flags select AssemblyAI but the BYOK user did not supply an Assembly + key, use Deepgram BYOK instead of Omi's server Assembly key. + """ + selected = get_prerecorded_provider_name(workload) + if selected != STTProviderName.assemblyai: + return selected + if get_byok_key and get_byok_key('assemblyai'): + return STTProviderName.assemblyai + if get_byok_key and get_byok_key('deepgram'): + return STTProviderName.deepgram + return STTProviderName.assemblyai def _deepgram_client_for_request() -> DeepgramClient: @@ -129,7 +146,7 @@ def transcribe_url( raw_audio_seconds: float = 0.0, ) -> PrerecordedTranscriptionResponse: workload = STTWorkload(workload) - provider_name = get_prerecorded_provider_name(workload) + provider_name = resolve_prerecorded_provider_for_request(workload) model = _model_for_provider(provider_name, model) provider = _get_prerecorded_provider(provider_name) started_at = datetime.now(timezone.utc) @@ -221,7 +238,7 @@ def transcribe_bytes( raw_audio_seconds: float = 0.0, ) -> PrerecordedTranscriptionResponse: workload = STTWorkload(workload) - provider_name = get_prerecorded_provider_name(workload) + provider_name = resolve_prerecorded_provider_for_request(workload) model = _model_for_provider(provider_name, model) provider = _get_prerecorded_provider(provider_name) started_at = datetime.now(timezone.utc) diff --git a/desktop/.gitignore b/desktop/.gitignore index 4cacc3c6802..7b16057ccd9 100644 --- a/desktop/.gitignore +++ b/desktop/.gitignore @@ -1,6 +1,7 @@ # Build artifacts .build/ build/ +target/ DerivedData/ *.xcodeproj/xcuserdata/ *.xcworkspace/xcuserdata/ diff --git a/desktop/Backend-Rust/src/byok.rs b/desktop/Backend-Rust/src/byok.rs index 3f60e26adc0..7b72ed76bf0 100644 --- a/desktop/Backend-Rust/src/byok.rs +++ b/desktop/Backend-Rust/src/byok.rs @@ -25,6 +25,7 @@ pub const HEADER_OPENAI: &str = "x-byok-openai"; pub const HEADER_ANTHROPIC: &str = "x-byok-anthropic"; pub const HEADER_GEMINI: &str = "x-byok-gemini"; pub const HEADER_DEEPGRAM: &str = "x-byok-deepgram"; +pub const HEADER_ASSEMBLYAI: &str = "x-byok-assemblyai"; /// All four required BYOK headers. Python's `_request_has_all_byok_keys()` checks /// the same set — a fully enrolled BYOK user sends all four on every request. @@ -41,6 +42,7 @@ const HEADER_TO_PROVIDER: &[(&str, &str)] = &[ (HEADER_ANTHROPIC, "anthropic"), (HEADER_GEMINI, "gemini"), (HEADER_DEEPGRAM, "deepgram"), + (HEADER_ASSEMBLYAI, "assemblyai"), ]; /// Heartbeat TTL: BYOK is considered inactive if last_seen_at is older than this. diff --git a/desktop/CHANGELOG.json b/desktop/CHANGELOG.json index f40e118b834..a08eaab3780 100644 --- a/desktop/CHANGELOG.json +++ b/desktop/CHANGELOG.json @@ -1,5 +1,7 @@ { - "unreleased": [], + "unreleased": [ + "Added optional AssemblyAI API key in Developer API Keys for BYOK async transcription (sync, background, postprocess) when enabled server-side; four-key free plan unchanged" + ], "releases": [ { "version": "0.11.411", diff --git a/desktop/Desktop/Sources/APIKeyService.swift b/desktop/Desktop/Sources/APIKeyService.swift index 60f28f19a16..6934135a7a4 100644 --- a/desktop/Desktop/Sources/APIKeyService.swift +++ b/desktop/Desktop/Sources/APIKeyService.swift @@ -20,6 +20,10 @@ enum BYOKProvider: String, CaseIterable { case anthropic case gemini case deepgram + case assemblyai + + /// Providers required for the BYOK free plan (subscription bypass). + static let requiredForFreePlan: [BYOKProvider] = [.openai, .anthropic, .gemini, .deepgram] var storageKey: String { switch self { @@ -27,6 +31,7 @@ enum BYOKProvider: String, CaseIterable { case .anthropic: return "dev_anthropic_api_key" case .gemini: return "dev_gemini_api_key" case .deepgram: return "dev_deepgram_api_key" + case .assemblyai: return "dev_assemblyai_api_key" } } @@ -36,6 +41,7 @@ enum BYOKProvider: String, CaseIterable { case .anthropic: return "X-BYOK-Anthropic" case .gemini: return "X-BYOK-Gemini" case .deepgram: return "X-BYOK-Deepgram" + case .assemblyai: return "X-BYOK-AssemblyAI" } } @@ -45,8 +51,13 @@ enum BYOKProvider: String, CaseIterable { case .anthropic: return "Anthropic" case .gemini: return "Gemini" case .deepgram: return "Deepgram" + case .assemblyai: return "AssemblyAI" } } + + var isRequiredForFreePlan: Bool { + Self.requiredForFreePlan.contains(self) + } } @MainActor final class APIKeyService: ObservableObject { @@ -183,11 +194,22 @@ final class APIKeyService: ObservableObject { nonEmptyStatic(UserDefaults.standard.string(forKey: provider.storageKey)) } - /// True when the user has supplied keys for all four BYOK providers. + /// True when the user has supplied keys for all four required BYOK providers. /// The subscription-bypass gate: when this is true, the user is on the free /// plan and we attach their keys to every backend request. nonisolated static var isByokActive: Bool { - BYOKProvider.allCases.allSatisfy { byokKey($0) != nil } + BYOKProvider.requiredForFreePlan.allSatisfy { byokKey($0) != nil } + } + + /// Fingerprints to send on BYOK activation (required four + optional Assembly when set). + nonisolated static var byokActivationFingerprints: [String: String] { + var out: [String: String] = [:] + for provider in BYOKProvider.allCases { + if let key = byokKey(provider) { + out[provider.rawValue] = byokFingerprint(key) + } + } + return out } /// SHA-256 fingerprint of a key, used by the backend to detect when the diff --git a/desktop/Desktop/Sources/BYOKValidator.swift b/desktop/Desktop/Sources/BYOKValidator.swift index a713f96245a..f942d063bfc 100644 --- a/desktop/Desktop/Sources/BYOKValidator.swift +++ b/desktop/Desktop/Sources/BYOKValidator.swift @@ -43,6 +43,11 @@ enum BYOKValidator { url: URL(string: "https://api.deepgram.com/v1/projects")!, headers: ["Authorization": "Token \(trimmed)"] ) + case .assemblyai: + return await ping( + url: URL(string: "https://api.assemblyai.com/v2/account")!, + headers: ["Authorization": trimmed] + ) } } diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index ebb5b42ff06..482c3937b03 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -377,6 +377,7 @@ struct SettingsContentView: View { @AppStorage("dev_anthropic_api_key") private var devAnthropicKey: String = "" @AppStorage("dev_openai_api_key") private var devOpenAIKey: String = "" @AppStorage("dev_deepgram_api_key") private var devDeepgramKey: String = "" + @AppStorage("dev_assemblyai_api_key") private var devAssemblyAIKey: String = "" @State private var byokKeyStatuses: [BYOKProvider: BYOKValidator.Status] = [:] @State private var byokActivationError: String? @@ -2138,7 +2139,7 @@ struct SettingsContentView: View { .foregroundColor(OmiColors.textPrimary) Text( APIKeyService.isByokActive - ? "You're using your own OpenAI, Anthropic, Gemini, and Deepgram keys. No subscription." + ? "You're using your own OpenAI, Anthropic, Gemini, and Deepgram keys. No subscription. Add an optional AssemblyAI key for async transcription when enabled on our servers." : "Provide your own OpenAI, Anthropic, Gemini, and Deepgram keys to skip the subscription entirely." ) .scaledFont(size: 12) @@ -5318,11 +5319,20 @@ struct SettingsContentView: View { developerKeyField( provider: .deepgram, title: "Deepgram API Key", - subtitle: "For live transcription.", + subtitle: "For live transcription and prerecorded fallback.", settingId: "advanced.devkeys.deepgram", value: $devDeepgramKey ) + developerKeyField( + provider: .assemblyai, + title: "AssemblyAI API Key", + subtitle: + "Optional. For async transcription (sync, background, postprocess) when enabled server-side. Live listening still uses Deepgram.", + settingId: "advanced.devkeys.assemblyai", + value: $devAssemblyAIKey + ) + if let byokActivationError { settingsCard(settingId: "advanced.devkeys.error") { HStack(spacing: 10) { @@ -5354,11 +5364,12 @@ struct SettingsContentView: View { .onChange(of: devAnthropicKey) { _, _ in refreshBYOKActivation() } .onChange(of: devGeminiKey) { _, _ in refreshBYOKActivation() } .onChange(of: devDeepgramKey) { _, _ in refreshBYOKActivation() } + .onChange(of: devAssemblyAIKey) { _, _ in refreshBYOKActivation() } } private var hasAnyBYOKKey: Bool { !devOpenAIKey.isEmpty || !devAnthropicKey.isEmpty || !devGeminiKey.isEmpty - || !devDeepgramKey.isEmpty + || !devDeepgramKey.isEmpty || !devAssemblyAIKey.isEmpty } private var hasAllBYOKKeys: Bool { @@ -5378,8 +5389,8 @@ struct SettingsContentView: View { .foregroundColor(OmiColors.textPrimary) Text( hasAllBYOKKeys - ? "You're paying your own providers. Omi skips the subscription charge. Keys stay on this Mac." - : "Provide all four keys (OpenAI, Anthropic, Gemini, Deepgram) to switch to the free plan. Keys stay on this Mac — we never store them on our servers." + ? "You're paying your own providers. Omi skips the subscription charge. Keys stay on this Mac. AssemblyAI is optional for async transcription." + : "Provide all four keys (OpenAI, Anthropic, Gemini, Deepgram) to switch to the free plan. AssemblyAI is optional. Keys stay on this Mac — we never store them on our servers." ) .scaledFont(size: 12) .foregroundColor(OmiColors.textTertiary) @@ -5394,6 +5405,7 @@ struct SettingsContentView: View { devAnthropicKey = "" devGeminiKey = "" devDeepgramKey = "" + devAssemblyAIKey = "" Task { try? await APIClient.shared.deactivateBYOK() } @@ -5402,33 +5414,39 @@ struct SettingsContentView: View { private func refreshBYOKActivation() { Task { if APIKeyService.isByokActive { - // Validate before flipping the backend flag — otherwise we'd put the - // user on the free plan with dead keys and every chat would 401. - let snapshot = APIKeyService.byokSnapshot.reduce(into: [BYOKProvider: String]()) { - acc, entry in acc[entry.key] = entry.value.key - } - let results = await BYOKValidator.validateAll(snapshot) - let allOk = results.allSatisfy { - if case .ok = $0.value { return true } + // Validate required four before flipping the backend flag. + let requiredSnapshot = BYOKProvider.requiredForFreePlan.reduce(into: [BYOKProvider: String]()) { + acc, provider in + if let key = APIKeyService.byokKey(provider) { + acc[provider] = key + } + } + var keysToValidate = requiredSnapshot + if let assemblyKey = APIKeyService.byokKey(.assemblyai) { + keysToValidate[.assemblyai] = assemblyKey + } + let results = await BYOKValidator.validateAll(keysToValidate) + let requiredOk = BYOKProvider.requiredForFreePlan.allSatisfy { + if case .ok = results[$0] ?? .notChecked { return true } return false } - if allOk { - let fingerprints = APIKeyService.byokSnapshot.reduce(into: [String: String]()) { - acc, entry in acc[entry.key.rawValue] = entry.value.fingerprint - } - try? await APIClient.shared.activateBYOK(fingerprints: fingerprints) + if requiredOk { + try? await APIClient.shared.activateBYOK(fingerprints: APIKeyService.byokActivationFingerprints) await FloatingBarUsageLimiter.shared.fetchPlan() await MainActor.run { - // Clear any sticky paywall flag from a prior `freemium_threshold_reached` - // event — once all 4 BYOK keys validate, the user is on the free BYOK - // plan and shouldn't be locked out of capture/transcription anymore. AppState.current?.isPaywalled = false byokKeyStatuses = results - byokActivationError = nil + if case .failed(let msg) = results[.assemblyai] ?? .notChecked { + byokActivationError = + "Optional AssemblyAI key was rejected: \(msg). Free plan is still active with your four required keys." + } else { + byokActivationError = nil + } } } else { - let failed = results.filter { - if case .ok = $0.value { return false } + let failed = results.filter { provider, status in + guard BYOKProvider.requiredForFreePlan.contains(provider) else { return false } + if case .ok = status { return false } return true } let names = failed.keys.map(\.displayName).sorted().joined(separator: ", ") @@ -5437,7 +5455,7 @@ struct SettingsContentView: View { await MainActor.run { byokKeyStatuses = results byokActivationError = - "Rejected by provider: \(names). Free plan stays off until all 4 keys authenticate." + "Rejected by provider: \(names). Free plan stays off until all 4 required keys authenticate." } } } else { diff --git a/desktop/Desktop/Sources/OnboardingBYOKStepView.swift b/desktop/Desktop/Sources/OnboardingBYOKStepView.swift index e4f1cb8dab8..6c895c1a11e 100644 --- a/desktop/Desktop/Sources/OnboardingBYOKStepView.swift +++ b/desktop/Desktop/Sources/OnboardingBYOKStepView.swift @@ -138,7 +138,7 @@ struct OnboardingBYOKStepView: View { .gemini: geminiKey, .deepgram: deepgramKey, ] - for provider in BYOKProvider.allCases { + for provider in keysToCheck.keys { keyStatuses[provider] = .checking } let results = await BYOKValidator.validateAll(keysToCheck) @@ -156,12 +156,7 @@ struct OnboardingBYOKStepView: View { // Step 2: all four authenticate — flip the backend flag. do { - try await APIClient.shared.activateBYOK(fingerprints: BYOKProvider.allCases.reduce(into: [:]) { - acc, provider in - if let key = APIKeyService.byokKey(provider) { - acc[provider.rawValue] = APIKeyService.byokFingerprint(key) - } - }) + try await APIClient.shared.activateBYOK(fingerprints: APIKeyService.byokActivationFingerprints) // Refresh the in-memory quota snapshot — otherwise the client keeps // blocking chat against the stale basic-tier 30-message cap. await FloatingBarUsageLimiter.shared.fetchPlan() diff --git a/desktop/run.sh b/desktop/run.sh index f4f3d58d3a7..df0fe572eab 100755 --- a/desktop/run.sh +++ b/desktop/run.sh @@ -44,7 +44,9 @@ fi # ─── YOLO mode: use prod backend, zero local setup ─────────────────── # WARNING: Temporary shortcut while desktop dev setup is being cleaned up. # Will be removed once all desktop slop is fixed. +YOLO_MODE=0 if [ "$1" = "--yolo" ]; then + YOLO_MODE=1 echo "" echo "==========================================" echo " YOLO MODE — using production backend" @@ -278,6 +280,14 @@ fi if [ -f "$BACKEND_DIR/.env" ]; then set -a; source "$BACKEND_DIR/.env"; set +a fi +# YOLO must win over Backend-Rust/.env (e.g. OMI_PYTHON_API_URL=http://localhost:8080) +if [ "$YOLO_MODE" = "1" ]; then + export OMI_SKIP_BACKEND=1 + export OMI_SKIP_TUNNEL=1 + export OMI_DESKTOP_API_URL="https://desktop-backend-hhibjajaja-uc.a.run.app" + export OMI_PYTHON_API_URL="https://api.omi.me" + export FIREBASE_API_KEY="AIzaSyD9dzBdglc7IO9pPDIOvqnCoTis_xKkkC8" +fi # Read backend PORT from env (default: 10201, never use 8080) BACKEND_PORT="${PORT:-10201}" @@ -508,18 +518,20 @@ if ! grep -q "^FIREBASE_API_KEY=" "$APP_BUNDLE/Contents/Resources/.env"; then fi # Bootstrap OMI_PYTHON_API_URL — main Omi Python backend (auth, subscriptions, payments, transcription) # Do NOT fall back to OMI_DESKTOP_API_URL — that's the Rust desktop-backend which doesn't serve these routes -if ! grep -q "^OMI_PYTHON_API_URL=" "$APP_BUNDLE/Contents/Resources/.env"; then - PYTHON_API_URL="${OMI_PYTHON_API_URL:-}" - if [ -z "$PYTHON_API_URL" ] && [ -f "$BACKEND_DIR/.env" ]; then - PYTHON_API_URL=$(grep "^OMI_PYTHON_API_URL=" "$BACKEND_DIR/.env" | head -1 | cut -d= -f2-) - fi - if [ -z "$PYTHON_API_URL" ]; then - PYTHON_API_URL="https://api.omi.me" - substep "OMI_PYTHON_API_URL not set — defaulting to production: $PYTHON_API_URL" - fi +PYTHON_API_URL="${OMI_PYTHON_API_URL:-}" +if [ -z "$PYTHON_API_URL" ] && [ -f "$BACKEND_DIR/.env" ]; then + PYTHON_API_URL=$(grep "^OMI_PYTHON_API_URL=" "$BACKEND_DIR/.env" | head -1 | cut -d= -f2-) +fi +if [ -z "$PYTHON_API_URL" ]; then + PYTHON_API_URL="https://api.omi.me" + substep "OMI_PYTHON_API_URL not set — defaulting to production: $PYTHON_API_URL" +fi +if grep -q "^OMI_PYTHON_API_URL=" "$APP_BUNDLE/Contents/Resources/.env"; then + sed -i '' "s|^OMI_PYTHON_API_URL=.*|OMI_PYTHON_API_URL=$PYTHON_API_URL|" "$APP_BUNDLE/Contents/Resources/.env" +else echo "OMI_PYTHON_API_URL=$PYTHON_API_URL" >> "$APP_BUNDLE/Contents/Resources/.env" - substep "Set OMI_PYTHON_API_URL=$PYTHON_API_URL" fi +substep "OMI_PYTHON_API_URL=$PYTHON_API_URL" substep "Copying app icon" cp -f omi_icon.icns "$APP_BUNDLE/Contents/Resources/OmiIcon.icns" 2>/dev/null || true diff --git a/docs/doc/developer/backend/assemblyai_background_rollout.mdx b/docs/doc/developer/backend/assemblyai_background_rollout.mdx index 96ae52c3187..5d34fe12bee 100644 --- a/docs/doc/developer/backend/assemblyai_background_rollout.mdx +++ b/docs/doc/developer/backend/assemblyai_background_rollout.mdx @@ -54,6 +54,27 @@ Rollback is to set `ASSEMBLYAI_BACKGROUND_STT_ENABLED=false` or remove the affected workload from `ASSEMBLYAI_BACKGROUND_STT_WORKLOADS`. No client deploy is required for rollback. +## BYOK (Bring Your Own Keys) + +Desktop BYOK users still require four keys (OpenAI, Anthropic, Gemini, Deepgram) +for the free plan. AssemblyAI is an **optional fifth** key (`X-BYOK-AssemblyAI`). + +**Deploy backend routing before enabling Assembly for BYOK cohorts.** `resolve_prerecorded_provider_for_request()` in `backend/utils/stt/provider_service.py` applies: + +| Flags select Assembly? | Request headers | Provider used | +| --- | --- | --- | +| No | any | Deepgram (BYOK or server) | +| Yes | `x-byok-assemblyai` | AssemblyAI with user key | +| Yes | `x-byok-deepgram` only | Deepgram BYOK (avoids billing Omi's Assembly key) | +| Yes | neither | AssemblyAI with server `ASSEMBLYAI_API_KEY` | + +If a user enrolls an optional Assembly fingerprint in Firestore, they must send +`x-byok-assemblyai` on requests that send other BYOK headers. Clearing the +Assembly key requires re-activation with fingerprints that omit `assemblyai`. + +Listen/realtime STT is unchanged (Deepgram only). Assembly BYOK does not affect +`/v4/listen`. + ## Storage And Identity Provider output is normalized into `ProviderTranscriptResult` and reconstructed diff --git a/docs/doc/developer/backend/listen_pusher_pipeline.mdx b/docs/doc/developer/backend/listen_pusher_pipeline.mdx index c38071290c2..d1bc98a8ef1 100644 --- a/docs/doc/developer/backend/listen_pusher_pipeline.mdx +++ b/docs/doc/developer/backend/listen_pusher_pipeline.mdx @@ -278,6 +278,9 @@ speaker identity is applied after provider normalization with the Omi embedding service. Low-confidence matches remain explicit unknown speakers; provider cluster metadata is preserved separately from `person_id` and `is_user`. +Optional desktop BYOK for AssemblyAI applies only to async prerecorded workloads +above; listen/realtime streaming in this pipeline remains Deepgram-only. + ## 6. Event Wire Protocol ### Server → Client (JSON over WS text frames) From d9e505320ee4b4efb125a907c9bfd39a2828c65f Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 03:29:23 -0400 Subject: [PATCH 24/44] Add desktop AssemblyAI batch background transcription path. Wire desktop Audio Recording to POST /v2/desktop/background-transcribe via chunker/session queue, add backend endpoints and e2e script, and include an agent prompt for multi-chunk non-desktop E2E verification. Co-authored-by: Cursor --- .gitignore | 4 +- backend/main.py | 2 + backend/routers/desktop_background.py | 276 ++++++++++++ backend/routers/transcribe.py | 58 +-- backend/run-local.sh | 5 + .../unit/test_byok_assemblyai_routing.py | 19 + .../test_desktop_background_transcribe.py | 298 +++++++++++++ .../utils/conversations/desktop_background.py | 103 +++++ backend/utils/rate_limit_config.py | 1 + desktop/Desktop/Sources/APIClient.swift | 19 + desktop/Desktop/Sources/AppState.swift | 406 +++++++++++++++++- .../BackgroundAudioChunker.swift | 138 ++++++ .../BackgroundTranscriptMerger.swift | 146 +++++++ ...BackgroundTranscriptionConfiguration.swift | 29 ++ .../BackgroundTranscriptionRoutingGuard.swift | 25 ++ .../CloudBackgroundTranscriptionSession.swift | 110 +++++ .../SpeakerSegmentReducer.swift | 84 ++++ .../MainWindow/Pages/SettingsPage.swift | 112 +++-- .../Sources/TranscriptionService.swift | 100 +++++ .../Tests/BackgroundTranscriptionTests.swift | 251 +++++++++++ .../backend/assemblyai_background_rollout.mdx | 19 +- .../backend/listen_pusher_pipeline.mdx | 14 +- .../ASSEMBLYAI_BACKGROUND_E2E_AGENT_PROMPT.md | 168 ++++++++ scripts/desktop_assemblyai_e2e.py | 278 ++++++++++++ 24 files changed, 2561 insertions(+), 104 deletions(-) create mode 100644 backend/routers/desktop_background.py create mode 100755 backend/run-local.sh create mode 100644 backend/tests/unit/test_desktop_background_transcribe.py create mode 100644 backend/utils/conversations/desktop_background.py create mode 100644 desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift create mode 100644 desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptMerger.swift create mode 100644 desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift create mode 100644 desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionRoutingGuard.swift create mode 100644 desktop/Desktop/Sources/BackgroundTranscription/CloudBackgroundTranscriptionSession.swift create mode 100644 desktop/Desktop/Sources/BackgroundTranscription/SpeakerSegmentReducer.swift create mode 100644 desktop/Desktop/Tests/BackgroundTranscriptionTests.swift create mode 100644 scripts/ASSEMBLYAI_BACKGROUND_E2E_AGENT_PROMPT.md create mode 100755 scripts/desktop_assemblyai_e2e.py diff --git a/.gitignore b/.gitignore index c273043e2b7..4ea395c4b38 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,9 @@ myenv/ venv/ .DS_Store dump/ -/scripts/ +/scripts/* +!/scripts/desktop_assemblyai_e2e.py +!/scripts/ASSEMBLYAI_BACKGROUND_E2E_AGENT_PROMPT.md *.zip *.wav node_modules diff --git a/backend/main.py b/backend/main.py index 3ecd2546125..2ef585b5bf6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -14,6 +14,7 @@ from routers import ( chat, + desktop_background, firmware, transcribe, notifications, @@ -86,6 +87,7 @@ app = FastAPI() app.include_router(transcribe.router) +app.include_router(desktop_background.router) app.include_router(conversations.router) app.include_router(action_items.router) app.include_router(task_integrations.router) diff --git a/backend/routers/desktop_background.py b/backend/routers/desktop_background.py new file mode 100644 index 00000000000..0242ca1fa1d --- /dev/null +++ b/backend/routers/desktop_background.py @@ -0,0 +1,276 @@ +import json +import logging +from datetime import datetime, timezone +from typing import Dict, List, Optional + +from fastapi import APIRouter, Depends, Header, HTTPException, Request +from pydantic import BaseModel + +import database.conversations as conversations_db +from database import redis_db +from models.conversation_enums import ConversationSource, ConversationStatus +from models.transcript_segment import TranscriptSegment +from utils.analytics import record_usage +from utils.chat import resolve_voice_message_language +from utils.conversations.desktop_background import ( + append_segments_to_in_progress_conversation, + create_in_progress_desktop_conversation, +) +from utils.executors import db_executor, run_blocking, sync_executor +from utils.fair_use import is_hard_restricted, record_speech_ms +from utils.other import endpoints as auth +from utils.speaker_identification import _pcm_to_wav_bytes +from utils.stt.provider_service import transcribe_bytes +from utils.stt.providers import STTWorkload +from utils.subscription import has_transcription_credits, is_trial_paywalled +from utils.voice_duration_limiter import compute_pcm_duration_ms + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/v2/desktop", tags=["desktop-background"]) + +_MAX_PCM_BODY_BYTES = 200_000_000 +_SPEAKER_MAP_TTL_SECONDS = 60 * 60 * 24 + + +class BackgroundConversationStartRequest(BaseModel): + language: Optional[str] = None + source: Optional[str] = "desktop" + + +@router.post("/background-conversation/start") +async def start_background_conversation( + body: BackgroundConversationStartRequest, + uid: str = Depends(auth.get_current_user_uid), + x_app_platform: Optional[str] = Header(None, alias='X-App-Platform'), +): + if is_trial_paywalled(uid, x_app_platform or body.source or "desktop"): + raise HTTPException(status_code=402, detail={'error': 'quota_exceeded', 'plan_type': 'basic'}) + + language = resolve_voice_message_language(uid, body.language) + try: + source = ConversationSource(body.source or "desktop") + except ValueError: + raise HTTPException(status_code=422, detail='Invalid source') + + conversation_id = await run_blocking( + db_executor, + create_in_progress_desktop_conversation, + uid, + language, + source, + ) + return {"conversation_id": conversation_id} + + +@router.post("/background-transcribe") +async def background_transcribe( + request: Request, + uid: str = Depends(auth.with_rate_limit(auth.get_current_user_uid, "desktop:background_transcribe")), + x_app_platform: Optional[str] = Header(None, alias='X-App-Platform'), +): + if is_trial_paywalled(uid, x_app_platform or "desktop"): + raise HTTPException(status_code=402, detail={'error': 'quota_exceeded', 'plan_type': 'basic'}) + if await run_blocking(db_executor, is_hard_restricted, uid): + raise HTTPException(status_code=429, detail='Transcription temporarily restricted') + if not await run_blocking(db_executor, has_transcription_credits, uid, source="desktop"): + raise HTTPException(status_code=429, detail='Transcription credits exhausted') + + content_length = request.headers.get("content-length") + if content_length: + try: + content_length_value = int(content_length) + except ValueError: + raise HTTPException(status_code=422, detail='content-length must be an integer') + if content_length_value > _MAX_PCM_BODY_BYTES: + raise HTTPException(status_code=413, detail=f'Body too large (max {_MAX_PCM_BODY_BYTES} bytes)') + + audio_bytes = await request.body() + if not audio_bytes: + raise HTTPException(status_code=400, detail='No audio data provided') + if len(audio_bytes) > _MAX_PCM_BODY_BYTES: + del audio_bytes + raise HTTPException(status_code=413, detail=f'Body too large (max {_MAX_PCM_BODY_BYTES} bytes)') + + try: + sample_rate = int(request.query_params.get("sample_rate", "16000")) + channels = int(request.query_params.get("channels", "1")) + except ValueError: + del audio_bytes + raise HTTPException(status_code=422, detail='sample_rate and channels must be integers') + + if sample_rate < 8000 or sample_rate > 48000: + del audio_bytes + raise HTTPException(status_code=422, detail='sample_rate must be between 8000 and 48000') + if channels != 1: + del audio_bytes + raise HTTPException(status_code=422, detail='channels must be 1') + + chunk_start_ms_raw = request.query_params.get("chunk_start_ms") + if chunk_start_ms_raw is None: + del audio_bytes + raise HTTPException(status_code=400, detail='chunk_start_ms is required') + try: + chunk_start_ms = int(chunk_start_ms_raw) + except ValueError: + del audio_bytes + raise HTTPException(status_code=422, detail='chunk_start_ms must be an integer') + if chunk_start_ms < 0: + del audio_bytes + raise HTTPException(status_code=422, detail='chunk_start_ms must be non-negative') + + conversation_id = request.query_params.get("conversation_id") + persist = _parse_persist(request.query_params.get("persist"), default=bool(conversation_id)) + if persist: + if not conversation_id: + del audio_bytes + raise HTTPException(status_code=400, detail='conversation_id is required when persist=true') + await _validate_in_progress_conversation(uid, conversation_id) + + language = resolve_voice_message_language(uid, request.query_params.get("language")) + keywords = _parse_context_keywords(request.query_params.get("keywords")) + encoding = request.query_params.get("encoding", "linear16") + duration_ms = compute_pcm_duration_ms(len(audio_bytes), sample_rate, channels) + duration_sec = duration_ms / 1000.0 + + try: + audio_for_stt = _pcm_to_wav_bytes(audio_bytes, sample_rate) if encoding == "linear16" else audio_bytes + response = await run_blocking( + sync_executor, + transcribe_bytes, + audio_for_stt, + workload=STTWorkload.background, + uid=uid, + conversation_id=conversation_id, + sample_rate=sample_rate, + diarize=True, + encoding=None if encoding == "linear16" else encoding, + channels=channels, + language=language, + return_language=language == "multi", + keywords=keywords, + raw_audio_seconds=duration_sec, + ) + except RuntimeError as e: + logger.error("Desktop background transcription failed: %s", e) + raise HTTPException(status_code=500, detail=f'Transcription failed: {str(e)}') + finally: + del audio_bytes + + segments = response.segments + _apply_chunk_offset(segments, chunk_start_ms / 1000.0) + if conversation_id: + _apply_speaker_ids(conversation_id, segments) + + finished_at = datetime.now(timezone.utc) + if persist and conversation_id: + await run_blocking( + db_executor, + append_segments_to_in_progress_conversation, + uid, + conversation_id, + segments, + finished_at, + ) + + await run_blocking(db_executor, record_speech_ms, uid, duration_ms, source='background') + await run_blocking(db_executor, record_usage, uid, transcription_seconds=duration_sec, speech_seconds=duration_sec) + provider = response.result.provider if response.result else None + logger.info( + "desktop_background_transcribe completed uid=%s conversation_id=%s workload=background provider=%s run_id=%s " + "chunk_start_ms=%s chunk_duration_ms=%s segments=%s persisted=%s", + uid, + conversation_id, + provider, + response.run_id, + chunk_start_ms, + duration_ms, + len(segments), + bool(persist and conversation_id), + ) + + return { + "segments": [segment.model_dump() for segment in segments], + "language": response.detected_language or language, + "provider": provider, + "run_id": response.run_id, + "chunk_duration_ms": duration_ms, + } + + +def _parse_persist(raw: Optional[str], default: bool) -> bool: + if raw is None: + return default + return raw.strip().lower() not in ("0", "false", "no") + + +def _parse_context_keywords(raw: Optional[str]) -> List[str]: + if not raw: + return [] + keywords: List[str] = [] + seen = set() + for item in raw.split(','): + keyword = item.strip() + if len(keyword) < 2 or len(keyword) > 80: + continue + dedupe_key = keyword.lower() + if dedupe_key in seen: + continue + seen.add(dedupe_key) + keywords.append(keyword) + if len(keywords) >= 100: + break + return keywords + + +async def _validate_in_progress_conversation(uid: str, conversation_id: str) -> None: + conversation = await run_blocking(db_executor, conversations_db.get_conversation, uid, conversation_id) + if not conversation: + raise HTTPException(status_code=404, detail='conversation_id not found') + if conversation.get('status') != ConversationStatus.in_progress: + raise HTTPException(status_code=404, detail='conversation is not in_progress') + + +def _apply_chunk_offset(segments: List[TranscriptSegment], offset_sec: float) -> None: + for segment in segments: + segment.start += offset_sec + segment.end += offset_sec + + +def _apply_speaker_ids(conversation_id: str, segments: List[TranscriptSegment]) -> None: + speaker_map = _load_speaker_map(conversation_id) + changed = False + for segment in segments: + cluster = segment.provider_cluster_id or segment.provider_speaker_label or segment.speaker + if not cluster: + continue + if cluster not in speaker_map: + speaker_map[cluster] = len(speaker_map) + changed = True + speaker_id = speaker_map[cluster] + segment.speaker_id = speaker_id + segment.speaker = f"SPEAKER_{speaker_id:02d}" + if segment.speaker_identity_state == "legacy_ambiguous": + segment.speaker_identity_state = "unassigned" + + if changed: + _store_speaker_map(conversation_id, speaker_map) + + +def _speaker_map_key(conversation_id: str) -> str: + return f"desktop_batch_speaker_map:{conversation_id}" + + +def _load_speaker_map(conversation_id: str) -> Dict[str, int]: + raw = redis_db.r.get(_speaker_map_key(conversation_id)) + if not raw: + return {} + try: + return {str(key): int(value) for key, value in json.loads(raw).items()} + except (TypeError, ValueError, json.JSONDecodeError): + logger.warning("Invalid desktop batch speaker map for conversation_id=%s", conversation_id) + return {} + + +def _store_speaker_map(conversation_id: str, speaker_map: Dict[str, int]) -> None: + redis_db.r.set(_speaker_map_key(conversation_id), json.dumps(speaker_map), ex=_SPEAKER_MAP_TTL_SECONDS) diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 532a83fc5b0..520ed8cae39 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -34,7 +34,6 @@ should_update_speaker_to_person_map, ) import database.conversations as conversations_db -import database.calendar_meetings as calendar_db import database.users as user_db from utils.byok import get_byok_keys, extract_byok_from_websocket, set_byok_keys from database.users import get_user_transcription_preferences @@ -43,8 +42,8 @@ from models.conversation import Conversation from models.conversation_enums import ConversationSource, ConversationStatus from utils.conversations.factory import deserialize_conversation +from utils.conversations.desktop_background import create_in_progress_desktop_conversation from models.conversation_photo import ConversationPhoto -from models.structured import Structured from models.transcript_segment import TranscriptSegment from models.message_event import ( ConversationEvent, @@ -798,61 +797,14 @@ async def _create_new_in_progress_conversation(): logger.error(f"Invalid conversation source '{source}', defaulting to 'omi' {uid} {session_id}") conversation_source = ConversationSource.omi - new_conversation_id = str(uuid.uuid4()) - stub_conversation = Conversation( - id=new_conversation_id, - created_at=datetime.now(timezone.utc), - started_at=datetime.now(timezone.utc), - finished_at=datetime.now(timezone.utc), - structured=Structured(), - language=language, - transcript_segments=[], - photos=[], - status=ConversationStatus.in_progress, + new_conversation_id = create_in_progress_desktop_conversation( + uid, + language, source=conversation_source, private_cloud_sync_enabled=private_cloud_sync_enabled, call_id=call_id if is_multi_channel else None, + session_id=session_id, ) - conversations_db.upsert_conversation(uid, conversation_data=stub_conversation.dict()) - redis_db.set_in_progress_conversation_id(uid, new_conversation_id) - - detected_meeting_id = None - - # Only check for meetings if source is desktop - if conversation_source == ConversationSource.desktop: - now = datetime.now(timezone.utc) - # Check ±2 minute window - time_window = timedelta(minutes=2) - start_range = now - time_window - end_range = now + time_window - - meetings = calendar_db.get_meetings_in_time_range(uid, start_range, end_range) - - if len(meetings) == 1: - # Exactly one meeting found - detected_meeting_id = meetings[0]['id'] - elif len(meetings) > 1: - closest_meeting = None - smallest_diff = None - - for meeting in meetings: - # Calculate absolute time difference between meeting start and now - time_diff = abs((meeting['start_time'] - now).total_seconds()) - - if smallest_diff is None or time_diff < smallest_diff: - smallest_diff = time_diff - closest_meeting = meeting - - if closest_meeting: - detected_meeting_id = closest_meeting['id'] - logger.info( - f"Selected closest meeting: {closest_meeting['title']} (diff: {smallest_diff}s) {uid} {session_id}" - ) - - # Store meeting association if auto-detected - if detected_meeting_id: - redis_db.set_conversation_meeting_id(new_conversation_id, detected_meeting_id) - current_conversation_id = new_conversation_id logger.info(f"Created new stub conversation: {new_conversation_id} {uid} {session_id}") diff --git a/backend/run-local.sh b/backend/run-local.sh new file mode 100755 index 00000000000..6080e983e30 --- /dev/null +++ b/backend/run-local.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -euo pipefail +cd "$(dirname "$0")" +export DYLD_FALLBACK_LIBRARY_PATH="/opt/homebrew/lib:${DYLD_FALLBACK_LIBRARY_PATH:-}" +exec ./venv/bin/uvicorn main:app --host 127.0.0.1 --port 8080 --env-file .env diff --git a/backend/tests/unit/test_byok_assemblyai_routing.py b/backend/tests/unit/test_byok_assemblyai_routing.py index 7861a320811..08e56083c08 100644 --- a/backend/tests/unit/test_byok_assemblyai_routing.py +++ b/backend/tests/unit/test_byok_assemblyai_routing.py @@ -35,6 +35,12 @@ def test_resolve_uses_deepgram_byok_when_no_assembly_header(mock_get_key): assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.deepgram +@patch('utils.stt.provider_service.get_byok_key') +def test_resolve_uses_deepgram_byok_for_background_when_no_assembly_header(mock_get_key): + mock_get_key.side_effect = lambda provider: {'deepgram': 'dg-user-key'}.get(provider) + assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.background) == STTProviderName.deepgram + + @patch('utils.stt.provider_service.get_byok_key') def test_resolve_uses_assemblyai_when_byok_assembly_header_present(mock_get_key): keys = {'assemblyai': 'aa-user-key', 'deepgram': 'dg-user-key'} @@ -46,6 +52,19 @@ def _lookup(provider): assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.assemblyai +@patch('utils.stt.provider_service.get_byok_key') +def test_resolve_uses_assemblyai_for_background_when_byok_assembly_header_present(mock_get_key): + keys = {'assemblyai': 'aa-user-key', 'deepgram': 'dg-user-key'} + + def _lookup(provider): + return keys.get(provider) + + mock_get_key.side_effect = _lookup + assert ( + provider_service.resolve_prerecorded_provider_for_request(STTWorkload.background) == STTProviderName.assemblyai + ) + + @patch('utils.stt.provider_service.get_byok_key', return_value=None) def test_resolve_uses_server_assembly_when_no_byok_headers(_mock_get_key): assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.assemblyai diff --git a/backend/tests/unit/test_desktop_background_transcribe.py b/backend/tests/unit/test_desktop_background_transcribe.py new file mode 100644 index 00000000000..f305a6e0cd9 --- /dev/null +++ b/backend/tests/unit/test_desktop_background_transcribe.py @@ -0,0 +1,298 @@ +import io +import os +import sys +from types import SimpleNamespace +import wave +from unittest.mock import MagicMock + +os.environ.setdefault('DEEPGRAM_API_KEY', 'fake-for-test') + +for mod_name in ['deepgram', 'deepgram.clients', 'deepgram.clients.live', 'deepgram.clients.live.v1']: + if mod_name not in sys.modules: + sys.modules[mod_name] = MagicMock() + +sys.modules['deepgram'].DeepgramClient = MagicMock +sys.modules['deepgram'].DeepgramClientOptions = MagicMock + +if 'google.api_core.exceptions' not in sys.modules: + sys.modules['google'] = MagicMock() + sys.modules['google.api_core'] = MagicMock() + sys.modules['google.api_core.exceptions'] = MagicMock(NotFound=Exception) + sys.modules['google.cloud'] = MagicMock() + sys.modules['google.cloud.firestore'] = MagicMock() + sys.modules['google.cloud.firestore_v1'] = MagicMock(FieldFilter=MagicMock) +sys.modules['google.cloud.firestore'].Increment = lambda value: value +database_client_stub = sys.modules.setdefault('database._client', SimpleNamespace()) +database_client_stub.db = getattr(database_client_stub, 'db', MagicMock()) +database_client_stub.document_id_from_seed = getattr(database_client_stub, 'document_id_from_seed', MagicMock()) +sys.modules.setdefault('database.conversations', MagicMock()) +sys.modules.setdefault('database.redis_db', SimpleNamespace(r=MagicMock())) +sys.modules.setdefault( + 'utils.conversations.desktop_background', + SimpleNamespace( + append_segments_to_in_progress_conversation=MagicMock(), + create_in_progress_desktop_conversation=MagicMock(return_value='conv-1'), + ), +) +sys.modules.setdefault( + 'utils.chat', SimpleNamespace(resolve_voice_message_language=lambda _uid, language: language or 'en') +) +sys.modules.setdefault('utils.analytics', SimpleNamespace(record_usage=lambda *_args, **_kwargs: None)) +sys.modules.setdefault( + 'utils.fair_use', + SimpleNamespace( + is_hard_restricted=lambda *_args, **_kwargs: False, + record_speech_ms=lambda *_args, **_kwargs: None, + ), +) +sys.modules.setdefault( + 'utils.subscription', + SimpleNamespace( + has_transcription_credits=lambda *_args, **_kwargs: True, + is_trial_paywalled=lambda *_args, **_kwargs: False, + ), +) + + +def _pcm_to_wav_bytes(pcm_data: bytes, sample_rate: int) -> bytes: + buffer = io.BytesIO() + with wave.open(buffer, 'wb') as wav: + wav.setnchannels(1) + wav.setsampwidth(2) + wav.setframerate(sample_rate) + wav.writeframes(pcm_data) + return buffer.getvalue() + + +sys.modules.setdefault('utils.speaker_identification', SimpleNamespace(_pcm_to_wav_bytes=_pcm_to_wav_bytes)) +sys.modules.setdefault('utils.stt.provider_service', SimpleNamespace(transcribe_bytes=MagicMock())) +sys.modules.setdefault( + 'utils.voice_duration_limiter', + SimpleNamespace( + compute_pcm_duration_ms=lambda byte_length, sample_rate, channels: int( + byte_length * 1000 / (sample_rate * channels * 2) + ) + ), +) +sys.modules.setdefault('utils.other.hume', MagicMock()) +import utils.other as _utils_other + +setattr(_utils_other, 'hume', sys.modules['utils.other.hume']) +_endpoints_stub = SimpleNamespace( + get_current_user_uid=lambda: 'test-uid', + with_rate_limit=lambda dep, _policy: dep, +) +sys.modules.setdefault('utils.other.endpoints', _endpoints_stub) +setattr(_utils_other, 'endpoints', _endpoints_stub) + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from models.transcript_segment import ProviderTranscriptResult, TranscriptSegment +from routers import desktop_background +from utils.stt.providers import STTProviderName, STTWorkload + +sys.modules.pop('utils.stt.provider_service', None) +if 'utils.stt' in sys.modules and hasattr(sys.modules['utils.stt'], 'provider_service'): + delattr(sys.modules['utils.stt'], 'provider_service') + + +def _client(monkeypatch, *, segments=None): + app = FastAPI() + app.include_router(desktop_background.router) + for route in app.routes: + if hasattr(route, 'dependant'): + for dep in route.dependant.dependencies: + if dep.call is not None: + app.dependency_overrides[dep.call] = lambda: 'test-uid' + + monkeypatch.setattr(desktop_background, 'is_trial_paywalled', lambda *_args, **_kwargs: False) + monkeypatch.setattr(desktop_background, 'is_hard_restricted', lambda *_args, **_kwargs: False) + monkeypatch.setattr(desktop_background, 'has_transcription_credits', lambda *_args, **_kwargs: True) + monkeypatch.setattr(desktop_background, 'record_speech_ms', lambda *_args, **_kwargs: None) + monkeypatch.setattr(desktop_background, 'record_usage', lambda *_args, **_kwargs: None) + monkeypatch.setattr(desktop_background, 'resolve_voice_message_language', lambda _uid, language: language or 'en') + monkeypatch.setattr( + desktop_background.conversations_db, + 'get_conversation', + lambda _uid, _cid: {'id': _cid, 'status': 'in_progress'}, + ) + monkeypatch.setattr(desktop_background, 'append_segments_to_in_progress_conversation', MagicMock(return_value=[])) + monkeypatch.setattr(desktop_background.redis_db.r, 'get', lambda _key: None) + monkeypatch.setattr(desktop_background.redis_db.r, 'set', lambda *_args, **_kwargs: None) + + default_segments = segments or [ + TranscriptSegment( + id='seg-1', + text='Hello world.', + speaker='SPEAKER_00', + speaker_id=0, + is_user=False, + start=0.5, + end=1.2, + provider_cluster_id='A', + provider_speaker_label='ASSEMBLYAI_SPEAKER_A', + stt_provider='assemblyai', + stt_model='universal-2', + ) + ] + + def _transcribe_bytes(audio_bytes, **_kwargs): + return SimpleNamespace( + result=ProviderTranscriptResult(provider='assemblyai', model='universal-2', words=[], utterances=[]), + detected_language='en', + segments=default_segments, + run_id='run-1', + ) + + mock_transcribe = MagicMock(side_effect=_transcribe_bytes) + monkeypatch.setattr(desktop_background, 'transcribe_bytes', mock_transcribe) + return TestClient(app), mock_transcribe + + +def test_background_transcribe_returns_segments_with_offset(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=12000', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + data = response.json() + assert data['provider'] == 'assemblyai' + assert data['run_id'] == 'run-1' + assert data['segments'][0]['start'] == 12.5 + assert data['segments'][0]['end'] == 13.2 + assert data['segments'][0]['speaker_id'] == 0 + assert data['segments'][0]['speaker'] == 'SPEAKER_00' + + +def test_background_transcribe_wraps_linear16_pcm_as_wav(monkeypatch): + client, mock_transcribe = _client(monkeypatch) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=0&sample_rate=16000', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + audio_arg = mock_transcribe.call_args.args[0] + assert audio_arg[:4] == b'RIFF' + assert b'WAVE' in audio_arg[:16] + assert mock_transcribe.call_args.kwargs['workload'].value == 'background' + + +def test_background_transcribe_persists_segments(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + desktop_background.append_segments_to_in_progress_conversation.assert_called_once() + + +def test_background_transcribe_can_skip_persist_without_conversation(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + record_speech_ms = MagicMock() + record_usage = MagicMock() + monkeypatch.setattr(desktop_background, 'record_speech_ms', record_speech_ms) + monkeypatch.setattr(desktop_background, 'record_usage', record_usage) + + response = client.post( + '/v2/desktop/background-transcribe?chunk_start_ms=0&persist=false', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + desktop_background.append_segments_to_in_progress_conversation.assert_not_called() + record_speech_ms.assert_called_once() + record_usage.assert_called_once() + + +def test_cluster_speaker_mapping_assigns_distinct_ids(monkeypatch): + segments = [ + TranscriptSegment(text='One', is_user=False, start=0.0, end=1.0, provider_cluster_id='A'), + TranscriptSegment(text='Two', is_user=False, start=1.0, end=2.0, provider_cluster_id='B'), + ] + client, _mock_transcribe = _client(monkeypatch, segments=segments) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + data = response.json() + assert [segment['speaker_id'] for segment in data['segments']] == [0, 1] + assert [segment['speaker'] for segment in data['segments']] == ['SPEAKER_00', 'SPEAKER_01'] + + +def test_byok_background_routing_uses_deepgram_when_only_deepgram_key(monkeypatch): + from utils.stt import provider_service + + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync,background,postprocess') + monkeypatch.setattr(provider_service, 'get_byok_key', lambda provider: {'deepgram': 'dg-user-key'}.get(provider)) + + provider = provider_service.resolve_prerecorded_provider_for_request(STTWorkload.background) + + assert provider == STTProviderName.deepgram + + +def test_background_transcribe_rejects_empty_body(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=0', + content=b'', + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 400 + + +def test_background_transcribe_rejects_stereo_pcm_for_v1(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=0&channels=2', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 422 + assert response.json()['detail'] == 'channels must be 1' + + +def test_background_transcribe_rejects_malformed_content_length(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream', 'Content-Length': 'not-an-int'}, + ) + + assert response.status_code == 422 + + +def test_background_transcribe_rejects_invalid_conversation(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + monkeypatch.setattr(desktop_background.conversations_db, 'get_conversation', lambda _uid, _cid: None) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=missing&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 404 diff --git a/backend/utils/conversations/desktop_background.py b/backend/utils/conversations/desktop_background.py new file mode 100644 index 00000000000..0f10df6f8d1 --- /dev/null +++ b/backend/utils/conversations/desktop_background.py @@ -0,0 +1,103 @@ +import logging +import uuid +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +import database.calendar_meetings as calendar_db +import database.conversations as conversations_db +from database import redis_db +from models.conversation import Conversation +from models.conversation_enums import ConversationSource, ConversationStatus +from models.structured import Structured +from models.transcript_segment import TranscriptSegment +from utils.conversations.factory import deserialize_conversation + +logger = logging.getLogger(__name__) + + +def create_in_progress_desktop_conversation( + uid: str, + language: str, + source: ConversationSource = ConversationSource.desktop, + private_cloud_sync_enabled: bool = False, + call_id: Optional[str] = None, + session_id: Optional[str] = None, +) -> str: + """Create a desktop/listen in-progress conversation and Redis pointer.""" + new_conversation_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + stub_conversation = Conversation( + id=new_conversation_id, + created_at=now, + started_at=now, + finished_at=now, + structured=Structured(), + language=language, + transcript_segments=[], + photos=[], + status=ConversationStatus.in_progress, + source=source, + private_cloud_sync_enabled=private_cloud_sync_enabled, + call_id=call_id, + ) + conversations_db.upsert_conversation(uid, conversation_data=stub_conversation.dict()) + redis_db.set_in_progress_conversation_id(uid, new_conversation_id) + + detected_meeting_id = _detect_current_desktop_meeting(uid) if source == ConversationSource.desktop else None + if detected_meeting_id: + redis_db.set_conversation_meeting_id(new_conversation_id, detected_meeting_id) + + logger.info( + "Created new in-progress conversation: %s uid=%s session=%s source=%s", + new_conversation_id, + uid, + session_id, + source.value, + ) + return new_conversation_id + + +def append_segments_to_in_progress_conversation( + uid: str, + conversation_id: str, + segments: List[TranscriptSegment], + finished_at: datetime, +) -> List[TranscriptSegment]: + """Append transcript segments to the in-progress conversation and bump finished_at.""" + conversation_data = conversations_db.get_conversation(uid, conversation_id) + if not conversation_data: + raise ValueError("conversation not found") + + conversation = deserialize_conversation(conversation_data) + if conversation.status != ConversationStatus.in_progress: + raise ValueError("conversation is not in_progress") + + if not segments: + conversations_db.update_conversation_finished_at(uid, conversation_id, finished_at) + return [] + + conversation.transcript_segments, updated_segments, _removed_ids = TranscriptSegment.combine_segments( + conversation.transcript_segments, + segments, + ) + conversations_db.update_conversation_segments( + uid, + conversation.id, + [segment.dict() for segment in conversation.transcript_segments], + finished_at=finished_at, + ) + return updated_segments + + +def _detect_current_desktop_meeting(uid: str) -> Optional[str]: + now = datetime.now(timezone.utc) + time_window = timedelta(minutes=2) + meetings = calendar_db.get_meetings_in_time_range(uid, now - time_window, now + time_window) + + if len(meetings) == 1: + return meetings[0]['id'] + if len(meetings) <= 1: + return None + + closest_meeting = min(meetings, key=lambda meeting: abs((meeting['start_time'] - now).total_seconds())) + return closest_meeting['id'] diff --git a/backend/utils/rate_limit_config.py b/backend/utils/rate_limit_config.py index f4db6bc418d..7c7533dd882 100644 --- a/backend/utils/rate_limit_config.py +++ b/backend/utils/rate_limit_config.py @@ -46,6 +46,7 @@ "voice:transcribe": (60, 3600), "voice:transcribe_stream": (60, 3600), "voice:message": (60, 3600), + "desktop:background_transcribe": (120, 3600), "file:upload": (40, 3600), # Agent/MCP — bursty tool calls "agent:execute_tool": (120, 3600), diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 36d6f8b49e0..5dec9b31530 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -1284,6 +1284,10 @@ extension APIClient { let conversation: ServerConversation } + struct StartBackgroundConversationResponse: Decodable { + let conversation_id: String + } + /// Force-process the current in-progress conversation on the Python backend. /// Endpoint: POST /v1/conversations (Python backend) /// This is the same endpoint the mobile app uses when stopping phone mic recording. @@ -1305,6 +1309,21 @@ extension APIClient { return nil } } + + /// Start a Python-backed in-progress conversation for desktop background batch transcription. + /// Endpoint: POST /v2/desktop/background-conversation/start + func startBackgroundConversation(language: String) async throws -> String { + struct StartBackgroundConversationRequest: Encodable { + let language: String + } + + let response: StartBackgroundConversationResponse = try await post( + "v2/desktop/background-conversation/start", + body: StartBackgroundConversationRequest(language: language), + customBaseURL: nil + ) + return response.conversation_id + } } // MARK: - Memories API diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index 8f02a8bd340..4e8c56f54cc 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -42,6 +42,7 @@ class AppState: ObservableObject { // Transcription state @Published var isTranscribing = false + @Published private(set) var isStartingTranscription = false /// Monotonically increasing counter — incremented each time a new recording starts. /// Used to detect if a new recording began during the post-stop force-process delay. private(set) var recordingGeneration: UInt64 = 0 @@ -304,6 +305,14 @@ class AppState: ObservableObject { private var systemAudioCaptureService: Any? // SystemAudioCaptureService (macOS 14.4+) private var audioMixer: AudioMixer? private var vadGateService: VADGateService? + private var cloudBackgroundSession: CloudBackgroundTranscriptionSession? + private var cloudBackgroundConversationId: String? + private var cloudBackgroundStartTask: Task? + private var cloudBackgroundDrainTask: Task? + private var cloudBackgroundSampleCursor = 0 + private var isCloudBackgroundTranscription = false + private var backgroundTranscriptMerger = BackgroundTranscriptMerger() + private var speakerSegmentReducer = SpeakerSegmentReducer() // Speaker segments for diarized transcription (sliding window — older segments are in SQLite) private var speakerSegments: [SpeakerSegment] = [] @@ -1349,7 +1358,7 @@ class AppState: ObservableObject { /// Toggle transcription on/off func toggleTranscription() { - if isTranscribing { + if isTranscribing || isStartingTranscription { stopTranscription() } else { startTranscription() @@ -1359,7 +1368,7 @@ class AppState: ObservableObject { /// Start real-time transcription /// - Parameter source: Audio source to use (defaults to current audioSource setting) func startTranscription(source: AudioSource? = nil) { - guard !isTranscribing else { return } + guard !isTranscribing && !isStartingTranscription else { return } // Paywall hard-stop: every code path that enables the mic + WS streaming // funnels through here, including auto-restart from sleep and toggle @@ -1390,6 +1399,23 @@ class AppState: ObservableObject { "Transcription: Using language=\(effectiveLanguage) (autoDetect=\(AssistantSettings.shared.transcriptionAutoDetect), selected=\(AssistantSettings.shared.transcriptionLanguage))" ) + let routing = BackgroundTranscriptionRoutingGuard().decide( + batchEnabled: AssistantSettings.shared.batchTranscriptionEnabled, + serverAssemblyBackgroundEnabled: Self.isServerBackgroundBatchEnabled, + audioSource: effectiveSource + ) + if routing == .cloudBatchAssembly { + isStartingTranscription = true + cloudBackgroundStartTask?.cancel() + cloudBackgroundStartTask = Task { + await self.startCloudBackgroundTranscription( + source: effectiveSource, + language: effectiveLanguage + ) + } + return + } + // Always streaming via Python backend /v4/listen transcriptionService = try TranscriptionService(language: effectiveLanguage) @@ -1507,10 +1533,14 @@ class AppState: ObservableObject { Task { @MainActor in guard let self = self, self.isTranscribing else { return } log("Transcription: 4-hour limit reached - restarting session") - // Stop and restart (WebSocket close triggers backend conversation processing) - self.stopAudioCapture() - self.clearTranscriptionState() - self.startTranscription() + if self.isCloudBackgroundTranscription { + _ = await self.finishConversation() + } else { + // Stop and restart (WebSocket close triggers backend conversation processing) + self.stopAudioCapture() + self.clearTranscriptionState() + self.startTranscription() + } } } @@ -1525,6 +1555,110 @@ class AppState: ObservableObject { } } + private static var isServerBackgroundBatchEnabled: Bool { + let baseURL = DesktopBackendEnvironment.pythonBaseURL().lowercased() + return baseURL.contains("127.0.0.1") + || baseURL.contains("localhost") + || baseURL.contains("omiapi.com") + } + + private func startCloudBackgroundTranscription(source: AudioSource, language: String) async { + defer { isStartingTranscription = false } + do { + let conversationId = try await APIClient.shared.startBackgroundConversation(language: language) + guard !Task.isCancelled else { return } + cloudBackgroundConversationId = conversationId + cloudBackgroundSampleCursor = 0 + isCloudBackgroundTranscription = true + backgroundTranscriptMerger.reset() + speakerSegmentReducer.reset() + transcriptionService = nil + + if source == .bleDevice, let device = DeviceProvider.shared.connectedDevice { + currentConversationSource = ConversationSource.from(deviceType: device.type) + recordingInputDeviceName = device.displayName + } else { + currentConversationSource = .desktop + recordingInputDeviceName = AudioCaptureService.getCurrentMicrophoneName() + } + + audioCaptureService = AudioCaptureService() + audioMixer = AudioMixer() + vadGateService = nil + + let resolvedLanguage = language + cloudBackgroundSession = CloudBackgroundTranscriptionSession { chunk in + try await TranscriptionService.batchTranscribeSegments( + audioData: chunk.pcmData, + conversationId: conversationId, + chunkStartMs: max(0, Int((chunk.startTime * 1000.0).rounded())), + language: resolvedLanguage, + contextKeywords: AssistantSettings.shared.effectiveVocabulary + ) + } + + let systemAudioDisabled = UserDefaults.standard.bool(forKey: "disableSystemAudioCapture") + if systemAudioDisabled { + log("Transcription: System audio capture DISABLED by user preference (disableSystemAudioCapture)") + } else if #available(macOS 14.4, *) { + systemAudioCaptureService = SystemAudioCaptureService() + log("Transcription: System audio capture initialized for cloud background batch") + } + + isTranscribing = true + recordingGeneration &+= 1 + AssistantSettings.shared.transcriptionEnabled = true + audioSource = source + currentTranscript = "" + speakerSegments = [] + totalSegmentCount = 0 + totalWordCount = 0 + liveSpeakerPersonMap = [:] + LiveTranscriptMonitor.shared.clear() + recordingStartTime = Date() + AudioLevelMonitor.shared.reset() + RecordingTimer.shared.start() + + await startAudioCapture(source: source) + await startCrashSafeTranscriptionSession(language: language) + + maxRecordingTimer = Timer.scheduledTimer(withTimeInterval: maxRecordingDuration, repeats: false) { + [weak self] _ in + Task { @MainActor in + guard let self = self, self.isTranscribing else { return } + log("Transcription: 4-hour limit reached - rotating cloud background batch conversation") + _ = await self.finishConversation() + } + } + + AnalyticsManager.shared.transcriptionStarted() + log("Transcription: Started cloud background batch conversation \(conversationId)") + } catch { + isCloudBackgroundTranscription = false + cloudBackgroundSession = nil + cloudBackgroundConversationId = nil + isTranscribing = false + AnalyticsManager.shared.recordingError(error: error.localizedDescription) + showAlert(title: "Transcription Error", message: error.localizedDescription) + } + } + + private func startCrashSafeTranscriptionSession(language: String) async { + do { + let sessionId = try await TranscriptionStorage.shared.startSession( + source: currentConversationSource.rawValue, + language: language, + timezone: TimeZone.current.identifier, + inputDeviceName: recordingInputDeviceName + ) + currentSessionId = sessionId + LiveNotesMonitor.shared.startSession(sessionId: sessionId) + log("Transcription: Created DB session \(sessionId)") + } catch { + logError("Transcription: Failed to create DB session", error: error) + } + } + /// Start audio capture and pipe to transcription service /// - Parameter source: Audio source to capture from private func startAudioCapture(source: AudioSource = .microphone) async { @@ -1554,9 +1688,13 @@ class AppState: ObservableObject { } // Start the mixer — it sums mic + system into a mono stream and forwards it to - // the transcription WebSocket. + // the active transcription transport. audioMixer?.start { [weak self] monoMixed in - self?.transcriptionService?.sendAudio(monoMixed) + if self?.isCloudBackgroundTranscription == true { + self?.handleMixedBackgroundAudio(monoMixed) + } else { + self?.transcriptionService?.sendAudio(monoMixed) + } } do { @@ -1719,6 +1857,31 @@ class AppState: ObservableObject { /// triggers conversation processing on the backend side. We also call force-process to ensure /// the conversation is finalized, preventing the retry service from creating duplicates. func stopTranscription() { + if isStartingTranscription { + cloudBackgroundStartTask?.cancel() + cloudBackgroundStartTask = nil + isStartingTranscription = false + isCloudBackgroundTranscription = false + cloudBackgroundSession = nil + cloudBackgroundConversationId = nil + AssistantSettings.shared.transcriptionEnabled = false + return + } + + if isCloudBackgroundTranscription { + let capturedSessionId = currentSessionId + let capturedStartTime = recordingStartTime + let generationAtStop = recordingGeneration + Task { + await self.stopCloudBackgroundTranscription( + capturedSessionId: capturedSessionId, + capturedStartTime: capturedStartTime, + generationAtStop: generationAtStop + ) + } + return + } + // Capture session metadata BEFORE clearing state (clearTranscriptionState sets sessionId to nil) let capturedSessionId = currentSessionId let capturedStartTime = recordingStartTime @@ -1773,6 +1936,152 @@ class AppState: ObservableObject { } } + func restartTranscriptionAfterSettingsChange() async { + guard isTranscribing || isStartingTranscription else { return } + + if isCloudBackgroundTranscription { + let capturedSessionId = currentSessionId + let capturedStartTime = recordingStartTime + let generationAtStop = recordingGeneration + await stopCloudBackgroundTranscription( + capturedSessionId: capturedSessionId, + capturedStartTime: capturedStartTime, + generationAtStop: generationAtStop, + forSettingsChange: true + ) + } else { + stopTranscription() + try? await Task.sleep(nanoseconds: 1_000_000_000) + } + + startTranscription() + } + + private func handleMixedBackgroundAudio(_ pcmData: Data) { + guard let session = cloudBackgroundSession else { return } + let startTime = Double(cloudBackgroundSampleCursor) / 16000.0 + cloudBackgroundSampleCursor += pcmData.count / 2 + let result = session.append(pcmData: pcmData, startTime: startTime) + if result.enqueuedChunks > 0 { + drainCloudBackgroundASRQueue() + } + } + + private func drainCloudBackgroundASRQueue() { + guard cloudBackgroundDrainTask == nil else { return } + cloudBackgroundDrainTask = Task { @MainActor in + defer { cloudBackgroundDrainTask = nil } + while !Task.isCancelled { + guard let session = cloudBackgroundSession else { break } + do { + guard let result = try await session.transcribeNext() else { break } + let merged = backgroundTranscriptMerger.merge(result.segments) + _ = speakerSegmentReducer.apply(merged) + handleBackendSegments(merged) + } catch { + logError("Transcription: Cloud background chunk failed", error: error) + break + } + } + } + } + + private func waitForCloudBackgroundBacklog(timeout: TimeInterval = 60) async { + let deadline = Date().addingTimeInterval(timeout) + while Date() < deadline { + drainCloudBackgroundASRQueue() + if cloudBackgroundSession?.pendingChunkCount ?? 0 == 0 && cloudBackgroundDrainTask == nil { + return + } + try? await Task.sleep(nanoseconds: 250_000_000) + } + log("Transcription: Cloud background backlog wait timed out with \(cloudBackgroundSession?.pendingChunkCount ?? 0) pending chunks") + } + + private func stopCloudBackgroundTranscription( + capturedSessionId: Int64?, + capturedStartTime: Date?, + generationAtStop: UInt64, + forSettingsChange: Bool = false + ) async { + stopAudioCapture() + _ = cloudBackgroundSession?.finishInput() + drainCloudBackgroundASRQueue() + await waitForCloudBackgroundBacklog(timeout: forSettingsChange ? 5 : 60) + + clearCloudBackgroundState() + clearTranscriptionState() + silentMicFallbackInProgress = false + + if forSettingsChange { + // Finalize the stopped conversation in the background so settings UI stays responsive. + Task { + try? await Task.sleep(nanoseconds: 3_000_000_000) + guard recordingGeneration == generationAtStop else { return } + await forceProcessStoppedConversation( + capturedSessionId: capturedSessionId, + capturedStartTime: capturedStartTime, + logPrefix: "Cloud background batch" + ) + await loadConversations() + } + return + } + + try? await Task.sleep(nanoseconds: 3_000_000_000) + guard recordingGeneration == generationAtStop else { + log("Transcription: New recording started during cloud batch delay, skipping force-process for session \(capturedSessionId.map(String.init) ?? "nil")") + return + } + + await forceProcessStoppedConversation( + capturedSessionId: capturedSessionId, + capturedStartTime: capturedStartTime, + logPrefix: "Cloud background batch" + ) + await loadConversations() + } + + private func clearCloudBackgroundState() { + cloudBackgroundStartTask?.cancel() + cloudBackgroundStartTask = nil + cloudBackgroundDrainTask?.cancel() + cloudBackgroundDrainTask = nil + cloudBackgroundSession = nil + cloudBackgroundConversationId = nil + cloudBackgroundSampleCursor = 0 + isCloudBackgroundTranscription = false + backgroundTranscriptMerger.reset() + speakerSegmentReducer.reset() + } + + private func forceProcessStoppedConversation( + capturedSessionId: Int64?, + capturedStartTime: Date?, + logPrefix: String + ) async { + do { + if let conversation = try await APIClient.shared.forceProcessConversation() { + if let sessionId = capturedSessionId, let startTime = capturedStartTime, + let convStarted = conversation.startedAt, + abs(convStarted.timeIntervalSince(startTime)) < 10, + conversation.source == .desktop + { + try? await TranscriptionStorage.shared.markSessionCompleted( + id: sessionId, backendId: conversation.id) + log("Transcription: \(logPrefix) force-processed conversation \(conversation.id), session \(sessionId) completed") + } else if let sessionId = capturedSessionId, let startTime = capturedStartTime { + log("Transcription: \(logPrefix) force-process returned different conversation \(conversation.id), reconciling session \(sessionId)") + await reconcileSession(sessionId: sessionId, startTime: startTime) + } + } else if let sessionId = capturedSessionId, let startTime = capturedStartTime { + await reconcileSession(sessionId: sessionId, startTime: startTime) + } + } catch { + logError("Transcription: \(logPrefix) force-process failed, retry service will reconcile", error: error) + } + } + /// Reconcile a local session by checking if a matching conversation exists on the backend. /// If found, marks the session as completed. Otherwise leaves it as pendingUpload for retry. private func reconcileSession(sessionId: Int64, startTime: Date) async { @@ -1803,6 +2112,10 @@ class AppState: ObservableObject { /// Finish the current conversation and keep recording for a new one. /// Disconnects the WebSocket (triggers backend conversation processing) then reconnects. func finishConversation() async -> FinishConversationResult { + if isCloudBackgroundTranscription { + return await finishCloudBackgroundConversation() + } + guard totalSegmentCount > 0 || !speakerSegments.isEmpty else { log("Transcription: No segments to finish") return .discarded @@ -1921,6 +2234,83 @@ class AppState: ObservableObject { return .saved } + private func finishCloudBackgroundConversation() async -> FinishConversationResult { + log("Transcription: Finishing cloud background batch conversation") + let capturedSessionId = currentSessionId + let capturedStartTime = recordingStartTime + finishedSessionId = capturedSessionId + finishedRecordingStartTime = capturedStartTime + + stopAudioCapture() + _ = cloudBackgroundSession?.finishInput() + drainCloudBackgroundASRQueue() + await waitForCloudBackgroundBacklog() + let finishedHadSegments = totalSegmentCount > 0 || !speakerSegments.isEmpty + + if let sessionId = currentSessionId { + do { + try await TranscriptionStorage.shared.finishSession(id: sessionId) + log("Transcription: Finished DB session \(sessionId) before cloud batch rotation") + } catch { + logError("Transcription: Failed to finish DB session \(sessionId)", error: error) + } + } + + await forceProcessStoppedConversation( + capturedSessionId: capturedSessionId, + capturedStartTime: capturedStartTime, + logPrefix: "Cloud background batch rotation" + ) + + currentSessionId = nil + speakerSegments = [] + totalSegmentCount = 0 + totalWordCount = 0 + liveSpeakerPersonMap = [:] + LiveTranscriptMonitor.shared.clear() + LiveNotesMonitor.shared.endSession() + LiveNotesMonitor.shared.clear() + backgroundTranscriptMerger.reset() + speakerSegmentReducer.reset() + cloudBackgroundSampleCursor = 0 + + recordingStartTime = Date() + RecordingTimer.shared.restart() + maxRecordingTimer?.invalidate() + maxRecordingTimer = Timer.scheduledTimer(withTimeInterval: maxRecordingDuration, repeats: false) { + [weak self] _ in + Task { @MainActor in + guard let self = self, self.isTranscribing else { return } + log("Transcription: 4-hour limit reached - rotating cloud background batch conversation") + _ = await self.finishConversation() + } + } + + do { + let language = AssistantSettings.shared.effectiveTranscriptionLanguage + let conversationId = try await APIClient.shared.startBackgroundConversation(language: language) + cloudBackgroundConversationId = conversationId + cloudBackgroundSession = CloudBackgroundTranscriptionSession { chunk in + try await TranscriptionService.batchTranscribeSegments( + audioData: chunk.pcmData, + conversationId: conversationId, + chunkStartMs: max(0, Int((chunk.startTime * 1000.0).rounded())), + language: language, + contextKeywords: AssistantSettings.shared.effectiveVocabulary + ) + } + await startCrashSafeTranscriptionSession(language: language) + await startAudioCapture(source: audioSource) + isTranscribing = true + await loadConversations() + log("Transcription: Ready for next cloud background batch conversation \(conversationId)") + return finishedHadSegments ? .saved : .discarded + } catch { + logError("Transcription: Failed to start next cloud background batch conversation", error: error) + return .error(error.localizedDescription) + } + } + /// Stop audio capture services (but keep transcript data for saving) private func stopAudioCapture() { // Cancel timers diff --git a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift new file mode 100644 index 00000000000..e4e2052f429 --- /dev/null +++ b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift @@ -0,0 +1,138 @@ +import Foundation + +struct BackgroundAudioChunk: Equatable { + let pcmData: Data + let startTime: Double + let isFinal: Bool +} + +struct BackgroundAudioChunker { + private let configuration: BackgroundTranscriptionConfiguration + private var buffer = Data() + private var bufferStartTime: Double? + + init(configuration: BackgroundTranscriptionConfiguration = BackgroundTranscriptionConfiguration()) { + self.configuration = configuration + } + + mutating func append(pcmData: Data, startTime: Double) -> [BackgroundAudioChunk] { + guard !pcmData.isEmpty else { return [] } + + if buffer.isEmpty { + bufferStartTime = startTime + } + buffer.append(pcmData) + + var chunks: [BackgroundAudioChunk] = [] + while let chunk = nextChunk(isFinal: false) { + chunks.append(chunk) + } + return chunks + } + + mutating func finishInput() -> [BackgroundAudioChunk] { + guard !buffer.isEmpty else { return [] } + let chunk = BackgroundAudioChunk( + pcmData: buffer, + startTime: bufferStartTime ?? 0, + isFinal: true + ) + buffer.removeAll(keepingCapacity: false) + bufferStartTime = nil + return [chunk] + } + + private mutating func nextChunk(isFinal: Bool) -> BackgroundAudioChunk? { + let minBytes = configuration.alignedByteCount(for: configuration.minChunkDuration) + let maxBytes = configuration.alignedByteCount(for: configuration.maxChunkDuration) + guard buffer.count >= minBytes else { return nil } + + let cutBytes: Int? + if let silenceCut = firstSilenceCut(minBytes: minBytes, maxBytes: min(buffer.count, maxBytes)) { + cutBytes = silenceCut + } else if buffer.count >= maxBytes { + cutBytes = maxBytes + } else { + cutBytes = nil + } + + guard let cutBytes, cutBytes > 0 else { return nil } + return cut(at: cutBytes, isFinal: isFinal) + } + + private mutating func cut(at requestedCutBytes: Int, isFinal: Bool) -> BackgroundAudioChunk { + let cutBytes = min(requestedCutBytes, buffer.count).alignedToSample + let startTime = bufferStartTime ?? 0 + let chunk = BackgroundAudioChunk( + pcmData: buffer.prefix(cutBytes), + startTime: startTime, + isFinal: isFinal + ) + + let overlapBytes = min(configuration.alignedByteCount(for: configuration.overlapDuration), cutBytes) + let retainedStart = max(0, cutBytes - overlapBytes) + let retained = buffer.suffix(buffer.count - retainedStart) + buffer = Data(retained) + bufferStartTime = startTime + Double(retainedStart / configuration.bytesPerSample) / Double(configuration.sampleRate) + return chunk + } + + private func firstSilenceCut(minBytes: Int, maxBytes: Int) -> Int? { + let windowBytes = max(configuration.bytesPerSample, configuration.alignedByteCount(for: configuration.silenceWindowDuration)) + guard maxBytes >= minBytes + windowBytes else { return nil } + + var offset = minBytes.alignedToSample + while offset + windowBytes <= maxBytes { + if isSilentWindow(start: offset, byteCount: windowBytes), hasSpeech(before: offset) { + return offset + } + offset += configuration.bytesPerSample + } + return nil + } + + private func isSilentWindow(start: Int, byteCount: Int) -> Bool { + guard start >= 0, start + byteCount <= buffer.count else { return false } + var maxAmplitude = 0 + for offset in stride(from: start, to: start + byteCount, by: configuration.bytesPerSample) { + maxAmplitude = max(maxAmplitude, sampleAmplitude(at: offset)) + if maxAmplitude > configuration.silenceAmplitudeThreshold { + return false + } + } + return true + } + + private func hasSpeech(before endOffset: Int) -> Bool { + guard endOffset > 0 else { return false } + var peak = 0 + var sumSquares = 0.0 + var count = 0 + + for offset in stride(from: 0, to: min(endOffset, buffer.count), by: configuration.bytesPerSample) { + let amplitude = sampleAmplitude(at: offset) + peak = max(peak, amplitude) + sumSquares += Double(amplitude * amplitude) + count += 1 + } + + guard count > 0 else { return false } + let rms = sqrt(sumSquares / Double(count)) + return peak >= configuration.speechPeakAmplitudeThreshold + || rms >= Double(configuration.speechRMSAmplitudeThreshold) + } + + private func sampleAmplitude(at offset: Int) -> Int { + guard offset + 1 < buffer.count else { return 0 } + let low = UInt16(buffer[offset]) + let high = UInt16(buffer[offset + 1]) << 8 + let sample = Int16(bitPattern: low | high) + return abs(Int(sample)) + } +} + +extension Data { + fileprivate func prefix(_ count: Int) -> Data { + Data(self[startIndex.. [TranscriptionService.BackendSegment] + { + for incoming in incomingSegments + where !incoming.text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + upsert(incoming) + } + segments.sort { lhs, rhs in + if lhs.start == rhs.start { + return lhs.end < rhs.end + } + return lhs.start < rhs.start + } + return segments + } + + private mutating func upsert(_ incoming: TranscriptionService.BackendSegment) { + if let segmentId = incoming.id, + let index = segments.firstIndex(where: { $0.id == segmentId }) + { + segments[index] = preferredSegment(existing: segments[index], incoming: incoming) + return + } + + if let index = segments.firstIndex(where: { isDuplicate($0, incoming) }) { + segments[index] = preferredSegment(existing: segments[index], incoming: incoming) + return + } + + if let index = segments.firstIndex(where: { canMergeOverlap($0, incoming) }) { + segments[index] = mergedOverlap(segments[index], incoming) + return + } + + segments.append(incoming) + } + + private func isDuplicate( + _ existing: TranscriptionService.BackendSegment, + _ incoming: TranscriptionService.BackendSegment + ) -> Bool { + guard normalizedText(existing.text) == normalizedText(incoming.text) else { return false } + let intersection = max(0, min(existing.end, incoming.end) - max(existing.start, incoming.start)) + let shorterDuration = max(0.001, min(existing.end - existing.start, incoming.end - incoming.start)) + return intersection / shorterDuration >= duplicateOverlapThreshold + } + + private func canMergeOverlap( + _ existing: TranscriptionService.BackendSegment, + _ incoming: TranscriptionService.BackendSegment + ) -> Bool { + guard (existing.speaker_id ?? 0) == (incoming.speaker_id ?? 0) else { return false } + guard min(existing.end, incoming.end) > max(existing.start, incoming.start) else { + return false + } + return edgeTokenOverlap(existing.text, incoming.text) > 0 + } + + private func mergedOverlap( + _ existing: TranscriptionService.BackendSegment, + _ incoming: TranscriptionService.BackendSegment + ) -> TranscriptionService.BackendSegment { + let existingFirst = existing.start <= incoming.start + let first = existingFirst ? existing : incoming + let second = existingFirst ? incoming : existing + let overlap = edgeTokenOverlap(first.text, second.text) + let suffix = tokenized(second.text).dropFirst(overlap).joined(separator: " ") + let mergedText = suffix.isEmpty ? first.text : "\(first.text) \(suffix)" + + return TranscriptionService.BackendSegment( + id: first.id ?? second.id, + text: mergedText, + speaker: first.speaker ?? second.speaker, + speaker_id: first.speaker_id ?? second.speaker_id, + is_user: first.is_user || second.is_user, + person_id: first.person_id ?? second.person_id, + start: min(first.start, second.start), + end: max(first.end, second.end), + translations: first.translations ?? second.translations, + stt_provider: first.stt_provider ?? second.stt_provider, + stt_model: first.stt_model ?? second.stt_model, + provider_cluster_id: first.provider_cluster_id ?? second.provider_cluster_id, + provider_speaker_label: first.provider_speaker_label ?? second.provider_speaker_label, + speaker_identity_state: first.speaker_identity_state ?? second.speaker_identity_state, + speaker_identity_confidence: first.speaker_identity_confidence ?? second.speaker_identity_confidence, + speaker_identity_source: first.speaker_identity_source ?? second.speaker_identity_source, + speaker_identity_version: first.speaker_identity_version ?? second.speaker_identity_version + ) + } + + private func preferredSegment( + existing: TranscriptionService.BackendSegment, + incoming: TranscriptionService.BackendSegment + ) -> TranscriptionService.BackendSegment { + if normalizedText(existing.text) == normalizedText(incoming.text), + existing.text.count <= incoming.text.count + { + return existing + } + if incoming.end - incoming.start > existing.end - existing.start { + return incoming + } + if incoming.text.count > existing.text.count { + return incoming + } + return existing + } + + private func normalizedText(_ value: String) -> String { + value.lowercased() + .replacingOccurrences(of: #"\s+"#, with: " ", options: .regularExpression) + .trimmingCharacters(in: .whitespacesAndNewlines) + } + + private func tokenized(_ value: String) -> [String] { + normalizedText(value).split(separator: " ").map(String.init) + } + + private func edgeTokenOverlap(_ first: String, _ second: String) -> Int { + let left = tokenized(first) + let right = tokenized(second) + guard !left.isEmpty, !right.isEmpty else { return 0 } + + let maxOverlap = min(left.count, right.count) + for count in stride(from: maxOverlap, through: 1, by: -1) { + if Array(left.suffix(count)) == Array(right.prefix(count)) { + return count + } + } + return 0 + } +} diff --git a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift new file mode 100644 index 00000000000..51127229812 --- /dev/null +++ b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift @@ -0,0 +1,29 @@ +import Foundation + +struct BackgroundTranscriptionConfiguration: Equatable { + var sampleRate: Int = 16000 + var maxChunkDuration: TimeInterval = 15.0 + var minChunkDuration: TimeInterval = 1.0 + var overlapDuration: TimeInterval = 1.0 + var silenceWindowDuration: TimeInterval = 0.35 + var silenceAmplitudeThreshold: Int = 256 + var speechPeakAmplitudeThreshold: Int = 512 + var speechRMSAmplitudeThreshold: Int = 64 + var maxPendingChunks: Int = 4 + + var bytesPerSample: Int { 2 } + + func byteCount(for duration: TimeInterval) -> Int { + max(0, Int(duration * Double(sampleRate)) * bytesPerSample) + } + + func alignedByteCount(for duration: TimeInterval) -> Int { + byteCount(for: duration).alignedToSample + } +} + +extension Int { + var alignedToSample: Int { + self - (self % 2) + } +} diff --git a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionRoutingGuard.swift b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionRoutingGuard.swift new file mode 100644 index 00000000000..13b97de5bcf --- /dev/null +++ b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionRoutingGuard.swift @@ -0,0 +1,25 @@ +import Foundation + +enum BackgroundTranscriptionRoutingDecision: Equatable { + case cloudBatchAssembly + case cloudListenStreaming(reason: String?) +} + +struct BackgroundTranscriptionRoutingGuard { + func decide( + batchEnabled: Bool, + serverAssemblyBackgroundEnabled: Bool, + audioSource: AudioSource + ) -> BackgroundTranscriptionRoutingDecision { + guard batchEnabled else { + return .cloudListenStreaming(reason: "batch_disabled") + } + guard serverAssemblyBackgroundEnabled else { + return .cloudListenStreaming(reason: "server_background_batch_disabled") + } + guard audioSource == .microphone else { + return .cloudListenStreaming(reason: "batch_microphone_only") + } + return .cloudBatchAssembly + } +} diff --git a/desktop/Desktop/Sources/BackgroundTranscription/CloudBackgroundTranscriptionSession.swift b/desktop/Desktop/Sources/BackgroundTranscription/CloudBackgroundTranscriptionSession.swift new file mode 100644 index 00000000000..b91f8579bd6 --- /dev/null +++ b/desktop/Desktop/Sources/BackgroundTranscription/CloudBackgroundTranscriptionSession.swift @@ -0,0 +1,110 @@ +import Foundation + +struct BackgroundIngestResult: Equatable { + let enqueuedChunks: Int + let pendingChunkCount: Int + let acceptedInputBytes: Int + let didFinishInput: Bool + let isBackpressured: Bool +} + +struct BackgroundTranscriptionResult { + let chunk: BackgroundAudioChunk + let segments: [TranscriptionService.BackendSegment] +} + +struct BackgroundTranscriptSnapshot { + let pendingChunkCount: Int + let processedChunkCount: Int + let isInputFinished: Bool + let segments: [TranscriptionService.BackendSegment] +} + +final class CloudBackgroundTranscriptionSession { + typealias TranscribeHandler = (BackgroundAudioChunk) async throws -> [TranscriptionService.BackendSegment] + + private let configuration: BackgroundTranscriptionConfiguration + private let transcribe: TranscribeHandler + private var chunker: BackgroundAudioChunker + private var pendingChunks: [BackgroundAudioChunk] = [] + private var processedChunkCount = 0 + private var isInputFinished = false + private var processedSegments: [TranscriptionService.BackendSegment] = [] + + init( + configuration: BackgroundTranscriptionConfiguration = BackgroundTranscriptionConfiguration(), + transcribe: @escaping TranscribeHandler + ) { + self.configuration = configuration + self.transcribe = transcribe + self.chunker = BackgroundAudioChunker(configuration: configuration) + } + + var pendingChunkCount: Int { + pendingChunks.count + } + + func append(pcmData: Data, startTime: Double) -> BackgroundIngestResult { + guard !isInputFinished else { + return BackgroundIngestResult( + enqueuedChunks: 0, + pendingChunkCount: pendingChunks.count, + acceptedInputBytes: 0, + didFinishInput: true, + isBackpressured: pendingChunks.count >= configuration.maxPendingChunks + ) + } + + let chunks = chunker.append(pcmData: pcmData, startTime: startTime) + pendingChunks.append(contentsOf: chunks) + return BackgroundIngestResult( + enqueuedChunks: chunks.count, + pendingChunkCount: pendingChunks.count, + acceptedInputBytes: pcmData.count, + didFinishInput: false, + isBackpressured: pendingChunks.count >= configuration.maxPendingChunks + ) + } + + func finishInput() -> BackgroundIngestResult { + guard !isInputFinished else { + return BackgroundIngestResult( + enqueuedChunks: 0, + pendingChunkCount: pendingChunks.count, + acceptedInputBytes: 0, + didFinishInput: true, + isBackpressured: pendingChunks.count >= configuration.maxPendingChunks + ) + } + + let chunks = chunker.finishInput() + pendingChunks.append(contentsOf: chunks) + isInputFinished = true + return BackgroundIngestResult( + enqueuedChunks: chunks.count, + pendingChunkCount: pendingChunks.count, + acceptedInputBytes: 0, + didFinishInput: true, + isBackpressured: pendingChunks.count >= configuration.maxPendingChunks + ) + } + + func transcribeNext() async throws -> BackgroundTranscriptionResult? { + guard !pendingChunks.isEmpty else { return nil } + let chunk = pendingChunks[0] + let segments = try await transcribe(chunk) + pendingChunks.removeFirst() + processedChunkCount += 1 + processedSegments.append(contentsOf: segments) + return BackgroundTranscriptionResult(chunk: chunk, segments: segments) + } + + func snapshot() -> BackgroundTranscriptSnapshot { + BackgroundTranscriptSnapshot( + pendingChunkCount: pendingChunks.count, + processedChunkCount: processedChunkCount, + isInputFinished: isInputFinished, + segments: processedSegments + ) + } +} diff --git a/desktop/Desktop/Sources/BackgroundTranscription/SpeakerSegmentReducer.swift b/desktop/Desktop/Sources/BackgroundTranscription/SpeakerSegmentReducer.swift new file mode 100644 index 00000000000..27021183615 --- /dev/null +++ b/desktop/Desktop/Sources/BackgroundTranscription/SpeakerSegmentReducer.swift @@ -0,0 +1,84 @@ +import Foundation + +struct SpeakerSegmentReducer { + struct ApplyResult: Equatable { + var added: Int = 0 + var updated: Int = 0 + var totalSegmentCount: Int = 0 + var totalWordCount: Int = 0 + } + + private(set) var segments: [SpeakerSegment] = [] + private(set) var totalSegmentCount: Int = 0 + private(set) var totalWordCount: Int = 0 + var maxInMemorySegments: Int + + init(maxInMemorySegments: Int = 400) { + self.maxInMemorySegments = maxInMemorySegments + } + + mutating func reset() { + segments = [] + totalSegmentCount = 0 + totalWordCount = 0 + } + + mutating func replaceSegments(_ replacement: [SpeakerSegment]) { + segments = replacement + totalSegmentCount = replacement.count + totalWordCount = replacement.reduce(0) { $0 + wordCount($1.text) } + } + + mutating func apply(_ incomingSegments: [TranscriptionService.BackendSegment]) -> ApplyResult { + apply(incomingSegments.map(Self.speakerSegment(from:))) + } + + mutating func apply(_ incomingSegments: [SpeakerSegment]) -> ApplyResult { + var result = ApplyResult() + + for incoming in incomingSegments where !incoming.text.isEmpty { + if let segId = incoming.segmentId, + let existingIdx = segments.firstIndex(where: { $0.segmentId == segId }) + { + let oldWords = wordCount(segments[existingIdx].text) + var updated = incoming + if updated.translations.isEmpty && !segments[existingIdx].translations.isEmpty { + updated.translations = segments[existingIdx].translations + } + segments[existingIdx] = updated + totalWordCount += wordCount(updated.text) - oldWords + result.updated += 1 + } else { + segments.append(incoming) + totalSegmentCount += 1 + totalWordCount += wordCount(incoming.text) + result.added += 1 + } + } + + if segments.count > maxInMemorySegments { + segments.removeFirst(segments.count - maxInMemorySegments) + } + + result.totalSegmentCount = totalSegmentCount + result.totalWordCount = totalWordCount + return result + } + + private static func speakerSegment(from segment: TranscriptionService.BackendSegment) -> SpeakerSegment { + SpeakerSegment( + segmentId: segment.id, + speaker: segment.speaker_id ?? 0, + text: segment.text, + start: segment.start, + end: segment.end, + isUser: segment.is_user, + personId: segment.person_id, + translations: (segment.translations ?? []).map { SegmentTranslation(lang: $0.lang, text: $0.text) } + ) + } + + private func wordCount(_ text: String) -> Int { + text.split(separator: " ").count + } +} diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index 482c3937b03..eab9773b534 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -261,6 +261,7 @@ struct SettingsContentView: View { @State private var transcriptionAutoDetect: Bool = true @State private var transcriptionLanguage: String = "en" @State private var vadGateEnabled: Bool = false + @State private var batchTranscriptionEnabled: Bool = false // Multi-chat mode setting @AppStorage("multiChatEnabled") private var multiChatEnabled = false @@ -425,6 +426,7 @@ struct SettingsContentView: View { initialValue: MemoryAssistantSettings.shared.notificationsEnabled) _memoryExcludedApps = State(initialValue: MemoryAssistantSettings.shared.excludedApps) _vadGateEnabled = State(initialValue: settings.vadGateEnabled) + _batchTranscriptionEnabled = State(initialValue: settings.batchTranscriptionEnabled) _transcriptionLanguage = State(initialValue: settings.transcriptionLanguage) _transcriptionAutoDetect = State(initialValue: settings.transcriptionAutoDetect) } @@ -1070,7 +1072,7 @@ struct SettingsContentView: View { } .buttonStyle(.plain) - // Single Language option + // Single Language option (picker must stay outside Button — nested controls freeze SwiftUI) Button(action: { transcriptionAutoDetect = false AssistantSettings.shared.transcriptionAutoDetect = false @@ -1091,33 +1093,6 @@ struct SettingsContentView: View { Text("Best for speaking in one specific language") .scaledFont(size: 12) .foregroundColor(OmiColors.textTertiary) - - // Language picker (only shown when single language is selected) - if !transcriptionAutoDetect { - HStack { - Text("Language:") - .scaledFont(size: 12) - .foregroundColor(OmiColors.textTertiary) - - Picker("", selection: $transcriptionLanguage) { - ForEach(languageOptions, id: \.0) { option in - Text(option.1).tag(option.0) - } - } - .pickerStyle(.menu) - .frame(width: 180) - .onChange(of: transcriptionLanguage) { _, newValue in - AssistantSettings.shared.transcriptionLanguage = newValue - let supportsMulti = AssistantSettings.supportsAutoDetect(newValue) - transcriptionAutoDetect = supportsMulti - AssistantSettings.shared.transcriptionAutoDetect = supportsMulti - updateTranscriptionPreferences(singleLanguageMode: !supportsMulti) - updateLanguage(newValue) - restartTranscriptionIfNeeded() - } - } - .padding(.top, 4) - } } Spacer() @@ -1137,6 +1112,33 @@ struct SettingsContentView: View { } .buttonStyle(.plain) + if !transcriptionAutoDetect { + HStack(spacing: 12) { + Text("Language:") + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + + Picker("", selection: $transcriptionLanguage) { + ForEach(languageOptions, id: \.0) { option in + Text(option.1).tag(option.0) + } + } + .pickerStyle(.menu) + .frame(width: 180) + .onChange(of: transcriptionLanguage) { _, newValue in + applyTranscriptionLanguageChange(newValue) + } + + Spacer() + } + .padding(.horizontal, 12) + .padding(.vertical, 8) + .background( + RoundedRectangle(cornerRadius: 8) + .fill(OmiColors.backgroundSecondary) + ) + } + // Info about language support HStack(spacing: 8) { Image(systemName: "info.circle") @@ -1237,6 +1239,39 @@ struct SettingsContentView: View { } } + // Cloud batch transcription + settingsCard(settingId: "transcription.batch") { + VStack(alignment: .leading, spacing: 12) { + HStack { + Image(systemName: "waveform.path.ecg.rectangle") + .scaledFont(size: 16) + .foregroundColor(OmiColors.purplePrimary) + + VStack(alignment: .leading, spacing: 4) { + Text("Batch transcription (AssemblyAI)") + .scaledFont(size: 15, weight: .medium) + .foregroundColor(OmiColors.textPrimary) + + Text( + "Transcribe microphone audio in chunks instead of live streaming. Requires server-side AssemblyAI." + ) + .scaledFont(size: 13) + .foregroundColor(OmiColors.textTertiary) + .fixedSize(horizontal: false, vertical: true) + } + + Spacer() + + Toggle("", isOn: $batchTranscriptionEnabled) + .toggleStyle(.switch) + .onChange(of: batchTranscriptionEnabled) { _, newValue in + AssistantSettings.shared.batchTranscriptionEnabled = newValue + restartTranscriptionIfNeeded() + } + } + } + } + // Local VAD Gate settingsCard(settingId: "transcription.vadgate") { VStack(alignment: .leading, spacing: 12) { @@ -1304,16 +1339,22 @@ struct SettingsContentView: View { updateTranscriptionPreferences(vocabulary: vocabularyList.joined(separator: ", ")) } + private func applyTranscriptionLanguageChange(_ newValue: String) { + AssistantSettings.shared.transcriptionLanguage = newValue + let supportsMulti = AssistantSettings.supportsAutoDetect(newValue) + transcriptionAutoDetect = supportsMulti + AssistantSettings.shared.transcriptionAutoDetect = supportsMulti + updateTranscriptionPreferences(singleLanguageMode: !supportsMulti) + updateLanguage(newValue) + restartTranscriptionIfNeeded() + } + /// Restart transcription if currently running to apply new settings private func restartTranscriptionIfNeeded() { - guard appState.isTranscribing else { return } - - // Stop and restart to apply new language settings - appState.stopTranscription() + guard appState.isTranscribing || appState.isStartingTranscription else { return } - // Wait a moment for cleanup, then restart - DispatchQueue.main.asyncAfter(deadline: .now() + 1.0) { - self.appState.startTranscription() + Task { + await appState.restartTranscriptionAfterSettingsChange() } } @@ -6923,6 +6964,7 @@ struct SettingsContentView: View { transcriptionAutoDetect = AssistantSettings.shared.transcriptionAutoDetect vocabularyList = AssistantSettings.shared.transcriptionVocabulary vadGateEnabled = AssistantSettings.shared.vadGateEnabled + batchTranscriptionEnabled = AssistantSettings.shared.batchTranscriptionEnabled Task { do { diff --git a/desktop/Desktop/Sources/TranscriptionService.swift b/desktop/Desktop/Sources/TranscriptionService.swift index 0e88c1308f0..a72417d4f67 100644 --- a/desktop/Desktop/Sources/TranscriptionService.swift +++ b/desktop/Desktop/Sources/TranscriptionService.swift @@ -643,6 +643,98 @@ extension TranscriptionService { return transcript } + /// Transcribe a background recording chunk through the Python backend `/v2/desktop/background-transcribe`. + /// This is used by cloud batch background transcription only; PTT stays on `batchTranscribe`. + static func batchTranscribeSegments( + audioData: Data, + conversationId: String, + chunkStartMs: Int, + language: String, + contextKeywords: [String] = [] + ) async throws -> [BackendSegment] { + let authService = await MainActor.run { AuthService.shared } + let authHeader = try await authService.getAuthHeader() + let baseURLString = "\(pythonBackendBaseURL)v2/desktop/background-transcribe" + + guard var components = URLComponents(string: baseURLString) else { + throw TranscriptionError.connectionFailed(NSError(domain: "Invalid backend URL", code: -1)) + } + let sanitizedKeywords = sanitizedContextKeywords(contextKeywords) + var queryItems = [ + URLQueryItem(name: "conversation_id", value: conversationId), + URLQueryItem(name: "chunk_start_ms", value: String(chunkStartMs)), + URLQueryItem(name: "language", value: language), + URLQueryItem(name: "sample_rate", value: "16000"), + URLQueryItem(name: "encoding", value: "linear16"), + URLQueryItem(name: "channels", value: "1"), + ] + if !sanitizedKeywords.isEmpty { + queryItems.append(URLQueryItem(name: "keywords", value: sanitizedKeywords.joined(separator: ","))) + } + components.queryItems = queryItems + + guard let url = components.url else { + throw TranscriptionError.connectionFailed(NSError(domain: "Invalid URL", code: -1)) + } + + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.setValue(authHeader, forHTTPHeaderField: "Authorization") + request.setValue("application/octet-stream", forHTTPHeaderField: "Content-Type") + request.setValue("desktop", forHTTPHeaderField: "X-App-Platform") + for (provider, entry) in APIKeyService.byokSnapshot { + request.setValue(entry.key, forHTTPHeaderField: provider.headerName) + } + request.httpBody = audioData + + var lastError: Error? + for attempt in 0..<3 { + let data: Data + let response: URLResponse + do { + log("TranscriptionService: Background batch transcribing \(audioData.count) bytes at \(chunkStartMs)ms, attempt=\(attempt + 1), contextKeywords=\(sanitizedKeywords.count)") + (data, response) = try await URLSession.shared.data(for: request) + } catch { + lastError = error + if attempt == 2 { + break + } + let delayNs = UInt64(pow(2.0, Double(attempt)) * 250_000_000) + try await Task.sleep(nanoseconds: delayNs) + continue + } + + guard let httpResponse = response as? HTTPURLResponse else { + throw TranscriptionError.invalidResponse + } + + if httpResponse.statusCode == 200 { + let decoded = try JSONDecoder().decode(BackgroundTranscribeResponse.self, from: data) + let provider = decoded.provider ?? "unknown" + log("TranscriptionService: Background batch transcription completed provider=\(provider), run_id=\(decoded.run_id ?? "nil"), segments=\(decoded.segments.count), chunkStartMs=\(chunkStartMs)") + if provider != "assemblyai" { + log("TranscriptionService: Background batch expected AssemblyAI but backend returned provider=\(provider)") + } + return decoded.segments + } + + let body = String(data: data, encoding: .utf8) ?? "no body" + logError("TranscriptionService: Background batch transcription failed with status \(httpResponse.statusCode): \(body)", error: nil) + if httpResponse.statusCode == 413 { + throw TranscriptionError.payloadTooLarge + } + if !(500...599).contains(httpResponse.statusCode) { + throw TranscriptionError.invalidResponse + } + + lastError = TranscriptionError.invalidResponse + let delayNs = UInt64(pow(2.0, Double(attempt)) * 250_000_000) + try await Task.sleep(nanoseconds: delayNs) + } + + throw lastError ?? TranscriptionError.invalidResponse + } + } /// Response model for Python backend `/v2/voice-message/transcribe` (batch PTT) @@ -650,3 +742,11 @@ private struct PythonTranscribeResponse: Decodable { let transcript: String let language: String? } + +/// Response model for Python backend `/v2/desktop/background-transcribe`. +struct BackgroundTranscribeResponse: Decodable { + let segments: [TranscriptionService.BackendSegment] + let language: String? + let provider: String? + let run_id: String? +} diff --git a/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift new file mode 100644 index 00000000000..98ccb2f8895 --- /dev/null +++ b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift @@ -0,0 +1,251 @@ +import XCTest +@testable import Omi_Computer + +final class BackgroundTranscriptionTests: XCTestCase { + func testChunkerCutsAtSilenceAndRetainsOverlap() { + var chunker = BackgroundAudioChunker( + configuration: BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 3.0, + minChunkDuration: 1.0, + overlapDuration: 0.5, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 4 + ) + ) + + let chunks = chunker.append(pcmData: pcm(samples: Array(repeating: 1_000, count: 10) + [0, 0, 0]), startTime: 0) + + XCTAssertEqual(chunks.count, 1) + XCTAssertEqual(chunks[0].startTime, 0) + XCTAssertEqual(sampleCount(chunks[0].pcmData), 10) + XCTAssertFalse(chunks[0].isFinal) + + let final = chunker.finishInput() + XCTAssertEqual(final.count, 1) + XCTAssertEqual(final[0].startTime, 0.5, accuracy: 0.001) + XCTAssertEqual(sampleCount(final[0].pcmData), 8) + XCTAssertTrue(final[0].isFinal) + } + + func testChunkerHardCutsAtMaxDurationWithoutSilence() { + var chunker = BackgroundAudioChunker( + configuration: BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 1.0, + minChunkDuration: 0.5, + overlapDuration: 0.2, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 4 + ) + ) + + let chunks = chunker.append(pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 3) + + XCTAssertEqual(chunks.count, 1) + XCTAssertEqual(chunks[0].startTime, 3) + XCTAssertEqual(sampleCount(chunks[0].pcmData), 10) + XCTAssertEqual(sampleCount(chunker.finishInput()[0].pcmData), 4) + } + + func testSessionQueuesChunksAndSignalsBackpressureWithoutDropping() async throws { + let configuration = BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 1.0, + minChunkDuration: 0.5, + overlapDuration: 0, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 1 + ) + var transcribedStarts: [Double] = [] + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + transcribedStarts.append(chunk.startTime) + return [Self.backendSegment(id: "chunk-\(chunk.startTime)", text: "hello", start: chunk.startTime, end: chunk.startTime + 1)] + } + + let result = session.append(pcmData: pcm(samples: Array(repeating: 1_000, count: 22)), startTime: 0) + + XCTAssertEqual(result.enqueuedChunks, 2) + XCTAssertEqual(result.pendingChunkCount, 2) + XCTAssertTrue(result.isBackpressured) + XCTAssertEqual(session.pendingChunkCount, 2) + + let first = try await session.transcribeNext() + let second = try await session.transcribeNext() + + XCTAssertEqual(first?.chunk.startTime, 0) + XCTAssertEqual(second?.chunk.startTime, 1) + XCTAssertEqual(transcribedStarts, [0, 1]) + let empty = try await session.transcribeNext() + XCTAssertNil(empty) + XCTAssertEqual(session.snapshot().processedChunkCount, 2) + XCTAssertEqual(session.snapshot().segments.count, 2) + } + + func testSessionFinishFlushesTail() async throws { + let configuration = BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 10, + minChunkDuration: 1, + overlapDuration: 0, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 4 + ) + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + [Self.backendSegment(id: "final", text: chunk.isFinal ? "final" : "not final", start: chunk.startTime, end: 1)] + } + + XCTAssertEqual(session.append(pcmData: pcm(samples: [1_000, 1_000, 1_000]), startTime: 2).enqueuedChunks, 0) + let finish = session.finishInput() + + XCTAssertTrue(finish.didFinishInput) + XCTAssertEqual(finish.enqueuedChunks, 1) + let result = try await session.transcribeNext() + XCTAssertEqual(result?.chunk.startTime, 2) + XCTAssertTrue(result?.chunk.isFinal ?? false) + } + + func testSessionRetainsChunkWhenTranscriptionFails() async throws { + let configuration = BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 1.0, + minChunkDuration: 0.5, + overlapDuration: 0, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 4 + ) + var shouldFail = true + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + if shouldFail { + shouldFail = false + throw NSError(domain: "test", code: 1) + } + return [Self.backendSegment(id: "retry", text: "retried", start: chunk.startTime, end: chunk.startTime + 1)] + } + + XCTAssertEqual(session.append(pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 0).enqueuedChunks, 1) + + do { + _ = try await session.transcribeNext() + XCTFail("Expected first transcription attempt to fail") + } catch { + XCTAssertEqual(session.pendingChunkCount, 1) + } + + let retried = try await session.transcribeNext() + XCTAssertEqual(retried?.chunk.startTime, 0) + XCTAssertEqual(session.pendingChunkCount, 0) + XCTAssertEqual(session.snapshot().processedChunkCount, 1) + } + + func testTranscriptMergerDeduplicatesAndMergesOverlap() { + var merger = BackgroundTranscriptMerger() + let first = Self.backendSegment(id: "a", text: "hello world", speakerId: 0, start: 0, end: 2) + let duplicate = Self.backendSegment(id: nil, text: "hello world", speakerId: 0, start: 0.5, end: 1.8) + let overlap = Self.backendSegment(id: "b", text: "world again", speakerId: 0, start: 1.5, end: 3) + + XCTAssertEqual(merger.merge([first]).count, 1) + XCTAssertEqual(merger.merge([duplicate]).count, 1) + let merged = merger.merge([overlap]) + + XCTAssertEqual(merged.count, 1) + XCTAssertEqual(merged[0].text, "hello world again") + XCTAssertEqual(merged[0].start, 0) + XCTAssertEqual(merged[0].end, 3) + } + + func testSpeakerSegmentReducerUpdatesAndPreservesTranslations() { + var reducer = SpeakerSegmentReducer(maxInMemorySegments: 10) + let original = SpeakerSegment( + segmentId: "seg-1", + speaker: 1, + text: "hello", + start: 0, + end: 1, + translations: [SegmentTranslation(lang: "es", text: "hola")] + ) + _ = reducer.apply([original]) + + let update = Self.backendSegment(id: "seg-1", text: "hello again", speakerId: 1, start: 0, end: 2) + let result = reducer.apply([update]) + + XCTAssertEqual(result.added, 0) + XCTAssertEqual(result.updated, 1) + XCTAssertEqual(result.totalSegmentCount, 1) + XCTAssertEqual(result.totalWordCount, 2) + XCTAssertEqual(reducer.segments[0].translations.first?.text, "hola") + } + + func testRoutingGuardOnlyAllowsCloudBatchForEnabledMicrophone() { + let guardrail = BackgroundTranscriptionRoutingGuard() + + XCTAssertEqual( + guardrail.decide(batchEnabled: true, serverAssemblyBackgroundEnabled: true, audioSource: .microphone), + .cloudBatchAssembly + ) + XCTAssertEqual( + guardrail.decide(batchEnabled: true, serverAssemblyBackgroundEnabled: true, audioSource: .bleDevice), + .cloudListenStreaming(reason: "batch_microphone_only") + ) + XCTAssertEqual( + guardrail.decide(batchEnabled: false, serverAssemblyBackgroundEnabled: true, audioSource: .microphone), + .cloudListenStreaming(reason: "batch_disabled") + ) + XCTAssertEqual( + guardrail.decide(batchEnabled: true, serverAssemblyBackgroundEnabled: false, audioSource: .microphone), + .cloudListenStreaming(reason: "server_background_batch_disabled") + ) + } + + private static func backendSegment( + id: String?, + text: String, + speakerId: Int = 0, + start: Double, + end: Double + ) -> TranscriptionService.BackendSegment { + TranscriptionService.BackendSegment( + id: id, + text: text, + speaker: "SPEAKER_\(String(format: "%02d", speakerId))", + speaker_id: speakerId, + is_user: false, + person_id: nil, + start: start, + end: end, + translations: nil, + stt_provider: "assemblyai", + stt_model: "universal-2", + provider_cluster_id: nil, + provider_speaker_label: nil, + speaker_identity_state: nil, + speaker_identity_confidence: nil, + speaker_identity_source: nil, + speaker_identity_version: nil + ) + } + + private func pcm(samples: [Int16]) -> Data { + var samples = samples + return samples.withUnsafeMutableBufferPointer { Data(buffer: $0) } + } + + private func sampleCount(_ data: Data) -> Int { + data.count / MemoryLayout.size + } +} diff --git a/docs/doc/developer/backend/assemblyai_background_rollout.mdx b/docs/doc/developer/backend/assemblyai_background_rollout.mdx index 5d34fe12bee..97f4f5df18c 100644 --- a/docs/doc/developer/backend/assemblyai_background_rollout.mdx +++ b/docs/doc/developer/backend/assemblyai_background_rollout.mdx @@ -9,8 +9,8 @@ description: "Rollout gates, feature flags, instrumentation, and rollback for th ## Scope AssemblyAI is the MVP async/background prerecorded provider. Deepgram remains -the provider for `/v4/listen`, realtime assistant streaming, Hold-to-Talk -streaming, and voice-message finalize semantics. +the provider for mobile/BLE `/v4/listen`, realtime assistant streaming, +Hold-to-Talk streaming, and voice-message finalize semantics. Eligible AssemblyAI workloads are: @@ -45,6 +45,17 @@ AssemblyAI is disabled by default. use AssemblyAI only when the main flag is enabled and the workload is in `ASSEMBLYAI_BACKGROUND_STT_WORKLOADS`; otherwise they use Deepgram. +Desktop Audio Recording uses the background workload through: + +- `POST /v2/desktop/background-conversation/start` +- `POST /v2/desktop/background-transcribe` +- `POST /v1/conversations` for finalization + +The chunk endpoint accepts raw PCM, wraps linear16 PCM as WAV before provider +upload, applies the chunk timeline offset, maps provider speaker clusters to +stable numeric `speaker_id` values per conversation, and appends segments to the +in-progress conversation when `persist=true`. + Deepgram is the prerecorded fallback provider. If AssemblyAI fails, times out, or exhausts retries for an eligible background workload, the failed AssemblyAI run is finalized in the provider ledger and the request retries through @@ -72,8 +83,8 @@ If a user enrolls an optional Assembly fingerprint in Firestore, they must send `x-byok-assemblyai` on requests that send other BYOK headers. Clearing the Assembly key requires re-activation with fingerprints that omit `assemblyai`. -Listen/realtime STT is unchanged (Deepgram only). Assembly BYOK does not affect -`/v4/listen`. +Listen/realtime STT is unchanged for mobile/BLE (Deepgram only). Assembly BYOK +does not affect `/v4/listen`. ## Storage And Identity diff --git a/docs/doc/developer/backend/listen_pusher_pipeline.mdx b/docs/doc/developer/backend/listen_pusher_pipeline.mdx index d1bc98a8ef1..88a4ad89e0d 100644 --- a/docs/doc/developer/backend/listen_pusher_pipeline.mdx +++ b/docs/doc/developer/backend/listen_pusher_pipeline.mdx @@ -234,6 +234,14 @@ selects Deepgram by default and can select AssemblyAI only when `ASSEMBLYAI_BACKGROUND_STT_ENABLED=true` and the workload is listed in `ASSEMBLYAI_BACKGROUND_STT_WORKLOADS`. +Desktop Audio Recording can use the batch background path instead of `/v4/listen` +when the desktop batch setting is enabled and the client is pointed at a backend +with AssemblyAI background support. The desktop client creates a stub +conversation with `POST /v2/desktop/background-conversation/start`, sends raw +mono PCM chunks to `POST /v2/desktop/background-transcribe`, and still finalizes +with `POST /v1/conversations`. Mobile listen, BLE listen, and PTT continue to use +their existing Deepgram paths. + Eligible AssemblyAI workloads are `sync`, `background`, and `postprocess`. Latency-critical workloads (`realtime`, streaming `ptt`, and `voice_message`) continue to use Deepgram and the existing websocket/finalize @@ -250,8 +258,8 @@ sequenceDiagram participant Firestore participant Metrics as /metrics - Client->>Backend: Upload background audio or sync audio URL - Backend->>ProviderService: transcribe_url/workload=sync|background|postprocess + Client->>Backend: Upload sync audio URL or desktop PCM background chunk + Backend->>ProviderService: transcribe_url/transcribe_bytes workload=sync|background|postprocess alt AssemblyAI enabled for workload ProviderService->>AssemblyAI: Submit async prerecorded transcript request AssemblyAI-->>ProviderService: Completed transcript with utterances, words, speaker labels @@ -266,7 +274,7 @@ sequenceDiagram Deepgram-->>ProviderService: Transcript result end - ProviderService->>ProviderService: Normalize provider result and reconstruct canonical segments +ProviderService->>ProviderService: Normalize provider result and reconstruct canonical segments ProviderService->>EmbeddingAPI: Extract cluster samples for Omi speaker identity ProviderService->>Firestore: Store canonical transcript segments and provider run ledger ProviderService->>Firestore: Increment daily provider usage rollup diff --git a/scripts/ASSEMBLYAI_BACKGROUND_E2E_AGENT_PROMPT.md b/scripts/ASSEMBLYAI_BACKGROUND_E2E_AGENT_PROMPT.md new file mode 100644 index 00000000000..436949f4118 --- /dev/null +++ b/scripts/ASSEMBLYAI_BACKGROUND_E2E_AGENT_PROMPT.md @@ -0,0 +1,168 @@ +# Agent prompt: AssemblyAI background batch E2E (no desktop app) + +## Context + +Branch work adds **desktop always-on Audio Recording via AssemblyAI batch chunks** instead of `/v4/listen` (Deepgram streaming). Implementation is committed on the current branch. + +**Goal for you:** Build and run the strongest **non-desktop-app** E2E proof that the pipeline works end-to-end. Do **not** launch Omi Dev or use agent-swift unless absolutely necessary. + +## What exists today + +### Backend (working — verified manually) + +- `POST /v2/desktop/background-conversation/start` → `{ conversation_id }` +- `POST /v2/desktop/background-transcribe` → `{ segments, provider, run_id }` using `STTWorkload.background` → AssemblyAI +- Helpers: [`backend/utils/conversations/desktop_background.py`](backend/utils/conversations/desktop_background.py) +- Router: [`backend/routers/desktop_background.py`](backend/routers/desktop_background.py) +- Unit tests: [`backend/tests/unit/test_desktop_background_transcribe.py`](backend/tests/unit/test_desktop_background_transcribe.py) + +### Script (partial E2E — single-chunk only) + +[`scripts/desktop_assemblyai_e2e.py`](scripts/desktop_assemblyai_e2e.py): + +```bash +# Requires: local backend on :8080, ASSEMBLYAI_BACKGROUND_STT_ENABLED=true, Omi Dev signed in +python3 scripts/desktop_assemblyai_e2e.py --background-chunk --api http://127.0.0.1:8080 +``` + +Currently uploads **one full sample MP3 as a single PCM blob** (~154s). Proves backend + AssemblyAI once; does **not** simulate desktop chunking cadence or multi-chunk timeline. + +### Desktop (implemented but not in scope for your E2E) + +- Chunker/session in [`desktop/Desktop/Sources/BackgroundTranscription/`](desktop/Desktop/Sources/BackgroundTranscription/) +- AppState wiring in [`desktop/Desktop/Sources/AppState.swift`](desktop/Desktop/Sources/AppState.swift) +- Swift unit tests: [`desktop/Desktop/Tests/BackgroundTranscriptionTests.swift`](desktop/Desktop/Tests/BackgroundTranscriptionTests.swift) — chunker/merger/reducer only; **does not** exercise `AudioMixer` → AppState path + +### Known gap (why desktop failed in manual test) + +Live desktop recording started batch mode (`background-conversation/start`) but **never POSTed chunks** for 90+ seconds. Suspected cause: `AudioMixer` callback invokes `handleMixedBackgroundAudio` off MainActor. Your E2E should **not depend on fixing this** — but optionally add a Swift integration test that would catch it (see below). + +--- + +## Your mission + +Extend automated testing to get **as close as possible** to proving the full batch background path works **without running the desktop app**. + +### Tier 1 — Must deliver (Python, runnable in CI-like env) + +Extend [`scripts/desktop_assemblyai_e2e.py`](scripts/desktop_assemblyai_e2e.py) (or add `scripts/desktop_assemblyai_e2e_batch.py`) with a **`--background-batch`** mode that: + +1. **Prerequisites check** (exit with clear message if missing): + - Backend reachable at `--api` (default `http://127.0.0.1:8080`) + - Firebase token from `defaults read com.omi.desktop-dev auth_idToken` OR `--token` flag for CI + - Optional: grep backend `/docs` or a health endpoint; document required env vars + +2. **Simulate desktop chunking in Python** (mirror [`BackgroundTranscriptionConfiguration`](desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift)): + - 16 kHz mono int16 PCM + - Split sample audio into **15s max chunks** with **1s overlap** (same as desktop chunker defaults) + - Use silence-boundary cuts if easy; hard-cut at 15s is acceptable for v1 + +3. **Full conversation lifecycle:** + ``` + POST /v2/desktop/background-conversation/start + for each chunk i: + POST /v2/desktop/background-transcribe + ?conversation_id=... + &chunk_start_ms= + &language=es # or en + POST /v1/conversations # force-process finalize + GET /v1/conversations?... # verify segments exist on conversation + ``` + +4. **Assertions (fail loudly):** + - Every chunk returns HTTP 200 + - Every chunk has `provider == "assemblyai"` + - `segments` non-empty for at least one chunk (speech sample) + - `chunk_start_ms` offsets applied: segment `start` values increase across chunks, no regression + - Multi-chunk: at least **2 chunks** uploaded from sample audio + - Finalize returns 200 or expected 404 if already processed + - Optional: fetch conversation from Firestore/API and assert `transcript_segments` length > 0 + +5. **CLI UX:** + ```bash + python3 scripts/desktop_assemblyai_e2e.py --background-batch [--api URL] [--language es] [--token TOKEN] + ``` + Print summary: chunk count, total segments, conversation_id, sample transcript snippet. + +6. **Document** in script header how to run locally: + ```bash + cd backend && DYLD_FALLBACK_LIBRARY_PATH="/opt/homebrew/lib" ./run-local.sh + # .env: ASSEMBLYAI_BACKGROUND_STT_ENABLED=true, ASSEMBLYAI_API_KEY=... + python3 scripts/desktop_assemblyai_e2e.py --background-batch + ``` + +### Tier 2 — Backend pytest integration (no live AssemblyAI if possible) + +Add [`backend/tests/integration/test_desktop_background_batch_e2e.py`](backend/tests/integration/test_desktop_background_batch_e2e.py) OR extend unit tests: + +- Mock `transcribe_bytes` for most tests (fast, no API key) +- One optional `@pytest.mark.integration` test that calls real AssemblyAI when `ASSEMBLYAI_API_KEY` set (skip otherwise) +- Test multi-chunk append: 3 chunks with offsets → Firestore segments merged correctly +- Test `provider_cluster_id` → distinct `speaker_id` when mock returns two clusters + +Run: `cd backend && python3 -m pytest tests/unit/test_desktop_background_transcribe.py -v` + +### Tier 3 — Swift integration test (no desktop app launch) + +Add test in [`desktop/Desktop/Tests/BackgroundTranscriptionTests.swift`](desktop/Desktop/Tests/BackgroundTranscriptionTests.swift) or new file: + +- **`testMixerCallbackDispatchesChunksToSession`**: Feed 16+ seconds of synthetic PCM through `BackgroundAudioChunker` the same way AppState would after receiving mixer output; mock HTTP handler returns fake segments; assert `transcribeNext` called and segments merged. +- **`testFifteenSecondContinuousSpeechProducesChunk`**: Append PCM at 100ms frames for 16s → assert at least one chunk enqueued without silence gaps. + +Run: `cd desktop && xcrun swift test -c debug --package-path Desktop --filter BackgroundTranscription` + +This catches the class of bug where audio never reaches the session (won't fix MainActor alone but validates chunker + session + drain loop). + +### Tier 4 — Optional: chunker parity test (Python ↔ Swift) + +Export chunk boundaries from Python chunker and assert Swift `BackgroundAudioChunker` produces same cut points for a fixture PCM file. Low priority unless easy. + +--- + +## Out of scope + +- Launching Omi Dev / agent-swift / live mic +- PTT path (`/v2/voice-message/transcribe`) +- `/v4/listen` WebSocket +- Prod Helm rollout + +--- + +## Success criteria + +You are done when: + +1. `python3 scripts/desktop_assemblyai_e2e.py --background-batch` passes against local backend with AssemblyAI enabled, printing conversation_id + segment count. +2. `cd backend && python3 -m pytest tests/unit/test_desktop_background_transcribe.py -v` passes (0 failures). +3. Swift `BackgroundTranscription` tests pass including at least one **multi-chunk / 15s boundary** test. +4. README or script docstring explains how to run without desktop app. + +--- + +## Reference files + +| File | Purpose | +|------|---------| +| [`backend/routers/desktop_background.py`](backend/routers/desktop_background.py) | HTTP endpoints | +| [`scripts/desktop_assemblyai_e2e.py`](scripts/desktop_assemblyai_e2e.py) | Extend this | +| [`desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift`](desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift) | Match chunk sizes | +| [`.cursor/plans/assemblyai_background_listen_098fc011.plan.md`](.cursor/plans/assemblyai_background_listen_098fc011.plan.md) | Full architecture plan | + +## Environment + +```bash +# backend/.env (required) +ASSEMBLYAI_BACKGROUND_STT_ENABLED=true +ASSEMBLYAI_BACKGROUND_STT_WORKLOADS=sync,background,postprocess +ASSEMBLYAI_API_KEY= +LOCAL_DEVELOPMENT=true +``` + +Auth: `LOCAL_DEVELOPMENT=true` maps failed token verify → uid `123`. + +--- + +## Notes from manual debugging + +- Single-chunk `--background-chunk` **passed** (AssemblyAI returned transcript for NBC sample). +- Desktop live path: conversation created but **no chunk POSTs** — fix is likely `Task { @MainActor in handleMixedBackgroundAudio(...) }` in mixer callback; **separate task**, not required for your Python E2E. diff --git a/scripts/desktop_assemblyai_e2e.py b/scripts/desktop_assemblyai_e2e.py new file mode 100755 index 00000000000..570799e6cbe --- /dev/null +++ b/scripts/desktop_assemblyai_e2e.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +"""Desktop-adjacent AssemblyAI E2E: uses Omi Dev auth + local Python backend. + +By default this script exercises the same backend path as mobile offline sync +(POST /v2/sync-local-files -> STTWorkload.sync). With --background-chunk it +exercises desktop batch listen (POST /v2/desktop/background-transcribe -> +STTWorkload.background) using raw PCM bytes. + +Usage: + python3 scripts/desktop_assemblyai_e2e.py [--api http://127.0.0.1:8080] + python3 scripts/desktop_assemblyai_e2e.py --background-chunk [--api http://127.0.0.1:8080] +""" +from __future__ import annotations + +import argparse +import json +import struct +import subprocess +import sys +import time +import urllib.error +import urllib.request +import wave +from pathlib import Path + +DEFAULTS_DOMAIN = "com.omi.desktop-dev" +SAMPLE_MP3_URL = "https://storage.googleapis.com/aai-docs-samples/nbc.mp3" + + +def read_desktop_auth_token() -> str: + result = subprocess.run( + ["defaults", "read", DEFAULTS_DOMAIN, "auth_idToken"], + capture_output=True, + text=True, + check=False, + ) + token = (result.stdout or "").strip() + if result.returncode != 0 or not token: + raise SystemExit( + "No Omi Dev auth token found. Sign in via ./run.sh --yolo first, " + f"then retry (defaults domain: {DEFAULTS_DOMAIN})." + ) + return token + + +def mp3_to_pcm_bytes(mp3_path: Path, wav_path: Path, sample_rate: int = 16000) -> bytes: + subprocess.run( + [ + "ffmpeg", + "-y", + "-i", + str(mp3_path), + "-ar", + str(sample_rate), + "-ac", + "1", + "-f", + "wav", + str(wav_path), + ], + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + with wave.open(str(wav_path), "rb") as wav_file: + return wav_file.readframes(wav_file.getnframes()) + + +def write_length_prefixed_pcm_bin(pcm: bytes, bin_path: Path, sample_rate: int = 16000) -> None: + # Length-prefixed PCM frames (~100ms) for sync-local-files decoder + frame_samples = sample_rate // 10 + frame_bytes = frame_samples * 2 + with bin_path.open("wb") as out: + for offset in range(0, len(pcm), frame_bytes): + chunk = pcm[offset : offset + frame_bytes] + if not chunk: + continue + out.write(struct.pack(" None: + wav_path = bin_path.with_suffix(".wav") + pcm = mp3_to_pcm_bytes(mp3_path, wav_path, sample_rate) + write_length_prefixed_pcm_bin(pcm, bin_path, sample_rate) + wav_path.unlink(missing_ok=True) + + +def ensure_sample_mp3(workdir: Path) -> Path: + workdir.mkdir(parents=True, exist_ok=True) + mp3_path = workdir / "sample.mp3" + if not mp3_path.exists(): + print("Downloading sample audio...") + urllib.request.urlretrieve(SAMPLE_MP3_URL, mp3_path) + return mp3_path + + +def ensure_sample_bin(workdir: Path) -> Path: + mp3_path = ensure_sample_mp3(workdir) + bin_path = workdir / "desktop_e2e_sample.bin" + if not bin_path.exists(): + print("Converting sample to sync .bin format...") + mp3_to_pcm_bin(mp3_path, bin_path) + return bin_path + + +def ensure_sample_pcm(workdir: Path) -> Path: + mp3_path = ensure_sample_mp3(workdir) + pcm_path = workdir / "desktop_e2e_sample.raw.pcm" + if not pcm_path.exists(): + print("Converting sample to raw PCM format...") + wav_path = workdir / "desktop_e2e_sample.raw.wav" + pcm_path.write_bytes(mp3_to_pcm_bytes(mp3_path, wav_path)) + wav_path.unlink(missing_ok=True) + return pcm_path + + +def json_request( + url: str, + token: str, + *, + method: str = "POST", + payload: dict | None = None, + timeout: int = 120, +) -> dict: + data = json.dumps(payload or {}).encode() + req = urllib.request.Request( + url, + data=data, + method=method, + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + ) + with urllib.request.urlopen(req, timeout=timeout) as resp: + return json.loads(resp.read().decode()) + + +def multipart_upload(url: str, token: str, bin_path: Path, filename: str) -> dict: + boundary = "----omiAssemblyAIe2e" + body = bytearray() + body.extend(f"--{boundary}\r\n".encode()) + body.extend(f'Content-Disposition: form-data; name="files"; filename="{filename}"\r\n'.encode()) + body.extend(b"Content-Type: application/octet-stream\r\n\r\n") + body.extend(bin_path.read_bytes()) + body.extend(f"\r\n--{boundary}--\r\n".encode()) + + req = urllib.request.Request( + url, + data=bytes(body), + method="POST", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": f"multipart/form-data; boundary={boundary}", + }, + ) + with urllib.request.urlopen(req, timeout=120) as resp: + return json.loads(resp.read().decode()) + + +def background_chunk_upload(api_base: str, token: str, pcm_path: Path) -> dict: + start_url = f"{api_base.rstrip('/')}/v2/desktop/background-conversation/start" + started = json_request(start_url, token, payload={"language": "en", "source": "desktop"}) + conversation_id = started.get("conversation_id") + if not conversation_id: + raise RuntimeError(f"Unexpected background-conversation response: {started}") + + transcribe_url = ( + f"{api_base.rstrip('/')}/v2/desktop/background-transcribe" + f"?conversation_id={conversation_id}&chunk_start_ms=0&sample_rate=16000&channels=1" + ) + req = urllib.request.Request( + transcribe_url, + data=pcm_path.read_bytes(), + method="POST", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/octet-stream", + }, + ) + with urllib.request.urlopen(req, timeout=180) as resp: + result = json.loads(resp.read().decode()) + result["_conversation_id"] = conversation_id + return result + + +def poll_job(api_base: str, token: str, job_id: str, timeout_s: int = 900) -> dict: + url = f"{api_base.rstrip('/')}/v2/sync-local-files/{job_id}" + deadline = time.time() + timeout_s + while time.time() < deadline: + req = urllib.request.Request( + url, + headers={"Authorization": f"Bearer {token}"}, + ) + with urllib.request.urlopen(req, timeout=60) as resp: + job = json.loads(resp.read().decode()) + status = job.get("status") + stage = job.get("stage") + print(f" job {job_id}: status={status} stage={stage}") + if status in {"completed", "partial_failure", "failed"}: + return job + time.sleep(3) + raise TimeoutError(f"Timed out waiting for job {job_id}") + + +def main() -> int: + parser = argparse.ArgumentParser(description="Desktop AssemblyAI E2E via sync-local-files") + parser.add_argument("--api", default="http://127.0.0.1:8080", help="Local Python backend base URL") + parser.add_argument("--workdir", default="/tmp/omi-assemblyai-e2e", help="Temp dir for sample audio") + parser.add_argument( + "--background-chunk", + action="store_true", + help="Exercise /v2/desktop/background-transcribe with raw PCM instead of sync-local-files", + ) + args = parser.parse_args() + + token = read_desktop_auth_token() + + if args.background_chunk: + pcm_path = ensure_sample_pcm(Path(args.workdir)) + print(f"Posting raw PCM chunk {pcm_path.name} to {args.api.rstrip('/')}/v2/desktop/background-transcribe ...") + try: + result = background_chunk_upload(args.api, token, pcm_path) + except urllib.error.HTTPError as exc: + body = exc.read().decode(errors="replace") + print(f"Background chunk failed: HTTP {exc.code}\n{body}", file=sys.stderr) + return 1 + except RuntimeError as exc: + print(str(exc), file=sys.stderr) + return 1 + + print(json.dumps(result, indent=2)) + if result.get("provider") != "assemblyai": + print(f"Expected provider=assemblyai, got {result.get('provider')!r}.", file=sys.stderr) + return 1 + if not result.get("segments"): + print("Expected non-empty segments from background chunk.", file=sys.stderr) + return 1 + + print("\nBackground chunk succeeded with provider=assemblyai.") + print(f"Conversation: {result.get('_conversation_id')}") + return 0 + + bin_path = ensure_sample_bin(Path(args.workdir)) + ts = int(time.time()) + # Filename must include _pcm16_{sampleRate}_ so sync decode uses PCM path (not Opus). + filename = f"audio_desktop_pcm16_16000_1_fs160_{ts}.bin" + + upload_url = f"{args.api.rstrip('/')}/v2/sync-local-files" + print(f"Uploading {bin_path.name} to {upload_url} ...") + try: + queued = multipart_upload(upload_url, token, bin_path, filename) + except urllib.error.HTTPError as exc: + body = exc.read().decode(errors="replace") + print(f"Upload failed: HTTP {exc.code}\n{body}", file=sys.stderr) + return 1 + + job_id = queued.get("job_id") + if not job_id: + print(f"Unexpected response: {queued}", file=sys.stderr) + return 1 + + print(f"Queued job_id={job_id}; polling...") + final = poll_job(args.api, token, job_id) + print(json.dumps(final, indent=2)) + + if final.get("status") not in {"completed", "partial_failure"}: + print("Job did not complete successfully.", file=sys.stderr) + return 1 + + print("\nCheck backend logs for provider=assemblyai on workload=sync.") + print("Firestore: transcription_provider_runs collection for your uid.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From ca550536e3ffcf03984bb0d329d7d05cd869ff8b Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 20:28:27 +0700 Subject: [PATCH 25/44] Add AssemblyAI background batch E2E coverage --- .../test_desktop_background_transcribe.py | 49 +++- .../Tests/BackgroundTranscriptionTests.swift | 35 +++ scripts/desktop_assemblyai_e2e.py | 242 +++++++++++++++++- 3 files changed, 323 insertions(+), 3 deletions(-) diff --git a/backend/tests/unit/test_desktop_background_transcribe.py b/backend/tests/unit/test_desktop_background_transcribe.py index f305a6e0cd9..79f45262773 100644 --- a/backend/tests/unit/test_desktop_background_transcribe.py +++ b/backend/tests/unit/test_desktop_background_transcribe.py @@ -141,7 +141,7 @@ def _transcribe_bytes(audio_bytes, **_kwargs): return SimpleNamespace( result=ProviderTranscriptResult(provider='assemblyai', model='universal-2', words=[], utterances=[]), detected_language='en', - segments=default_segments, + segments=[segment.model_copy(deep=True) for segment in default_segments], run_id='run-1', ) @@ -236,6 +236,53 @@ def test_cluster_speaker_mapping_assigns_distinct_ids(monkeypatch): assert [segment['speaker'] for segment in data['segments']] == ['SPEAKER_00', 'SPEAKER_01'] +def test_background_transcribe_multi_chunk_offsets_persist_and_keep_speaker_map(monkeypatch): + client, mock_transcribe = _client(monkeypatch) + speaker_map_store = {} + + def _redis_get(key): + return speaker_map_store.get(key) + + def _redis_set(key, value, **_kwargs): + speaker_map_store[key] = value + + monkeypatch.setattr(desktop_background.redis_db.r, 'get', _redis_get) + monkeypatch.setattr(desktop_background.redis_db.r, 'set', _redis_set) + + responses = [ + [TranscriptSegment(text='First.', is_user=False, start=0.1, end=1.0, provider_cluster_id='A')], + [TranscriptSegment(text='Second.', is_user=False, start=0.1, end=1.0, provider_cluster_id='B')], + [TranscriptSegment(text='Third.', is_user=False, start=0.1, end=1.0, provider_cluster_id='A')], + ] + + def _transcribe_bytes(_audio_bytes, **_kwargs): + return SimpleNamespace( + result=ProviderTranscriptResult(provider='assemblyai', model='universal-2', words=[], utterances=[]), + detected_language='en', + segments=[segment.model_copy(deep=True) for segment in responses.pop(0)], + run_id='run-multi', + ) + + mock_transcribe.side_effect = _transcribe_bytes + + for chunk_start_ms in (0, 14000, 28000): + response = client.post( + f'/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms={chunk_start_ms}', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + assert response.status_code == 200 + assert response.json()['provider'] == 'assemblyai' + + appended_segments = [ + call.args[2][0] for call in desktop_background.append_segments_to_in_progress_conversation.call_args_list + ] + assert [segment.start for segment in appended_segments] == [0.1, 14.1, 28.1] + assert [segment.end for segment in appended_segments] == [1.0, 15.0, 29.0] + assert [segment.speaker_id for segment in appended_segments] == [0, 1, 0] + assert [segment.speaker for segment in appended_segments] == ['SPEAKER_00', 'SPEAKER_01', 'SPEAKER_00'] + + def test_byok_background_routing_uses_deepgram_when_only_deepgram_key(monkeypatch): from utils.stt import provider_service diff --git a/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift index 98ccb2f8895..86e3daf7323 100644 --- a/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift +++ b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift @@ -91,6 +91,41 @@ final class BackgroundTranscriptionTests: XCTestCase { XCTAssertEqual(session.snapshot().segments.count, 2) } + func testFifteenSecondContinuousSpeechProducesChunkThroughSession() async throws { + let configuration = BackgroundTranscriptionConfiguration() + var transcribedStarts: [Double] = [] + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + transcribedStarts.append(chunk.startTime) + return [ + Self.backendSegment( + id: "chunk-\(chunk.startTime)", + text: "continuous speech", + start: chunk.startTime, + end: chunk.startTime + 1 + ) + ] + } + + let samplesPerFrame = configuration.sampleRate / 10 + let frame = pcm(samples: Array(repeating: 1_000, count: samplesPerFrame)) + var enqueuedChunks = 0 + + for frameIndex in 0..<160 { + let result = session.append(pcmData: frame, startTime: Double(frameIndex) / 10.0) + enqueuedChunks += result.enqueuedChunks + } + + XCTAssertGreaterThanOrEqual(enqueuedChunks, 1) + XCTAssertGreaterThanOrEqual(session.pendingChunkCount, 1) + + let first = try await session.transcribeNext() + XCTAssertNotNil(first) + XCTAssertEqual(first?.chunk.startTime, 0, accuracy: 0.001) + XCTAssertEqual(transcribedStarts, [0]) + XCTAssertEqual(session.snapshot().processedChunkCount, 1) + XCTAssertEqual(session.snapshot().segments.count, 1) + } + func testSessionFinishFlushesTail() async throws { let configuration = BackgroundTranscriptionConfiguration( sampleRate: 10, diff --git a/scripts/desktop_assemblyai_e2e.py b/scripts/desktop_assemblyai_e2e.py index 570799e6cbe..d25d65ee881 100755 --- a/scripts/desktop_assemblyai_e2e.py +++ b/scripts/desktop_assemblyai_e2e.py @@ -4,11 +4,19 @@ By default this script exercises the same backend path as mobile offline sync (POST /v2/sync-local-files -> STTWorkload.sync). With --background-chunk it exercises desktop batch listen (POST /v2/desktop/background-transcribe -> -STTWorkload.background) using raw PCM bytes. +STTWorkload.background) using raw PCM bytes. With --background-batch it +simulates desktop background chunking without launching the desktop app. Usage: + cd backend && DYLD_FALLBACK_LIBRARY_PATH="/opt/homebrew/lib" ./run-local.sh + # backend/.env: + # LOCAL_DEVELOPMENT=true + # ASSEMBLYAI_BACKGROUND_STT_ENABLED=true + # ASSEMBLYAI_BACKGROUND_STT_WORKLOADS=sync,background,postprocess + # ASSEMBLYAI_API_KEY=... python3 scripts/desktop_assemblyai_e2e.py [--api http://127.0.0.1:8080] python3 scripts/desktop_assemblyai_e2e.py --background-chunk [--api http://127.0.0.1:8080] + python3 scripts/desktop_assemblyai_e2e.py --background-batch [--api http://127.0.0.1:8080] [--language en] [--token TOKEN] """ from __future__ import annotations @@ -19,12 +27,18 @@ import sys import time import urllib.error +import urllib.parse import urllib.request import wave from pathlib import Path DEFAULTS_DOMAIN = "com.omi.desktop-dev" SAMPLE_MP3_URL = "https://storage.googleapis.com/aai-docs-samples/nbc.mp3" +SAMPLE_RATE = 16000 +CHANNELS = 1 +BYTES_PER_SAMPLE = 2 +BACKGROUND_CHUNK_SECONDS = 15 +BACKGROUND_OVERLAP_SECONDS = 1 def read_desktop_auth_token() -> str: @@ -43,6 +57,23 @@ def read_desktop_auth_token() -> str: return token +def require_backend_reachable(api_base: str) -> None: + url = f"{api_base.rstrip('/')}/docs" + try: + req = urllib.request.Request(url, method="GET") + with urllib.request.urlopen(req, timeout=10) as resp: + if resp.status >= 500: + raise RuntimeError(f"Backend returned HTTP {resp.status}") + except (urllib.error.URLError, TimeoutError, RuntimeError) as exc: + raise SystemExit( + f"Backend is not reachable at {api_base}. Start the local backend first, for example:\n" + ' cd backend && DYLD_FALLBACK_LIBRARY_PATH="/opt/homebrew/lib" ./run-local.sh\n' + "Required backend env includes LOCAL_DEVELOPMENT=true, " + "ASSEMBLYAI_BACKGROUND_STT_ENABLED=true, and ASSEMBLYAI_API_KEY.\n" + f"Reachability error: {exc}" + ) + + def mp3_to_pcm_bytes(mp3_path: Path, wav_path: Path, sample_rate: int = 16000) -> bytes: subprocess.run( [ @@ -115,6 +146,34 @@ def ensure_sample_pcm(workdir: Path) -> Path: return pcm_path +def split_background_pcm_chunks(pcm: bytes) -> list[tuple[int, bytes]]: + """Mirror desktop's 15s chunks with 1s retained overlap. + + Returns (chunk_start_ms, pcm_bytes). Hard cuts are intentional here; the + Swift chunker may cut earlier at silence, but hard-cut parity is enough to + prove multi-chunk backend persistence and offset handling. + """ + bytes_per_second = SAMPLE_RATE * CHANNELS * BYTES_PER_SAMPLE + chunk_bytes = BACKGROUND_CHUNK_SECONDS * bytes_per_second + overlap_bytes = BACKGROUND_OVERLAP_SECONDS * bytes_per_second + stride_bytes = chunk_bytes - overlap_bytes + if stride_bytes <= 0: + raise ValueError("background chunk stride must be positive") + + chunks: list[tuple[int, bytes]] = [] + offset = 0 + while offset < len(pcm): + chunk = pcm[offset : offset + chunk_bytes] + if not chunk: + break + start_ms = int(offset / bytes_per_second * 1000) + chunks.append((start_ms, chunk)) + if offset + chunk_bytes >= len(pcm): + break + offset += stride_bytes + return chunks + + def json_request( url: str, token: str, @@ -137,6 +196,16 @@ def json_request( return json.loads(resp.read().decode()) +def get_json_request(url: str, token: str, *, timeout: int = 60) -> dict: + req = urllib.request.Request( + url, + method="GET", + headers={"Authorization": f"Bearer {token}"}, + ) + with urllib.request.urlopen(req, timeout=timeout) as resp: + return json.loads(resp.read().decode()) + + def multipart_upload(url: str, token: str, bin_path: Path, filename: str) -> dict: boundary = "----omiAssemblyAIe2e" body = bytearray() @@ -185,6 +254,138 @@ def background_chunk_upload(api_base: str, token: str, pcm_path: Path) -> dict: return result +def background_transcribe_chunk( + api_base: str, + token: str, + conversation_id: str, + chunk_start_ms: int, + chunk: bytes, + language: str, +) -> dict: + transcribe_url = ( + f"{api_base.rstrip('/')}/v2/desktop/background-transcribe" + f"?conversation_id={conversation_id}" + f"&chunk_start_ms={chunk_start_ms}" + f"&sample_rate={SAMPLE_RATE}" + f"&channels={CHANNELS}" + f"&language={urllib.parse.quote(language)}" + ) + req = urllib.request.Request( + transcribe_url, + data=chunk, + method="POST", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/octet-stream", + }, + ) + with urllib.request.urlopen(req, timeout=180) as resp: + return json.loads(resp.read().decode()) + + +def background_batch_upload(api_base: str, token: str, pcm_path: Path, language: str) -> dict: + require_backend_reachable(api_base) + api_base = api_base.rstrip("/") + + start_url = f"{api_base}/v2/desktop/background-conversation/start" + started = json_request(start_url, token, payload={"language": language, "source": "desktop"}) + conversation_id = started.get("conversation_id") + if not conversation_id: + raise RuntimeError(f"Unexpected background-conversation response: {started}") + + chunks = split_background_pcm_chunks(pcm_path.read_bytes()) + if len(chunks) < 2: + raise RuntimeError(f"Expected sample to produce at least 2 chunks, got {len(chunks)}") + + total_segments = 0 + non_empty_chunks = 0 + previous_first_segment_start = None + snippet_parts: list[str] = [] + chunk_summaries: list[dict] = [] + + for index, (chunk_start_ms, chunk) in enumerate(chunks, start=1): + duration_ms = int(len(chunk) / (SAMPLE_RATE * CHANNELS * BYTES_PER_SAMPLE) * 1000) + print( + f"Posting chunk {index}/{len(chunks)} " + f"start_ms={chunk_start_ms} duration_ms={duration_ms} bytes={len(chunk)} ..." + ) + result = background_transcribe_chunk(api_base, token, conversation_id, chunk_start_ms, chunk, language) + provider = result.get("provider") + if provider != "assemblyai": + raise RuntimeError(f"Chunk {index} expected provider=assemblyai, got {provider!r}: {result}") + + segments = result.get("segments") or [] + for segment in segments: + if segment.get("start") is None or segment.get("end") is None: + raise RuntimeError(f"Chunk {index} returned segment without start/end: {segment}") + if segment["start"] + 0.001 < chunk_start_ms / 1000.0: + raise RuntimeError( + f"Chunk {index} segment start {segment['start']} regressed before offset {chunk_start_ms / 1000.0}" + ) + + first_segment_start = next((segment["start"] for segment in segments if segment.get("start") is not None), None) + if first_segment_start is not None: + if previous_first_segment_start is not None and first_segment_start + 0.001 < previous_first_segment_start: + raise RuntimeError( + f"Chunk {index} first segment start regressed: " + f"{first_segment_start} < {previous_first_segment_start}" + ) + previous_first_segment_start = first_segment_start + + if segments: + non_empty_chunks += 1 + snippet_parts.extend(segment.get("text", "") for segment in segments[:2]) + total_segments += len(segments) + chunk_summaries.append( + { + "index": index, + "chunk_start_ms": chunk_start_ms, + "chunk_duration_ms": result.get("chunk_duration_ms"), + "segments": len(segments), + "run_id": result.get("run_id"), + } + ) + + if non_empty_chunks == 0 or total_segments == 0: + raise RuntimeError("Expected non-empty AssemblyAI segments from at least one chunk.") + + before_finalize = get_json_request(f"{api_base}/v1/conversations/{conversation_id}", token) + persisted_before = before_finalize.get("transcript_segments") or [] + if not persisted_before: + raise RuntimeError(f"Expected persisted transcript_segments before finalize: {before_finalize}") + + finalize_status = None + finalize_body: dict | str | None = None + try: + finalize_body = json_request(f"{api_base}/v1/conversations", token, payload={}, timeout=240) + finalize_status = 200 + except urllib.error.HTTPError as exc: + finalize_status = exc.code + body = exc.read().decode(errors="replace") + try: + finalize_body = json.loads(body) + except json.JSONDecodeError: + finalize_body = body + if exc.code != 404: + raise RuntimeError(f"Finalize failed: HTTP {exc.code}\n{body}") from exc + + after_finalize = get_json_request(f"{api_base}/v1/conversations/{conversation_id}", token) + persisted_after = after_finalize.get("transcript_segments") or [] + if not persisted_after: + raise RuntimeError(f"Expected transcript_segments on conversation after finalize: {after_finalize}") + + return { + "conversation_id": conversation_id, + "chunks": chunk_summaries, + "total_segments_returned": total_segments, + "persisted_segments_before_finalize": len(persisted_before), + "persisted_segments_after_finalize": len(persisted_after), + "finalize_status": finalize_status, + "finalize_body": finalize_body, + "snippet": " ".join(part.strip() for part in snippet_parts if part.strip())[:300], + } + + def poll_job(api_base: str, token: str, job_id: str, timeout_s: int = 900) -> dict: url = f"{api_base.rstrip('/')}/v2/sync-local-files/{job_id}" deadline = time.time() + timeout_s @@ -208,14 +409,25 @@ def main() -> int: parser = argparse.ArgumentParser(description="Desktop AssemblyAI E2E via sync-local-files") parser.add_argument("--api", default="http://127.0.0.1:8080", help="Local Python backend base URL") parser.add_argument("--workdir", default="/tmp/omi-assemblyai-e2e", help="Temp dir for sample audio") + parser.add_argument("--language", default="en", help="Language passed to background transcription") + parser.add_argument("--token", help="Firebase ID token; defaults to Omi Dev auth_idToken from macOS defaults") parser.add_argument( "--background-chunk", action="store_true", help="Exercise /v2/desktop/background-transcribe with raw PCM instead of sync-local-files", ) + parser.add_argument( + "--background-batch", + action="store_true", + help="Exercise desktop background batch lifecycle with 15s/1s-overlap raw PCM chunks", + ) args = parser.parse_args() - token = read_desktop_auth_token() + if args.background_chunk and args.background_batch: + print("Choose only one of --background-chunk or --background-batch.", file=sys.stderr) + return 2 + + token = args.token or read_desktop_auth_token() if args.background_chunk: pcm_path = ensure_sample_pcm(Path(args.workdir)) @@ -242,6 +454,32 @@ def main() -> int: print(f"Conversation: {result.get('_conversation_id')}") return 0 + if args.background_batch: + pcm_path = ensure_sample_pcm(Path(args.workdir)) + print(f"Posting simulated background batch from {pcm_path.name} to {args.api.rstrip('/')} ...") + try: + summary = background_batch_upload(args.api, token, pcm_path, args.language) + except urllib.error.HTTPError as exc: + body = exc.read().decode(errors="replace") + print(f"Background batch failed: HTTP {exc.code}\n{body}", file=sys.stderr) + return 1 + except (RuntimeError, SystemExit) as exc: + print(str(exc), file=sys.stderr) + return 1 + + print("\nBackground batch succeeded with provider=assemblyai.") + print(f"Conversation: {summary['conversation_id']}") + print(f"Chunks uploaded: {len(summary['chunks'])}") + print(f"Segments returned by chunks: {summary['total_segments_returned']}") + print(f"Segments persisted before finalize: {summary['persisted_segments_before_finalize']}") + print(f"Segments persisted after finalize: {summary['persisted_segments_after_finalize']}") + print(f"Finalize status: HTTP {summary['finalize_status']}") + if summary["snippet"]: + print(f"Transcript snippet: {summary['snippet']}") + print("\nChunk summary:") + print(json.dumps(summary["chunks"], indent=2)) + return 0 + bin_path = ensure_sample_bin(Path(args.workdir)) ts = int(time.time()) # Filename must include _pcm16_{sampleRate}_ so sync decode uses PCM path (not Opus). From e1ebf27a4d110c7ecbbe359ed7975e1a1a44d613 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 12:38:16 -0400 Subject: [PATCH 26/44] Fix desktop AssemblyAI background batch transcription Stabilize desktop cloud batch recording by using 15s cloud chunks, nonfatal backpressure, speech activity gating, explicit batch language, and resilient ASR queue draining. Add explicit desktop background conversation finish routing, AssemblyAI background fail-closed behavior, route/rate-limit coverage, and developer docs for the batch path. Validation: backend focused pytest suite 41 passed, 2 skipped; Swift BackgroundTranscription 15 passed; APIClient finish route 1 passed; ListenProtocol 25 passed; git diff --check clean; live Omi Dev session uploaded 15s AssemblyAI chunks, suppressed quiet-room chunks, and reconciled conversation ba94a0a9-1af8-4d51-b98b-0a0f269bef65. --- backend/routers/desktop_background.py | 18 + backend/tests/unit/test_assemblyai_adapter.py | 9 + .../unit/test_background_provider_service.py | 9 +- .../test_desktop_background_transcribe.py | 39 ++ backend/tests/unit/test_rate_limiting.py | 2 + .../utils/conversations/desktop_background.py | 34 ++ backend/utils/rate_limit_config.py | 1 + backend/utils/stt/assemblyai_adapter.py | 3 +- backend/utils/stt/providers.py | 2 + desktop/Desktop/Sources/APIClient.swift | 9 + desktop/Desktop/Sources/AppState.swift | 433 ++++++++++++------ .../BackgroundAudioChunker.swift | 47 +- .../BackgroundTranscriptMerger.swift | 58 ++- ...BackgroundTranscriptionConfiguration.swift | 25 +- .../CloudBackgroundTranscriptionSession.swift | 86 +++- .../SpeechActivityDetector.swift | 183 ++++++++ .../MainWindow/Pages/SettingsPage.swift | 2 +- .../Services/AssistantSettings.swift | 22 +- .../Sources/TranscriptionService.swift | 10 +- .../Desktop/Tests/APIClientRoutingTests.swift | 8 + .../Tests/BackgroundTranscriptionTests.swift | 288 ++++++++++-- .../backend/listen_pusher_pipeline.mdx | 18 +- 22 files changed, 1058 insertions(+), 248 deletions(-) create mode 100644 desktop/Desktop/Sources/BackgroundTranscription/SpeechActivityDetector.swift diff --git a/backend/routers/desktop_background.py b/backend/routers/desktop_background.py index 0242ca1fa1d..141581764fa 100644 --- a/backend/routers/desktop_background.py +++ b/backend/routers/desktop_background.py @@ -13,8 +13,10 @@ from utils.analytics import record_usage from utils.chat import resolve_voice_message_language from utils.conversations.desktop_background import ( + DesktopBackgroundConversationError, append_segments_to_in_progress_conversation, create_in_progress_desktop_conversation, + finish_desktop_background_conversation, ) from utils.executors import db_executor, run_blocking, sync_executor from utils.fair_use import is_hard_restricted, record_speech_ms @@ -63,6 +65,22 @@ async def start_background_conversation( return {"conversation_id": conversation_id} +@router.post("/background-conversation/{conversation_id}/finish") +async def finish_background_conversation( + conversation_id: str, + uid: str = Depends(auth.with_rate_limit(auth.get_current_user_uid, "desktop:background_conversation_finish")), +): + try: + return await run_blocking( + db_executor, + finish_desktop_background_conversation, + uid, + conversation_id, + ) + except DesktopBackgroundConversationError as e: + raise HTTPException(status_code=e.status_code, detail=str(e)) + + @router.post("/background-transcribe") async def background_transcribe( request: Request, diff --git a/backend/tests/unit/test_assemblyai_adapter.py b/backend/tests/unit/test_assemblyai_adapter.py index 4bf18b396cc..2c42bfde443 100644 --- a/backend/tests/unit/test_assemblyai_adapter.py +++ b/backend/tests/unit/test_assemblyai_adapter.py @@ -81,6 +81,15 @@ def test_assemblyai_result_normalizes_utterances_words_and_speaker_clusters(): assert result.words[1].provider_cluster_id == 'A' +def test_assemblyai_result_does_not_report_multi_when_detection_returns_no_language(): + transcript = _completed_transcript() + transcript.pop('language_code') + + result = normalize_assemblyai_transcript_result(transcript, model='universal-2', language='multi') + + assert result.language is None + + def test_assemblyai_transcribe_url_submits_diarization_and_polls_to_completion(): fake_client = FakeClient( [ diff --git a/backend/tests/unit/test_background_provider_service.py b/backend/tests/unit/test_background_provider_service.py index f69dd7640b4..a0de2de4a2b 100644 --- a/backend/tests/unit/test_background_provider_service.py +++ b/backend/tests/unit/test_background_provider_service.py @@ -238,11 +238,11 @@ def test_provider_service_records_retry_exhaustion_without_fallback(monkeypatch) with patch.object( provider_service, '_assemblyai_prerecorded_provider', return_value=assemblyai_provider - ), patch.object(provider_service, 'create_provider_run', return_value='run-aai'), patch.object( + ), patch.object(provider_service, '_deepgram_prerecorded_provider') as deepgram_provider, patch.object( + provider_service, 'create_provider_run', return_value='run-aai' + ), patch.object( provider_service, 'finalize_provider_run' - ) as finalize_run, patch.object( - provider_service, 'get_fallback_prerecorded_provider_name', return_value=None - ): + ) as finalize_run: with pytest.raises(RuntimeError, match='assemblyai transcription failed after 2 attempts'): provider_service.transcribe_url( 'https://example.test/audio.wav', @@ -253,6 +253,7 @@ def test_provider_service_records_retry_exhaustion_without_fallback(monkeypatch) ) assert assemblyai_provider.transcribe_url.call_count == 2 + deepgram_provider.assert_not_called() finalize_run.assert_called_once() assert finalize_run.call_args.kwargs['run_id'] == 'run-aai' assert finalize_run.call_args.kwargs['provider'] == 'assemblyai' diff --git a/backend/tests/unit/test_desktop_background_transcribe.py b/backend/tests/unit/test_desktop_background_transcribe.py index 79f45262773..ef5fa0477a1 100644 --- a/backend/tests/unit/test_desktop_background_transcribe.py +++ b/backend/tests/unit/test_desktop_background_transcribe.py @@ -27,11 +27,21 @@ database_client_stub.document_id_from_seed = getattr(database_client_stub, 'document_id_from_seed', MagicMock()) sys.modules.setdefault('database.conversations', MagicMock()) sys.modules.setdefault('database.redis_db', SimpleNamespace(r=MagicMock())) + + +class _DesktopBackgroundConversationError(ValueError): + def __init__(self, message: str, status_code: int = 400): + super().__init__(message) + self.status_code = status_code + + sys.modules.setdefault( 'utils.conversations.desktop_background', SimpleNamespace( + DesktopBackgroundConversationError=_DesktopBackgroundConversationError, append_segments_to_in_progress_conversation=MagicMock(), create_in_progress_desktop_conversation=MagicMock(return_value='conv-1'), + finish_desktop_background_conversation=MagicMock(return_value={'id': 'conv-1', 'status': 'completed'}), ), ) sys.modules.setdefault( @@ -118,6 +128,13 @@ def _client(monkeypatch, *, segments=None): lambda _uid, _cid: {'id': _cid, 'status': 'in_progress'}, ) monkeypatch.setattr(desktop_background, 'append_segments_to_in_progress_conversation', MagicMock(return_value=[])) + monkeypatch.setattr( + desktop_background, + 'finish_desktop_background_conversation', + MagicMock(return_value={'id': 'conv-1', 'status': 'completed'}), + ) + if not hasattr(desktop_background.redis_db, 'r'): + monkeypatch.setattr(desktop_background.redis_db, 'r', MagicMock(), raising=False) monkeypatch.setattr(desktop_background.redis_db.r, 'get', lambda _key: None) monkeypatch.setattr(desktop_background.redis_db.r, 'set', lambda *_args, **_kwargs: None) @@ -169,6 +186,28 @@ def test_background_transcribe_returns_segments_with_offset(monkeypatch): assert data['segments'][0]['speaker'] == 'SPEAKER_00' +def test_finish_background_conversation_uses_explicit_conversation_id(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + + response = client.post('/v2/desktop/background-conversation/conv-1/finish') + + assert response.status_code == 200 + assert response.json()['id'] == 'conv-1' + desktop_background.finish_desktop_background_conversation.assert_called_once_with('test-uid', 'conv-1') + + +def test_finish_background_conversation_maps_validation_error(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + desktop_background.finish_desktop_background_conversation.side_effect = ( + desktop_background.DesktopBackgroundConversationError('conversation is not in_progress', status_code=409) + ) + + response = client.post('/v2/desktop/background-conversation/conv-1/finish') + + assert response.status_code == 409 + assert response.json()['detail'] == 'conversation is not in_progress' + + def test_background_transcribe_wraps_linear16_pcm_as_wav(monkeypatch): client, mock_transcribe = _client(monkeypatch) diff --git a/backend/tests/unit/test_rate_limiting.py b/backend/tests/unit/test_rate_limiting.py index 38964d0563f..ea36db2c39b 100644 --- a/backend/tests/unit/test_rate_limiting.py +++ b/backend/tests/unit/test_rate_limiting.py @@ -331,6 +331,8 @@ def test_all_router_policies_exist(self): "chat:initial", "voice:message", "voice:transcribe", + "desktop:background_transcribe", + "desktop:background_conversation_finish", "file:upload", "agent:execute_tool", "mcp:sse", diff --git a/backend/utils/conversations/desktop_background.py b/backend/utils/conversations/desktop_background.py index 0f10df6f8d1..9d82f5132ae 100644 --- a/backend/utils/conversations/desktop_background.py +++ b/backend/utils/conversations/desktop_background.py @@ -11,10 +11,17 @@ from models.structured import Structured from models.transcript_segment import TranscriptSegment from utils.conversations.factory import deserialize_conversation +from utils.conversations.process_conversation import process_conversation logger = logging.getLogger(__name__) +class DesktopBackgroundConversationError(ValueError): + def __init__(self, message: str, status_code: int = 400): + super().__init__(message) + self.status_code = status_code + + def create_in_progress_desktop_conversation( uid: str, language: str, @@ -89,6 +96,33 @@ def append_segments_to_in_progress_conversation( return updated_segments +def finish_desktop_background_conversation(uid: str, conversation_id: str) -> Conversation: + """Finalize one explicit desktop background conversation by ID.""" + conversation_data = conversations_db.get_conversation(uid, conversation_id) + if not conversation_data: + raise DesktopBackgroundConversationError("conversation_id not found", status_code=404) + + conversation = deserialize_conversation(conversation_data) + if conversation.status == ConversationStatus.completed: + return conversation + if conversation.status != ConversationStatus.in_progress: + raise DesktopBackgroundConversationError("conversation is not in_progress", status_code=409) + + conversations_db.update_conversation_status(uid, conversation.id, ConversationStatus.processing) + processed_conversation = process_conversation(uid, conversation.language, conversation, force_process=True) + + if redis_db.get_in_progress_conversation_id(uid) == conversation.id: + redis_db.remove_in_progress_conversation_id(uid) + + logger.info( + "Finished desktop background conversation: %s uid=%s segments=%s", + conversation.id, + uid, + len(processed_conversation.transcript_segments), + ) + return processed_conversation + + def _detect_current_desktop_meeting(uid: str) -> Optional[str]: now = datetime.now(timezone.utc) time_window = timedelta(minutes=2) diff --git a/backend/utils/rate_limit_config.py b/backend/utils/rate_limit_config.py index 7c7533dd882..f3a587dedbe 100644 --- a/backend/utils/rate_limit_config.py +++ b/backend/utils/rate_limit_config.py @@ -47,6 +47,7 @@ "voice:transcribe_stream": (60, 3600), "voice:message": (60, 3600), "desktop:background_transcribe": (120, 3600), + "desktop:background_conversation_finish": (60, 3600), "file:upload": (40, 3600), # Agent/MCP — bursty tool calls "agent:execute_tool": (120, 3600), diff --git a/backend/utils/stt/assemblyai_adapter.py b/backend/utils/stt/assemblyai_adapter.py index f97cda069ac..9751b677112 100644 --- a/backend/utils/stt/assemblyai_adapter.py +++ b/backend/utils/stt/assemblyai_adapter.py @@ -47,10 +47,11 @@ def normalize_assemblyai_transcript_result( if not words and utterances: words = [word for utterance in utterances for word in (utterance.words or [])] + requested_language = None if language == 'multi' else language return ProviderTranscriptResult( provider=STTProviderName.assemblyai.value, model=result.get('speech_model_used') or result.get('speech_model') or model, - language=_normalize_language(result.get('language_code') or language), + language=_normalize_language(result.get('language_code') or requested_language), duration=_seconds_float(result.get('audio_duration')), words=words, utterances=utterances, diff --git a/backend/utils/stt/providers.py b/backend/utils/stt/providers.py index 5c83fb8533b..1d1af88e632 100644 --- a/backend/utils/stt/providers.py +++ b/backend/utils/stt/providers.py @@ -107,6 +107,8 @@ def get_fallback_prerecorded_provider_name( ) -> Optional[STTProviderName]: workload = STTWorkload(workload) provider = STTProviderName(provider) + if workload == STTWorkload.background and provider == STTProviderName.assemblyai: + return None fallback = _DEFAULT_PRERECORDED_WORKLOAD_PROVIDERS[workload] if provider != fallback: return fallback diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 5dec9b31530..1efc69c633d 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -1324,6 +1324,15 @@ extension APIClient { ) return response.conversation_id } + + /// Finish one explicit Python-backed desktop background batch conversation. + /// Endpoint: POST /v2/desktop/background-conversation/{conversation_id}/finish + func finishBackgroundConversation(conversationId: String) async throws -> ServerConversation { + try await post( + "v2/desktop/background-conversation/\(conversationId)/finish", + customBaseURL: nil + ) + } } // MARK: - Memories API diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index 4e8c56f54cc..d04685db6ef 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -15,13 +15,13 @@ struct SegmentTranslation: Identifiable { struct SpeakerSegment: Identifiable { /// Stable identity — uses backend segment ID when available, otherwise speaker + start time var id: String { segmentId ?? "\(speaker)-\(start)" } - var segmentId: String? // Backend-assigned UUID + var segmentId: String? // Backend-assigned UUID var speaker: Int var text: String var start: Double var end: Double var isUser: Bool = false - var personId: String? // Backend-assigned person ID from speaker identification + var personId: String? // Backend-assigned person ID from speaker identification var translations: [SegmentTranslation] = [] } @@ -170,10 +170,10 @@ class AppState: ObservableObject { func fetchTrialMetadata() { #if DEBUG - if let debugMode = UserDefaults.standard.string(forKey: "debug_trial_mode") { - applyDebugTrialMode(debugMode) - return - } + if let debugMode = UserDefaults.standard.string(forKey: "debug_trial_mode") { + applyDebugTrialMode(debugMode) + return + } #endif Task { @MainActor in @@ -192,58 +192,63 @@ class AppState: ObservableObject { } #if DEBUG - private func applyDebugTrialMode(_ mode: String) { - let now = Int(Date().timeIntervalSince1970) - let features = ["unlimited_listening", "unlimited_transcription", "unlimited_memories", "unlimited_insights", "30_chat_questions_per_month"] - let dur = 3 * 24 * 3600 - - func mock(remaining: Int, expired: Bool) -> TrialMetadataResponse { - TrialMetadataResponse( - trialStartedAt: now - (dur - remaining), trialEndsAt: now + remaining, - trialRemainingSeconds: remaining, trialExpired: expired, - trialDurationSeconds: dur, trialFeatures: features, planAfterTrial: "Free" - ) - } + private func applyDebugTrialMode(_ mode: String) { + let now = Int(Date().timeIntervalSince1970) + let features = [ + "unlimited_listening", "unlimited_transcription", "unlimited_memories", + "unlimited_insights", "30_chat_questions_per_month", + ] + let dur = 3 * 24 * 3600 + + func mock(remaining: Int, expired: Bool) -> TrialMetadataResponse { + TrialMetadataResponse( + trialStartedAt: now - (dur - remaining), trialEndsAt: now + remaining, + trialRemainingSeconds: remaining, trialExpired: expired, + trialDurationSeconds: dur, trialFeatures: features, planAfterTrial: "Free" + ) + } - switch mode { - case "active": - self.trialMetadata = mock(remaining: 2 * 24 * 3600 + 3600, expired: false) - case "warning": - self.trialMetadata = mock(remaining: 12 * 3600, expired: false) - case "expiring": - self.trialMetadata = mock(remaining: 1800, expired: false) - case "expired": - self.trialMetadata = mock(remaining: 0, expired: true) - case "realtime": - let endKey = "debug_trial_end_time" - let rtDur = 120 - var endTime = UserDefaults.standard.integer(forKey: endKey) - if endTime == 0 { - endTime = now + rtDur - UserDefaults.standard.set(endTime, forKey: endKey) - } - let remaining = max(0, endTime - now) - self.trialMetadata = TrialMetadataResponse( - trialStartedAt: endTime - rtDur, trialEndsAt: endTime, - trialRemainingSeconds: remaining, trialExpired: remaining == 0, - trialDurationSeconds: rtDur, trialFeatures: features, planAfterTrial: "Free" - ) - if remaining == 0 && !self.isPaywalled { self.isPaywalled = true } - default: - break + switch mode { + case "active": + self.trialMetadata = mock(remaining: 2 * 24 * 3600 + 3600, expired: false) + case "warning": + self.trialMetadata = mock(remaining: 12 * 3600, expired: false) + case "expiring": + self.trialMetadata = mock(remaining: 1800, expired: false) + case "expired": + self.trialMetadata = mock(remaining: 0, expired: true) + case "realtime": + let endKey = "debug_trial_end_time" + let rtDur = 120 + var endTime = UserDefaults.standard.integer(forKey: endKey) + if endTime == 0 { + endTime = now + rtDur + UserDefaults.standard.set(endTime, forKey: endKey) + } + let remaining = max(0, endTime - now) + self.trialMetadata = TrialMetadataResponse( + trialStartedAt: endTime - rtDur, trialEndsAt: endTime, + trialRemainingSeconds: remaining, trialExpired: remaining == 0, + trialDurationSeconds: rtDur, trialFeatures: features, planAfterTrial: "Free" + ) + if remaining == 0 && !self.isPaywalled { self.isPaywalled = true } + default: + break + } } - } #endif func startTrialMetadataRefresh() { trialRefreshTimer?.invalidate() fetchTrialMetadata() #if DEBUG - let interval: TimeInterval = UserDefaults.standard.string(forKey: "debug_trial_mode") == "realtime" ? 10 : 60 + let interval: TimeInterval = + UserDefaults.standard.string(forKey: "debug_trial_mode") == "realtime" ? 10 : 60 #else - let interval: TimeInterval = 60 + let interval: TimeInterval = 60 #endif - trialRefreshTimer = Timer.scheduledTimer(withTimeInterval: interval, repeats: true) { [weak self] _ in + trialRefreshTimer = Timer.scheduledTimer(withTimeInterval: interval, repeats: true) { + [weak self] _ in Task { @MainActor in self?.fetchTrialMetadata() } @@ -311,6 +316,8 @@ class AppState: ObservableObject { private var cloudBackgroundDrainTask: Task? private var cloudBackgroundSampleCursor = 0 private var isCloudBackgroundTranscription = false + private var isCloudBackgroundBackpressured = false + private var didLogCloudBackgroundBackpressure = false private var backgroundTranscriptMerger = BackgroundTranscriptMerger() private var speakerSegmentReducer = SpeakerSegmentReducer() @@ -915,73 +922,73 @@ class AppState: ObservableObject { DispatchQueue.main.async { [weak self] in guard let self else { return } UNUserNotificationCenter.current().getNotificationSettings { settings in - DispatchQueue.main.async { - let isNowGranted = settings.authorizationStatus == .authorized - self.hasNotificationPermission = isNowGranted - self.notificationAlertStyle = settings.alertStyle - - // Log the current notification settings - let authStatus = - switch settings.authorizationStatus { - case .notDetermined: "notDetermined" - case .denied: "denied" - case .authorized: "authorized" - case .provisional: "provisional" - case .ephemeral: "ephemeral" - @unknown default: "unknown" - } - let alertStyleName = - switch settings.alertStyle { - case .none: "NONE (no banners)" - case .banner: "BANNER" - case .alert: "ALERT" - @unknown default: "unknown" - } - log( - "Notification settings: auth=\(authStatus), alertStyle=\(alertStyleName), sound=\(settings.soundSetting.rawValue), badge=\(settings.badgeSetting.rawValue)" - ) - - // Track notification settings in analytics only when they change - let soundEnabled = settings.soundSetting == .enabled - let badgeEnabled = settings.badgeSetting == .enabled - let settingsChanged = - authStatus != self.lastNotificationAuthStatus - || alertStyleName != self.lastNotificationAlertStyle - || soundEnabled != self.lastNotificationSoundEnabled - || badgeEnabled != self.lastNotificationBadgeEnabled - - if settingsChanged { - AnalyticsManager.shared.notificationSettingsChecked( - authStatus: authStatus, - alertStyle: alertStyleName, - soundEnabled: soundEnabled, - badgeEnabled: badgeEnabled, - bannersDisabled: settings.alertStyle == .none + DispatchQueue.main.async { + let isNowGranted = settings.authorizationStatus == .authorized + self.hasNotificationPermission = isNowGranted + self.notificationAlertStyle = settings.alertStyle + + // Log the current notification settings + let authStatus = + switch settings.authorizationStatus { + case .notDetermined: "notDetermined" + case .denied: "denied" + case .authorized: "authorized" + case .provisional: "provisional" + case .ephemeral: "ephemeral" + @unknown default: "unknown" + } + let alertStyleName = + switch settings.alertStyle { + case .none: "NONE (no banners)" + case .banner: "BANNER" + case .alert: "ALERT" + @unknown default: "unknown" + } + log( + "Notification settings: auth=\(authStatus), alertStyle=\(alertStyleName), sound=\(settings.soundSetting.rawValue), badge=\(settings.badgeSetting.rawValue)" ) - // Detect regression: was authorized, now reverted to notDetermined - // This happens on macOS 26+ where the OS silently revokes notification permission - if self.lastNotificationAuthStatus == "authorized" && authStatus == "notDetermined" { - log( - "Notification permission REGRESSED from authorized to notDetermined — triggering auto-repair" + // Track notification settings in analytics only when they change + let soundEnabled = settings.soundSetting == .enabled + let badgeEnabled = settings.badgeSetting == .enabled + let settingsChanged = + authStatus != self.lastNotificationAuthStatus + || alertStyleName != self.lastNotificationAlertStyle + || soundEnabled != self.lastNotificationSoundEnabled + || badgeEnabled != self.lastNotificationBadgeEnabled + + if settingsChanged { + AnalyticsManager.shared.notificationSettingsChecked( + authStatus: authStatus, + alertStyle: alertStyleName, + soundEnabled: soundEnabled, + badgeEnabled: badgeEnabled, + bannersDisabled: settings.alertStyle == .none ) - AnalyticsManager.shared.notificationRepairTriggered( - reason: "auth_regression", - previousStatus: "authorized", - currentStatus: "notDetermined" - ) - self.repairNotificationRegistrationAndRetry() + + // Detect regression: was authorized, now reverted to notDetermined + // This happens on macOS 26+ where the OS silently revokes notification permission + if self.lastNotificationAuthStatus == "authorized" && authStatus == "notDetermined" { + log( + "Notification permission REGRESSED from authorized to notDetermined — triggering auto-repair" + ) + AnalyticsManager.shared.notificationRepairTriggered( + reason: "auth_regression", + previousStatus: "authorized", + currentStatus: "notDetermined" + ) + self.repairNotificationRegistrationAndRetry() + } + + // Update last known state + self.lastNotificationAuthStatus = authStatus + self.lastNotificationAlertStyle = alertStyleName + self.lastNotificationSoundEnabled = soundEnabled + self.lastNotificationBadgeEnabled = badgeEnabled } - // Update last known state - self.lastNotificationAuthStatus = authStatus - self.lastNotificationAlertStyle = alertStyleName - self.lastNotificationSoundEnabled = soundEnabled - self.lastNotificationBadgeEnabled = badgeEnabled } - } - } } // end DispatchQueue.main.async } @@ -1405,12 +1412,16 @@ class AppState: ObservableObject { audioSource: effectiveSource ) if routing == .cloudBatchAssembly { + let batchLanguage = AssistantSettings.shared.effectiveBatchTranscriptionLanguage + log( + "Transcription: Cloud background batch using language=\(batchLanguage) (autoDetect=\(AssistantSettings.shared.transcriptionAutoDetect), selected=\(AssistantSettings.shared.transcriptionLanguage))" + ) isStartingTranscription = true cloudBackgroundStartTask?.cancel() cloudBackgroundStartTask = Task { await self.startCloudBackgroundTranscription( source: effectiveSource, - language: effectiveLanguage + language: batchLanguage ) } return @@ -1565,11 +1576,14 @@ class AppState: ObservableObject { private func startCloudBackgroundTranscription(source: AudioSource, language: String) async { defer { isStartingTranscription = false } do { - let conversationId = try await APIClient.shared.startBackgroundConversation(language: language) + let conversationId = try await APIClient.shared.startBackgroundConversation( + language: language) guard !Task.isCancelled else { return } cloudBackgroundConversationId = conversationId cloudBackgroundSampleCursor = 0 isCloudBackgroundTranscription = true + isCloudBackgroundBackpressured = false + didLogCloudBackgroundBackpressure = false backgroundTranscriptMerger.reset() speakerSegmentReducer.reset() transcriptionService = nil @@ -1586,20 +1600,25 @@ class AppState: ObservableObject { audioMixer = AudioMixer() vadGateService = nil + // AssemblyAI batch requests use an explicit selected language and no keyword prompt. + // `multi` and keyterms have both produced provider-side 400s on low-speech chunks. let resolvedLanguage = language - cloudBackgroundSession = CloudBackgroundTranscriptionSession { chunk in + cloudBackgroundSession = CloudBackgroundTranscriptionSession( + configuration: .cloudBatch + ) { chunk in try await TranscriptionService.batchTranscribeSegments( audioData: chunk.pcmData, conversationId: conversationId, chunkStartMs: max(0, Int((chunk.startTime * 1000.0).rounded())), - language: resolvedLanguage, - contextKeywords: AssistantSettings.shared.effectiveVocabulary + language: resolvedLanguage ) } let systemAudioDisabled = UserDefaults.standard.bool(forKey: "disableSystemAudioCapture") if systemAudioDisabled { - log("Transcription: System audio capture DISABLED by user preference (disableSystemAudioCapture)") + log( + "Transcription: System audio capture DISABLED by user preference (disableSystemAudioCapture)" + ) } else if #available(macOS 14.4, *) { systemAudioCaptureService = SystemAudioCaptureService() log("Transcription: System audio capture initialized for cloud background batch") @@ -1622,7 +1641,9 @@ class AppState: ObservableObject { await startAudioCapture(source: source) await startCrashSafeTranscriptionSession(language: language) - maxRecordingTimer = Timer.scheduledTimer(withTimeInterval: maxRecordingDuration, repeats: false) { + maxRecordingTimer = Timer.scheduledTimer( + withTimeInterval: maxRecordingDuration, repeats: false + ) { [weak self] _ in Task { @MainActor in guard let self = self, self.isTranscribing else { return } @@ -1634,15 +1655,47 @@ class AppState: ObservableObject { AnalyticsManager.shared.transcriptionStarted() log("Transcription: Started cloud background batch conversation \(conversationId)") } catch { + stopAudioCapture() isCloudBackgroundTranscription = false cloudBackgroundSession = nil cloudBackgroundConversationId = nil + isCloudBackgroundBackpressured = false + didLogCloudBackgroundBackpressure = false isTranscribing = false + AssistantSettings.shared.transcriptionEnabled = false AnalyticsManager.shared.recordingError(error: error.localizedDescription) - showAlert(title: "Transcription Error", message: error.localizedDescription) + logError("Transcription: Cloud background batch failed to start", error: error) + showAlert(title: "Audio Recording Not Started", message: cloudBackgroundStartFailureMessage(for: error)) } } + private func cloudBackgroundStartFailureMessage(for error: Error) -> String { + if let urlError = error as? URLError { + switch urlError.code { + case .networkConnectionLost, .cannotConnectToHost, .notConnectedToInternet, .timedOut: + return + "Audio Recording could not start because the transcription backend connection was unavailable. Check the backend or network, then turn Audio Recording back on to retry." + default: + break + } + } + + if let apiError = error as? APIError { + switch apiError { + case .httpError(let statusCode): + return + "Audio Recording could not start because the transcription backend returned HTTP \(statusCode). Turn Audio Recording back on after the backend is healthy." + case .unauthorized: + return "Audio Recording could not start because your session needs to sign in again." + default: + break + } + } + + return + "Audio Recording could not start. Turn Audio Recording back on after the transcription backend is healthy. Details: \(error.localizedDescription)" + } + private func startCrashSafeTranscriptionSession(language: String) async { do { let sessionId = try await TranscriptionStorage.shared.startSession( @@ -1690,10 +1743,12 @@ class AppState: ObservableObject { // Start the mixer — it sums mic + system into a mono stream and forwards it to // the active transcription transport. audioMixer?.start { [weak self] monoMixed in - if self?.isCloudBackgroundTranscription == true { - self?.handleMixedBackgroundAudio(monoMixed) - } else { - self?.transcriptionService?.sendAudio(monoMixed) + Task { @MainActor in + if self?.isCloudBackgroundTranscription == true { + self?.handleMixedBackgroundAudio(monoMixed) + } else { + self?.transcriptionService?.sendAudio(monoMixed) + } } } @@ -1750,7 +1805,9 @@ class AppState: ObservableObject { silentMicFallbackInProgress = true guard let builtInID = AudioCaptureService.findBuiltInMicDeviceID() else { - log("Transcription: silent-mic detected but no built-in microphone available — leaving capture as-is") + log( + "Transcription: silent-mic detected but no built-in microphone available — leaving capture as-is" + ) silentMicFallbackInProgress = false return } @@ -1901,7 +1958,9 @@ class AppState: ObservableObject { // finalize the NEW conversation instead of the one we just stopped. // The retry service will reconcile the old session by timestamp matching. guard self.recordingGeneration == generationAtStop else { - log("Transcription: New recording started during delay, skipping force-process for session \(capturedSessionId.map(String.init) ?? "nil")") + log( + "Transcription: New recording started during delay, skipping force-process for session \(capturedSessionId.map(String.init) ?? "nil")" + ) return } @@ -1909,15 +1968,20 @@ class AppState: ObservableObject { if let conversation = try await APIClient.shared.forceProcessConversation() { // Validate the returned conversation matches the session we just stopped if let sessionId = capturedSessionId, let startTime = capturedStartTime, - let convStarted = conversation.startedAt, - abs(convStarted.timeIntervalSince(startTime)) < 10, - conversation.source == .desktop { + let convStarted = conversation.startedAt, + abs(convStarted.timeIntervalSince(startTime)) < 10, + conversation.source == .desktop + { try? await TranscriptionStorage.shared.markSessionCompleted( id: sessionId, backendId: conversation.id) - log("Transcription: Force-processed conversation \(conversation.id), session \(sessionId) completed") + log( + "Transcription: Force-processed conversation \(conversation.id), session \(sessionId) completed" + ) } else if let sessionId = capturedSessionId, let startTime = capturedStartTime { // Force-process returned a different conversation — fall back to reconciliation - log("Transcription: Force-processed conversation \(conversation.id) does not match session \(sessionId), reconciling by timestamp") + log( + "Transcription: Force-processed conversation \(conversation.id) does not match session \(sessionId), reconciling by timestamp" + ) await reconcileSession(sessionId: sessionId, startTime: startTime) } } else { @@ -1962,6 +2026,15 @@ class AppState: ObservableObject { let startTime = Double(cloudBackgroundSampleCursor) / 16000.0 cloudBackgroundSampleCursor += pcmData.count / 2 let result = session.append(pcmData: pcmData, startTime: startTime) + if result.isBackpressured { + isCloudBackgroundBackpressured = true + if !didLogCloudBackgroundBackpressure { + didLogCloudBackgroundBackpressure = true + log( + "Transcription: Cloud background ASR backpressure active (\(result.pendingChunkCount) pending chunks); dropping new audio until backlog drains" + ) + } + } if result.enqueuedChunks > 0 { drainCloudBackgroundASRQueue() } @@ -1978,9 +2051,14 @@ class AppState: ObservableObject { let merged = backgroundTranscriptMerger.merge(result.segments) _ = speakerSegmentReducer.apply(merged) handleBackendSegments(merged) + if !session.isBackpressured && isCloudBackgroundBackpressured { + isCloudBackgroundBackpressured = false + didLogCloudBackgroundBackpressure = false + log("Transcription: Cloud background ASR backlog drained; resuming normal ingest") + } } catch { logError("Transcription: Cloud background chunk failed", error: error) - break + continue } } } @@ -1995,7 +2073,9 @@ class AppState: ObservableObject { } try? await Task.sleep(nanoseconds: 250_000_000) } - log("Transcription: Cloud background backlog wait timed out with \(cloudBackgroundSession?.pendingChunkCount ?? 0) pending chunks") + log( + "Transcription: Cloud background backlog wait timed out with \(cloudBackgroundSession?.pendingChunkCount ?? 0) pending chunks" + ) } private func stopCloudBackgroundTranscription( @@ -2004,6 +2084,7 @@ class AppState: ObservableObject { generationAtStop: UInt64, forSettingsChange: Bool = false ) async { + let stoppedConversationId = cloudBackgroundConversationId stopAudioCapture() _ = cloudBackgroundSession?.finishInput() drainCloudBackgroundASRQueue() @@ -2018,7 +2099,8 @@ class AppState: ObservableObject { Task { try? await Task.sleep(nanoseconds: 3_000_000_000) guard recordingGeneration == generationAtStop else { return } - await forceProcessStoppedConversation( + await finishStoppedCloudBackgroundConversation( + conversationId: stoppedConversationId, capturedSessionId: capturedSessionId, capturedStartTime: capturedStartTime, logPrefix: "Cloud background batch" @@ -2030,11 +2112,14 @@ class AppState: ObservableObject { try? await Task.sleep(nanoseconds: 3_000_000_000) guard recordingGeneration == generationAtStop else { - log("Transcription: New recording started during cloud batch delay, skipping force-process for session \(capturedSessionId.map(String.init) ?? "nil")") + log( + "Transcription: New recording started during cloud batch delay, skipping force-process for session \(capturedSessionId.map(String.init) ?? "nil")" + ) return } - await forceProcessStoppedConversation( + await finishStoppedCloudBackgroundConversation( + conversationId: stoppedConversationId, capturedSessionId: capturedSessionId, capturedStartTime: capturedStartTime, logPrefix: "Cloud background batch" @@ -2051,6 +2136,8 @@ class AppState: ObservableObject { cloudBackgroundConversationId = nil cloudBackgroundSampleCursor = 0 isCloudBackgroundTranscription = false + isCloudBackgroundBackpressured = false + didLogCloudBackgroundBackpressure = false backgroundTranscriptMerger.reset() speakerSegmentReducer.reset() } @@ -2069,16 +2156,56 @@ class AppState: ObservableObject { { try? await TranscriptionStorage.shared.markSessionCompleted( id: sessionId, backendId: conversation.id) - log("Transcription: \(logPrefix) force-processed conversation \(conversation.id), session \(sessionId) completed") + log( + "Transcription: \(logPrefix) force-processed conversation \(conversation.id), session \(sessionId) completed" + ) } else if let sessionId = capturedSessionId, let startTime = capturedStartTime { - log("Transcription: \(logPrefix) force-process returned different conversation \(conversation.id), reconciling session \(sessionId)") + log( + "Transcription: \(logPrefix) force-process returned different conversation \(conversation.id), reconciling session \(sessionId)" + ) await reconcileSession(sessionId: sessionId, startTime: startTime) } } else if let sessionId = capturedSessionId, let startTime = capturedStartTime { await reconcileSession(sessionId: sessionId, startTime: startTime) } } catch { - logError("Transcription: \(logPrefix) force-process failed, retry service will reconcile", error: error) + logError( + "Transcription: \(logPrefix) force-process failed, retry service will reconcile", + error: error) + } + } + + private func finishStoppedCloudBackgroundConversation( + conversationId: String?, + capturedSessionId: Int64?, + capturedStartTime: Date?, + logPrefix: String + ) async { + guard let conversationId else { + log("Transcription: \(logPrefix) missing backend conversation ID, reconciling") + if let sessionId = capturedSessionId, let startTime = capturedStartTime { + await reconcileSession(sessionId: sessionId, startTime: startTime) + } + return + } + + do { + let conversation = try await APIClient.shared.finishBackgroundConversation( + conversationId: conversationId) + if let sessionId = capturedSessionId { + try? await TranscriptionStorage.shared.markSessionCompleted( + id: sessionId, backendId: conversation.id) + log( + "Transcription: \(logPrefix) finalized conversation \(conversation.id), session \(sessionId) completed" + ) + } + } catch { + logError( + "Transcription: \(logPrefix) explicit finalization failed, retry service will reconcile", + error: error) + if let sessionId = capturedSessionId, let startTime = capturedStartTime { + await reconcileSession(sessionId: sessionId, startTime: startTime) + } } } @@ -2102,7 +2229,9 @@ class AppState: ObservableObject { id: sessionId, backendId: match.id) log("Transcription: Reconciled session \(sessionId) → backend conversation \(match.id)") } else { - log("Transcription: No matching backend conversation found for session \(sessionId), leaving for retry") + log( + "Transcription: No matching backend conversation found for session \(sessionId), leaving for retry" + ) } } catch { logError("Transcription: Reconciliation failed for session \(sessionId)", error: error) @@ -2121,7 +2250,9 @@ class AppState: ObservableObject { return .discarded } - log("Transcription: Finishing conversation — disconnecting WebSocket to trigger backend processing") + log( + "Transcription: Finishing conversation — disconnecting WebSocket to trigger backend processing" + ) // Capture state before rotation — memory_created event for this conversation // may arrive on the new WebSocket after currentSessionId and recordingStartTime have changed. @@ -2236,6 +2367,7 @@ class AppState: ObservableObject { private func finishCloudBackgroundConversation() async -> FinishConversationResult { log("Transcription: Finishing cloud background batch conversation") + let finishedConversationId = cloudBackgroundConversationId let capturedSessionId = currentSessionId let capturedStartTime = recordingStartTime finishedSessionId = capturedSessionId @@ -2256,7 +2388,8 @@ class AppState: ObservableObject { } } - await forceProcessStoppedConversation( + await finishStoppedCloudBackgroundConversation( + conversationId: finishedConversationId, capturedSessionId: capturedSessionId, capturedStartTime: capturedStartTime, logPrefix: "Cloud background batch rotation" @@ -2277,7 +2410,8 @@ class AppState: ObservableObject { recordingStartTime = Date() RecordingTimer.shared.restart() maxRecordingTimer?.invalidate() - maxRecordingTimer = Timer.scheduledTimer(withTimeInterval: maxRecordingDuration, repeats: false) { + maxRecordingTimer = Timer.scheduledTimer(withTimeInterval: maxRecordingDuration, repeats: false) + { [weak self] _ in Task { @MainActor in guard let self = self, self.isTranscribing else { return } @@ -2287,16 +2421,18 @@ class AppState: ObservableObject { } do { - let language = AssistantSettings.shared.effectiveTranscriptionLanguage - let conversationId = try await APIClient.shared.startBackgroundConversation(language: language) + let language = AssistantSettings.shared.effectiveBatchTranscriptionLanguage + let conversationId = try await APIClient.shared.startBackgroundConversation( + language: language) cloudBackgroundConversationId = conversationId - cloudBackgroundSession = CloudBackgroundTranscriptionSession { chunk in + cloudBackgroundSession = CloudBackgroundTranscriptionSession( + configuration: .cloudBatch + ) { chunk in try await TranscriptionService.batchTranscribeSegments( audioData: chunk.pcmData, conversationId: conversationId, chunkStartMs: max(0, Int((chunk.startTime * 1000.0).rounded())), - language: language, - contextKeywords: AssistantSettings.shared.effectiveVocabulary + language: language ) } await startCrashSafeTranscriptionSession(language: language) @@ -2306,7 +2442,8 @@ class AppState: ObservableObject { log("Transcription: Ready for next cloud background batch conversation \(conversationId)") return finishedHadSegments ? .saved : .discarded } catch { - logError("Transcription: Failed to start next cloud background batch conversation", error: error) + logError( + "Transcription: Failed to start next cloud background batch conversation", error: error) return .error(error.localizedDescription) } } @@ -2813,7 +2950,7 @@ class AppState: ObservableObject { let idSet = Set(segmentIds) if let idx = conversations.firstIndex(where: { $0.id == conversationId }) { for segIdx in conversations[idx].transcriptSegments.indices - where idSet.contains(conversations[idx].transcriptSegments[segIdx].id) { + where idSet.contains(conversations[idx].transcriptSegments[segIdx].id) { let old = conversations[idx].transcriptSegments[segIdx] conversations[idx].transcriptSegments[segIdx] = TranscriptSegment( id: old.id, @@ -3099,7 +3236,9 @@ class AppState: ObservableObject { // Always persist to SQLite — even if the segment was trimmed from // the in-memory window, the event payload has all fields needed if let sessionId = currentSessionId { - let mapped = newTranslations.map { TranscriptTranslation(lang: $0.lang, text: $0.text) } + let mapped = newTranslations.map { + TranscriptTranslation(lang: $0.lang, text: $0.text) + } var translationsJson: String? if let jsonData = try? JSONEncoder().encode(mapped) { translationsJson = String(data: jsonData, encoding: .utf8) diff --git a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift index e4e2052f429..0a881fc305f 100644 --- a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift +++ b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift @@ -11,7 +11,8 @@ struct BackgroundAudioChunker { private var buffer = Data() private var bufferStartTime: Double? - init(configuration: BackgroundTranscriptionConfiguration = BackgroundTranscriptionConfiguration()) { + init(configuration: BackgroundTranscriptionConfiguration = BackgroundTranscriptionConfiguration()) + { self.configuration = configuration } @@ -24,7 +25,7 @@ struct BackgroundAudioChunker { buffer.append(pcmData) var chunks: [BackgroundAudioChunk] = [] - while let chunk = nextChunk(isFinal: false) { + while chunks.count < configuration.maxChunksPerAppend, let chunk = nextChunk(isFinal: false) { chunks.append(chunk) } return chunks @@ -32,14 +33,22 @@ struct BackgroundAudioChunker { mutating func finishInput() -> [BackgroundAudioChunk] { guard !buffer.isEmpty else { return [] } + let maxBytes = configuration.alignedByteCount(for: configuration.maxChunkDuration) + var chunks: [BackgroundAudioChunk] = [] + + while buffer.count > maxBytes, let chunk = nextChunk(isFinal: false) { + chunks.append(chunk) + } + let chunk = BackgroundAudioChunk( pcmData: buffer, startTime: bufferStartTime ?? 0, isFinal: true ) + chunks.append(chunk) buffer.removeAll(keepingCapacity: false) bufferStartTime = nil - return [chunk] + return chunks } private mutating func nextChunk(isFinal: Bool) -> BackgroundAudioChunk? { @@ -48,7 +57,13 @@ struct BackgroundAudioChunker { guard buffer.count >= minBytes else { return nil } let cutBytes: Int? - if let silenceCut = firstSilenceCut(minBytes: minBytes, maxBytes: min(buffer.count, maxBytes)) { + let overlapBytesAtMaxCut = effectiveOverlapBytes(forCutBytes: min(buffer.count, maxBytes)) + let minimumProgressCutBytes = overlapBytesAtMaxCut + configuration.bytesPerSample + + if let silenceCut = firstSilenceCut( + minBytes: max(minBytes, minimumProgressCutBytes), + maxBytes: min(buffer.count, maxBytes) + ) { cutBytes = silenceCut } else if buffer.count >= maxBytes { cutBytes = maxBytes @@ -56,7 +71,9 @@ struct BackgroundAudioChunker { cutBytes = nil } - guard let cutBytes, cutBytes > 0 else { return nil } + guard let cutBytes, cutBytes - effectiveOverlapBytes(forCutBytes: cutBytes) > 0 else { + return nil + } return cut(at: cutBytes, isFinal: isFinal) } @@ -69,16 +86,26 @@ struct BackgroundAudioChunker { isFinal: isFinal ) - let overlapBytes = min(configuration.alignedByteCount(for: configuration.overlapDuration), cutBytes) + let overlapBytes = effectiveOverlapBytes(forCutBytes: cutBytes) let retainedStart = max(0, cutBytes - overlapBytes) let retained = buffer.suffix(buffer.count - retainedStart) buffer = Data(retained) - bufferStartTime = startTime + Double(retainedStart / configuration.bytesPerSample) / Double(configuration.sampleRate) + bufferStartTime = + startTime + Double(retainedStart / configuration.bytesPerSample) + / Double(configuration.sampleRate) return chunk } + private func effectiveOverlapBytes(forCutBytes cutBytes: Int) -> Int { + let requestedOverlap = configuration.alignedByteCount(for: configuration.overlapDuration) + let maxProgressSafeOverlap = max(0, cutBytes - configuration.bytesPerSample) + return min(requestedOverlap, maxProgressSafeOverlap).alignedToSample + } + private func firstSilenceCut(minBytes: Int, maxBytes: Int) -> Int? { - let windowBytes = max(configuration.bytesPerSample, configuration.alignedByteCount(for: configuration.silenceWindowDuration)) + let windowBytes = max( + configuration.bytesPerSample, + configuration.alignedByteCount(for: configuration.silenceWindowDuration)) guard maxBytes >= minBytes + windowBytes else { return nil } var offset = minBytes.alignedToSample @@ -109,7 +136,9 @@ struct BackgroundAudioChunker { var sumSquares = 0.0 var count = 0 - for offset in stride(from: 0, to: min(endOffset, buffer.count), by: configuration.bytesPerSample) { + for offset in stride( + from: 0, to: min(endOffset, buffer.count), by: configuration.bytesPerSample) + { let amplitude = sampleAmplitude(at: offset) peak = max(peak, amplitude) sumSquares += Double(amplitude * amplitude) diff --git a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptMerger.swift b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptMerger.swift index 9252e322000..effac6ee451 100644 --- a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptMerger.swift +++ b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptMerger.swift @@ -15,9 +15,12 @@ struct BackgroundTranscriptMerger { mutating func merge(_ incomingSegments: [TranscriptionService.BackendSegment]) -> [TranscriptionService.BackendSegment] { + var changedSegments: [TranscriptionService.BackendSegment] = [] for incoming in incomingSegments where !incoming.text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { - upsert(incoming) + if let changed = upsert(incoming) { + changedSegments.append(changed) + } } segments.sort { lhs, rhs in if lhs.start == rhs.start { @@ -25,28 +28,37 @@ struct BackgroundTranscriptMerger { } return lhs.start < rhs.start } - return segments + return changedSegments } - private mutating func upsert(_ incoming: TranscriptionService.BackendSegment) { + private mutating func upsert(_ incoming: TranscriptionService.BackendSegment) + -> TranscriptionService.BackendSegment? + { if let segmentId = incoming.id, let index = segments.firstIndex(where: { $0.id == segmentId }) { - segments[index] = preferredSegment(existing: segments[index], incoming: incoming) - return + let preferred = preferredSegment(existing: segments[index], incoming: incoming) + guard !sameSegment(preferred, segments[index]) else { return nil } + segments[index] = preferred + return preferred } if let index = segments.firstIndex(where: { isDuplicate($0, incoming) }) { - segments[index] = preferredSegment(existing: segments[index], incoming: incoming) - return + let preferred = preferredSegment(existing: segments[index], incoming: incoming) + guard !sameSegment(preferred, segments[index]) else { return nil } + segments[index] = preferred + return preferred } if let index = segments.firstIndex(where: { canMergeOverlap($0, incoming) }) { - segments[index] = mergedOverlap(segments[index], incoming) - return + let merged = mergedOverlap(segments[index], incoming) + guard !sameSegment(merged, segments[index]) else { return nil } + segments[index] = merged + return merged } segments.append(incoming) + return incoming } private func isDuplicate( @@ -55,7 +67,8 @@ struct BackgroundTranscriptMerger { ) -> Bool { guard normalizedText(existing.text) == normalizedText(incoming.text) else { return false } let intersection = max(0, min(existing.end, incoming.end) - max(existing.start, incoming.start)) - let shorterDuration = max(0.001, min(existing.end - existing.start, incoming.end - incoming.start)) + let shorterDuration = max( + 0.001, min(existing.end - existing.start, incoming.end - incoming.start)) return intersection / shorterDuration >= duplicateOverlapThreshold } @@ -96,7 +109,8 @@ struct BackgroundTranscriptMerger { provider_cluster_id: first.provider_cluster_id ?? second.provider_cluster_id, provider_speaker_label: first.provider_speaker_label ?? second.provider_speaker_label, speaker_identity_state: first.speaker_identity_state ?? second.speaker_identity_state, - speaker_identity_confidence: first.speaker_identity_confidence ?? second.speaker_identity_confidence, + speaker_identity_confidence: first.speaker_identity_confidence + ?? second.speaker_identity_confidence, speaker_identity_source: first.speaker_identity_source ?? second.speaker_identity_source, speaker_identity_version: first.speaker_identity_version ?? second.speaker_identity_version ) @@ -120,6 +134,28 @@ struct BackgroundTranscriptMerger { return existing } + private func sameSegment( + _ lhs: TranscriptionService.BackendSegment, + _ rhs: TranscriptionService.BackendSegment + ) -> Bool { + lhs.id == rhs.id + && lhs.text == rhs.text + && lhs.speaker == rhs.speaker + && lhs.speaker_id == rhs.speaker_id + && lhs.is_user == rhs.is_user + && lhs.person_id == rhs.person_id + && lhs.start == rhs.start + && lhs.end == rhs.end + && lhs.stt_provider == rhs.stt_provider + && lhs.stt_model == rhs.stt_model + && lhs.provider_cluster_id == rhs.provider_cluster_id + && lhs.provider_speaker_label == rhs.provider_speaker_label + && lhs.speaker_identity_state == rhs.speaker_identity_state + && lhs.speaker_identity_confidence == rhs.speaker_identity_confidence + && lhs.speaker_identity_source == rhs.speaker_identity_source + && lhs.speaker_identity_version == rhs.speaker_identity_version + } + private func normalizedText(_ value: String) -> String { value.lowercased() .replacingOccurrences(of: #"\s+"#, with: " ", options: .regularExpression) diff --git a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift index 51127229812..3849aa6f516 100644 --- a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift +++ b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift @@ -4,12 +4,14 @@ struct BackgroundTranscriptionConfiguration: Equatable { var sampleRate: Int = 16000 var maxChunkDuration: TimeInterval = 15.0 var minChunkDuration: TimeInterval = 1.0 - var overlapDuration: TimeInterval = 1.0 + var overlapDuration: TimeInterval = 0.5 var silenceWindowDuration: TimeInterval = 0.35 var silenceAmplitudeThreshold: Int = 256 var speechPeakAmplitudeThreshold: Int = 512 var speechRMSAmplitudeThreshold: Int = 64 var maxPendingChunks: Int = 4 + var requiresSpeechBeforeUpload: Bool = false + var speechActivityDetection = SpeechActivityDetectionConfiguration() var bytesPerSample: Int { 2 } @@ -20,6 +22,27 @@ struct BackgroundTranscriptionConfiguration: Equatable { func alignedByteCount(for duration: TimeInterval) -> Int { byteCount(for: duration).alignedToSample } + + var maxChunksPerAppend: Int { + 1 + } + + static var cloudBatch: BackgroundTranscriptionConfiguration { + BackgroundTranscriptionConfiguration( + maxChunkDuration: 15.0, + minChunkDuration: 15.0, + overlapDuration: 0.5, + maxPendingChunks: 8, + requiresSpeechBeforeUpload: true, + speechActivityDetection: SpeechActivityDetectionConfiguration( + windowDuration: 0.02, + minimumSpeechDuration: 0.75, + peakAmplitudeThreshold: 900, + rmsAmplitudeThreshold: 180, + maximumSpeechZeroCrossingRate: 0.35 + ) + ) + } } extension Int { diff --git a/desktop/Desktop/Sources/BackgroundTranscription/CloudBackgroundTranscriptionSession.swift b/desktop/Desktop/Sources/BackgroundTranscription/CloudBackgroundTranscriptionSession.swift index b91f8579bd6..7332ffb8229 100644 --- a/desktop/Desktop/Sources/BackgroundTranscription/CloudBackgroundTranscriptionSession.swift +++ b/desktop/Desktop/Sources/BackgroundTranscription/CloudBackgroundTranscriptionSession.swift @@ -16,20 +16,26 @@ struct BackgroundTranscriptionResult { struct BackgroundTranscriptSnapshot { let pendingChunkCount: Int let processedChunkCount: Int + let droppedChunkCount: Int let isInputFinished: Bool let segments: [TranscriptionService.BackendSegment] + let lastSpeechActivityDecision: SpeechActivityDecision? } final class CloudBackgroundTranscriptionSession { - typealias TranscribeHandler = (BackgroundAudioChunk) async throws -> [TranscriptionService.BackendSegment] + typealias TranscribeHandler = (BackgroundAudioChunk) async throws -> [TranscriptionService + .BackendSegment] private let configuration: BackgroundTranscriptionConfiguration private let transcribe: TranscribeHandler + private let speechActivityDetector: SpeechActivityDetector private var chunker: BackgroundAudioChunker private var pendingChunks: [BackgroundAudioChunk] = [] private var processedChunkCount = 0 + private var droppedChunkCount = 0 private var isInputFinished = false private var processedSegments: [TranscriptionService.BackendSegment] = [] + private var lastSpeechActivityDecision: SpeechActivityDecision? init( configuration: BackgroundTranscriptionConfiguration = BackgroundTranscriptionConfiguration(), @@ -37,6 +43,11 @@ final class CloudBackgroundTranscriptionSession { ) { self.configuration = configuration self.transcribe = transcribe + self.speechActivityDetector = SpeechActivityDetector( + configuration: configuration.speechActivityDetection, + sampleRate: configuration.sampleRate, + bytesPerSample: configuration.bytesPerSample + ) self.chunker = BackgroundAudioChunker(configuration: configuration) } @@ -44,6 +55,10 @@ final class CloudBackgroundTranscriptionSession { pendingChunks.count } + var isBackpressured: Bool { + pendingChunks.count >= configuration.maxPendingChunks + } + func append(pcmData: Data, startTime: Double) -> BackgroundIngestResult { guard !isInputFinished else { return BackgroundIngestResult( @@ -51,18 +66,35 @@ final class CloudBackgroundTranscriptionSession { pendingChunkCount: pendingChunks.count, acceptedInputBytes: 0, didFinishInput: true, - isBackpressured: pendingChunks.count >= configuration.maxPendingChunks + isBackpressured: isBackpressured + ) + } + guard !isBackpressured else { + return BackgroundIngestResult( + enqueuedChunks: 0, + pendingChunkCount: pendingChunks.count, + acceptedInputBytes: 0, + didFinishInput: false, + isBackpressured: true ) } let chunks = chunker.append(pcmData: pcmData, startTime: startTime) - pendingChunks.append(contentsOf: chunks) + var enqueuedChunks = 0 + for chunk in chunks where pendingChunks.count < configuration.maxPendingChunks { + guard shouldUpload(chunk) else { + droppedChunkCount += 1 + continue + } + pendingChunks.append(chunk) + enqueuedChunks += 1 + } return BackgroundIngestResult( - enqueuedChunks: chunks.count, + enqueuedChunks: enqueuedChunks, pendingChunkCount: pendingChunks.count, acceptedInputBytes: pcmData.count, didFinishInput: false, - isBackpressured: pendingChunks.count >= configuration.maxPendingChunks + isBackpressured: isBackpressured ) } @@ -73,38 +105,58 @@ final class CloudBackgroundTranscriptionSession { pendingChunkCount: pendingChunks.count, acceptedInputBytes: 0, didFinishInput: true, - isBackpressured: pendingChunks.count >= configuration.maxPendingChunks + isBackpressured: isBackpressured ) } - let chunks = chunker.finishInput() - pendingChunks.append(contentsOf: chunks) + var enqueuedChunks = 0 + for chunk in chunker.finishInput() { + guard shouldUpload(chunk) else { + droppedChunkCount += 1 + continue + } + pendingChunks.append(chunk) + enqueuedChunks += 1 + } isInputFinished = true return BackgroundIngestResult( - enqueuedChunks: chunks.count, + enqueuedChunks: enqueuedChunks, pendingChunkCount: pendingChunks.count, acceptedInputBytes: 0, didFinishInput: true, - isBackpressured: pendingChunks.count >= configuration.maxPendingChunks + isBackpressured: isBackpressured ) } func transcribeNext() async throws -> BackgroundTranscriptionResult? { guard !pendingChunks.isEmpty else { return nil } - let chunk = pendingChunks[0] - let segments = try await transcribe(chunk) - pendingChunks.removeFirst() - processedChunkCount += 1 - processedSegments.append(contentsOf: segments) - return BackgroundTranscriptionResult(chunk: chunk, segments: segments) + let chunk = pendingChunks.removeFirst() + do { + let segments = try await transcribe(chunk) + processedChunkCount += 1 + processedSegments.append(contentsOf: segments) + return BackgroundTranscriptionResult(chunk: chunk, segments: segments) + } catch { + droppedChunkCount += 1 + throw error + } } func snapshot() -> BackgroundTranscriptSnapshot { BackgroundTranscriptSnapshot( pendingChunkCount: pendingChunks.count, processedChunkCount: processedChunkCount, + droppedChunkCount: droppedChunkCount, isInputFinished: isInputFinished, - segments: processedSegments + segments: processedSegments, + lastSpeechActivityDecision: lastSpeechActivityDecision ) } + + private func shouldUpload(_ chunk: BackgroundAudioChunk) -> Bool { + guard configuration.requiresSpeechBeforeUpload else { return true } + let decision = speechActivityDetector.evaluate(pcmData: chunk.pcmData) + lastSpeechActivityDecision = decision + return decision.shouldUpload + } } diff --git a/desktop/Desktop/Sources/BackgroundTranscription/SpeechActivityDetector.swift b/desktop/Desktop/Sources/BackgroundTranscription/SpeechActivityDetector.swift new file mode 100644 index 00000000000..4908c07ca5e --- /dev/null +++ b/desktop/Desktop/Sources/BackgroundTranscription/SpeechActivityDetector.swift @@ -0,0 +1,183 @@ +import Foundation + +struct SpeechActivityDetectionConfiguration: Equatable { + var windowDuration: TimeInterval = 0.02 + var minimumSpeechDuration: TimeInterval = 0.25 + var peakAmplitudeThreshold: Int = 512 + var rmsAmplitudeThreshold: Int = 64 + var maximumSpeechZeroCrossingRate: Double = 0.35 +} + +struct SpeechActivityDecision: Equatable { + enum Reason: String, Equatable { + case speechDetected + case emptyAudio + case insufficientSpeech + case energeticNonSpeech + } + + let shouldUpload: Bool + let reason: Reason + let totalWindows: Int + let energeticWindows: Int + let speechLikeWindows: Int + let rejectedHighZeroCrossingWindows: Int + let maxPeakAmplitude: Int + let maxRMSAmplitude: Double +} + +struct SpeechActivityDetector { + private let configuration: SpeechActivityDetectionConfiguration + private let sampleRate: Int + private let bytesPerSample: Int + + init( + configuration: SpeechActivityDetectionConfiguration, + sampleRate: Int, + bytesPerSample: Int = 2 + ) { + self.configuration = configuration + self.sampleRate = sampleRate + self.bytesPerSample = bytesPerSample + } + + func evaluate(pcmData: Data) -> SpeechActivityDecision { + guard !pcmData.isEmpty else { + return SpeechActivityDecision( + shouldUpload: false, + reason: .emptyAudio, + totalWindows: 0, + energeticWindows: 0, + speechLikeWindows: 0, + rejectedHighZeroCrossingWindows: 0, + maxPeakAmplitude: 0, + maxRMSAmplitude: 0 + ) + } + + let windowBytes = max(bytesPerSample, alignedByteCount(for: configuration.windowDuration)) + let requiredSpeechWindows = max( + 1, + Int(ceil(configuration.minimumSpeechDuration / configuration.windowDuration)) + ) + + var totalWindows = 0 + var energeticWindows = 0 + var speechLikeWindows = 0 + var consecutiveSpeechLikeWindows = 0 + var rejectedHighZeroCrossingWindows = 0 + var maxPeakAmplitude = 0 + var maxRMSAmplitude = 0.0 + + var offset = 0 + while offset < pcmData.count { + let endOffset = min(offset + windowBytes, pcmData.count) + let window = analyzeWindow(pcmData, start: offset, end: endOffset) + totalWindows += 1 + maxPeakAmplitude = max(maxPeakAmplitude, window.peakAmplitude) + maxRMSAmplitude = max(maxRMSAmplitude, window.rmsAmplitude) + + if window.isEnergetic( + peakThreshold: configuration.peakAmplitudeThreshold, + rmsThreshold: configuration.rmsAmplitudeThreshold + ) { + energeticWindows += 1 + if window.zeroCrossingRate <= configuration.maximumSpeechZeroCrossingRate { + speechLikeWindows += 1 + consecutiveSpeechLikeWindows += 1 + if consecutiveSpeechLikeWindows >= requiredSpeechWindows { + return SpeechActivityDecision( + shouldUpload: true, + reason: .speechDetected, + totalWindows: totalWindows, + energeticWindows: energeticWindows, + speechLikeWindows: speechLikeWindows, + rejectedHighZeroCrossingWindows: rejectedHighZeroCrossingWindows, + maxPeakAmplitude: maxPeakAmplitude, + maxRMSAmplitude: maxRMSAmplitude + ) + } + } else { + consecutiveSpeechLikeWindows = 0 + rejectedHighZeroCrossingWindows += 1 + } + } else { + consecutiveSpeechLikeWindows = 0 + } + + offset += windowBytes + } + + let reason: SpeechActivityDecision.Reason = + energeticWindows > 0 && rejectedHighZeroCrossingWindows >= energeticWindows + ? .energeticNonSpeech + : .insufficientSpeech + + return SpeechActivityDecision( + shouldUpload: false, + reason: reason, + totalWindows: totalWindows, + energeticWindows: energeticWindows, + speechLikeWindows: speechLikeWindows, + rejectedHighZeroCrossingWindows: rejectedHighZeroCrossingWindows, + maxPeakAmplitude: maxPeakAmplitude, + maxRMSAmplitude: maxRMSAmplitude + ) + } + + private func analyzeWindow(_ pcmData: Data, start: Int, end: Int) -> WindowActivity { + var peak = 0 + var sumSquares = 0.0 + var sampleCount = 0 + var zeroCrossings = 0 + var previousSample: Int16? + + for offset in stride(from: start, to: end, by: bytesPerSample) { + guard offset + 1 < pcmData.count else { break } + let sample = sampleValue(in: pcmData, at: offset) + let amplitude = abs(Int(sample)) + peak = max(peak, amplitude) + sumSquares += Double(amplitude * amplitude) + if let previousSample, crossesZero(previousSample, sample) { + zeroCrossings += 1 + } + previousSample = sample + sampleCount += 1 + } + + guard sampleCount > 0 else { + return WindowActivity(peakAmplitude: 0, rmsAmplitude: 0, zeroCrossingRate: 0) + } + let rms = sqrt(sumSquares / Double(sampleCount)) + let denominator = max(1, sampleCount - 1) + return WindowActivity( + peakAmplitude: peak, + rmsAmplitude: rms, + zeroCrossingRate: Double(zeroCrossings) / Double(denominator) + ) + } + + private func alignedByteCount(for duration: TimeInterval) -> Int { + (Int(duration * Double(sampleRate)) * bytesPerSample).alignedToSample + } + + private func sampleValue(in pcmData: Data, at offset: Int) -> Int16 { + let low = UInt16(pcmData[offset]) + let high = UInt16(pcmData[offset + 1]) << 8 + return Int16(bitPattern: low | high) + } + + private func crossesZero(_ lhs: Int16, _ rhs: Int16) -> Bool { + (lhs < 0 && rhs > 0) || (lhs > 0 && rhs < 0) + } +} + +private struct WindowActivity { + let peakAmplitude: Int + let rmsAmplitude: Double + let zeroCrossingRate: Double + + func isEnergetic(peakThreshold: Int, rmsThreshold: Int) -> Bool { + peakAmplitude >= peakThreshold || rmsAmplitude >= Double(rmsThreshold) + } +} diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index eab9773b534..8c1fdf32f34 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -1253,7 +1253,7 @@ struct SettingsContentView: View { .foregroundColor(OmiColors.textPrimary) Text( - "Transcribe microphone audio in chunks instead of live streaming. Requires server-side AssemblyAI." + "Transcribe microphone audio in selected-language chunks instead of live streaming. Requires server-side AssemblyAI." ) .scaledFont(size: 13) .foregroundColor(OmiColors.textTertiary) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift b/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift index 38d8bf229fc..87343968d98 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift @@ -156,6 +156,12 @@ class AssistantSettings { return transcriptionLanguage } + /// AssemblyAI prerecorded language detection is brittle on short or low-speech chunks. + /// Cloud batch uses the selected single language even when streaming auto-detect is enabled. + var effectiveBatchTranscriptionLanguage: String { + return transcriptionLanguage + } + /// Custom vocabulary for improved transcription accuracy /// Array of words/terms that DeepGram should recognize (Nova-3 limit: 500 tokens total) var transcriptionVocabulary: [String] { @@ -183,20 +189,20 @@ class AssistantSettings { } } - /// Whether batch transcription mode is enabled (transcribes audio in chunks at silence boundaries) - var batchTranscriptionEnabled: Bool { - get { UserDefaults.standard.bool(forKey: batchTranscriptionEnabledKey) } + /// Whether local VAD gate is enabled to skip silence and reduce Deepgram usage + var vadGateEnabled: Bool { + get { UserDefaults.standard.bool(forKey: vadGateEnabledKey) } set { - UserDefaults.standard.set(newValue, forKey: batchTranscriptionEnabledKey) + UserDefaults.standard.set(newValue, forKey: vadGateEnabledKey) NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) } } - /// Whether local VAD gate is enabled to skip silence and reduce Deepgram usage - var vadGateEnabled: Bool { - get { UserDefaults.standard.bool(forKey: vadGateEnabledKey) } + /// Whether cloud batch transcription mode is enabled for microphone background audio. + var batchTranscriptionEnabled: Bool { + get { UserDefaults.standard.bool(forKey: batchTranscriptionEnabledKey) } set { - UserDefaults.standard.set(newValue, forKey: vadGateEnabledKey) + UserDefaults.standard.set(newValue, forKey: batchTranscriptionEnabledKey) NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) } } diff --git a/desktop/Desktop/Sources/TranscriptionService.swift b/desktop/Desktop/Sources/TranscriptionService.swift index a72417d4f67..a739918855d 100644 --- a/desktop/Desktop/Sources/TranscriptionService.swift +++ b/desktop/Desktop/Sources/TranscriptionService.swift @@ -113,7 +113,10 @@ class TranscriptionService { /// (Cloud Run), which does not have /v2/voice-message/* or /v4/listen endpoints. private static let pythonBackendBaseURL: String = DesktopBackendEnvironment.pythonBaseURL() - private static func sanitizedContextKeywords(_ keywords: [String]) -> [String] { + private static func sanitizedContextKeywords( + _ keywords: [String], + includeDefaultOmi: Bool = true + ) -> [String] { let stopWords: Set = [ "about", "after", "again", "all", "also", "and", "app", "are", "ask", "back", "browser", "but", "can", "chat", "code", "done", "each", "for", "from", "get", "has", "have", "help", "here", "home", "how", @@ -123,7 +126,8 @@ class TranscriptionService { ] var seen = Set() var result: [String] = [] - for keyword in ["Omi", "OMI"] + keywords { + let sourceKeywords = (includeDefaultOmi ? ["Omi", "OMI"] : []) + keywords + for keyword in sourceKeywords { let normalized = keyword .replacingOccurrences(of: #"\s+"#, with: " ", options: .regularExpression) .trimmingCharacters(in: .whitespacesAndNewlines) @@ -659,7 +663,7 @@ extension TranscriptionService { guard var components = URLComponents(string: baseURLString) else { throw TranscriptionError.connectionFailed(NSError(domain: "Invalid backend URL", code: -1)) } - let sanitizedKeywords = sanitizedContextKeywords(contextKeywords) + let sanitizedKeywords = sanitizedContextKeywords(contextKeywords, includeDefaultOmi: false) var queryItems = [ URLQueryItem(name: "conversation_id", value: conversationId), URLQueryItem(name: "chunk_start_ms", value: String(chunkStartMs)), diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index 788c9cd0b4e..ecab2929f6d 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -299,6 +299,14 @@ final class APIClientRoutingTests: XCTestCase { label: "deleteConversation") } + func testFinishBackgroundConversationRoutesToExplicitPythonConversation() async { + let client = await makeTestClient() + _ = try? await client.finishBackgroundConversation(conversationId: "batch-123") as ServerConversation + assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v2/desktop/background-conversation/batch-123/finish", method: "POST", + label: "finishBackgroundConversation") + } + // -- Conversations: manual URL(string: baseURL + ...) paths (PATCH → Python) -- func testSetConversationStarredRoutesToPython() async { diff --git a/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift index 86e3daf7323..a1c83a2dbf9 100644 --- a/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift +++ b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift @@ -1,4 +1,5 @@ import XCTest + @testable import Omi_Computer final class BackgroundTranscriptionTests: XCTestCase { @@ -17,7 +18,8 @@ final class BackgroundTranscriptionTests: XCTestCase { ) ) - let chunks = chunker.append(pcmData: pcm(samples: Array(repeating: 1_000, count: 10) + [0, 0, 0]), startTime: 0) + let chunks = chunker.append( + pcmData: pcm(samples: Array(repeating: 1_000, count: 10) + [0, 0, 0]), startTime: 0) XCTAssertEqual(chunks.count, 1) XCTAssertEqual(chunks[0].startTime, 0) @@ -46,7 +48,8 @@ final class BackgroundTranscriptionTests: XCTestCase { ) ) - let chunks = chunker.append(pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 3) + let chunks = chunker.append( + pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 3) XCTAssertEqual(chunks.count, 1) XCTAssertEqual(chunks[0].startTime, 3) @@ -54,7 +57,7 @@ final class BackgroundTranscriptionTests: XCTestCase { XCTAssertEqual(sampleCount(chunker.finishInput()[0].pcmData), 4) } - func testSessionQueuesChunksAndSignalsBackpressureWithoutDropping() async throws { + func testSessionBackpressuresWhenPendingQueueIsFull() async throws { let configuration = BackgroundTranscriptionConfiguration( sampleRate: 10, maxChunkDuration: 1.0, @@ -69,30 +72,56 @@ final class BackgroundTranscriptionTests: XCTestCase { var transcribedStarts: [Double] = [] let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in transcribedStarts.append(chunk.startTime) - return [Self.backendSegment(id: "chunk-\(chunk.startTime)", text: "hello", start: chunk.startTime, end: chunk.startTime + 1)] + return [ + Self.backendSegment( + id: "chunk-\(chunk.startTime)", text: "hello", start: chunk.startTime, + end: chunk.startTime + 1) + ] } - let result = session.append(pcmData: pcm(samples: Array(repeating: 1_000, count: 22)), startTime: 0) + let first = session.append( + pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 0) + let second = session.append( + pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 1.2) + + XCTAssertEqual(first.enqueuedChunks, 1) + XCTAssertEqual(first.pendingChunkCount, 1) + XCTAssertTrue(first.isBackpressured) + XCTAssertEqual(second.enqueuedChunks, 0) + XCTAssertEqual(second.pendingChunkCount, 1) + XCTAssertEqual(second.acceptedInputBytes, 0) + XCTAssertTrue(second.isBackpressured) + XCTAssertEqual(session.pendingChunkCount, 1) + XCTAssertTrue(transcribedStarts.isEmpty) + } - XCTAssertEqual(result.enqueuedChunks, 2) - XCTAssertEqual(result.pendingChunkCount, 2) - XCTAssertTrue(result.isBackpressured) - XCTAssertEqual(session.pendingChunkCount, 2) + func testChunkerDoesNotLoopWhenOverlapEqualsMinimumCut() { + var chunker = BackgroundAudioChunker( + configuration: BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 3.0, + minChunkDuration: 1.0, + overlapDuration: 1.0, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 4 + ) + ) - let first = try await session.transcribeNext() - let second = try await session.transcribeNext() + let chunks = chunker.append( + pcmData: pcm(samples: Array(repeating: 1_000, count: 10) + [0, 0, 0]), startTime: 0) - XCTAssertEqual(first?.chunk.startTime, 0) - XCTAssertEqual(second?.chunk.startTime, 1) - XCTAssertEqual(transcribedStarts, [0, 1]) - let empty = try await session.transcribeNext() - XCTAssertNil(empty) - XCTAssertEqual(session.snapshot().processedChunkCount, 2) - XCTAssertEqual(session.snapshot().segments.count, 2) + XCTAssertEqual(chunks.count, 1) + XCTAssertEqual(sampleCount(chunks[0].pcmData), 11) + let final = chunker.finishInput() + XCTAssertEqual(final.count, 1) + XCTAssertEqual(sampleCount(final[0].pcmData), 12) } func testFifteenSecondContinuousSpeechProducesChunkThroughSession() async throws { - let configuration = BackgroundTranscriptionConfiguration() + let configuration = BackgroundTranscriptionConfiguration.cloudBatch var transcribedStarts: [Double] = [] let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in transcribedStarts.append(chunk.startTime) @@ -115,17 +144,171 @@ final class BackgroundTranscriptionTests: XCTestCase { enqueuedChunks += result.enqueuedChunks } - XCTAssertGreaterThanOrEqual(enqueuedChunks, 1) - XCTAssertGreaterThanOrEqual(session.pendingChunkCount, 1) + XCTAssertEqual(enqueuedChunks, 1) + XCTAssertEqual(session.pendingChunkCount, 1) let first = try await session.transcribeNext() XCTAssertNotNil(first) - XCTAssertEqual(first?.chunk.startTime, 0, accuracy: 0.001) + XCTAssertEqual(first!.chunk.startTime, 0, accuracy: 0.001) + XCTAssertEqual(sampleCount(first!.chunk.pcmData), configuration.sampleRate * 15) XCTAssertEqual(transcribedStarts, [0]) XCTAssertEqual(session.snapshot().processedChunkCount, 1) XCTAssertEqual(session.snapshot().segments.count, 1) } + func testCloudBatchDropsSilentChunksBeforeUpload() async throws { + let configuration = BackgroundTranscriptionConfiguration.cloudBatch + var uploadCount = 0 + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + uploadCount += 1 + return [ + Self.backendSegment( + id: "chunk-\(chunk.startTime)", + text: "should not upload", + start: chunk.startTime, + end: chunk.startTime + 1 + ) + ] + } + + let samplesPerFrame = configuration.sampleRate / 10 + let frame = pcm(samples: Array(repeating: 0, count: samplesPerFrame)) + var enqueuedChunks = 0 + + for frameIndex in 0..<160 { + let result = session.append(pcmData: frame, startTime: Double(frameIndex) / 10.0) + enqueuedChunks += result.enqueuedChunks + } + + XCTAssertEqual(enqueuedChunks, 0) + XCTAssertEqual(session.pendingChunkCount, 0) + let uploaded = try await session.transcribeNext() + XCTAssertNil(uploaded) + XCTAssertEqual(uploadCount, 0) + XCTAssertEqual(session.snapshot().droppedChunkCount, 1) + XCTAssertEqual(session.snapshot().lastSpeechActivityDecision?.reason, .insufficientSpeech) + } + + func testCloudBatchUploadsChunkWithMinimumSpeechEnergy() async throws { + let configuration = BackgroundTranscriptionConfiguration.cloudBatch + var uploadCount = 0 + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + uploadCount += 1 + return [ + Self.backendSegment( + id: "chunk-\(chunk.startTime)", + text: "speech", + start: chunk.startTime, + end: chunk.startTime + 1 + ) + ] + } + + let speechSamples = Array( + repeating: Int16(1_000), count: Int(Double(configuration.sampleRate) * 1.0)) + let silenceSamples = Array( + repeating: Int16(0), + count: configuration.sampleRate * 16 - speechSamples.count + ) + let result = session.append( + pcmData: pcm(samples: speechSamples + silenceSamples), + startTime: 0 + ) + + XCTAssertEqual(result.enqueuedChunks, 1) + XCTAssertEqual(session.pendingChunkCount, 1) + + let uploaded = try await session.transcribeNext() + XCTAssertEqual(uploaded?.chunk.startTime, 0) + XCTAssertEqual(uploadCount, 1) + XCTAssertEqual(session.snapshot().lastSpeechActivityDecision?.reason, .speechDetected) + } + + func testCloudBatchRejectsEnergeticNonSpeechNoiseBeforeUpload() async throws { + let configuration = BackgroundTranscriptionConfiguration.cloudBatch + var uploadCount = 0 + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + uploadCount += 1 + return [ + Self.backendSegment( + id: "chunk-\(chunk.startTime)", + text: "noise", + start: chunk.startTime, + end: chunk.startTime + 1 + ) + ] + } + + let samples = (0..<(configuration.sampleRate * 16)).map { index -> Int16 in + index.isMultiple(of: 2) ? 2_000 : -2_000 + } + let result = session.append(pcmData: pcm(samples: samples), startTime: 0) + + XCTAssertEqual(result.enqueuedChunks, 0) + XCTAssertEqual(session.pendingChunkCount, 0) + let uploaded = try await session.transcribeNext() + XCTAssertNil(uploaded) + XCTAssertEqual(uploadCount, 0) + XCTAssertEqual(session.snapshot().droppedChunkCount, 1) + XCTAssertEqual(session.snapshot().lastSpeechActivityDecision?.reason, .energeticNonSpeech) + XCTAssertGreaterThan( + session.snapshot().lastSpeechActivityDecision?.rejectedHighZeroCrossingWindows ?? 0, + 0 + ) + } + + func testCloudBatchFinishSplitsLongTailAtFifteenSecondWindows() { + let configuration = BackgroundTranscriptionConfiguration.cloudBatch + var chunker = BackgroundAudioChunker(configuration: configuration) + + let samples = Array(repeating: Int16(1_000), count: configuration.sampleRate * 31) + let chunks = chunker.append(pcmData: pcm(samples: samples), startTime: 0) + let final = chunker.finishInput() + + XCTAssertEqual(chunks.count, 1) + XCTAssertEqual(sampleCount(chunks[0].pcmData), configuration.sampleRate * 15) + XCTAssertEqual(final.count, 2) + XCTAssertEqual(sampleCount(final[0].pcmData), configuration.sampleRate * 15) + XCTAssertLessThanOrEqual(sampleCount(final[1].pcmData), configuration.sampleRate * 2) + XCTAssertTrue(final[1].isFinal) + } + + func testSessionFinishEnqueuesTailEvenWhenLiveQueueIsFull() async throws { + let configuration = BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 1.0, + minChunkDuration: 1.0, + overlapDuration: 0, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 1 + ) + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + [ + Self.backendSegment( + id: "chunk-\(chunk.startTime)", text: "chunk", start: chunk.startTime, + end: chunk.startTime + 1) + ] + } + + let append = session.append( + pcmData: pcm(samples: Array(repeating: 1_000, count: 16)), startTime: 0) + let finish = session.finishInput() + + XCTAssertEqual(append.enqueuedChunks, 1) + XCTAssertEqual(finish.enqueuedChunks, 1) + XCTAssertEqual(finish.pendingChunkCount, 2) + XCTAssertTrue(finish.isBackpressured) + + let first = try await session.transcribeNext() + let tail = try await session.transcribeNext() + XCTAssertEqual(first?.chunk.startTime, 0) + XCTAssertEqual(tail!.chunk.startTime, 1, accuracy: 0.001) + XCTAssertTrue(tail?.chunk.isFinal ?? false) + } + func testSessionFinishFlushesTail() async throws { let configuration = BackgroundTranscriptionConfiguration( sampleRate: 10, @@ -139,10 +322,15 @@ final class BackgroundTranscriptionTests: XCTestCase { maxPendingChunks: 4 ) let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in - [Self.backendSegment(id: "final", text: chunk.isFinal ? "final" : "not final", start: chunk.startTime, end: 1)] + [ + Self.backendSegment( + id: "final", text: chunk.isFinal ? "final" : "not final", start: chunk.startTime, end: 1) + ] } - XCTAssertEqual(session.append(pcmData: pcm(samples: [1_000, 1_000, 1_000]), startTime: 2).enqueuedChunks, 0) + XCTAssertEqual( + session.append(pcmData: pcm(samples: [1_000, 1_000, 1_000]), startTime: 2).enqueuedChunks, + 0) let finish = session.finishInput() XCTAssertTrue(finish.didFinishInput) @@ -152,7 +340,7 @@ final class BackgroundTranscriptionTests: XCTestCase { XCTAssertTrue(result?.chunk.isFinal ?? false) } - func testSessionRetainsChunkWhenTranscriptionFails() async throws { + func testSessionDropsFailedChunkSoDrainCanContinue() async throws { let configuration = BackgroundTranscriptionConfiguration( sampleRate: 10, maxChunkDuration: 1.0, @@ -164,16 +352,22 @@ final class BackgroundTranscriptionTests: XCTestCase { speechRMSAmplitudeThreshold: 20, maxPendingChunks: 4 ) - var shouldFail = true let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in - if shouldFail { - shouldFail = false + if chunk.startTime == 0 { throw NSError(domain: "test", code: 1) } - return [Self.backendSegment(id: "retry", text: "retried", start: chunk.startTime, end: chunk.startTime + 1)] + return [ + Self.backendSegment( + id: "retry", text: "retried", start: chunk.startTime, end: chunk.startTime + 1) + ] } - XCTAssertEqual(session.append(pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 0).enqueuedChunks, 1) + XCTAssertEqual( + session.append(pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 0) + .enqueuedChunks, 1) + XCTAssertEqual( + session.append(pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 1.2) + .enqueuedChunks, 1) do { _ = try await session.transcribeNext() @@ -182,22 +376,27 @@ final class BackgroundTranscriptionTests: XCTestCase { XCTAssertEqual(session.pendingChunkCount, 1) } - let retried = try await session.transcribeNext() - XCTAssertEqual(retried?.chunk.startTime, 0) + let next = try await session.transcribeNext() + XCTAssertEqual(next?.chunk.startTime, 1.0) XCTAssertEqual(session.pendingChunkCount, 0) XCTAssertEqual(session.snapshot().processedChunkCount, 1) + XCTAssertEqual(session.snapshot().droppedChunkCount, 1) } func testTranscriptMergerDeduplicatesAndMergesOverlap() { var merger = BackgroundTranscriptMerger() let first = Self.backendSegment(id: "a", text: "hello world", speakerId: 0, start: 0, end: 2) - let duplicate = Self.backendSegment(id: nil, text: "hello world", speakerId: 0, start: 0.5, end: 1.8) - let overlap = Self.backendSegment(id: "b", text: "world again", speakerId: 0, start: 1.5, end: 3) + let duplicate = Self.backendSegment( + id: nil, text: "hello world", speakerId: 0, start: 0.5, end: 1.8) + let overlap = Self.backendSegment( + id: "b", text: "world again", speakerId: 0, start: 1.5, end: 3) XCTAssertEqual(merger.merge([first]).count, 1) - XCTAssertEqual(merger.merge([duplicate]).count, 1) - let merged = merger.merge([overlap]) + XCTAssertEqual(merger.merge([duplicate]).count, 0) + let changed = merger.merge([overlap]) + let merged = merger.segments + XCTAssertEqual(changed.count, 1) XCTAssertEqual(merged.count, 1) XCTAssertEqual(merged[0].text, "hello world again") XCTAssertEqual(merged[0].start, 0) @@ -216,7 +415,8 @@ final class BackgroundTranscriptionTests: XCTestCase { ) _ = reducer.apply([original]) - let update = Self.backendSegment(id: "seg-1", text: "hello again", speakerId: 1, start: 0, end: 2) + let update = Self.backendSegment( + id: "seg-1", text: "hello again", speakerId: 1, start: 0, end: 2) let result = reducer.apply([update]) XCTAssertEqual(result.added, 0) @@ -226,23 +426,27 @@ final class BackgroundTranscriptionTests: XCTestCase { XCTAssertEqual(reducer.segments[0].translations.first?.text, "hola") } - func testRoutingGuardOnlyAllowsCloudBatchForEnabledMicrophone() { + func testRoutingGuardUsesCloudBatchForMicrophoneWhenServerEnabled() { let guardrail = BackgroundTranscriptionRoutingGuard() XCTAssertEqual( - guardrail.decide(batchEnabled: true, serverAssemblyBackgroundEnabled: true, audioSource: .microphone), + guardrail.decide( + batchEnabled: true, serverAssemblyBackgroundEnabled: true, audioSource: .microphone), .cloudBatchAssembly ) XCTAssertEqual( - guardrail.decide(batchEnabled: true, serverAssemblyBackgroundEnabled: true, audioSource: .bleDevice), + guardrail.decide( + batchEnabled: true, serverAssemblyBackgroundEnabled: true, audioSource: .bleDevice), .cloudListenStreaming(reason: "batch_microphone_only") ) XCTAssertEqual( - guardrail.decide(batchEnabled: false, serverAssemblyBackgroundEnabled: true, audioSource: .microphone), + guardrail.decide( + batchEnabled: false, serverAssemblyBackgroundEnabled: true, audioSource: .microphone), .cloudListenStreaming(reason: "batch_disabled") ) XCTAssertEqual( - guardrail.decide(batchEnabled: true, serverAssemblyBackgroundEnabled: false, audioSource: .microphone), + guardrail.decide( + batchEnabled: true, serverAssemblyBackgroundEnabled: false, audioSource: .microphone), .cloudListenStreaming(reason: "server_background_batch_disabled") ) } diff --git a/docs/doc/developer/backend/listen_pusher_pipeline.mdx b/docs/doc/developer/backend/listen_pusher_pipeline.mdx index 88a4ad89e0d..616dbe41146 100644 --- a/docs/doc/developer/backend/listen_pusher_pipeline.mdx +++ b/docs/doc/developer/backend/listen_pusher_pipeline.mdx @@ -238,9 +238,17 @@ Desktop Audio Recording can use the batch background path instead of `/v4/listen when the desktop batch setting is enabled and the client is pointed at a backend with AssemblyAI background support. The desktop client creates a stub conversation with `POST /v2/desktop/background-conversation/start`, sends raw -mono PCM chunks to `POST /v2/desktop/background-transcribe`, and still finalizes -with `POST /v1/conversations`. Mobile listen, BLE listen, and PTT continue to use -their existing Deepgram paths. +mono PCM chunks to `POST /v2/desktop/background-transcribe`, and finalizes that +explicit conversation with +`POST /v2/desktop/background-conversation/{conversation_id}/finish`. Mobile +listen, BLE listen, and PTT continue to use their existing Deepgram paths. + +For cloud batch, the desktop client buffers roughly 15s PCM chunks with 0.5s +overlap, drops chunks that do not contain sustained speech activity before +uploading, and uses backpressure instead of stopping recording when AssemblyAI +is slower than realtime. AssemblyAI batch requests use the selected single +language rather than `multi` to avoid provider-side language-detection failures +on low-speech audio. Eligible AssemblyAI workloads are `sync`, `background`, and `postprocess`. Latency-critical workloads (`realtime`, streaming `ptt`, and @@ -268,12 +276,14 @@ sequenceDiagram Deepgram-->>ProviderService: Transcript with words and speaker labels end - alt AssemblyAI fails, times out, or exhausts retries + alt AssemblyAI fails, times out, or exhausts retries for sync/postprocess ProviderService->>Firestore: Finalize failed AssemblyAI provider run ProviderService->>Deepgram: Fallback prerecorded transcript request Deepgram-->>ProviderService: Transcript result end + Note over ProviderService: AssemblyAI background workload fails closed;
it does not fall back to Deepgram during desktop cloud batch. + ProviderService->>ProviderService: Normalize provider result and reconstruct canonical segments ProviderService->>EmbeddingAPI: Extract cluster samples for Omi speaker identity ProviderService->>Firestore: Store canonical transcript segments and provider run ledger From 047bf977a1437dc9262b856273a0fa5e2e2bd37b Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 13:00:43 -0400 Subject: [PATCH 27/44] Wire desktop background speaker identity Run Omi speaker identity matching on AssemblyAI desktop background chunks before applying global chunk offsets, update provider run identity metrics, and cover the Omi user match path in desktop background transcription tests. Validation: backend focused pytest suite 42 passed, 2 skipped; speaker identity focused suite 23 passed; pre-commit Python formatting clean; git diff --check clean. --- backend/routers/desktop_background.py | 84 ++++++++++++++++++- .../test_desktop_background_transcribe.py | 58 ++++++++++++- 2 files changed, 139 insertions(+), 3 deletions(-) diff --git a/backend/routers/desktop_background.py b/backend/routers/desktop_background.py index 141581764fa..f61961c8953 100644 --- a/backend/routers/desktop_background.py +++ b/backend/routers/desktop_background.py @@ -3,10 +3,12 @@ from datetime import datetime, timezone from typing import Dict, List, Optional +import numpy as np from fastapi import APIRouter, Depends, Header, HTTPException, Request from pydantic import BaseModel import database.conversations as conversations_db +import database.users as users_db from database import redis_db from models.conversation_enums import ConversationSource, ConversationStatus from models.transcript_segment import TranscriptSegment @@ -22,7 +24,9 @@ from utils.fair_use import is_hard_restricted, record_speech_ms from utils.other import endpoints as auth from utils.speaker_identification import _pcm_to_wav_bytes -from utils.stt.provider_service import transcribe_bytes +from utils.stt.background_speaker_identity import USER_SELF_PERSON_ID, identify_background_speaker_clusters +from utils.stt.provider_service import transcribe_bytes, update_provider_run_identity_metrics +from utils.stt.speaker_embedding import extract_embedding_from_bytes from utils.stt.providers import STTWorkload from utils.subscription import has_transcription_credits, is_trial_paywalled from utils.voice_duration_limiter import compute_pcm_duration_ms @@ -176,6 +180,16 @@ async def background_transcribe( del audio_bytes segments = response.segments + if conversation_id and segments: + await _identify_speakers( + uid=uid, + conversation_id=conversation_id, + audio_bytes=audio_for_stt, + segments=segments, + provider=response.result.provider if response.result else None, + model=response.result.model if response.result else None, + run_id=response.run_id, + ) _apply_chunk_offset(segments, chunk_start_ms / 1000.0) if conversation_id: _apply_speaker_ids(conversation_id, segments) @@ -255,6 +269,74 @@ def _apply_chunk_offset(segments: List[TranscriptSegment], offset_sec: float) -> segment.end += offset_sec +async def _identify_speakers( + uid: str, + conversation_id: str, + audio_bytes: bytes, + segments: List[TranscriptSegment], + provider: Optional[str], + model: Optional[str], + run_id: Optional[str], +) -> None: + """Apply Omi speaker identity to AssemblyAI background chunk-local segments.""" + try: + person_embeddings_cache = await run_blocking(db_executor, _build_person_embeddings_cache, uid) + if not person_embeddings_cache: + return + assignments = await run_blocking( + sync_executor, + identify_background_speaker_clusters, + segments, + audio_bytes, + person_embeddings_cache, + extract_embedding_from_bytes, + ) + identified_count = sum(1 for assignment in assignments.values() if assignment.state in ('identified', 'user')) + logger.info( + "Speaker ID (desktop background): cluster assignments=%s identified=%s uid=%s conversation_id=%s", + len(assignments), + identified_count, + uid, + conversation_id, + ) + await run_blocking( + db_executor, + update_provider_run_identity_metrics, + run_id, + provider or 'unknown', + model or 'unknown', + STTWorkload.background, + segments, + ) + except Exception as e: + logger.warning( + "Speaker ID (desktop background): identification failed uid=%s conversation_id=%s: %s", + uid, + conversation_id, + e, + ) + + +def _build_person_embeddings_cache(uid: str) -> Dict[str, dict]: + cache: Dict[str, dict] = {} + + embedding_list = users_db.get_user_speaker_embedding(uid) + if embedding_list: + user_embedding = np.array(embedding_list, dtype=np.float32).reshape(1, -1) + cache[USER_SELF_PERSON_ID] = {'embedding': user_embedding, 'name': 'User'} + + people = users_db.get_people(uid) + for person in people or []: + embedding = person.get('speaker_embedding') + if embedding and person.get('speech_samples'): + cache[person['id']] = { + 'embedding': np.array(embedding, dtype=np.float32).reshape(1, -1), + 'name': person['name'], + } + + return cache + + def _apply_speaker_ids(conversation_id: str, segments: List[TranscriptSegment]) -> None: speaker_map = _load_speaker_map(conversation_id) changed = False diff --git a/backend/tests/unit/test_desktop_background_transcribe.py b/backend/tests/unit/test_desktop_background_transcribe.py index ef5fa0477a1..164bbb6b54c 100644 --- a/backend/tests/unit/test_desktop_background_transcribe.py +++ b/backend/tests/unit/test_desktop_background_transcribe.py @@ -5,6 +5,8 @@ import wave from unittest.mock import MagicMock +import numpy as np + os.environ.setdefault('DEEPGRAM_API_KEY', 'fake-for-test') for mod_name in ['deepgram', 'deepgram.clients', 'deepgram.clients.live', 'deepgram.clients.live.v1']: @@ -26,6 +28,7 @@ database_client_stub.db = getattr(database_client_stub, 'db', MagicMock()) database_client_stub.document_id_from_seed = getattr(database_client_stub, 'document_id_from_seed', MagicMock()) sys.modules.setdefault('database.conversations', MagicMock()) +sys.modules.setdefault('database.users', MagicMock()) sys.modules.setdefault('database.redis_db', SimpleNamespace(r=MagicMock())) @@ -75,7 +78,10 @@ def _pcm_to_wav_bytes(pcm_data: bytes, sample_rate: int) -> bytes: sys.modules.setdefault('utils.speaker_identification', SimpleNamespace(_pcm_to_wav_bytes=_pcm_to_wav_bytes)) -sys.modules.setdefault('utils.stt.provider_service', SimpleNamespace(transcribe_bytes=MagicMock())) +sys.modules.setdefault( + 'utils.stt.provider_service', + SimpleNamespace(transcribe_bytes=MagicMock(), update_provider_run_identity_metrics=MagicMock()), +) sys.modules.setdefault( 'utils.voice_duration_limiter', SimpleNamespace( @@ -107,7 +113,7 @@ def _pcm_to_wav_bytes(pcm_data: bytes, sample_rate: int) -> bytes: delattr(sys.modules['utils.stt'], 'provider_service') -def _client(monkeypatch, *, segments=None): +def _client(monkeypatch, *, segments=None, person_embeddings_cache=None): app = FastAPI() app.include_router(desktop_background.router) for route in app.routes: @@ -137,6 +143,12 @@ def _client(monkeypatch, *, segments=None): monkeypatch.setattr(desktop_background.redis_db, 'r', MagicMock(), raising=False) monkeypatch.setattr(desktop_background.redis_db.r, 'get', lambda _key: None) monkeypatch.setattr(desktop_background.redis_db.r, 'set', lambda *_args, **_kwargs: None) + monkeypatch.setattr( + desktop_background, + '_build_person_embeddings_cache', + lambda _uid: person_embeddings_cache or {}, + ) + monkeypatch.setattr(desktop_background, 'update_provider_run_identity_metrics', MagicMock()) default_segments = segments or [ TranscriptSegment( @@ -322,6 +334,48 @@ def _transcribe_bytes(_audio_bytes, **_kwargs): assert [segment.speaker for segment in appended_segments] == ['SPEAKER_00', 'SPEAKER_01', 'SPEAKER_00'] +def test_background_transcribe_identifies_assemblyai_speaker_with_omi_user_embedding(monkeypatch): + user_embedding = np.array([[1.0, 0.0]], dtype=np.float32) + segments = [ + TranscriptSegment( + text='This is my voice.', + is_user=False, + start=0.0, + end=3.0, + provider_cluster_id='A', + provider_speaker_label='ASSEMBLYAI_SPEAKER_A', + speaker_identity_state='unassigned', + stt_provider='assemblyai', + stt_model='universal-2', + ) + ] + client, _mock_transcribe = _client( + monkeypatch, + segments=segments, + person_embeddings_cache={'user': {'embedding': user_embedding, 'name': 'User'}}, + ) + monkeypatch.setattr(desktop_background, 'extract_embedding_from_bytes', lambda _audio, _filename: user_embedding) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=12000', + content=b'\x01\x00' * 16000 * 4, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + data = response.json() + segment = data['segments'][0] + assert segment['stt_provider'] == 'assemblyai' + assert segment['provider_cluster_id'] == 'A' + assert segment['is_user'] is True + assert segment['person_id'] is None + assert segment['speaker_identity_state'] == 'user' + assert segment['speaker_identity_source'] == 'omi_speaker_embedding' + assert segment['speaker_identity_provenance']['provider_cluster_id'] == 'A' + assert segment['start'] == 12.0 + desktop_background.update_provider_run_identity_metrics.assert_called_once() + + def test_byok_background_routing_uses_deepgram_when_only_deepgram_key(monkeypatch): from utils.stt import provider_service From 23cb2440352d7ba8cd76e13f6a72c370c9e7709e Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 13:15:09 -0400 Subject: [PATCH 28/44] Isolate desktop AssemblyAI e2e user --- scripts/desktop_assemblyai_e2e.py | 59 +++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/scripts/desktop_assemblyai_e2e.py b/scripts/desktop_assemblyai_e2e.py index d25d65ee881..2c030763e51 100755 --- a/scripts/desktop_assemblyai_e2e.py +++ b/scripts/desktop_assemblyai_e2e.py @@ -22,6 +22,7 @@ import argparse import json +import os import struct import subprocess import sys @@ -57,6 +58,44 @@ def read_desktop_auth_token() -> str: return token +def read_backend_admin_key() -> str | None: + admin_key = os.getenv("ADMIN_KEY") + if admin_key: + return admin_key + + env_path = Path(__file__).resolve().parents[1] / "backend" / ".env" + if not env_path.exists(): + return None + + for line in env_path.read_text().splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("#") or not stripped.startswith("ADMIN_KEY="): + continue + return stripped.split("=", 1)[1].strip().strip('"').strip("'") or None + return None + + +def resolve_auth_token(args: argparse.Namespace) -> str: + if args.token: + return args.token + + if args.background_chunk or args.background_batch: + if args.use_desktop_auth: + print("Using Omi Dev desktop auth; this will persist sample transcripts to the signed-in local account.") + return read_desktop_auth_token() + + admin_key = read_backend_admin_key() + if not admin_key: + raise SystemExit( + "Background e2e persists transcript segments. Set ADMIN_KEY in backend/.env, pass --token, " + "or pass --use-desktop-auth to explicitly use the signed-in Omi Dev account." + ) + print(f"Using isolated local e2e uid={args.e2e_uid}.") + return f"{admin_key}{args.e2e_uid}" + + return read_desktop_auth_token() + + def require_backend_reachable(api_base: str) -> None: url = f"{api_base.rstrip('/')}/docs" try: @@ -410,7 +449,23 @@ def main() -> int: parser.add_argument("--api", default="http://127.0.0.1:8080", help="Local Python backend base URL") parser.add_argument("--workdir", default="/tmp/omi-assemblyai-e2e", help="Temp dir for sample audio") parser.add_argument("--language", default="en", help="Language passed to background transcription") - parser.add_argument("--token", help="Firebase ID token; defaults to Omi Dev auth_idToken from macOS defaults") + parser.add_argument( + "--token", + help=( + "Firebase ID token or ADMIN_KEY-prefixed uid. Background modes default to an isolated local e2e uid; " + "sync mode defaults to Omi Dev auth_idToken from macOS defaults." + ), + ) + parser.add_argument( + "--e2e-uid", + default="desktop-assemblyai-e2e", + help="Isolated local uid used by background modes when ADMIN_KEY is available and --token is omitted.", + ) + parser.add_argument( + "--use-desktop-auth", + action="store_true", + help="For background modes, explicitly persist sample transcripts to the signed-in Omi Dev account.", + ) parser.add_argument( "--background-chunk", action="store_true", @@ -427,7 +482,7 @@ def main() -> int: print("Choose only one of --background-chunk or --background-batch.", file=sys.stderr) return 2 - token = args.token or read_desktop_auth_token() + token = resolve_auth_token(args) if args.background_chunk: pcm_path = ensure_sample_pcm(Path(args.workdir)) From 3bfe844ab6e785a752e2186d61afcd0efc0d3f9c Mon Sep 17 00:00:00 2001 From: David Zhang Date: Thu, 21 May 2026 13:20:48 -0400 Subject: [PATCH 29/44] Fix lint CI for Next 16 --- .github/workflows/lint.yml | 2 +- web/frontend/package.json | 2 +- web/frontend/src/app/globals.css | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 58c3fe2ab01..224249a1cb2 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -108,7 +108,7 @@ jobs: if: steps.changed.outputs.has_frontend == 'true' || steps.changed.outputs.has_personas == 'true' uses: actions/setup-node@v4 with: - node-version: '18' + node-version: '22' cache: 'npm' cache-dependency-path: | web/frontend/package-lock.json diff --git a/web/frontend/package.json b/web/frontend/package.json index c55ff317aaf..33016c30ac7 100644 --- a/web/frontend/package.json +++ b/web/frontend/package.json @@ -6,7 +6,7 @@ "dev": "next dev --turbo", "build": "next build", "start": "next start", - "lint": "next lint", + "lint": "eslint ./src --ext .jsx,.js,.ts,.tsx --quiet --ignore-path ./.eslintignore", "lint:fix": "eslint ./src --ext .jsx,.js,.ts,.tsx --quiet --fix --ignore-path ./.eslintignore", "lint:format": "prettier --loglevel warn --write \"./**/*.{js,jsx,ts,tsx,css,md,json}\" " }, diff --git a/web/frontend/src/app/globals.css b/web/frontend/src/app/globals.css index 45b289a6989..deea1d98934 100644 --- a/web/frontend/src/app/globals.css +++ b/web/frontend/src/app/globals.css @@ -35,5 +35,6 @@ } .font-system-ui { - font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif; + font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, + 'Helvetica Neue', Arial, sans-serif; } From 949ba39a78b555e2d641adfca27df3ce2e7098fe Mon Sep 17 00:00:00 2001 From: David Zhang Date: Fri, 22 May 2026 04:39:55 +0700 Subject: [PATCH 30/44] Finalize prerecorded STT provider policy --- backend/.env.template | 7 +- backend/routers/desktop_background.py | 60 ++++++- .../unit/test_background_provider_service.py | 165 ++++++++++++++++-- .../unit/test_byok_assemblyai_routing.py | 26 ++- .../test_desktop_background_transcribe.py | 81 ++++++++- .../tests/unit/test_stt_provider_facade.py | 11 +- backend/utils/stt/provider_service.py | 32 +++- backend/utils/stt/providers.py | 18 +- desktop/Desktop/Sources/APIClient.swift | 32 ++++ desktop/Desktop/Sources/AppState.swift | 26 +-- ...BackgroundTranscriptionConfiguration.swift | 2 + .../BackgroundTranscriptionRoutingGuard.swift | 4 - .../CloudBackgroundTranscriptionSession.swift | 16 +- .../MainWindow/Pages/SettingsPage.swift | 36 ---- .../Services/AssistantSettings.swift | 4 +- .../Desktop/Tests/APIClientRoutingTests.swift | 8 + .../Tests/BackgroundTranscriptionTests.swift | 69 ++++++-- .../backend/assemblyai_background_rollout.mdx | 24 +-- .../backend/listen_pusher_pipeline.mdx | 13 +- .../ASSEMBLYAI_BACKGROUND_E2E_AGENT_PROMPT.md | 8 +- scripts/desktop_assemblyai_e2e.py | 6 +- 21 files changed, 522 insertions(+), 126 deletions(-) diff --git a/backend/.env.template b/backend/.env.template index aea4e371039..fefc456fab5 100644 --- a/backend/.env.template +++ b/backend/.env.template @@ -14,10 +14,11 @@ REDIS_DB_PASSWORD= DEEPGRAM_API_KEY= -# AssemblyAI async/background STT. Disabled by default; eligible workloads are sync, background, postprocess. +# AssemblyAI async prerecorded STT. Enabled by default for eligible workloads: sync, background, postprocess. ASSEMBLYAI_API_KEY= -ASSEMBLYAI_BACKGROUND_STT_ENABLED=false -ASSEMBLYAI_BACKGROUND_STT_WORKLOADS=sync,background,postprocess +ASSEMBLYAI_PRERECORDED_STT_ENABLED=true +ASSEMBLYAI_PRERECORDED_STT_WORKLOADS=sync,background,postprocess +ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED=true ASSEMBLYAI_STT_MODEL=universal-2 ASSEMBLYAI_BASE_URL=https://api.assemblyai.com ASSEMBLYAI_POLL_INTERVAL_SECONDS=3 diff --git a/backend/routers/desktop_background.py b/backend/routers/desktop_background.py index f61961c8953..7999090517c 100644 --- a/backend/routers/desktop_background.py +++ b/backend/routers/desktop_background.py @@ -1,5 +1,6 @@ import json import logging +import os from datetime import datetime, timezone from typing import Dict, List, Optional @@ -23,11 +24,16 @@ from utils.executors import db_executor, run_blocking, sync_executor from utils.fair_use import is_hard_restricted, record_speech_ms from utils.other import endpoints as auth +from utils.byok import get_byok_key from utils.speaker_identification import _pcm_to_wav_bytes from utils.stt.background_speaker_identity import USER_SELF_PERSON_ID, identify_background_speaker_clusters -from utils.stt.provider_service import transcribe_bytes, update_provider_run_identity_metrics +from utils.stt.provider_service import ( + resolve_prerecorded_provider_for_request, + transcribe_bytes, + update_provider_run_identity_metrics, +) from utils.stt.speaker_embedding import extract_embedding_from_bytes -from utils.stt.providers import STTWorkload +from utils.stt.providers import STTProviderName, STTWorkload, get_fallback_prerecorded_provider_name from utils.subscription import has_transcription_credits, is_trial_paywalled from utils.voice_duration_limiter import compute_pcm_duration_ms @@ -44,6 +50,35 @@ class BackgroundConversationStartRequest(BaseModel): source: Optional[str] = "desktop" +@router.get("/capabilities") +async def desktop_capabilities(uid: str = Depends(auth.get_current_user_uid)): + background_provider = resolve_prerecorded_provider_for_request(STTWorkload.background) + assemblyai_key_available = bool(os.getenv('ASSEMBLYAI_API_KEY') or get_byok_key('assemblyai')) + fallback_provider = get_fallback_prerecorded_provider_name(background_provider, STTWorkload.background) + fallback_available = fallback_provider == STTProviderName.deepgram + enabled = background_provider == STTProviderName.assemblyai and (assemblyai_key_available or fallback_available) + reason = None + if background_provider != STTProviderName.assemblyai: + reason = f'provider_{background_provider.value}' + elif not assemblyai_key_available and fallback_available: + reason = 'fallback_deepgram_available' + elif not assemblyai_key_available: + reason = 'missing_assemblyai_api_key' + return { + "background_batch": { + "enabled": enabled, + "provider": background_provider.value, + "fallback_provider": fallback_provider.value if fallback_provider else None, + "workload": STTWorkload.background.value, + "reason": reason, + "sample_rate": 16000, + "channels": 1, + "encoding": "linear16", + "max_chunk_seconds": 15, + } + } + + @router.post("/background-conversation/start") async def start_background_conversation( body: BackgroundConversationStartRequest, @@ -180,6 +215,7 @@ async def background_transcribe( del audio_bytes segments = response.segments + speaker_diagnostics = _speaker_diagnostics(segments) if conversation_id and segments: await _identify_speakers( uid=uid, @@ -193,6 +229,7 @@ async def background_transcribe( _apply_chunk_offset(segments, chunk_start_ms / 1000.0) if conversation_id: _apply_speaker_ids(conversation_id, segments) + speaker_diagnostics.update(_speaker_diagnostics(segments, prefix="mapped_")) finished_at = datetime.now(timezone.utc) if persist and conversation_id: @@ -227,6 +264,7 @@ async def background_transcribe( "provider": provider, "run_id": response.run_id, "chunk_duration_ms": duration_ms, + "speaker_diagnostics": speaker_diagnostics, } @@ -374,3 +412,21 @@ def _load_speaker_map(conversation_id: str) -> Dict[str, int]: def _store_speaker_map(conversation_id: str, speaker_map: Dict[str, int]) -> None: redis_db.r.set(_speaker_map_key(conversation_id), json.dumps(speaker_map), ex=_SPEAKER_MAP_TTL_SECONDS) + + +def _speaker_diagnostics(segments: List[TranscriptSegment], prefix: str = "") -> Dict[str, object]: + provider_clusters = sorted( + {str(segment.provider_cluster_id) for segment in segments if segment.provider_cluster_id is not None} + ) + provider_labels = sorted( + {str(segment.provider_speaker_label) for segment in segments if segment.provider_speaker_label is not None} + ) + mapped_speakers = sorted({int(segment.speaker_id) for segment in segments if segment.speaker_id is not None}) + return { + f"{prefix}provider_cluster_count": len(provider_clusters), + f"{prefix}provider_clusters": provider_clusters[:20], + f"{prefix}provider_speaker_label_count": len(provider_labels), + f"{prefix}provider_speaker_labels": provider_labels[:20], + f"{prefix}speaker_id_count": len(mapped_speakers), + f"{prefix}speaker_ids": mapped_speakers[:20], + } diff --git a/backend/tests/unit/test_background_provider_service.py b/backend/tests/unit/test_background_provider_service.py index a0de2de4a2b..ee5948e0f23 100644 --- a/backend/tests/unit/test_background_provider_service.py +++ b/backend/tests/unit/test_background_provider_service.py @@ -15,6 +15,7 @@ sys.modules.setdefault('database._client', types.SimpleNamespace(db=MagicMock())) os.environ.setdefault('DEEPGRAM_API_KEY', 'fake-for-test') +os.environ.setdefault('ASSEMBLYAI_API_KEY', 'fake-for-test') from models.transcript_segment import ProviderTranscriptResult, ProviderTranscriptWord # noqa: E402 from utils.stt import provider_service # noqa: E402 @@ -47,7 +48,8 @@ def _provider_result(provider='deepgram', model='nova-3'): ) -def test_provider_service_transcribes_sync_upload_and_finalizes_deepgram_run(): +def test_provider_service_transcribes_sync_upload_and_finalizes_deepgram_run(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'false') fake_provider = MagicMock() fake_provider.provider_name = STTProviderName.deepgram fake_provider.transcribe_url.return_value = (_provider_result(), 'en') @@ -82,7 +84,8 @@ def test_provider_service_transcribes_sync_upload_and_finalizes_deepgram_run(): assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00016 -def test_provider_service_finalizes_background_run_on_deepgram_default(): +def test_provider_service_finalizes_background_run_on_deepgram_when_assemblyai_disabled(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'false') fake_provider = MagicMock() fake_provider.provider_name = STTProviderName.deepgram fake_provider.transcribe_url.return_value = _provider_result() @@ -106,14 +109,23 @@ def test_provider_service_finalizes_background_run_on_deepgram_default(): assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00076 +def test_background_routing_selects_assemblyai_by_default(monkeypatch): + monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', raising=False) + monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', raising=False) + + assert get_prerecorded_provider_name(STTWorkload.sync) == STTProviderName.assemblyai + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.assemblyai + assert get_prerecorded_provider_name(STTWorkload.postprocess) == STTProviderName.assemblyai + + def test_prerecorded_ptt_and_realtime_related_workloads_stay_deepgram(): assert get_prerecorded_provider_name(STTWorkload.ptt) == STTProviderName.deepgram assert get_prerecorded_provider_name(STTWorkload.voice_message) == STTProviderName.deepgram def test_background_routing_can_select_assemblyai_without_moving_latency_critical_workloads(monkeypatch): - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync,background,postprocess,ptt,realtime') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess,ptt,realtime') assert get_prerecorded_provider_name(STTWorkload.sync) == STTProviderName.assemblyai assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.assemblyai @@ -145,8 +157,8 @@ def test_background_call_sites_use_provider_service_layer(): def test_provider_service_uses_assemblyai_for_enabled_sync_workload(monkeypatch): - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') fake_provider = MagicMock() fake_provider.provider_name = STTProviderName.assemblyai @@ -182,8 +194,8 @@ def test_provider_service_uses_assemblyai_for_enabled_sync_workload(monkeypatch) def test_provider_service_falls_back_to_deepgram_when_assemblyai_fails(monkeypatch): - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') assemblyai_provider = MagicMock() assemblyai_provider.provider_name = STTProviderName.assemblyai @@ -228,9 +240,128 @@ def test_provider_service_falls_back_to_deepgram_when_assemblyai_fails(monkeypat assert finalize_run.call_args_list[1].kwargs['estimated_cost_usd'] == 0.00016 -def test_provider_service_records_retry_exhaustion_without_fallback(monkeypatch): - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'background') +def test_provider_service_falls_back_to_deepgram_for_background_when_assemblyai_fails(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'background') + monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', raising=False) + + assemblyai_provider = MagicMock() + assemblyai_provider.provider_name = STTProviderName.assemblyai + assemblyai_provider.transcribe_url.side_effect = RuntimeError('AssemblyAI failed') + + deepgram_provider = MagicMock() + deepgram_provider.provider_name = STTProviderName.deepgram + deepgram_provider.transcribe_url.return_value = _provider_result() + + with patch.object( + provider_service, '_assemblyai_prerecorded_provider', return_value=assemblyai_provider + ), patch.object(provider_service, '_deepgram_prerecorded_provider', return_value=deepgram_provider), patch.object( + provider_service, 'create_provider_run', side_effect=['run-aai', 'run-dg'] + ), patch.object( + provider_service, 'finalize_provider_run' + ) as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/background.wav', + workload=STTWorkload.background, + uid='uid-1', + language='multi', + model='nova-3', + raw_audio_seconds=2.0, + ) + + assert response.result.provider == 'deepgram' + assert assemblyai_provider.transcribe_url.call_count == 2 + deepgram_provider.transcribe_url.assert_called_once() + assert finalize_run.call_args_list[0].kwargs['provider'] == 'assemblyai' + assert finalize_run.call_args_list[0].kwargs['status'] == 'failed' + assert finalize_run.call_args_list[1].kwargs['provider'] == 'deepgram' + assert finalize_run.call_args_list[1].kwargs['fallback_count'] == 1 + assert finalize_run.call_args_list[1].kwargs['fallback_provider'] == 'assemblyai' + + +def test_provider_service_skips_missing_assemblyai_key_when_deepgram_fallback_is_usable(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.setenv('DEEPGRAM_API_KEY', 'dg-server-key') + + deepgram_provider = MagicMock() + deepgram_provider.provider_name = STTProviderName.deepgram + deepgram_provider.transcribe_url.return_value = _provider_result() + + with patch.object(provider_service, '_assemblyai_prerecorded_provider') as assemblyai_provider, patch.object( + provider_service, '_deepgram_prerecorded_provider', return_value=deepgram_provider + ), patch.object(provider_service, 'create_provider_run', return_value='run-dg'), patch.object( + provider_service, 'finalize_provider_run' + ) as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.sync, + uid='uid-1', + language='multi', + model='nova-3', + ) + + assert response.result.provider == 'deepgram' + assemblyai_provider.assert_not_called() + deepgram_provider.transcribe_url.assert_called_once() + assert finalize_run.call_args.kwargs['provider'] == 'deepgram' + assert finalize_run.call_args.kwargs['fallback_count'] == 0 + + +def test_provider_service_reports_missing_assemblyai_key_when_fallback_disabled(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', 'false') + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.setenv('DEEPGRAM_API_KEY', 'dg-server-key') + + with patch.object(provider_service, '_deepgram_prerecorded_provider') as deepgram_provider, patch.object( + provider_service, 'create_provider_run', return_value='run-aai' + ), patch.object(provider_service, 'finalize_provider_run') as finalize_run: + with pytest.raises(RuntimeError, match='ASSEMBLYAI_API_KEY is not configured'): + provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.sync, + uid='uid-1', + language='multi', + model='nova-3', + ) + + deepgram_provider.assert_not_called() + assert finalize_run.call_args.kwargs['provider'] == 'assemblyai' + assert finalize_run.call_args.kwargs['status'] == 'failed' + assert finalize_run.call_args.kwargs['error_class'] == 'AssemblyAIProviderError' + + +def test_provider_service_reports_missing_assemblyai_key_when_no_fallback_key_is_usable(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.delenv('DEEPGRAM_API_KEY', raising=False) + + with patch.object(provider_service, '_deepgram_prerecorded_provider') as deepgram_provider, patch.object( + provider_service, 'create_provider_run', return_value='run-aai' + ), patch.object(provider_service, 'finalize_provider_run') as finalize_run: + with pytest.raises(RuntimeError, match='ASSEMBLYAI_API_KEY is not configured'): + provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.sync, + uid='uid-1', + language='multi', + model='nova-3', + ) + + deepgram_provider.assert_not_called() + assert finalize_run.call_args.kwargs['provider'] == 'assemblyai' + assert finalize_run.call_args.kwargs['status'] == 'failed' + assert finalize_run.call_args.kwargs['error_class'] == 'AssemblyAIProviderError' + + +def test_provider_service_records_background_retry_exhaustion_when_fallback_disabled(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'background') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', 'false') assemblyai_provider = MagicMock() assemblyai_provider.provider_name = STTProviderName.assemblyai @@ -264,8 +395,8 @@ def test_provider_service_records_retry_exhaustion_without_fallback(monkeypatch) def test_provider_service_records_successful_after_retry(monkeypatch): - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') fake_provider = MagicMock() fake_provider.provider_name = STTProviderName.assemblyai @@ -295,8 +426,8 @@ def test_provider_service_records_successful_after_retry(monkeypatch): def test_provider_service_records_zero_cost_for_zero_duration_success(monkeypatch): - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') fake_provider = MagicMock() fake_provider.provider_name = STTProviderName.assemblyai @@ -406,8 +537,8 @@ def test_provider_service_live_assemblyai_smoke_records_ledger_when_credentials_ if not api_key or not audio_url: pytest.skip('ASSEMBLYAI_API_KEY and ASSEMBLYAI_SMOKE_AUDIO_URL are required for live smoke') - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') with patch.object(provider_service, 'create_provider_run', return_value='run-aai-live') as create_run, patch.object( provider_service, 'finalize_provider_run' diff --git a/backend/tests/unit/test_byok_assemblyai_routing.py b/backend/tests/unit/test_byok_assemblyai_routing.py index 08e56083c08..5d73e555502 100644 --- a/backend/tests/unit/test_byok_assemblyai_routing.py +++ b/backend/tests/unit/test_byok_assemblyai_routing.py @@ -21,8 +21,9 @@ @pytest.fixture(autouse=True) def _enable_assemblyai_routing(monkeypatch): - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync,background,postprocess') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') + monkeypatch.setenv('ASSEMBLYAI_API_KEY', 'aa-server-key') def test_env_selects_assemblyai_for_sync(): @@ -70,6 +71,27 @@ def test_resolve_uses_server_assembly_when_no_byok_headers(_mock_get_key): assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.assemblyai +@patch('utils.stt.provider_service.get_byok_key') +def test_resolve_uses_server_deepgram_when_server_assembly_missing_and_fallback_enabled(mock_get_key, monkeypatch): + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.setenv('DEEPGRAM_API_KEY', 'dg-server-key') + mock_get_key.return_value = None + + assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.deepgram + + +@patch('utils.stt.provider_service.get_byok_key') +def test_resolve_keeps_assemblyai_selected_when_server_assembly_missing_and_fallback_disabled( + mock_get_key, monkeypatch +): + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.setenv('DEEPGRAM_API_KEY', 'dg-server-key') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', 'false') + mock_get_key.return_value = None + + assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.assemblyai + + @patch('utils.stt.provider_service.get_byok_key') def test_assemblyai_provider_passes_byok_api_key(mock_get_key): mock_get_key.return_value = 'aa-user-key' diff --git a/backend/tests/unit/test_desktop_background_transcribe.py b/backend/tests/unit/test_desktop_background_transcribe.py index 164bbb6b54c..cba0cdd5949 100644 --- a/backend/tests/unit/test_desktop_background_transcribe.py +++ b/backend/tests/unit/test_desktop_background_transcribe.py @@ -51,6 +51,7 @@ def __init__(self, message: str, status_code: int = 400): 'utils.chat', SimpleNamespace(resolve_voice_message_language=lambda _uid, language: language or 'en') ) sys.modules.setdefault('utils.analytics', SimpleNamespace(record_usage=lambda *_args, **_kwargs: None)) +sys.modules.setdefault('utils.byok', SimpleNamespace(get_byok_key=lambda _provider: None)) sys.modules.setdefault( 'utils.fair_use', SimpleNamespace( @@ -78,9 +79,14 @@ def _pcm_to_wav_bytes(pcm_data: bytes, sample_rate: int) -> bytes: sys.modules.setdefault('utils.speaker_identification', SimpleNamespace(_pcm_to_wav_bytes=_pcm_to_wav_bytes)) +sys.modules.setdefault('utils.stt.speaker_embedding', SimpleNamespace(extract_embedding_from_bytes=MagicMock())) sys.modules.setdefault( 'utils.stt.provider_service', - SimpleNamespace(transcribe_bytes=MagicMock(), update_provider_run_identity_metrics=MagicMock()), + SimpleNamespace( + resolve_prerecorded_provider_for_request=MagicMock(), + transcribe_bytes=MagicMock(), + update_provider_run_identity_metrics=MagicMock(), + ), ) sys.modules.setdefault( 'utils.voice_duration_limiter', @@ -127,6 +133,7 @@ def _client(monkeypatch, *, segments=None, person_embeddings_cache=None): monkeypatch.setattr(desktop_background, 'has_transcription_credits', lambda *_args, **_kwargs: True) monkeypatch.setattr(desktop_background, 'record_speech_ms', lambda *_args, **_kwargs: None) monkeypatch.setattr(desktop_background, 'record_usage', lambda *_args, **_kwargs: None) + monkeypatch.setattr(desktop_background, 'get_byok_key', lambda _provider: None) monkeypatch.setattr(desktop_background, 'resolve_voice_message_language', lambda _uid, language: language or 'en') monkeypatch.setattr( desktop_background.conversations_db, @@ -179,6 +186,72 @@ def _transcribe_bytes(audio_bytes, **_kwargs): return TestClient(app), mock_transcribe +def test_desktop_capabilities_reports_assemblyai_background_when_key_available(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') + monkeypatch.setenv('ASSEMBLYAI_API_KEY', 'server-aai-key') + monkeypatch.setattr( + desktop_background, + 'resolve_prerecorded_provider_for_request', + lambda _workload: STTProviderName.assemblyai, + ) + + response = client.get('/v2/desktop/capabilities') + + assert response.status_code == 200 + data = response.json()['background_batch'] + assert data['enabled'] is True + assert data['provider'] == 'assemblyai' + assert data['fallback_provider'] == 'deepgram' + assert data['workload'] == 'background' + assert data['reason'] is None + + +def test_desktop_capabilities_allows_background_batch_with_deepgram_fallback(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') + monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', raising=False) + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.setattr( + desktop_background, + 'resolve_prerecorded_provider_for_request', + lambda _workload: STTProviderName.assemblyai, + ) + + response = client.get('/v2/desktop/capabilities') + + assert response.status_code == 200 + data = response.json()['background_batch'] + assert data['enabled'] is True + assert data['provider'] == 'assemblyai' + assert data['fallback_provider'] == 'deepgram' + assert data['reason'] == 'fallback_deepgram_available' + + +def test_desktop_capabilities_reports_missing_assemblyai_key_when_fallback_disabled(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', 'false') + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.setattr( + desktop_background, + 'resolve_prerecorded_provider_for_request', + lambda _workload: STTProviderName.assemblyai, + ) + + response = client.get('/v2/desktop/capabilities') + + assert response.status_code == 200 + data = response.json()['background_batch'] + assert data['enabled'] is False + assert data['provider'] == 'assemblyai' + assert data['fallback_provider'] is None + assert data['reason'] == 'missing_assemblyai_api_key' + + def test_background_transcribe_returns_segments_with_offset(monkeypatch): client, _mock_transcribe = _client(monkeypatch) @@ -285,6 +358,8 @@ def test_cluster_speaker_mapping_assigns_distinct_ids(monkeypatch): data = response.json() assert [segment['speaker_id'] for segment in data['segments']] == [0, 1] assert [segment['speaker'] for segment in data['segments']] == ['SPEAKER_00', 'SPEAKER_01'] + assert data['speaker_diagnostics']['provider_cluster_count'] == 2 + assert data['speaker_diagnostics']['mapped_speaker_ids'] == [0, 1] def test_background_transcribe_multi_chunk_offsets_persist_and_keep_speaker_map(monkeypatch): @@ -379,8 +454,8 @@ def test_background_transcribe_identifies_assemblyai_speaker_with_omi_user_embed def test_byok_background_routing_uses_deepgram_when_only_deepgram_key(monkeypatch): from utils.stt import provider_service - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'true') - monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync,background,postprocess') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') monkeypatch.setattr(provider_service, 'get_byok_key', lambda provider: {'deepgram': 'dg-user-key'}.get(provider)) provider = provider_service.resolve_prerecorded_provider_for_request(STTWorkload.background) diff --git a/backend/tests/unit/test_stt_provider_facade.py b/backend/tests/unit/test_stt_provider_facade.py index f4d6d691bf6..eb76fa55f71 100644 --- a/backend/tests/unit/test_stt_provider_facade.py +++ b/backend/tests/unit/test_stt_provider_facade.py @@ -145,15 +145,22 @@ def test_deepgram_adapter_preserves_prerecorded_request_options(): assert detected_language == 'en' -def test_provider_routing_keeps_all_current_workloads_on_deepgram(): +def test_provider_routing_uses_assemblyai_only_for_passive_prerecorded_workloads_by_default(monkeypatch): + monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', raising=False) + monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', raising=False) + assert get_streaming_provider_name(STTWorkload.ptt) == STTProviderName.deepgram assert get_streaming_provider_name(STTWorkload.realtime) == STTProviderName.deepgram for workload in [ STTWorkload.background, STTWorkload.postprocess, - STTWorkload.ptt, STTWorkload.sync, + ]: + assert get_prerecorded_provider_name(workload) == STTProviderName.assemblyai + + for workload in [ + STTWorkload.ptt, STTWorkload.voice_message, ]: assert get_prerecorded_provider_name(workload) == STTProviderName.deepgram diff --git a/backend/utils/stt/provider_service.py b/backend/utils/stt/provider_service.py index bd38ce79c31..9ad960c8076 100644 --- a/backend/utils/stt/provider_service.py +++ b/backend/utils/stt/provider_service.py @@ -16,6 +16,7 @@ from utils.stt.providers import ( STTProviderName, STTWorkload, + assemblyai_prerecorded_fallback_enabled, get_fallback_prerecorded_provider_name, get_prerecorded_provider_name, ) @@ -92,18 +93,30 @@ def resolve_prerecorded_provider_for_request(workload: STTWorkload) -> STTProvid """Pick prerecorded STT provider for this request, respecting BYOK headers. When env flags select AssemblyAI but the BYOK user did not supply an Assembly - key, use Deepgram BYOK instead of Omi's server Assembly key. + key, use Deepgram BYOK instead of Omi's server Assembly key. When server + AssemblyAI credentials are absent, skip directly to Deepgram if fallback is + enabled and a usable Deepgram key is available. """ selected = get_prerecorded_provider_name(workload) if selected != STTProviderName.assemblyai: return selected - if get_byok_key and get_byok_key('assemblyai'): + assemblyai_byok = get_byok_key('assemblyai') if get_byok_key else None + deepgram_byok = get_byok_key('deepgram') if get_byok_key else None + if assemblyai_byok: return STTProviderName.assemblyai - if get_byok_key and get_byok_key('deepgram'): + if deepgram_byok and assemblyai_prerecorded_fallback_enabled(): + return STTProviderName.deepgram + if os.getenv('ASSEMBLYAI_API_KEY'): + return STTProviderName.assemblyai + if assemblyai_prerecorded_fallback_enabled() and _has_deepgram_key_for_request(): return STTProviderName.deepgram return STTProviderName.assemblyai +def _has_deepgram_key_for_request() -> bool: + return bool((get_byok_key('deepgram') if get_byok_key else None) or os.getenv('DEEPGRAM_API_KEY')) + + def _deepgram_client_for_request() -> DeepgramClient: byok = get_byok_key('deepgram') if get_byok_key else None if byok: @@ -193,7 +206,7 @@ def transcribe_url( raw_audio_seconds, retry_count=_retry_count_from_exception(e), ) - fallback_provider_name = get_fallback_prerecorded_provider_name(provider_name, workload) + fallback_provider_name = _resolve_usable_fallback_prerecorded_provider(provider_name, workload) if fallback_provider_name: logger.warning( 'provider prerecorded url transcription falling back workload=%s from_provider=%s to_provider=%s: %s', @@ -287,7 +300,7 @@ def transcribe_bytes( raw_audio_seconds, retry_count=_retry_count_from_exception(e), ) - fallback_provider_name = get_fallback_prerecorded_provider_name(provider_name, workload) + fallback_provider_name = _resolve_usable_fallback_prerecorded_provider(provider_name, workload) if fallback_provider_name: logger.warning( 'provider prerecorded bytes transcription falling back workload=%s from_provider=%s to_provider=%s: %s', @@ -325,6 +338,15 @@ def _get_prerecorded_provider(provider_name: STTProviderName): raise ValueError(f'Unsupported prerecorded STT provider: {provider_name}') +def _resolve_usable_fallback_prerecorded_provider( + provider_name: STTProviderName, workload: STTWorkload +) -> Optional[STTProviderName]: + fallback_provider_name = get_fallback_prerecorded_provider_name(provider_name, workload) + if fallback_provider_name == STTProviderName.deepgram and not _has_deepgram_key_for_request(): + return None + return fallback_provider_name + + def _model_for_provider(provider_name: STTProviderName, requested_model: str) -> str: if provider_name == STTProviderName.assemblyai and str(requested_model or '').startswith('nova-'): return os.getenv('ASSEMBLYAI_STT_MODEL', 'universal-2') diff --git a/backend/utils/stt/providers.py b/backend/utils/stt/providers.py index 1d1af88e632..ee68555fa60 100644 --- a/backend/utils/stt/providers.py +++ b/backend/utils/stt/providers.py @@ -93,7 +93,7 @@ class SpeakerIdentityProvider(Protocol): def get_prerecorded_provider_name(workload: STTWorkload) -> STTProviderName: workload = STTWorkload(workload) - if _assemblyai_background_enabled() and workload in _assemblyai_enabled_workloads(): + if _assemblyai_prerecorded_enabled() and workload in _assemblyai_enabled_workloads(): return STTProviderName.assemblyai return _DEFAULT_PRERECORDED_WORKLOAD_PROVIDERS[workload] @@ -107,7 +107,11 @@ def get_fallback_prerecorded_provider_name( ) -> Optional[STTProviderName]: workload = STTWorkload(workload) provider = STTProviderName(provider) - if workload == STTWorkload.background and provider == STTProviderName.assemblyai: + if ( + workload in _ASSEMBLYAI_ELIGIBLE_WORKLOADS + and provider == STTProviderName.assemblyai + and not assemblyai_prerecorded_fallback_enabled() + ): return None fallback = _DEFAULT_PRERECORDED_WORKLOAD_PROVIDERS[workload] if provider != fallback: @@ -115,12 +119,16 @@ def get_fallback_prerecorded_provider_name( return None -def _assemblyai_background_enabled() -> bool: - return os.getenv('ASSEMBLYAI_BACKGROUND_STT_ENABLED', 'false').lower() == 'true' +def _assemblyai_prerecorded_enabled() -> bool: + return os.getenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true').lower() == 'true' + + +def assemblyai_prerecorded_fallback_enabled() -> bool: + return os.getenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', 'true').lower() == 'true' def _assemblyai_enabled_workloads() -> set[STTWorkload]: - configured = os.getenv('ASSEMBLYAI_BACKGROUND_STT_WORKLOADS', 'sync,background,postprocess') + configured = os.getenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') workloads = set() for raw_value in configured.split(','): value = raw_value.strip() diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 83019dc18b7..13f0af97485 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -1333,6 +1333,38 @@ extension APIClient { customBaseURL: nil ) } + + /// Fetch Python backend desktop transcription capabilities. + /// Endpoint: GET /v2/desktop/capabilities + func getDesktopCapabilities() async throws -> DesktopCapabilitiesResponse { + try await get("v2/desktop/capabilities", customBaseURL: nil) + } +} + +struct DesktopCapabilitiesResponse: Codable { + let backgroundBatch: DesktopBackgroundBatchCapability + + enum CodingKeys: String, CodingKey { + case backgroundBatch = "background_batch" + } +} + +struct DesktopBackgroundBatchCapability: Codable { + let enabled: Bool + let provider: String + let sampleRate: Int + let channels: Int + let encoding: String + let maxChunkSeconds: Int + + enum CodingKeys: String, CodingKey { + case enabled + case provider + case sampleRate = "sample_rate" + case channels + case encoding + case maxChunkSeconds = "max_chunk_seconds" + } } // MARK: - Memories API diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index d461022ef7a..b943017e9b4 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -1432,8 +1432,7 @@ class AppState: ObservableObject { ) let routing = BackgroundTranscriptionRoutingGuard().decide( - batchEnabled: AssistantSettings.shared.batchTranscriptionEnabled, - serverAssemblyBackgroundEnabled: Self.isServerBackgroundBatchEnabled, + serverAssemblyBackgroundEnabled: true, audioSource: effectiveSource ) if routing == .cloudBatchAssembly { @@ -1591,16 +1590,23 @@ class AppState: ObservableObject { } } - private static var isServerBackgroundBatchEnabled: Bool { - let baseURL = DesktopBackendEnvironment.pythonBaseURL().lowercased() - return baseURL.contains("127.0.0.1") - || baseURL.contains("localhost") - || baseURL.contains("omiapi.com") - } - private func startCloudBackgroundTranscription(source: AudioSource, language: String) async { defer { isStartingTranscription = false } do { + let capabilities = try await APIClient.shared.getDesktopCapabilities() + guard capabilities.backgroundBatch.enabled + && capabilities.backgroundBatch.provider.lowercased() == "assemblyai" + else { + throw NSError( + domain: "Omi.CloudBackgroundTranscription", + code: 1, + userInfo: [ + NSLocalizedDescriptionKey: + "Server background batch transcription is not enabled for AssemblyAI." + ] + ) + } + let conversationId = try await APIClient.shared.startBackgroundConversation( language: language) guard !Task.isCancelled else { return } @@ -2083,7 +2089,7 @@ class AppState: ObservableObject { } } catch { logError("Transcription: Cloud background chunk failed", error: error) - continue + try? await Task.sleep(nanoseconds: 1_000_000_000) } } } diff --git a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift index 3849aa6f516..b7eb8279db2 100644 --- a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift +++ b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift @@ -10,6 +10,7 @@ struct BackgroundTranscriptionConfiguration: Equatable { var speechPeakAmplitudeThreshold: Int = 512 var speechRMSAmplitudeThreshold: Int = 64 var maxPendingChunks: Int = 4 + var maxChunkTranscriptionAttempts: Int = 3 var requiresSpeechBeforeUpload: Bool = false var speechActivityDetection = SpeechActivityDetectionConfiguration() @@ -33,6 +34,7 @@ struct BackgroundTranscriptionConfiguration: Equatable { minChunkDuration: 15.0, overlapDuration: 0.5, maxPendingChunks: 8, + maxChunkTranscriptionAttempts: 3, requiresSpeechBeforeUpload: true, speechActivityDetection: SpeechActivityDetectionConfiguration( windowDuration: 0.02, diff --git a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionRoutingGuard.swift b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionRoutingGuard.swift index 13b97de5bcf..89c0c518374 100644 --- a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionRoutingGuard.swift +++ b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionRoutingGuard.swift @@ -7,13 +7,9 @@ enum BackgroundTranscriptionRoutingDecision: Equatable { struct BackgroundTranscriptionRoutingGuard { func decide( - batchEnabled: Bool, serverAssemblyBackgroundEnabled: Bool, audioSource: AudioSource ) -> BackgroundTranscriptionRoutingDecision { - guard batchEnabled else { - return .cloudListenStreaming(reason: "batch_disabled") - } guard serverAssemblyBackgroundEnabled else { return .cloudListenStreaming(reason: "server_background_batch_disabled") } diff --git a/desktop/Desktop/Sources/BackgroundTranscription/CloudBackgroundTranscriptionSession.swift b/desktop/Desktop/Sources/BackgroundTranscription/CloudBackgroundTranscriptionSession.swift index 7332ffb8229..af7a7d5da99 100644 --- a/desktop/Desktop/Sources/BackgroundTranscription/CloudBackgroundTranscriptionSession.swift +++ b/desktop/Desktop/Sources/BackgroundTranscription/CloudBackgroundTranscriptionSession.swift @@ -31,6 +31,7 @@ final class CloudBackgroundTranscriptionSession { private let speechActivityDetector: SpeechActivityDetector private var chunker: BackgroundAudioChunker private var pendingChunks: [BackgroundAudioChunk] = [] + private var pendingChunkAttempts: [Int] = [] private var processedChunkCount = 0 private var droppedChunkCount = 0 private var isInputFinished = false @@ -87,6 +88,7 @@ final class CloudBackgroundTranscriptionSession { continue } pendingChunks.append(chunk) + pendingChunkAttempts.append(0) enqueuedChunks += 1 } return BackgroundIngestResult( @@ -116,6 +118,7 @@ final class CloudBackgroundTranscriptionSession { continue } pendingChunks.append(chunk) + pendingChunkAttempts.append(0) enqueuedChunks += 1 } isInputFinished = true @@ -130,14 +133,23 @@ final class CloudBackgroundTranscriptionSession { func transcribeNext() async throws -> BackgroundTranscriptionResult? { guard !pendingChunks.isEmpty else { return nil } - let chunk = pendingChunks.removeFirst() + let chunk = pendingChunks[0] do { let segments = try await transcribe(chunk) + pendingChunks.removeFirst() + pendingChunkAttempts.removeFirst() processedChunkCount += 1 processedSegments.append(contentsOf: segments) return BackgroundTranscriptionResult(chunk: chunk, segments: segments) } catch { - droppedChunkCount += 1 + let attempts = (pendingChunkAttempts.first ?? 0) + 1 + if attempts >= configuration.maxChunkTranscriptionAttempts { + pendingChunks.removeFirst() + pendingChunkAttempts.removeFirst() + droppedChunkCount += 1 + } else { + pendingChunkAttempts[0] = attempts + } throw error } } diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index 8c1fdf32f34..fe5128e0e57 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -261,7 +261,6 @@ struct SettingsContentView: View { @State private var transcriptionAutoDetect: Bool = true @State private var transcriptionLanguage: String = "en" @State private var vadGateEnabled: Bool = false - @State private var batchTranscriptionEnabled: Bool = false // Multi-chat mode setting @AppStorage("multiChatEnabled") private var multiChatEnabled = false @@ -426,7 +425,6 @@ struct SettingsContentView: View { initialValue: MemoryAssistantSettings.shared.notificationsEnabled) _memoryExcludedApps = State(initialValue: MemoryAssistantSettings.shared.excludedApps) _vadGateEnabled = State(initialValue: settings.vadGateEnabled) - _batchTranscriptionEnabled = State(initialValue: settings.batchTranscriptionEnabled) _transcriptionLanguage = State(initialValue: settings.transcriptionLanguage) _transcriptionAutoDetect = State(initialValue: settings.transcriptionAutoDetect) } @@ -1239,39 +1237,6 @@ struct SettingsContentView: View { } } - // Cloud batch transcription - settingsCard(settingId: "transcription.batch") { - VStack(alignment: .leading, spacing: 12) { - HStack { - Image(systemName: "waveform.path.ecg.rectangle") - .scaledFont(size: 16) - .foregroundColor(OmiColors.purplePrimary) - - VStack(alignment: .leading, spacing: 4) { - Text("Batch transcription (AssemblyAI)") - .scaledFont(size: 15, weight: .medium) - .foregroundColor(OmiColors.textPrimary) - - Text( - "Transcribe microphone audio in selected-language chunks instead of live streaming. Requires server-side AssemblyAI." - ) - .scaledFont(size: 13) - .foregroundColor(OmiColors.textTertiary) - .fixedSize(horizontal: false, vertical: true) - } - - Spacer() - - Toggle("", isOn: $batchTranscriptionEnabled) - .toggleStyle(.switch) - .onChange(of: batchTranscriptionEnabled) { _, newValue in - AssistantSettings.shared.batchTranscriptionEnabled = newValue - restartTranscriptionIfNeeded() - } - } - } - } - // Local VAD Gate settingsCard(settingId: "transcription.vadgate") { VStack(alignment: .leading, spacing: 12) { @@ -6964,7 +6929,6 @@ struct SettingsContentView: View { transcriptionAutoDetect = AssistantSettings.shared.transcriptionAutoDetect vocabularyList = AssistantSettings.shared.transcriptionVocabulary vadGateEnabled = AssistantSettings.shared.vadGateEnabled - batchTranscriptionEnabled = AssistantSettings.shared.batchTranscriptionEnabled Task { do { diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift b/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift index 87343968d98..cc7a876607f 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift @@ -29,7 +29,7 @@ class AssistantSettings { private let defaultTranscriptionAutoDetect = true private let defaultTranscriptionVocabulary: [String] = [] private let defaultVadGateEnabled = false - private let defaultBatchTranscriptionEnabled = false + private let defaultBatchTranscriptionEnabled = true private init() { // Register defaults @@ -200,7 +200,7 @@ class AssistantSettings { /// Whether cloud batch transcription mode is enabled for microphone background audio. var batchTranscriptionEnabled: Bool { - get { UserDefaults.standard.bool(forKey: batchTranscriptionEnabledKey) } + get { true } set { UserDefaults.standard.set(newValue, forKey: batchTranscriptionEnabledKey) NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index ecab2929f6d..3e20e88c204 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -307,6 +307,14 @@ final class APIClientRoutingTests: XCTestCase { label: "finishBackgroundConversation") } + func testDesktopCapabilitiesRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.getDesktopCapabilities() as DesktopCapabilitiesResponse + assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v2/desktop/capabilities", method: "GET", + label: "getDesktopCapabilities") + } + // -- Conversations: manual URL(string: baseURL + ...) paths (PATCH → Python) -- func testSetConversationStarredRoutesToPython() async { diff --git a/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift index a1c83a2dbf9..09641321aeb 100644 --- a/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift +++ b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift @@ -340,7 +340,7 @@ final class BackgroundTranscriptionTests: XCTestCase { XCTAssertTrue(result?.chunk.isFinal ?? false) } - func testSessionDropsFailedChunkSoDrainCanContinue() async throws { + func testSessionRetainsFailedChunkForRetry() async throws { let configuration = BackgroundTranscriptionConfiguration( sampleRate: 10, maxChunkDuration: 1.0, @@ -350,7 +350,52 @@ final class BackgroundTranscriptionTests: XCTestCase { silenceAmplitudeThreshold: 10, speechPeakAmplitudeThreshold: 100, speechRMSAmplitudeThreshold: 20, - maxPendingChunks: 4 + maxPendingChunks: 4, + maxChunkTranscriptionAttempts: 3 + ) + var shouldFail = true + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + if shouldFail { + shouldFail = false + throw NSError(domain: "test", code: 1) + } + return [ + Self.backendSegment( + id: "retry", text: "retried", start: chunk.startTime, end: chunk.startTime + 1) + ] + } + + XCTAssertEqual( + session.append(pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 0) + .enqueuedChunks, 1) + + do { + _ = try await session.transcribeNext() + XCTFail("Expected first transcription attempt to fail") + } catch { + XCTAssertEqual(session.pendingChunkCount, 1) + XCTAssertEqual(session.snapshot().droppedChunkCount, 0) + } + + let retried = try await session.transcribeNext() + XCTAssertEqual(retried?.chunk.startTime, 0) + XCTAssertEqual(session.pendingChunkCount, 0) + XCTAssertEqual(session.snapshot().processedChunkCount, 1) + XCTAssertEqual(session.snapshot().droppedChunkCount, 0) + } + + func testSessionRetriesFailedChunkBeforeDroppingSoDrainCanContinue() async throws { + let configuration = BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 1.0, + minChunkDuration: 0.5, + overlapDuration: 0, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 4, + maxChunkTranscriptionAttempts: 2 ) let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in if chunk.startTime == 0 { @@ -372,8 +417,17 @@ final class BackgroundTranscriptionTests: XCTestCase { do { _ = try await session.transcribeNext() XCTFail("Expected first transcription attempt to fail") + } catch { + XCTAssertEqual(session.pendingChunkCount, 2) + XCTAssertEqual(session.snapshot().droppedChunkCount, 0) + } + + do { + _ = try await session.transcribeNext() + XCTFail("Expected second transcription attempt to fail and drop the chunk") } catch { XCTAssertEqual(session.pendingChunkCount, 1) + XCTAssertEqual(session.snapshot().droppedChunkCount, 1) } let next = try await session.transcribeNext() @@ -431,22 +485,17 @@ final class BackgroundTranscriptionTests: XCTestCase { XCTAssertEqual( guardrail.decide( - batchEnabled: true, serverAssemblyBackgroundEnabled: true, audioSource: .microphone), + serverAssemblyBackgroundEnabled: true, audioSource: .microphone), .cloudBatchAssembly ) XCTAssertEqual( guardrail.decide( - batchEnabled: true, serverAssemblyBackgroundEnabled: true, audioSource: .bleDevice), + serverAssemblyBackgroundEnabled: true, audioSource: .bleDevice), .cloudListenStreaming(reason: "batch_microphone_only") ) XCTAssertEqual( guardrail.decide( - batchEnabled: false, serverAssemblyBackgroundEnabled: true, audioSource: .microphone), - .cloudListenStreaming(reason: "batch_disabled") - ) - XCTAssertEqual( - guardrail.decide( - batchEnabled: true, serverAssemblyBackgroundEnabled: false, audioSource: .microphone), + serverAssemblyBackgroundEnabled: false, audioSource: .microphone), .cloudListenStreaming(reason: "server_background_batch_disabled") ) } diff --git a/docs/doc/developer/backend/assemblyai_background_rollout.mdx b/docs/doc/developer/backend/assemblyai_background_rollout.mdx index 97f4f5df18c..7f7f55617e4 100644 --- a/docs/doc/developer/backend/assemblyai_background_rollout.mdx +++ b/docs/doc/developer/backend/assemblyai_background_rollout.mdx @@ -8,7 +8,7 @@ description: "Rollout gates, feature flags, instrumentation, and rollback for th ## Scope -AssemblyAI is the MVP async/background prerecorded provider. Deepgram remains +AssemblyAI is the MVP async prerecorded provider. Deepgram remains the provider for mobile/BLE `/v4/listen`, realtime assistant streaming, Hold-to-Talk streaming, and voice-message finalize semantics. @@ -26,13 +26,15 @@ Ineligible latency-sensitive workloads stay on Deepgram: ## Feature Flags And Environment -AssemblyAI is disabled by default. +AssemblyAI is enabled by default for eligible prerecorded workloads when +credentials are configured. | Variable | Default | Purpose | | --- | --- | --- | | `ASSEMBLYAI_API_KEY` | unset | Required before any AssemblyAI request can run. | -| `ASSEMBLYAI_BACKGROUND_STT_ENABLED` | `false` | Main rollout switch. Set to `true` only for canary/smoke cohorts first. | -| `ASSEMBLYAI_BACKGROUND_STT_WORKLOADS` | `sync,background,postprocess` | Comma-separated eligible background workloads. Unknown or ineligible values are ignored. | +| `ASSEMBLYAI_PRERECORDED_STT_ENABLED` | `true` | Main rollout switch for passive prerecorded workloads. | +| `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS` | `sync,background,postprocess` | Comma-separated eligible prerecorded workloads. Unknown or ineligible values are ignored. | +| `ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED` | `true` | Use Deepgram when AssemblyAI is unavailable or fails and Deepgram credentials are usable. | | `ASSEMBLYAI_STT_MODEL` | `universal-2` | AssemblyAI model used when a Deepgram `nova-*` model request is routed to AssemblyAI. | | `ASSEMBLYAI_BASE_URL` | `https://api.assemblyai.com` | Provider API base URL. | | `ASSEMBLYAI_POLL_INTERVAL_SECONDS` | `3` | Poll interval for async transcript completion. | @@ -41,9 +43,9 @@ AssemblyAI is disabled by default. ## Routing And Fallback -`backend/utils/stt/providers.py` owns provider selection. Background workloads +`backend/utils/stt/providers.py` owns provider selection. Eligible workloads use AssemblyAI only when the main flag is enabled and the workload is in -`ASSEMBLYAI_BACKGROUND_STT_WORKLOADS`; otherwise they use Deepgram. +`ASSEMBLYAI_PRERECORDED_STT_WORKLOADS`; otherwise they use Deepgram. Desktop Audio Recording uses the background workload through: @@ -57,12 +59,14 @@ stable numeric `speaker_id` values per conversation, and appends segments to the in-progress conversation when `persist=true`. Deepgram is the prerecorded fallback provider. If AssemblyAI fails, times out, -or exhausts retries for an eligible background workload, the failed AssemblyAI +or exhausts retries for an eligible prerecorded workload, the failed AssemblyAI run is finalized in the provider ledger and the request retries through -Deepgram with fallback metadata. +Deepgram with fallback metadata. If AssemblyAI credentials are missing before +the request starts, provider resolution skips directly to Deepgram when +fallback is enabled and Deepgram credentials are usable. -Rollback is to set `ASSEMBLYAI_BACKGROUND_STT_ENABLED=false` or remove the -affected workload from `ASSEMBLYAI_BACKGROUND_STT_WORKLOADS`. No client deploy +Rollback is to set `ASSEMBLYAI_PRERECORDED_STT_ENABLED=false` or remove the +affected workload from `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS`. No client deploy is required for rollback. ## BYOK (Bring Your Own Keys) diff --git a/docs/doc/developer/backend/listen_pusher_pipeline.mdx b/docs/doc/developer/backend/listen_pusher_pipeline.mdx index 616dbe41146..6133718d500 100644 --- a/docs/doc/developer/backend/listen_pusher_pipeline.mdx +++ b/docs/doc/developer/backend/listen_pusher_pipeline.mdx @@ -6,7 +6,7 @@ description: "Sequence diagrams for the /v4/listen WebSocket and Pusher processi # Listen + Pusher Pipeline — Sequence Diagrams -> Last updated: 2026-05-21 (AssemblyAI background STT MVP) +> Last updated: 2026-05-21 (AssemblyAI prerecorded STT MVP) > > These diagrams document the real behavior observed during E2E testing with > live services (backend, pusher, Deepgram, AssemblyAI, embedding API). Update @@ -230,9 +230,10 @@ sequenceDiagram Realtime listen and Hold-to-Talk streaming remain Deepgram paths. Background prerecorded transcription routes through `utils/stt/provider_service.py`, which -selects Deepgram by default and can select AssemblyAI only when -`ASSEMBLYAI_BACKGROUND_STT_ENABLED=true` and the workload is listed in -`ASSEMBLYAI_BACKGROUND_STT_WORKLOADS`. +selects AssemblyAI by default for eligible prerecorded workloads when +`ASSEMBLYAI_PRERECORDED_STT_ENABLED=true` and the workload is listed in +`ASSEMBLYAI_PRERECORDED_STT_WORKLOADS`. Deepgram remains the fallback when +AssemblyAI is disabled, unavailable, or fails and fallback is enabled. Desktop Audio Recording can use the batch background path instead of `/v4/listen` when the desktop batch setting is enabled and the client is pointed at a backend @@ -276,13 +277,13 @@ sequenceDiagram Deepgram-->>ProviderService: Transcript with words and speaker labels end - alt AssemblyAI fails, times out, or exhausts retries for sync/postprocess + alt AssemblyAI fails, times out, or exhausts retries ProviderService->>Firestore: Finalize failed AssemblyAI provider run ProviderService->>Deepgram: Fallback prerecorded transcript request Deepgram-->>ProviderService: Transcript result end - Note over ProviderService: AssemblyAI background workload fails closed;
it does not fall back to Deepgram during desktop cloud batch. + Note over ProviderService: Missing AssemblyAI credentials skip directly
to Deepgram when fallback credentials are usable. ProviderService->>ProviderService: Normalize provider result and reconstruct canonical segments ProviderService->>EmbeddingAPI: Extract cluster samples for Omi speaker identity diff --git a/scripts/ASSEMBLYAI_BACKGROUND_E2E_AGENT_PROMPT.md b/scripts/ASSEMBLYAI_BACKGROUND_E2E_AGENT_PROMPT.md index 436949f4118..99ee1787553 100644 --- a/scripts/ASSEMBLYAI_BACKGROUND_E2E_AGENT_PROMPT.md +++ b/scripts/ASSEMBLYAI_BACKGROUND_E2E_AGENT_PROMPT.md @@ -21,7 +21,7 @@ Branch work adds **desktop always-on Audio Recording via AssemblyAI batch chunks [`scripts/desktop_assemblyai_e2e.py`](scripts/desktop_assemblyai_e2e.py): ```bash -# Requires: local backend on :8080, ASSEMBLYAI_BACKGROUND_STT_ENABLED=true, Omi Dev signed in +# Requires: local backend on :8080, ASSEMBLYAI_PRERECORDED_STT_ENABLED=true, Omi Dev signed in python3 scripts/desktop_assemblyai_e2e.py --background-chunk --api http://127.0.0.1:8080 ``` @@ -87,7 +87,7 @@ Extend [`scripts/desktop_assemblyai_e2e.py`](scripts/desktop_assemblyai_e2e.py) 6. **Document** in script header how to run locally: ```bash cd backend && DYLD_FALLBACK_LIBRARY_PATH="/opt/homebrew/lib" ./run-local.sh - # .env: ASSEMBLYAI_BACKGROUND_STT_ENABLED=true, ASSEMBLYAI_API_KEY=... + # .env: ASSEMBLYAI_PRERECORDED_STT_ENABLED=true, ASSEMBLYAI_API_KEY=... python3 scripts/desktop_assemblyai_e2e.py --background-batch ``` @@ -152,8 +152,8 @@ You are done when: ```bash # backend/.env (required) -ASSEMBLYAI_BACKGROUND_STT_ENABLED=true -ASSEMBLYAI_BACKGROUND_STT_WORKLOADS=sync,background,postprocess +ASSEMBLYAI_PRERECORDED_STT_ENABLED=true +ASSEMBLYAI_PRERECORDED_STT_WORKLOADS=sync,background,postprocess ASSEMBLYAI_API_KEY= LOCAL_DEVELOPMENT=true ``` diff --git a/scripts/desktop_assemblyai_e2e.py b/scripts/desktop_assemblyai_e2e.py index 2c030763e51..9e7b0c23b3a 100755 --- a/scripts/desktop_assemblyai_e2e.py +++ b/scripts/desktop_assemblyai_e2e.py @@ -11,8 +11,8 @@ cd backend && DYLD_FALLBACK_LIBRARY_PATH="/opt/homebrew/lib" ./run-local.sh # backend/.env: # LOCAL_DEVELOPMENT=true - # ASSEMBLYAI_BACKGROUND_STT_ENABLED=true - # ASSEMBLYAI_BACKGROUND_STT_WORKLOADS=sync,background,postprocess + # ASSEMBLYAI_PRERECORDED_STT_ENABLED=true + # ASSEMBLYAI_PRERECORDED_STT_WORKLOADS=sync,background,postprocess # ASSEMBLYAI_API_KEY=... python3 scripts/desktop_assemblyai_e2e.py [--api http://127.0.0.1:8080] python3 scripts/desktop_assemblyai_e2e.py --background-chunk [--api http://127.0.0.1:8080] @@ -108,7 +108,7 @@ def require_backend_reachable(api_base: str) -> None: f"Backend is not reachable at {api_base}. Start the local backend first, for example:\n" ' cd backend && DYLD_FALLBACK_LIBRARY_PATH="/opt/homebrew/lib" ./run-local.sh\n' "Required backend env includes LOCAL_DEVELOPMENT=true, " - "ASSEMBLYAI_BACKGROUND_STT_ENABLED=true, and ASSEMBLYAI_API_KEY.\n" + "ASSEMBLYAI_PRERECORDED_STT_ENABLED=true, and ASSEMBLYAI_API_KEY.\n" f"Reachability error: {exc}" ) From be5a8f9f437db3e0ef47d6cb2f207524c5f2a84f Mon Sep 17 00:00:00 2001 From: David Zhang Date: Fri, 22 May 2026 04:46:16 +0700 Subject: [PATCH 31/44] Make desktop background batch resilient --- backend/routers/desktop_background.py | 53 +++++-- .../test_desktop_background_transcribe.py | 84 +++++++++++ desktop/Desktop/Sources/APIClient.swift | 14 ++ desktop/Desktop/Sources/AppState.swift | 33 +++-- .../BackgroundTranscriptionRoutingGuard.swift | 24 +++- .../Tests/BackgroundTranscriptionTests.swift | 134 +++++++++++++++++- 6 files changed, 312 insertions(+), 30 deletions(-) diff --git a/backend/routers/desktop_background.py b/backend/routers/desktop_background.py index 7999090517c..91b35ba8be0 100644 --- a/backend/routers/desktop_background.py +++ b/backend/routers/desktop_background.py @@ -33,7 +33,12 @@ update_provider_run_identity_metrics, ) from utils.stt.speaker_embedding import extract_embedding_from_bytes -from utils.stt.providers import STTProviderName, STTWorkload, get_fallback_prerecorded_provider_name +from utils.stt.providers import ( + STTProviderName, + STTWorkload, + get_fallback_prerecorded_provider_name, + get_prerecorded_provider_name, +) from utils.subscription import has_transcription_credits, is_trial_paywalled from utils.voice_duration_limiter import compute_pcm_duration_ms @@ -52,23 +57,49 @@ class BackgroundConversationStartRequest(BaseModel): @router.get("/capabilities") async def desktop_capabilities(uid: str = Depends(auth.get_current_user_uid)): - background_provider = resolve_prerecorded_provider_for_request(STTWorkload.background) - assemblyai_key_available = bool(os.getenv('ASSEMBLYAI_API_KEY') or get_byok_key('assemblyai')) - fallback_provider = get_fallback_prerecorded_provider_name(background_provider, STTWorkload.background) - fallback_available = fallback_provider == STTProviderName.deepgram - enabled = background_provider == STTProviderName.assemblyai and (assemblyai_key_available or fallback_available) + primary_provider = get_prerecorded_provider_name(STTWorkload.background) + effective_provider = resolve_prerecorded_provider_for_request(STTWorkload.background) + fallback_provider = get_fallback_prerecorded_provider_name(primary_provider, STTWorkload.background) + + assemblyai_key_available = bool(get_byok_key('assemblyai') or os.getenv('ASSEMBLYAI_API_KEY')) + deepgram_key_available = bool(get_byok_key('deepgram') or os.getenv('DEEPGRAM_API_KEY')) + fallback_available = fallback_provider == STTProviderName.deepgram and deepgram_key_available + + if effective_provider == STTProviderName.assemblyai and not assemblyai_key_available and fallback_available: + effective_provider = STTProviderName.deepgram + + usable_provider = None + if effective_provider == STTProviderName.assemblyai and assemblyai_key_available: + usable_provider = STTProviderName.assemblyai + elif effective_provider == STTProviderName.deepgram and deepgram_key_available: + usable_provider = STTProviderName.deepgram + + enabled = usable_provider is not None + mode = 'disabled' reason = None - if background_provider != STTProviderName.assemblyai: - reason = f'provider_{background_provider.value}' - elif not assemblyai_key_available and fallback_available: + if usable_provider == STTProviderName.assemblyai: + mode = 'assemblyai_primary' + elif usable_provider == STTProviderName.deepgram and primary_provider == STTProviderName.assemblyai: + mode = 'deepgram_fallback' reason = 'fallback_deepgram_available' - elif not assemblyai_key_available: + elif usable_provider == STTProviderName.deepgram: + mode = 'deepgram_primary' + elif primary_provider == STTProviderName.assemblyai and not assemblyai_key_available: reason = 'missing_assemblyai_api_key' + elif primary_provider == STTProviderName.deepgram and not deepgram_key_available: + reason = 'missing_deepgram_api_key' + else: + reason = 'no_usable_batch_provider' return { "background_batch": { "enabled": enabled, - "provider": background_provider.value, + "mode": mode, + "provider": primary_provider.value, + "primary_provider": primary_provider.value, + "effective_provider": usable_provider.value if usable_provider else None, "fallback_provider": fallback_provider.value if fallback_provider else None, + "fallback_enabled": fallback_provider is not None, + "fallback_available": fallback_available, "workload": STTWorkload.background.value, "reason": reason, "sample_rate": 16000, diff --git a/backend/tests/unit/test_desktop_background_transcribe.py b/backend/tests/unit/test_desktop_background_transcribe.py index cba0cdd5949..208edd75985 100644 --- a/backend/tests/unit/test_desktop_background_transcribe.py +++ b/backend/tests/unit/test_desktop_background_transcribe.py @@ -203,7 +203,12 @@ def test_desktop_capabilities_reports_assemblyai_background_when_key_available(m data = response.json()['background_batch'] assert data['enabled'] is True assert data['provider'] == 'assemblyai' + assert data['primary_provider'] == 'assemblyai' + assert data['effective_provider'] == 'assemblyai' + assert data['mode'] == 'assemblyai_primary' assert data['fallback_provider'] == 'deepgram' + assert data['fallback_enabled'] is True + assert data['fallback_available'] is True assert data['workload'] == 'background' assert data['reason'] is None @@ -226,7 +231,12 @@ def test_desktop_capabilities_allows_background_batch_with_deepgram_fallback(mon data = response.json()['background_batch'] assert data['enabled'] is True assert data['provider'] == 'assemblyai' + assert data['primary_provider'] == 'assemblyai' + assert data['effective_provider'] == 'deepgram' + assert data['mode'] == 'deepgram_fallback' assert data['fallback_provider'] == 'deepgram' + assert data['fallback_enabled'] is True + assert data['fallback_available'] is True assert data['reason'] == 'fallback_deepgram_available' @@ -248,10 +258,84 @@ def test_desktop_capabilities_reports_missing_assemblyai_key_when_fallback_disab data = response.json()['background_batch'] assert data['enabled'] is False assert data['provider'] == 'assemblyai' + assert data['effective_provider'] is None + assert data['mode'] == 'disabled' assert data['fallback_provider'] is None + assert data['fallback_enabled'] is False assert data['reason'] == 'missing_assemblyai_api_key' +def test_desktop_capabilities_reports_no_usable_batch_provider_without_any_key(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.delenv('DEEPGRAM_API_KEY', raising=False) + monkeypatch.setattr( + desktop_background, + 'resolve_prerecorded_provider_for_request', + lambda _workload: STTProviderName.assemblyai, + ) + + response = client.get('/v2/desktop/capabilities') + + assert response.status_code == 200 + data = response.json()['background_batch'] + assert data['enabled'] is False + assert data['primary_provider'] == 'assemblyai' + assert data['effective_provider'] is None + assert data['fallback_provider'] == 'deepgram' + assert data['fallback_available'] is False + assert data['reason'] == 'missing_assemblyai_api_key' + + +def test_desktop_capabilities_uses_byok_assemblyai(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.setattr( + desktop_background, 'get_byok_key', lambda provider: 'aai-user-key' if provider == 'assemblyai' else None + ) + monkeypatch.setattr( + desktop_background, + 'resolve_prerecorded_provider_for_request', + lambda _workload: STTProviderName.assemblyai, + ) + + response = client.get('/v2/desktop/capabilities') + + assert response.status_code == 200 + data = response.json()['background_batch'] + assert data['enabled'] is True + assert data['effective_provider'] == 'assemblyai' + assert data['mode'] == 'assemblyai_primary' + + +def test_desktop_capabilities_uses_byok_deepgram_fallback(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.delenv('DEEPGRAM_API_KEY', raising=False) + monkeypatch.setattr( + desktop_background, 'get_byok_key', lambda provider: 'dg-user-key' if provider == 'deepgram' else None + ) + monkeypatch.setattr( + desktop_background, + 'resolve_prerecorded_provider_for_request', + lambda _workload: STTProviderName.deepgram, + ) + + response = client.get('/v2/desktop/capabilities') + + assert response.status_code == 200 + data = response.json()['background_batch'] + assert data['enabled'] is True + assert data['effective_provider'] == 'deepgram' + assert data['mode'] == 'deepgram_fallback' + + def test_background_transcribe_returns_segments_with_offset(monkeypatch): client, _mock_transcribe = _client(monkeypatch) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 13f0af97485..997efbac8f6 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -1351,7 +1351,14 @@ struct DesktopCapabilitiesResponse: Codable { struct DesktopBackgroundBatchCapability: Codable { let enabled: Bool + let mode: String? let provider: String + let primaryProvider: String? + let effectiveProvider: String? + let fallbackProvider: String? + let fallbackEnabled: Bool? + let fallbackAvailable: Bool? + let reason: String? let sampleRate: Int let channels: Int let encoding: String @@ -1359,7 +1366,14 @@ struct DesktopBackgroundBatchCapability: Codable { enum CodingKeys: String, CodingKey { case enabled + case mode case provider + case primaryProvider = "primary_provider" + case effectiveProvider = "effective_provider" + case fallbackProvider = "fallback_provider" + case fallbackEnabled = "fallback_enabled" + case fallbackAvailable = "fallback_available" + case reason case sampleRate = "sample_rate" case channels case encoding diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index b943017e9b4..dd4932647c3 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -337,6 +337,7 @@ class AppState: ObservableObject { private var cloudBackgroundStartTask: Task? private var cloudBackgroundDrainTask: Task? private var cloudBackgroundSampleCursor = 0 + private var forceNextTranscriptionStartStreaming = false private var isCloudBackgroundTranscription = false private var isCloudBackgroundBackpressured = false private var didLogCloudBackgroundBackpressure = false @@ -1431,11 +1432,9 @@ class AppState: ObservableObject { "Transcription: Using language=\(effectiveLanguage) (autoDetect=\(AssistantSettings.shared.transcriptionAutoDetect), selected=\(AssistantSettings.shared.transcriptionLanguage))" ) - let routing = BackgroundTranscriptionRoutingGuard().decide( - serverAssemblyBackgroundEnabled: true, - audioSource: effectiveSource - ) - if routing == .cloudBatchAssembly { + let shouldAttemptCloudBatch = effectiveSource == .microphone && !forceNextTranscriptionStartStreaming + forceNextTranscriptionStartStreaming = false + if shouldAttemptCloudBatch { let batchLanguage = AssistantSettings.shared.effectiveBatchTranscriptionLanguage log( "Transcription: Cloud background batch using language=\(batchLanguage) (autoDetect=\(AssistantSettings.shared.transcriptionAutoDetect), selected=\(AssistantSettings.shared.transcriptionLanguage))" @@ -1594,15 +1593,17 @@ class AppState: ObservableObject { defer { isStartingTranscription = false } do { let capabilities = try await APIClient.shared.getDesktopCapabilities() - guard capabilities.backgroundBatch.enabled - && capabilities.backgroundBatch.provider.lowercased() == "assemblyai" - else { + let routing = BackgroundTranscriptionRoutingGuard().decide( + backgroundBatchCapability: capabilities.backgroundBatch, + audioSource: source + ) + guard routing == .cloudBatchAssembly else { throw NSError( domain: "Omi.CloudBackgroundTranscription", code: 1, userInfo: [ NSLocalizedDescriptionKey: - "Server background batch transcription is not enabled for AssemblyAI." + "Server background batch transcription is not available." ] ) } @@ -1694,9 +1695,19 @@ class AppState: ObservableObject { didLogCloudBackgroundBackpressure = false isTranscribing = false AssistantSettings.shared.transcriptionEnabled = false - AnalyticsManager.shared.recordingError(error: error.localizedDescription) logError("Transcription: Cloud background batch failed to start", error: error) - showAlert(title: "Audio Recording Not Started", message: cloudBackgroundStartFailureMessage(for: error)) + if BackgroundTranscriptionRoutingGuard().shouldFallbackToStreamingAfterBatchStartupFailure( + audioSource: source, + captureStarted: false + ) { + isStartingTranscription = false + forceNextTranscriptionStartStreaming = true + startTranscription(source: source) + log("Transcription: Fell back to streaming after cloud background batch startup failed") + } else { + AnalyticsManager.shared.recordingError(error: error.localizedDescription) + showAlert(title: "Audio Recording Not Started", message: cloudBackgroundStartFailureMessage(for: error)) + } } } diff --git a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionRoutingGuard.swift b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionRoutingGuard.swift index 89c0c518374..6d5e114f68a 100644 --- a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionRoutingGuard.swift +++ b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionRoutingGuard.swift @@ -7,15 +7,31 @@ enum BackgroundTranscriptionRoutingDecision: Equatable { struct BackgroundTranscriptionRoutingGuard { func decide( - serverAssemblyBackgroundEnabled: Bool, + backgroundBatchCapability: DesktopBackgroundBatchCapability?, audioSource: AudioSource ) -> BackgroundTranscriptionRoutingDecision { - guard serverAssemblyBackgroundEnabled else { - return .cloudListenStreaming(reason: "server_background_batch_disabled") - } guard audioSource == .microphone else { return .cloudListenStreaming(reason: "batch_microphone_only") } + guard let capability = backgroundBatchCapability else { + return .cloudListenStreaming(reason: "server_background_batch_capability_unavailable") + } + guard capability.enabled else { + return .cloudListenStreaming(reason: capability.reason ?? "server_background_batch_disabled") + } + guard let effectiveProvider = capability.effectiveProvider?.lowercased() else { + return .cloudListenStreaming(reason: capability.reason ?? "server_background_batch_provider_unavailable") + } + guard effectiveProvider == "assemblyai" || effectiveProvider == "deepgram" else { + return .cloudListenStreaming(reason: capability.reason ?? "server_background_batch_provider_unsupported") + } return .cloudBatchAssembly } + + func shouldFallbackToStreamingAfterBatchStartupFailure( + audioSource: AudioSource, + captureStarted: Bool + ) -> Bool { + audioSource == .microphone && !captureStarted + } } diff --git a/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift index 09641321aeb..147b5735dac 100644 --- a/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift +++ b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift @@ -480,26 +480,152 @@ final class BackgroundTranscriptionTests: XCTestCase { XCTAssertEqual(reducer.segments[0].translations.first?.text, "hola") } - func testRoutingGuardUsesCloudBatchForMicrophoneWhenServerEnabled() { + func testRoutingGuardUsesCloudBatchForMicrophoneWhenCapabilityUsable() { let guardrail = BackgroundTranscriptionRoutingGuard() XCTAssertEqual( guardrail.decide( - serverAssemblyBackgroundEnabled: true, audioSource: .microphone), + backgroundBatchCapability: Self.backgroundBatchCapability( + enabled: true, effectiveProvider: "assemblyai"), + audioSource: .microphone), .cloudBatchAssembly ) XCTAssertEqual( guardrail.decide( - serverAssemblyBackgroundEnabled: true, audioSource: .bleDevice), + backgroundBatchCapability: Self.backgroundBatchCapability( + enabled: true, effectiveProvider: "assemblyai"), + audioSource: .bleDevice), .cloudListenStreaming(reason: "batch_microphone_only") ) XCTAssertEqual( guardrail.decide( - serverAssemblyBackgroundEnabled: false, audioSource: .microphone), + backgroundBatchCapability: Self.backgroundBatchCapability( + enabled: false, effectiveProvider: nil, reason: "missing_assemblyai_api_key"), + audioSource: .microphone), + .cloudListenStreaming(reason: "missing_assemblyai_api_key") + ) + XCTAssertEqual( + guardrail.decide( + backgroundBatchCapability: nil, + audioSource: .microphone), + .cloudListenStreaming(reason: "server_background_batch_capability_unavailable") + ) + XCTAssertEqual( + guardrail.decide( + backgroundBatchCapability: Self.backgroundBatchCapability( + enabled: true, effectiveProvider: nil, reason: "no_usable_batch_provider"), + audioSource: .microphone), + .cloudListenStreaming(reason: "no_usable_batch_provider") + ) + XCTAssertEqual( + guardrail.decide( + backgroundBatchCapability: Self.backgroundBatchCapability( + enabled: true, effectiveProvider: "unknown-provider"), + audioSource: .microphone), + .cloudListenStreaming(reason: "server_background_batch_provider_unsupported") + ) + XCTAssertTrue( + guardrail.shouldFallbackToStreamingAfterBatchStartupFailure( + audioSource: .microphone, + captureStarted: false + ) + ) + XCTAssertFalse( + guardrail.shouldFallbackToStreamingAfterBatchStartupFailure( + audioSource: .microphone, + captureStarted: true + ) + ) + XCTAssertFalse( + guardrail.shouldFallbackToStreamingAfterBatchStartupFailure( + audioSource: .bleDevice, + captureStarted: false + ) + ) + } + + func testDesktopCapabilitiesDecodeEffectiveProviderFields() throws { + let json = Data( + """ + { + "background_batch": { + "enabled": true, + "mode": "deepgram_fallback", + "provider": "assemblyai", + "primary_provider": "assemblyai", + "effective_provider": "deepgram", + "fallback_provider": "deepgram", + "fallback_enabled": true, + "fallback_available": true, + "workload": "background", + "reason": "fallback_deepgram_available", + "sample_rate": 16000, + "channels": 1, + "encoding": "linear16", + "max_chunk_seconds": 15 + } + } + """.utf8) + + let decoded = try JSONDecoder().decode(DesktopCapabilitiesResponse.self, from: json) + + XCTAssertTrue(decoded.backgroundBatch.enabled) + XCTAssertEqual(decoded.backgroundBatch.mode, "deepgram_fallback") + XCTAssertEqual(decoded.backgroundBatch.primaryProvider, "assemblyai") + XCTAssertEqual(decoded.backgroundBatch.effectiveProvider, "deepgram") + XCTAssertEqual(decoded.backgroundBatch.fallbackProvider, "deepgram") + XCTAssertEqual(decoded.backgroundBatch.reason, "fallback_deepgram_available") + } + + func testOlderDesktopCapabilitiesDecodeWithoutEffectiveProviderFields() throws { + let json = Data( + """ + { + "background_batch": { + "enabled": false, + "provider": "assemblyai", + "sample_rate": 16000, + "channels": 1, + "encoding": "linear16", + "max_chunk_seconds": 15 + } + } + """.utf8) + + let decoded = try JSONDecoder().decode(DesktopCapabilitiesResponse.self, from: json) + + XCTAssertFalse(decoded.backgroundBatch.enabled) + XCTAssertNil(decoded.backgroundBatch.effectiveProvider) + XCTAssertEqual( + BackgroundTranscriptionRoutingGuard().decide( + backgroundBatchCapability: decoded.backgroundBatch, + audioSource: .microphone), .cloudListenStreaming(reason: "server_background_batch_disabled") ) } + private static func backgroundBatchCapability( + enabled: Bool, + effectiveProvider: String?, + reason: String? = nil + ) -> DesktopBackgroundBatchCapability { + DesktopBackgroundBatchCapability( + enabled: enabled, + mode: enabled ? "assemblyai_primary" : "disabled", + provider: "assemblyai", + primaryProvider: "assemblyai", + effectiveProvider: effectiveProvider, + fallbackProvider: "deepgram", + fallbackEnabled: true, + fallbackAvailable: true, + reason: reason, + sampleRate: 16000, + channels: 1, + encoding: "linear16", + maxChunkSeconds: 15 + ) + } + private static func backendSegment( id: String?, text: String, From 4efe0e6e800b94364238b574e8babf1c31d58f5f Mon Sep 17 00:00:00 2001 From: David Zhang Date: Fri, 22 May 2026 04:55:20 +0700 Subject: [PATCH 32/44] Make desktop background chunks idempotent --- backend/database/conversations.py | 30 +++++ backend/models/conversation.py | 1 + backend/routers/desktop_background.py | 108 ++++++++++++++++-- .../test_desktop_background_transcribe.py | 100 +++++++++++++--- .../utils/conversations/desktop_background.py | 108 +++++++++++++++++- .../Sources/TranscriptionService.swift | 20 ++++ .../Tests/BackgroundTranscriptionTests.swift | 23 ++++ 7 files changed, 365 insertions(+), 25 deletions(-) diff --git a/backend/database/conversations.py b/backend/database/conversations.py index 3a58052a58e..47996c5b3e5 100644 --- a/backend/database/conversations.py +++ b/backend/database/conversations.py @@ -874,6 +874,36 @@ def update_conversation_segments( return +def update_conversation_segments_and_background_chunks( + uid: str, + conversation_id: str, + segments: List[dict], + background_processed_chunks: Dict[str, dict], + finished_at: datetime = None, + data_protection_level: str = None, +): + doc_ref = db.collection('users').document(uid).collection(conversations_collection).document(conversation_id) + if data_protection_level is not None: + doc_level = data_protection_level + else: + doc_snapshot = doc_ref.get(field_paths=['data_protection_level']) + if not doc_snapshot.exists: + return + doc_level = doc_snapshot.to_dict().get('data_protection_level', 'standard') + update_payload = { + 'transcript_segments': segments, + 'background_processed_chunks': background_processed_chunks, + } + if finished_at: + update_payload['finished_at'] = finished_at + prepared_payload = _prepare_conversation_for_write(update_payload, uid, doc_level) + try: + doc_ref.update(prepared_payload) + except NotFound: + # Document was deleted between cache read and write — safe to skip + return + + # *********************************** # ********** VISIBILITY ************* # *********************************** diff --git a/backend/models/conversation.py b/backend/models/conversation.py index 7bcce5517f3..214107946e8 100644 --- a/backend/models/conversation.py +++ b/backend/models/conversation.py @@ -92,6 +92,7 @@ class Conversation(BaseModel): plugins_results: List[PluginResult] = [] external_data: Optional[Dict] = None + background_processed_chunks: Dict[str, Dict] = Field(default_factory=dict) app_id: Optional[str] = None discarded: bool = False diff --git a/backend/routers/desktop_background.py b/backend/routers/desktop_background.py index 91b35ba8be0..a11994a5962 100644 --- a/backend/routers/desktop_background.py +++ b/backend/routers/desktop_background.py @@ -1,3 +1,4 @@ +import hashlib import json import logging import os @@ -17,9 +18,10 @@ from utils.chat import resolve_voice_message_language from utils.conversations.desktop_background import ( DesktopBackgroundConversationError, - append_segments_to_in_progress_conversation, + append_background_chunk_to_in_progress_conversation, create_in_progress_desktop_conversation, finish_desktop_background_conversation, + get_background_chunk_record, ) from utils.executors import db_executor, run_blocking, sync_executor from utils.fair_use import is_hard_restricted, record_speech_ms @@ -209,15 +211,59 @@ async def background_transcribe( conversation_id = request.query_params.get("conversation_id") persist = _parse_persist(request.query_params.get("persist"), default=bool(conversation_id)) + chunk_id = request.query_params.get("chunk_id") + encoding = request.query_params.get("encoding", "linear16") if persist: if not conversation_id: del audio_bytes raise HTTPException(status_code=400, detail='conversation_id is required when persist=true') + if not chunk_id: + del audio_bytes + raise HTTPException(status_code=400, detail='chunk_id is required when persist=true') + if not _is_valid_chunk_id(chunk_id): + del audio_bytes + raise HTTPException(status_code=422, detail='chunk_id is invalid') await _validate_in_progress_conversation(uid, conversation_id) + payload_hash = _background_chunk_payload_hash(audio_bytes, sample_rate, channels, encoding, chunk_start_ms) + if persist and conversation_id and chunk_id: + try: + existing_chunk = await run_blocking( + db_executor, get_background_chunk_record, uid, conversation_id, chunk_id + ) + except DesktopBackgroundConversationError as e: + del audio_bytes + raise HTTPException(status_code=e.status_code, detail=str(e)) + if existing_chunk: + del audio_bytes + if existing_chunk.get('payload_hash') != payload_hash: + logger.warning( + "desktop_background_transcribe chunk payload conflict uid=%s conversation_id=%s chunk_id=%s", + uid, + conversation_id, + chunk_id, + ) + raise HTTPException(status_code=409, detail='chunk_id payload mismatch') + logger.info( + "desktop_background_transcribe duplicate chunk uid=%s conversation_id=%s chunk_id=%s provider=%s", + uid, + conversation_id, + chunk_id, + existing_chunk.get('provider'), + ) + return { + "segments": [], + "language": None, + "provider": existing_chunk.get('provider'), + "run_id": existing_chunk.get('run_id'), + "chunk_duration_ms": existing_chunk.get('chunk_duration_ms'), + "chunk_id": chunk_id, + "duplicate": True, + "speaker_diagnostics": _speaker_diagnostics([]), + } + language = resolve_voice_message_language(uid, request.query_params.get("language")) keywords = _parse_context_keywords(request.query_params.get("keywords")) - encoding = request.query_params.get("encoding", "linear16") duration_ms = compute_pcm_duration_ms(len(audio_bytes), sample_rate, channels) duration_sec = duration_ms / 1000.0 @@ -264,25 +310,37 @@ async def background_transcribe( finished_at = datetime.now(timezone.utc) if persist and conversation_id: - await run_blocking( - db_executor, - append_segments_to_in_progress_conversation, - uid, - conversation_id, - segments, - finished_at, - ) + try: + append_result = await run_blocking( + db_executor, + append_background_chunk_to_in_progress_conversation, + uid, + conversation_id, + chunk_id, + payload_hash, + segments, + finished_at, + response.result.provider if response.result else None, + response.run_id, + chunk_start_ms, + duration_ms, + ) + except DesktopBackgroundConversationError as e: + raise HTTPException(status_code=e.status_code, detail=str(e)) + if append_result.duplicate: + segments = append_result.segments await run_blocking(db_executor, record_speech_ms, uid, duration_ms, source='background') await run_blocking(db_executor, record_usage, uid, transcription_seconds=duration_sec, speech_seconds=duration_sec) provider = response.result.provider if response.result else None logger.info( "desktop_background_transcribe completed uid=%s conversation_id=%s workload=background provider=%s run_id=%s " - "chunk_start_ms=%s chunk_duration_ms=%s segments=%s persisted=%s", + "chunk_id=%s chunk_start_ms=%s chunk_duration_ms=%s segments=%s persisted=%s", uid, conversation_id, provider, response.run_id, + chunk_id, chunk_start_ms, duration_ms, len(segments), @@ -295,6 +353,8 @@ async def background_transcribe( "provider": provider, "run_id": response.run_id, "chunk_duration_ms": duration_ms, + "chunk_id": chunk_id, + "duplicate": False, "speaker_diagnostics": speaker_diagnostics, } @@ -324,6 +384,32 @@ def _parse_context_keywords(raw: Optional[str]) -> List[str]: return keywords +def _is_valid_chunk_id(chunk_id: str) -> bool: + if len(chunk_id) < 8 or len(chunk_id) > 160: + return False + return all(char.isalnum() or char in ('-', '_') for char in chunk_id) + + +def _background_chunk_payload_hash( + audio_bytes: bytes, + sample_rate: int, + channels: int, + encoding: str, + chunk_start_ms: int, +) -> str: + hasher = hashlib.sha256() + hasher.update(str(sample_rate).encode('ascii')) + hasher.update(b':') + hasher.update(str(channels).encode('ascii')) + hasher.update(b':') + hasher.update(encoding.encode('utf-8')) + hasher.update(b':') + hasher.update(str(chunk_start_ms).encode('ascii')) + hasher.update(b':') + hasher.update(audio_bytes) + return hasher.hexdigest() + + async def _validate_in_progress_conversation(uid: str, conversation_id: str) -> None: conversation = await run_blocking(db_executor, conversations_db.get_conversation, uid, conversation_id) if not conversation: diff --git a/backend/tests/unit/test_desktop_background_transcribe.py b/backend/tests/unit/test_desktop_background_transcribe.py index 208edd75985..85261003af2 100644 --- a/backend/tests/unit/test_desktop_background_transcribe.py +++ b/backend/tests/unit/test_desktop_background_transcribe.py @@ -42,9 +42,11 @@ def __init__(self, message: str, status_code: int = 400): 'utils.conversations.desktop_background', SimpleNamespace( DesktopBackgroundConversationError=_DesktopBackgroundConversationError, + append_background_chunk_to_in_progress_conversation=MagicMock(), append_segments_to_in_progress_conversation=MagicMock(), create_in_progress_desktop_conversation=MagicMock(return_value='conv-1'), finish_desktop_background_conversation=MagicMock(return_value={'id': 'conv-1', 'status': 'completed'}), + get_background_chunk_record=MagicMock(return_value=None), ), ) sys.modules.setdefault( @@ -140,7 +142,12 @@ def _client(monkeypatch, *, segments=None, person_embeddings_cache=None): 'get_conversation', lambda _uid, _cid: {'id': _cid, 'status': 'in_progress'}, ) - monkeypatch.setattr(desktop_background, 'append_segments_to_in_progress_conversation', MagicMock(return_value=[])) + monkeypatch.setattr(desktop_background, 'get_background_chunk_record', MagicMock(return_value=None)) + monkeypatch.setattr( + desktop_background, + 'append_background_chunk_to_in_progress_conversation', + MagicMock(return_value=SimpleNamespace(appended=True, duplicate=False, segments=[], chunk_record=None)), + ) monkeypatch.setattr( desktop_background, 'finish_desktop_background_conversation', @@ -340,7 +347,7 @@ def test_background_transcribe_returns_segments_with_offset(monkeypatch): client, _mock_transcribe = _client(monkeypatch) response = client.post( - '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=12000', + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=12000', content=b'\x01\x00' * 1600, headers={'Content-Type': 'application/octet-stream'}, ) @@ -381,7 +388,7 @@ def test_background_transcribe_wraps_linear16_pcm_as_wav(monkeypatch): client, mock_transcribe = _client(monkeypatch) response = client.post( - '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=0&sample_rate=16000', + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0&sample_rate=16000', content=b'\x01\x00' * 1600, headers={'Content-Type': 'application/octet-stream'}, ) @@ -396,14 +403,80 @@ def test_background_transcribe_wraps_linear16_pcm_as_wav(monkeypatch): def test_background_transcribe_persists_segments(monkeypatch): client, _mock_transcribe = _client(monkeypatch) + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + desktop_background.append_background_chunk_to_in_progress_conversation.assert_called_once() + + +def test_background_transcribe_requires_chunk_id_when_persisting(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + response = client.post( '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=0', content=b'\x01\x00' * 1600, headers={'Content-Type': 'application/octet-stream'}, ) + assert response.status_code == 400 + assert response.json()['detail'] == 'chunk_id is required when persist=true' + + +def test_background_transcribe_duplicate_chunk_returns_success_without_appending(monkeypatch): + client, mock_transcribe = _client(monkeypatch) + payload = b'\x01\x00' * 1600 + payload_hash = desktop_background._background_chunk_payload_hash(payload, 16000, 1, 'linear16', 0) + monkeypatch.setattr( + desktop_background, + 'get_background_chunk_record', + MagicMock( + return_value={ + 'payload_hash': payload_hash, + 'provider': 'assemblyai', + 'run_id': 'run-previous', + 'chunk_duration_ms': 100, + } + ), + ) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', + content=payload, + headers={'Content-Type': 'application/octet-stream'}, + ) + assert response.status_code == 200 - desktop_background.append_segments_to_in_progress_conversation.assert_called_once() + data = response.json() + assert data['duplicate'] is True + assert data['segments'] == [] + assert data['provider'] == 'assemblyai' + assert data['run_id'] == 'run-previous' + mock_transcribe.assert_not_called() + desktop_background.append_background_chunk_to_in_progress_conversation.assert_not_called() + + +def test_background_transcribe_rejects_conflicting_duplicate_chunk(monkeypatch): + client, mock_transcribe = _client(monkeypatch) + monkeypatch.setattr( + desktop_background, + 'get_background_chunk_record', + MagicMock(return_value={'payload_hash': 'different', 'provider': 'assemblyai'}), + ) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 409 + assert response.json()['detail'] == 'chunk_id payload mismatch' + mock_transcribe.assert_not_called() + desktop_background.append_background_chunk_to_in_progress_conversation.assert_not_called() def test_background_transcribe_can_skip_persist_without_conversation(monkeypatch): @@ -420,7 +493,7 @@ def test_background_transcribe_can_skip_persist_without_conversation(monkeypatch ) assert response.status_code == 200 - desktop_background.append_segments_to_in_progress_conversation.assert_not_called() + desktop_background.append_background_chunk_to_in_progress_conversation.assert_not_called() record_speech_ms.assert_called_once() record_usage.assert_called_once() @@ -433,7 +506,7 @@ def test_cluster_speaker_mapping_assigns_distinct_ids(monkeypatch): client, _mock_transcribe = _client(monkeypatch, segments=segments) response = client.post( - '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=0', + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', content=b'\x01\x00' * 1600, headers={'Content-Type': 'application/octet-stream'}, ) @@ -477,7 +550,7 @@ def _transcribe_bytes(_audio_bytes, **_kwargs): for chunk_start_ms in (0, 14000, 28000): response = client.post( - f'/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms={chunk_start_ms}', + f'/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-id-{chunk_start_ms}&chunk_start_ms={chunk_start_ms}', content=b'\x01\x00' * 1600, headers={'Content-Type': 'application/octet-stream'}, ) @@ -485,7 +558,8 @@ def _transcribe_bytes(_audio_bytes, **_kwargs): assert response.json()['provider'] == 'assemblyai' appended_segments = [ - call.args[2][0] for call in desktop_background.append_segments_to_in_progress_conversation.call_args_list + call.args[4][0] + for call in desktop_background.append_background_chunk_to_in_progress_conversation.call_args_list ] assert [segment.start for segment in appended_segments] == [0.1, 14.1, 28.1] assert [segment.end for segment in appended_segments] == [1.0, 15.0, 29.0] @@ -516,7 +590,7 @@ def test_background_transcribe_identifies_assemblyai_speaker_with_omi_user_embed monkeypatch.setattr(desktop_background, 'extract_embedding_from_bytes', lambda _audio, _filename: user_embedding) response = client.post( - '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=12000', + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=12000', content=b'\x01\x00' * 16000 * 4, headers={'Content-Type': 'application/octet-stream'}, ) @@ -551,7 +625,7 @@ def test_background_transcribe_rejects_empty_body(monkeypatch): client, _mock_transcribe = _client(monkeypatch) response = client.post( - '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=0', + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', content=b'', headers={'Content-Type': 'application/octet-stream'}, ) @@ -563,7 +637,7 @@ def test_background_transcribe_rejects_stereo_pcm_for_v1(monkeypatch): client, _mock_transcribe = _client(monkeypatch) response = client.post( - '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=0&channels=2', + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0&channels=2', content=b'\x01\x00' * 1600, headers={'Content-Type': 'application/octet-stream'}, ) @@ -576,7 +650,7 @@ def test_background_transcribe_rejects_malformed_content_length(monkeypatch): client, _mock_transcribe = _client(monkeypatch) response = client.post( - '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=0', + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', content=b'\x01\x00' * 1600, headers={'Content-Type': 'application/octet-stream', 'Content-Length': 'not-an-int'}, ) @@ -589,7 +663,7 @@ def test_background_transcribe_rejects_invalid_conversation(monkeypatch): monkeypatch.setattr(desktop_background.conversations_db, 'get_conversation', lambda _uid, _cid: None) response = client.post( - '/v2/desktop/background-transcribe?conversation_id=missing&chunk_start_ms=0', + '/v2/desktop/background-transcribe?conversation_id=missing&chunk_id=chunk-001&chunk_start_ms=0', content=b'\x01\x00' * 1600, headers={'Content-Type': 'application/octet-stream'}, ) diff --git a/backend/utils/conversations/desktop_background.py b/backend/utils/conversations/desktop_background.py index 9d82f5132ae..9f9b8bb9f0e 100644 --- a/backend/utils/conversations/desktop_background.py +++ b/backend/utils/conversations/desktop_background.py @@ -1,7 +1,8 @@ import logging import uuid +from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import List, Optional +from typing import Dict, List, Optional import database.calendar_meetings as calendar_db import database.conversations as conversations_db @@ -14,6 +15,7 @@ from utils.conversations.process_conversation import process_conversation logger = logging.getLogger(__name__) +_MAX_BACKGROUND_CHUNK_RECORDS = 1000 class DesktopBackgroundConversationError(ValueError): @@ -22,6 +24,14 @@ def __init__(self, message: str, status_code: int = 400): self.status_code = status_code +@dataclass +class DesktopBackgroundAppendResult: + appended: bool + duplicate: bool + segments: List[TranscriptSegment] + chunk_record: Optional[Dict] + + def create_in_progress_desktop_conversation( uid: str, language: str, @@ -96,6 +106,102 @@ def append_segments_to_in_progress_conversation( return updated_segments +def append_background_chunk_to_in_progress_conversation( + uid: str, + conversation_id: str, + chunk_id: str, + payload_hash: str, + segments: List[TranscriptSegment], + finished_at: datetime, + provider: Optional[str], + run_id: Optional[str], + chunk_start_ms: int, + chunk_duration_ms: int, +) -> DesktopBackgroundAppendResult: + """Append one desktop background chunk once, keyed by stable client chunk_id.""" + conversation_data = conversations_db.get_conversation(uid, conversation_id) + if not conversation_data: + raise DesktopBackgroundConversationError("conversation_id not found", status_code=404) + + conversation = deserialize_conversation(conversation_data) + if conversation.status != ConversationStatus.in_progress: + raise DesktopBackgroundConversationError("conversation is not in_progress", status_code=409) + + processed_chunks = dict(conversation.background_processed_chunks or {}) + existing_record = processed_chunks.get(chunk_id) + if existing_record: + if existing_record.get('payload_hash') != payload_hash: + raise DesktopBackgroundConversationError("chunk_id payload mismatch", status_code=409) + conversations_db.update_conversation_finished_at(uid, conversation_id, finished_at) + logger.info( + "Duplicate desktop background chunk ignored uid=%s conversation_id=%s chunk_id=%s provider=%s", + uid, + conversation_id, + chunk_id, + existing_record.get('provider'), + ) + return DesktopBackgroundAppendResult( + appended=False, + duplicate=True, + segments=[], + chunk_record=existing_record, + ) + + if segments: + conversation.transcript_segments, updated_segments, _removed_ids = TranscriptSegment.combine_segments( + conversation.transcript_segments, + segments, + ) + else: + updated_segments = [] + + processed_chunks[chunk_id] = { + 'chunk_id': chunk_id, + 'payload_hash': payload_hash, + 'provider': provider, + 'run_id': run_id, + 'segment_count': len(segments), + 'chunk_start_ms': chunk_start_ms, + 'chunk_duration_ms': chunk_duration_ms, + 'accepted_at': finished_at.isoformat(), + } + processed_chunks = _prune_background_chunk_records(processed_chunks) + + conversations_db.update_conversation_segments_and_background_chunks( + uid, + conversation.id, + [segment.dict() for segment in conversation.transcript_segments], + processed_chunks, + finished_at=finished_at, + ) + return DesktopBackgroundAppendResult( + appended=True, + duplicate=False, + segments=updated_segments, + chunk_record=processed_chunks.get(chunk_id), + ) + + +def get_background_chunk_record(uid: str, conversation_id: str, chunk_id: str) -> Optional[Dict]: + conversation_data = conversations_db.get_conversation(uid, conversation_id) + if not conversation_data: + raise DesktopBackgroundConversationError("conversation_id not found", status_code=404) + if conversation_data.get('status') != ConversationStatus.in_progress: + raise DesktopBackgroundConversationError("conversation is not in_progress", status_code=409) + return (conversation_data.get('background_processed_chunks') or {}).get(chunk_id) + + +def _prune_background_chunk_records(processed_chunks: Dict[str, Dict]) -> Dict[str, Dict]: + if len(processed_chunks) <= _MAX_BACKGROUND_CHUNK_RECORDS: + return processed_chunks + + ordered = sorted( + processed_chunks.items(), + key=lambda item: (item[1].get('accepted_at') or '', item[0]), + ) + return dict(ordered[-_MAX_BACKGROUND_CHUNK_RECORDS:]) + + def finish_desktop_background_conversation(uid: str, conversation_id: str) -> Conversation: """Finalize one explicit desktop background conversation by ID.""" conversation_data = conversations_db.get_conversation(uid, conversation_id) diff --git a/desktop/Desktop/Sources/TranscriptionService.swift b/desktop/Desktop/Sources/TranscriptionService.swift index a739918855d..cdfd728f7ff 100644 --- a/desktop/Desktop/Sources/TranscriptionService.swift +++ b/desktop/Desktop/Sources/TranscriptionService.swift @@ -1,4 +1,5 @@ import Foundation +import CryptoKit /// Service for real-time speech-to-text transcription. /// Conversation capture: Python backend `/v4/listen` WebSocket (speech profiles, speaker assignment, memory events). @@ -666,6 +667,14 @@ extension TranscriptionService { let sanitizedKeywords = sanitizedContextKeywords(contextKeywords, includeDefaultOmi: false) var queryItems = [ URLQueryItem(name: "conversation_id", value: conversationId), + URLQueryItem( + name: "chunk_id", + value: backgroundChunkId( + conversationId: conversationId, + chunkStartMs: chunkStartMs, + audioData: audioData + ) + ), URLQueryItem(name: "chunk_start_ms", value: String(chunkStartMs)), URLQueryItem(name: "language", value: language), URLQueryItem(name: "sample_rate", value: "16000"), @@ -739,6 +748,17 @@ extension TranscriptionService { throw lastError ?? TranscriptionError.invalidResponse } + static func backgroundChunkId( + conversationId: String, + chunkStartMs: Int, + audioData: Data + ) -> String { + let digest = SHA256.hash(data: audioData) + let hashPrefix = digest.prefix(12).map { String(format: "%02x", $0) }.joined() + return "\(conversationId)-\(chunkStartMs)-\(audioData.count)-\(hashPrefix)" + .replacingOccurrences(of: #"[^A-Za-z0-9_-]"#, with: "_", options: .regularExpression) + } + } /// Response model for Python backend `/v2/voice-message/transcribe` (batch PTT) diff --git a/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift index 147b5735dac..95ad0d9a59e 100644 --- a/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift +++ b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift @@ -437,6 +437,29 @@ final class BackgroundTranscriptionTests: XCTestCase { XCTAssertEqual(session.snapshot().droppedChunkCount, 1) } + func testBackgroundChunkIdIsStableAndPayloadSensitive() { + let first = TranscriptionService.backgroundChunkId( + conversationId: "conv.123", + chunkStartMs: 15_000, + audioData: Data([1, 2, 3, 4]) + ) + let retry = TranscriptionService.backgroundChunkId( + conversationId: "conv.123", + chunkStartMs: 15_000, + audioData: Data([1, 2, 3, 4]) + ) + let changedPayload = TranscriptionService.backgroundChunkId( + conversationId: "conv.123", + chunkStartMs: 15_000, + audioData: Data([1, 2, 3, 5]) + ) + + XCTAssertEqual(first, retry) + XCTAssertNotEqual(first, changedPayload) + XCTAssertTrue(first.hasPrefix("conv_123-15000-4-")) + XCTAssertNil(first.range(of: #"[^A-Za-z0-9_-]"#, options: .regularExpression)) + } + func testTranscriptMergerDeduplicatesAndMergesOverlap() { var merger = BackgroundTranscriptMerger() let first = Self.backendSegment(id: "a", text: "hello world", speakerId: 0, start: 0, end: 2) From 83c53f63ec306d97e8df8418b32ffa2f017a906e Mon Sep 17 00:00:00 2001 From: David Zhang Date: Fri, 22 May 2026 05:00:30 +0700 Subject: [PATCH 33/44] Add AssemblyAI speaker identity diagnostics --- .../database/transcription_provider_usage.py | 74 +++++++++++++++ backend/routers/desktop_background.py | 68 ++++++++------ backend/routers/sync.py | 2 + .../unit/test_background_provider_service.py | 23 +++++ .../test_desktop_background_transcribe.py | 89 +++++++++++++++++++ .../unit/test_transcription_provider_usage.py | 20 +++++ backend/utils/stt/provider_service.py | 44 +++++++++ 7 files changed, 294 insertions(+), 26 deletions(-) diff --git a/backend/database/transcription_provider_usage.py b/backend/database/transcription_provider_usage.py index ea5511f02c5..ef96697295e 100644 --- a/backend/database/transcription_provider_usage.py +++ b/backend/database/transcription_provider_usage.py @@ -131,6 +131,16 @@ def create_provider_run( 'transcript_word_count': 0, 'speaker_cluster_count': 0, 'identified_speaker_cluster_count': 0, + 'provider_speaker_count': 0, + 'mapped_speaker_count': 0, + 'mapped_person_count': 0, + 'unmapped_speaker_count': 0, + 'embedding_extraction_failure_count': 0, + 'identity_metric_update': { + 'status': 'pending', + 'skipped_reason': None, + 'updated_at': None, + }, 'identity_confidence_summary': {}, 'error_class': None, 'created_at': now, @@ -160,6 +170,11 @@ def finalize_provider_run( transcript_word_count: int = 0, speaker_cluster_count: int = 0, identified_speaker_cluster_count: int = 0, + provider_speaker_count: int = 0, + mapped_speaker_count: int = 0, + mapped_person_count: int = 0, + unmapped_speaker_count: int = 0, + embedding_extraction_failure_count: int = 0, identity_confidence_summary: Optional[dict[str, Any]] = None, error_class: Optional[str] = None, artifact_refs: Optional[dict[str, str]] = None, @@ -190,6 +205,11 @@ def finalize_provider_run( 'transcript_word_count': transcript_word_count, 'speaker_cluster_count': speaker_cluster_count, 'identified_speaker_cluster_count': identified_speaker_cluster_count, + 'provider_speaker_count': provider_speaker_count, + 'mapped_speaker_count': mapped_speaker_count, + 'mapped_person_count': mapped_person_count, + 'unmapped_speaker_count': unmapped_speaker_count, + 'embedding_extraction_failure_count': embedding_extraction_failure_count, 'identity_confidence_summary': summary, 'error_class': error_class, 'artifact_refs': artifact_refs or {}, @@ -219,6 +239,11 @@ def finalize_provider_run( transcript_word_count=transcript_word_count, speaker_cluster_count=speaker_cluster_count, identified_speaker_cluster_count=identified_speaker_cluster_count, + provider_speaker_count=provider_speaker_count, + mapped_speaker_count=mapped_speaker_count, + mapped_person_count=mapped_person_count, + unmapped_speaker_count=unmapped_speaker_count, + embedding_extraction_failure_count=embedding_extraction_failure_count, identity_confidence_summary=summary, ) emit_provider_run_metrics( @@ -271,6 +296,11 @@ def increment_daily_rollup( transcript_word_count: int = 0, speaker_cluster_count: int = 0, identified_speaker_cluster_count: int = 0, + provider_speaker_count: int = 0, + mapped_speaker_count: int = 0, + mapped_person_count: int = 0, + unmapped_speaker_count: int = 0, + embedding_extraction_failure_count: int = 0, identity_confidence_summary: Optional[dict[str, Any]] = None, ) -> None: update = { @@ -290,6 +320,11 @@ def increment_daily_rollup( 'transcript_word_count': firestore.Increment(transcript_word_count), 'speaker_cluster_count': firestore.Increment(speaker_cluster_count), 'identified_speaker_cluster_count': firestore.Increment(identified_speaker_cluster_count), + 'provider_speaker_count': firestore.Increment(provider_speaker_count), + 'mapped_speaker_count': firestore.Increment(mapped_speaker_count), + 'mapped_person_count': firestore.Increment(mapped_person_count), + 'unmapped_speaker_count': firestore.Increment(unmapped_speaker_count), + 'embedding_extraction_failure_count': firestore.Increment(embedding_extraction_failure_count), 'last_updated': _utc_now(), } for bucket, count in (identity_confidence_summary or {}).items(): @@ -341,6 +376,13 @@ def update_provider_run_identity_metrics( workload: str, identified_speaker_cluster_count: int, identity_confidence_summary: Optional[dict[str, Any]] = None, + provider_speaker_count: int = 0, + mapped_speaker_count: int = 0, + mapped_person_count: int = 0, + unmapped_speaker_count: int = 0, + embedding_extraction_failure_count: int = 0, + identity_metric_update_status: str = 'succeeded', + identity_metric_update_skipped_reason: Optional[str] = None, ) -> None: ref = _run_ref(run_id) snapshot = ref.get() @@ -348,6 +390,11 @@ def update_provider_run_identity_metrics( return data = snapshot.to_dict() or {} previous_identified = data.get('identified_speaker_cluster_count', 0) or 0 + previous_provider_speakers = data.get('provider_speaker_count', data.get('speaker_cluster_count', 0)) or 0 + previous_mapped_speakers = data.get('mapped_speaker_count', previous_identified) or 0 + previous_mapped_people = data.get('mapped_person_count', 0) or 0 + previous_unmapped_speakers = data.get('unmapped_speaker_count', 0) or 0 + previous_embedding_failures = data.get('embedding_extraction_failure_count', 0) or 0 previous_summary = data.get('identity_confidence_summary') or {} summary = identity_confidence_summary or {} completed_at = (data.get('timing') or {}).get('completed_at') or _utc_now() @@ -355,6 +402,16 @@ def update_provider_run_identity_metrics( ref.set( { 'identified_speaker_cluster_count': identified_speaker_cluster_count, + 'provider_speaker_count': provider_speaker_count, + 'mapped_speaker_count': mapped_speaker_count, + 'mapped_person_count': mapped_person_count, + 'unmapped_speaker_count': unmapped_speaker_count, + 'embedding_extraction_failure_count': embedding_extraction_failure_count, + 'identity_metric_update': { + 'status': identity_metric_update_status, + 'skipped_reason': identity_metric_update_skipped_reason, + 'updated_at': _utc_now(), + }, 'identity_confidence_summary': summary, 'updated_at': _utc_now(), }, @@ -363,6 +420,13 @@ def update_provider_run_identity_metrics( rollup_update = { 'identified_speaker_cluster_count': firestore.Increment(identified_speaker_cluster_count - previous_identified), + 'provider_speaker_count': firestore.Increment(provider_speaker_count - previous_provider_speakers), + 'mapped_speaker_count': firestore.Increment(mapped_speaker_count - previous_mapped_speakers), + 'mapped_person_count': firestore.Increment(mapped_person_count - previous_mapped_people), + 'unmapped_speaker_count': firestore.Increment(unmapped_speaker_count - previous_unmapped_speakers), + 'embedding_extraction_failure_count': firestore.Increment( + embedding_extraction_failure_count - previous_embedding_failures + ), 'last_updated': _utc_now(), } for bucket in set(previous_summary) | set(summary): @@ -444,6 +508,11 @@ def _empty_rollup(day: str, provider: str, model: str, workload: str) -> dict[st 'transcript_word_count': 0, 'speaker_cluster_count': 0, 'identified_speaker_cluster_count': 0, + 'provider_speaker_count': 0, + 'mapped_speaker_count': 0, + 'mapped_person_count': 0, + 'unmapped_speaker_count': 0, + 'embedding_extraction_failure_count': 0, 'identity_confidence_counts': {}, 'last_updated': _utc_now(), } @@ -464,6 +533,11 @@ def _add_run_to_rollup(rollup: dict[str, Any], data: dict[str, Any]) -> None: 'transcript_word_count', 'speaker_cluster_count', 'identified_speaker_cluster_count', + 'provider_speaker_count', + 'mapped_speaker_count', + 'mapped_person_count', + 'unmapped_speaker_count', + 'embedding_extraction_failure_count', ): rollup[field] += data.get(field, 0) or 0 for bucket, count in (data.get('identity_confidence_summary') or {}).items(): diff --git a/backend/routers/desktop_background.py b/backend/routers/desktop_background.py index a11994a5962..93f4df39dfd 100644 --- a/backend/routers/desktop_background.py +++ b/backend/routers/desktop_background.py @@ -31,6 +31,7 @@ from utils.stt.background_speaker_identity import USER_SELF_PERSON_ID, identify_background_speaker_clusters from utils.stt.provider_service import ( resolve_prerecorded_provider_for_request, + speaker_identity_metrics, transcribe_bytes, update_provider_run_identity_metrics, ) @@ -434,42 +435,51 @@ async def _identify_speakers( run_id: Optional[str], ) -> None: """Apply Omi speaker identity to AssemblyAI background chunk-local segments.""" + identity_metric_update_status = 'succeeded' + identity_metric_update_skipped_reason = None try: person_embeddings_cache = await run_blocking(db_executor, _build_person_embeddings_cache, uid) if not person_embeddings_cache: - return - assignments = await run_blocking( - sync_executor, - identify_background_speaker_clusters, - segments, - audio_bytes, - person_embeddings_cache, - extract_embedding_from_bytes, - ) - identified_count = sum(1 for assignment in assignments.values() if assignment.state in ('identified', 'user')) - logger.info( - "Speaker ID (desktop background): cluster assignments=%s identified=%s uid=%s conversation_id=%s", - len(assignments), - identified_count, - uid, - conversation_id, - ) - await run_blocking( - db_executor, - update_provider_run_identity_metrics, - run_id, - provider or 'unknown', - model or 'unknown', - STTWorkload.background, - segments, - ) + identity_metric_update_status = 'skipped' + identity_metric_update_skipped_reason = 'missing_candidate_embeddings' + else: + assignments = await run_blocking( + sync_executor, + identify_background_speaker_clusters, + segments, + audio_bytes, + person_embeddings_cache, + extract_embedding_from_bytes, + ) + identified_count = sum( + 1 for assignment in assignments.values() if assignment.state in ('identified', 'user') + ) + logger.info( + "Speaker ID (desktop background): cluster assignments=%s identified=%s uid=%s conversation_id=%s", + len(assignments), + identified_count, + uid, + conversation_id, + ) except Exception as e: + identity_metric_update_status = 'failed' logger.warning( "Speaker ID (desktop background): identification failed uid=%s conversation_id=%s: %s", uid, conversation_id, e, ) + await run_blocking( + db_executor, + update_provider_run_identity_metrics, + run_id, + provider or 'unknown', + model or 'unknown', + STTWorkload.background, + segments, + identity_metric_update_status, + identity_metric_update_skipped_reason, + ) def _build_person_embeddings_cache(uid: str) -> Dict[str, dict]: @@ -539,11 +549,17 @@ def _speaker_diagnostics(segments: List[TranscriptSegment], prefix: str = "") -> {str(segment.provider_speaker_label) for segment in segments if segment.provider_speaker_label is not None} ) mapped_speakers = sorted({int(segment.speaker_id) for segment in segments if segment.speaker_id is not None}) + identity_metrics = speaker_identity_metrics(segments) return { f"{prefix}provider_cluster_count": len(provider_clusters), f"{prefix}provider_clusters": provider_clusters[:20], f"{prefix}provider_speaker_label_count": len(provider_labels), f"{prefix}provider_speaker_labels": provider_labels[:20], + f"{prefix}provider_speaker_count": identity_metrics['provider_speaker_count'], + f"{prefix}mapped_speaker_count": identity_metrics['mapped_speaker_count'], + f"{prefix}mapped_person_count": identity_metrics['mapped_person_count'], + f"{prefix}unmapped_speaker_count": identity_metrics['unmapped_speaker_count'], + f"{prefix}embedding_extraction_failure_count": identity_metrics['embedding_extraction_failure_count'], f"{prefix}speaker_id_count": len(mapped_speakers), f"{prefix}speaker_ids": mapped_speakers[:20], } diff --git a/backend/routers/sync.py b/backend/routers/sync.py index acf914e3fa4..9d2bce016af 100644 --- a/backend/routers/sync.py +++ b/backend/routers/sync.py @@ -911,6 +911,8 @@ def delete_file(): transcription.result.model or 'unknown', STTWorkload.sync, transcript_segments, + 'skipped' if not person_embeddings_cache else 'succeeded', + 'missing_candidate_embeddings' if not person_embeddings_cache else None, ) except Exception as e: logger.warning(f'Speaker ID (sync): identity metric update failed for {path}: {e}') diff --git a/backend/tests/unit/test_background_provider_service.py b/backend/tests/unit/test_background_provider_service.py index ee5948e0f23..a174c9fd108 100644 --- a/backend/tests/unit/test_background_provider_service.py +++ b/backend/tests/unit/test_background_provider_service.py @@ -529,6 +529,29 @@ def test_provider_service_counts_label_only_identified_clusters(): assert finalize_run.call_args.kwargs['speaker_cluster_count'] == 2 assert finalize_run.call_args.kwargs['identified_speaker_cluster_count'] == 1 + assert finalize_run.call_args.kwargs['provider_speaker_count'] == 2 + assert finalize_run.call_args.kwargs['mapped_speaker_count'] == 1 + assert finalize_run.call_args.kwargs['mapped_person_count'] == 1 + assert finalize_run.call_args.kwargs['unmapped_speaker_count'] == 1 + + +def test_provider_service_preserves_assemblyai_labels_for_identity_metrics(): + result = ProviderTranscriptResult( + provider='assemblyai', + model='universal-2', + duration=4.0, + words=[ + ProviderTranscriptWord(text='alice', start=0.0, end=0.5, speaker_label='A'), + ProviderTranscriptWord(text='speaks', start=0.6, end=1.0, speaker_label='A'), + ProviderTranscriptWord(text='bob', start=2.0, end=2.4, speaker_label='B'), + ProviderTranscriptWord(text='replies', start=2.5, end=3.0, speaker_label='B'), + ], + ) + segments = provider_service.reconstruct_conversation(result) + + assert [segment.provider_speaker_label for segment in segments] == ['A', 'B'] + assert {segment.speaker_id for segment in segments} == {0} + assert provider_service.speaker_identity_metrics(segments)['provider_speaker_count'] == 2 def test_provider_service_live_assemblyai_smoke_records_ledger_when_credentials_are_present(monkeypatch): diff --git a/backend/tests/unit/test_desktop_background_transcribe.py b/backend/tests/unit/test_desktop_background_transcribe.py index 85261003af2..853bf00001a 100644 --- a/backend/tests/unit/test_desktop_background_transcribe.py +++ b/backend/tests/unit/test_desktop_background_transcribe.py @@ -86,6 +86,50 @@ def _pcm_to_wav_bytes(pcm_data: bytes, sample_rate: int) -> bytes: 'utils.stt.provider_service', SimpleNamespace( resolve_prerecorded_provider_for_request=MagicMock(), + speaker_identity_metrics=lambda segments: { + 'provider_speaker_count': len( + { + segment.provider_cluster_id or segment.provider_speaker_label + for segment in segments + if segment.provider_cluster_id or segment.provider_speaker_label + } + ), + 'mapped_speaker_count': len( + { + segment.provider_cluster_id or segment.provider_speaker_label + for segment in segments + if (segment.provider_cluster_id or segment.provider_speaker_label) + and ( + segment.person_id or segment.is_user or segment.speaker_identity_state in ('identified', 'user') + ) + } + ), + 'mapped_person_count': len( + { + segment.person_id or 'user' + for segment in segments + if segment.person_id or segment.is_user or segment.speaker_identity_state == 'user' + } + ), + 'unmapped_speaker_count': len( + { + segment.provider_cluster_id or segment.provider_speaker_label + for segment in segments + if segment.provider_cluster_id or segment.provider_speaker_label + } + ) + - len( + { + segment.provider_cluster_id or segment.provider_speaker_label + for segment in segments + if (segment.provider_cluster_id or segment.provider_speaker_label) + and ( + segment.person_id or segment.is_user or segment.speaker_identity_state in ('identified', 'user') + ) + } + ), + 'embedding_extraction_failure_count': 0, + }, transcribe_bytes=MagicMock(), update_provider_run_identity_metrics=MagicMock(), ), @@ -517,6 +561,51 @@ def test_cluster_speaker_mapping_assigns_distinct_ids(monkeypatch): assert [segment['speaker'] for segment in data['segments']] == ['SPEAKER_00', 'SPEAKER_01'] assert data['speaker_diagnostics']['provider_cluster_count'] == 2 assert data['speaker_diagnostics']['mapped_speaker_ids'] == [0, 1] + assert data['speaker_diagnostics']['provider_speaker_count'] == 2 + assert data['speaker_diagnostics']['mapped_provider_speaker_count'] == 2 + + +def test_assemblyai_label_only_speakers_do_not_collapse_to_single_local_speaker(monkeypatch): + segments = [ + TranscriptSegment( + text='Alice starts.', + is_user=False, + start=0.0, + end=2.0, + provider_speaker_label='A', + speaker_identity_state='unassigned', + stt_provider='assemblyai', + stt_model='universal-2', + ), + TranscriptSegment( + text='Bob replies.', + is_user=False, + start=2.0, + end=4.0, + provider_speaker_label='B', + speaker_identity_state='unassigned', + stt_provider='assemblyai', + stt_model='universal-2', + ), + ] + client, _mock_transcribe = _client(monkeypatch, segments=segments) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + data = response.json() + assert [segment['provider_speaker_label'] for segment in data['segments']] == ['A', 'B'] + assert [segment['speaker_id'] for segment in data['segments']] == [0, 1] + assert data['speaker_diagnostics']['provider_speaker_label_count'] == 2 + assert data['speaker_diagnostics']['mapped_speaker_ids'] == [0, 1] + assert data['speaker_diagnostics']['mapped_unmapped_speaker_count'] == 2 + desktop_background.update_provider_run_identity_metrics.assert_called_once() + assert desktop_background.update_provider_run_identity_metrics.call_args.args[5] == 'skipped' + assert desktop_background.update_provider_run_identity_metrics.call_args.args[6] == 'missing_candidate_embeddings' def test_background_transcribe_multi_chunk_offsets_persist_and_keep_speaker_map(monkeypatch): diff --git a/backend/tests/unit/test_transcription_provider_usage.py b/backend/tests/unit/test_transcription_provider_usage.py index 1e42126a74e..33df99f3e8d 100644 --- a/backend/tests/unit/test_transcription_provider_usage.py +++ b/backend/tests/unit/test_transcription_provider_usage.py @@ -412,6 +412,10 @@ def test_update_provider_run_identity_metrics_updates_doc_and_rollup_delta(monke completed_at=completed_at, speaker_cluster_count=2, identified_speaker_cluster_count=0, + provider_speaker_count=2, + mapped_speaker_count=0, + mapped_person_count=0, + unmapped_speaker_count=2, identity_confidence_summary={'unknown': 2}, ) @@ -421,16 +425,32 @@ def test_update_provider_run_identity_metrics_updates_doc_and_rollup_delta(monke model='universal-2', workload='sync', identified_speaker_cluster_count=1, + provider_speaker_count=2, + mapped_speaker_count=1, + mapped_person_count=1, + unmapped_speaker_count=1, + embedding_extraction_failure_count=1, + identity_metric_update_status='succeeded', identity_confidence_summary={'very_high': 1, 'unknown': 1}, ) run_doc = fake_db.docs[(usage.RUNS_COLLECTION, 'run-identity')] update = run_doc.set_calls[-1]['data'] assert update['identified_speaker_cluster_count'] == 1 + assert update['provider_speaker_count'] == 2 + assert update['mapped_speaker_count'] == 1 + assert update['mapped_person_count'] == 1 + assert update['unmapped_speaker_count'] == 1 + assert update['embedding_extraction_failure_count'] == 1 + assert update['identity_metric_update']['status'] == 'succeeded' assert update['identity_confidence_summary'] == {'very_high': 1, 'unknown': 1} rollup_doc = fake_db.docs[(usage.DAILY_USAGE_COLLECTION, '2026-05-21:assemblyai:universal-2:sync')] rollup_delta = rollup_doc.set_calls[-1]['data'] assert rollup_delta['identified_speaker_cluster_count'] == {'__increment': 1} + assert rollup_delta['mapped_speaker_count'] == {'__increment': 1} + assert rollup_delta['mapped_person_count'] == {'__increment': 1} + assert rollup_delta['unmapped_speaker_count'] == {'__increment': -1} + assert rollup_delta['embedding_extraction_failure_count'] == {'__increment': 1} assert rollup_delta['identity_confidence_counts.very_high'] == {'__increment': 1} assert rollup_delta['identity_confidence_counts.unknown'] == {'__increment': -1} diff --git a/backend/utils/stt/provider_service.py b/backend/utils/stt/provider_service.py index 9ad960c8076..27156847fe9 100644 --- a/backend/utils/stt/provider_service.py +++ b/backend/utils/stt/provider_service.py @@ -606,6 +606,7 @@ def _finalize_run( billable_seconds = raw_audio_seconds clusters = {_segment_cluster_key(segment) for segment in segments if _segment_cluster_key(segment)} confidences = [segment.speaker_identity_confidence for segment in segments] + identity_metrics = speaker_identity_metrics(segments) try: finalize_provider_run( run_id=run_id, @@ -630,6 +631,11 @@ def _finalize_run( speaker_cluster_count=len(clusters), identified_speaker_cluster_count=_identified_cluster_count(segments), identity_confidence_summary=summarize_identity_confidences(confidences), + provider_speaker_count=identity_metrics['provider_speaker_count'], + mapped_speaker_count=identity_metrics['mapped_speaker_count'], + mapped_person_count=identity_metrics['mapped_person_count'], + unmapped_speaker_count=identity_metrics['unmapped_speaker_count'], + embedding_extraction_failure_count=identity_metrics['embedding_extraction_failure_count'], artifact_refs=_provider_artifact_refs(result), fallback_provider=fallback_provider, ) @@ -689,6 +695,8 @@ def update_provider_run_identity_metrics( model: str, workload: STTWorkload, segments: List[TranscriptSegment], + identity_metric_update_status: str = 'succeeded', + identity_metric_update_skipped_reason: Optional[str] = None, ) -> None: if not run_id: return @@ -700,6 +708,7 @@ def update_provider_run_identity_metrics( ) return try: + identity_metrics = speaker_identity_metrics(segments) _db_update_provider_run_identity_metrics( run_id=run_id, provider=provider, @@ -709,11 +718,46 @@ def update_provider_run_identity_metrics( identity_confidence_summary=summarize_identity_confidences( [segment.speaker_identity_confidence for segment in segments] ), + provider_speaker_count=identity_metrics['provider_speaker_count'], + mapped_speaker_count=identity_metrics['mapped_speaker_count'], + mapped_person_count=identity_metrics['mapped_person_count'], + unmapped_speaker_count=identity_metrics['unmapped_speaker_count'], + embedding_extraction_failure_count=identity_metrics['embedding_extraction_failure_count'], + identity_metric_update_status=identity_metric_update_status, + identity_metric_update_skipped_reason=identity_metric_update_skipped_reason, ) except Exception as e: logger.warning('failed to update transcription provider identity metrics run_id=%s: %s', run_id, e) +def speaker_identity_metrics(segments: List[TranscriptSegment]) -> dict: + provider_speakers = {_segment_cluster_key(segment) for segment in segments if _segment_cluster_key(segment)} + mapped_speakers = { + _segment_cluster_key(segment) + for segment in segments + if _segment_cluster_key(segment) + and (segment.person_id or segment.is_user or segment.speaker_identity_state in ('identified', 'user')) + } + mapped_people = { + segment.person_id or 'user' + for segment in segments + if segment.person_id or segment.is_user or segment.speaker_identity_state == 'user' + } + embedding_failures = { + _segment_cluster_key(segment) + for segment in segments + if _segment_cluster_key(segment) + and (segment.speaker_identity_provenance or {}).get('reason') == 'embedding_extraction_failed' + } + return { + 'provider_speaker_count': len(provider_speakers), + 'mapped_speaker_count': len(mapped_speakers), + 'mapped_person_count': len(mapped_people), + 'unmapped_speaker_count': max(len(provider_speakers) - len(mapped_speakers), 0), + 'embedding_extraction_failure_count': len(embedding_failures), + } + + def _identified_cluster_count(segments: List[TranscriptSegment]) -> int: return len( { From 9cac76bd841120170b0f39eba30212ebaa0b5323 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Fri, 22 May 2026 10:48:00 +0700 Subject: [PATCH 34/44] Validate AssemblyAI background E2E --- backend/tests/unit/test_llm_usage_endpoints.py | 3 +++ scripts/desktop_assemblyai_e2e.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/backend/tests/unit/test_llm_usage_endpoints.py b/backend/tests/unit/test_llm_usage_endpoints.py index 149af6b193a..cd35f91055d 100644 --- a/backend/tests/unit/test_llm_usage_endpoints.py +++ b/backend/tests/unit/test_llm_usage_endpoints.py @@ -53,6 +53,9 @@ "set_user_data_protection_level", "get_generic_cache", "set_generic_cache", + "get_daily_summary_uid", + "store_daily_summary_to_uid", + "remove_daily_summary_to_uid", "set_speech_profile_duration", "r", ]: diff --git a/scripts/desktop_assemblyai_e2e.py b/scripts/desktop_assemblyai_e2e.py index 9e7b0c23b3a..b2656a3ecdb 100755 --- a/scripts/desktop_assemblyai_e2e.py +++ b/scripts/desktop_assemblyai_e2e.py @@ -21,6 +21,7 @@ from __future__ import annotations import argparse +import hashlib import json import os import struct @@ -301,9 +302,12 @@ def background_transcribe_chunk( chunk: bytes, language: str, ) -> dict: + chunk_hash = hashlib.sha256(chunk).hexdigest() + chunk_id = f"{conversation_id}-{chunk_start_ms}-{len(chunk)}-{chunk_hash[:16]}" transcribe_url = ( f"{api_base.rstrip('/')}/v2/desktop/background-transcribe" f"?conversation_id={conversation_id}" + f"&chunk_id={urllib.parse.quote(chunk_id)}" f"&chunk_start_ms={chunk_start_ms}" f"&sample_rate={SAMPLE_RATE}" f"&channels={CHANNELS}" From 908e483939075b337d7f312d027c4e829f5b202a Mon Sep 17 00:00:00 2001 From: David Zhang Date: Fri, 22 May 2026 10:56:56 +0700 Subject: [PATCH 35/44] Fix AssemblyAI BYOK routing test isolation --- .../unit/test_byok_assemblyai_routing.py | 41 ++++++++----------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/backend/tests/unit/test_byok_assemblyai_routing.py b/backend/tests/unit/test_byok_assemblyai_routing.py index 5d73e555502..bd9c9019534 100644 --- a/backend/tests/unit/test_byok_assemblyai_routing.py +++ b/backend/tests/unit/test_byok_assemblyai_routing.py @@ -30,72 +30,65 @@ def test_env_selects_assemblyai_for_sync(): assert get_prerecorded_provider_name(STTWorkload.sync) == STTProviderName.assemblyai -@patch('utils.stt.provider_service.get_byok_key') -def test_resolve_uses_deepgram_byok_when_no_assembly_header(mock_get_key): - mock_get_key.side_effect = lambda provider: {'deepgram': 'dg-user-key'}.get(provider) +def test_resolve_uses_deepgram_byok_when_no_assembly_header(monkeypatch): + monkeypatch.setattr(provider_service, 'get_byok_key', lambda provider: {'deepgram': 'dg-user-key'}.get(provider)) assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.deepgram -@patch('utils.stt.provider_service.get_byok_key') -def test_resolve_uses_deepgram_byok_for_background_when_no_assembly_header(mock_get_key): - mock_get_key.side_effect = lambda provider: {'deepgram': 'dg-user-key'}.get(provider) +def test_resolve_uses_deepgram_byok_for_background_when_no_assembly_header(monkeypatch): + monkeypatch.setattr(provider_service, 'get_byok_key', lambda provider: {'deepgram': 'dg-user-key'}.get(provider)) assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.background) == STTProviderName.deepgram -@patch('utils.stt.provider_service.get_byok_key') -def test_resolve_uses_assemblyai_when_byok_assembly_header_present(mock_get_key): +def test_resolve_uses_assemblyai_when_byok_assembly_header_present(monkeypatch): keys = {'assemblyai': 'aa-user-key', 'deepgram': 'dg-user-key'} def _lookup(provider): return keys.get(provider) - mock_get_key.side_effect = _lookup + monkeypatch.setattr(provider_service, 'get_byok_key', _lookup) assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.assemblyai -@patch('utils.stt.provider_service.get_byok_key') -def test_resolve_uses_assemblyai_for_background_when_byok_assembly_header_present(mock_get_key): +def test_resolve_uses_assemblyai_for_background_when_byok_assembly_header_present(monkeypatch): keys = {'assemblyai': 'aa-user-key', 'deepgram': 'dg-user-key'} def _lookup(provider): return keys.get(provider) - mock_get_key.side_effect = _lookup + monkeypatch.setattr(provider_service, 'get_byok_key', _lookup) assert ( provider_service.resolve_prerecorded_provider_for_request(STTWorkload.background) == STTProviderName.assemblyai ) -@patch('utils.stt.provider_service.get_byok_key', return_value=None) -def test_resolve_uses_server_assembly_when_no_byok_headers(_mock_get_key): +def test_resolve_uses_server_assembly_when_no_byok_headers(monkeypatch): + monkeypatch.setattr(provider_service, 'get_byok_key', lambda _provider: None) assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.assemblyai -@patch('utils.stt.provider_service.get_byok_key') -def test_resolve_uses_server_deepgram_when_server_assembly_missing_and_fallback_enabled(mock_get_key, monkeypatch): +def test_resolve_uses_server_deepgram_when_server_assembly_missing_and_fallback_enabled(monkeypatch): monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) monkeypatch.setenv('DEEPGRAM_API_KEY', 'dg-server-key') - mock_get_key.return_value = None + monkeypatch.setattr(provider_service, 'get_byok_key', lambda _provider: None) assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.deepgram -@patch('utils.stt.provider_service.get_byok_key') def test_resolve_keeps_assemblyai_selected_when_server_assembly_missing_and_fallback_disabled( - mock_get_key, monkeypatch + monkeypatch, ): monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) monkeypatch.setenv('DEEPGRAM_API_KEY', 'dg-server-key') monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', 'false') - mock_get_key.return_value = None + monkeypatch.setattr(provider_service, 'get_byok_key', lambda _provider: None) assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.assemblyai -@patch('utils.stt.provider_service.get_byok_key') -def test_assemblyai_provider_passes_byok_api_key(mock_get_key): - mock_get_key.return_value = 'aa-user-key' - with patch('utils.stt.provider_service.AssemblyAIAsyncTranscriptionProvider') as mock_cls: +def test_assemblyai_provider_passes_byok_api_key(monkeypatch): + monkeypatch.setattr(provider_service, 'get_byok_key', lambda _provider: 'aa-user-key') + with patch.object(provider_service, 'AssemblyAIAsyncTranscriptionProvider') as mock_cls: provider_service._assemblyai_prerecorded_provider() mock_cls.assert_called_once_with(api_key='aa-user-key') From 16870f348631806a396675474edf9ad687fd7f58 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Sat, 23 May 2026 21:29:42 +0700 Subject: [PATCH 36/44] Improve AssemblyAI speaker cluster handling --- backend/routers/desktop_background.py | 48 ++++++++++++++++ .../test_desktop_background_transcribe.py | 46 +++++++++++++++ .../backend/assemblyai_background_rollout.mdx | 57 +++++++++++++++++++ .../backend/listen_pusher_pipeline.mdx | 6 ++ 4 files changed, 157 insertions(+) diff --git a/backend/routers/desktop_background.py b/backend/routers/desktop_background.py index 93f4df39dfd..ae0f1d39330 100644 --- a/backend/routers/desktop_background.py +++ b/backend/routers/desktop_background.py @@ -51,6 +51,7 @@ _MAX_PCM_BODY_BYTES = 200_000_000 _SPEAKER_MAP_TTL_SECONDS = 60 * 60 * 24 +_LOCAL_CLUSTER_SPLIT_MARKER = "::local_part:" class BackgroundConversationStartRequest(BaseModel): @@ -293,6 +294,7 @@ async def background_transcribe( del audio_bytes segments = response.segments + _split_noncontiguous_provider_clusters(segments) speaker_diagnostics = _speaker_diagnostics(segments) if conversation_id and segments: await _identify_speakers( @@ -425,6 +427,52 @@ def _apply_chunk_offset(segments: List[TranscriptSegment], offset_sec: float) -> segment.end += offset_sec +def _split_noncontiguous_provider_clusters(segments: List[TranscriptSegment]) -> None: + """Split reused provider clusters into local contiguous groups. + + Batched providers can reuse the same cluster label for rapid, non-contiguous turns + inside a chunk. Omi treats provider labels as hints, so split those groups before + identity matching and final speaker_id assignment. + """ + groups: list[tuple[str, list[TranscriptSegment]]] = [] + current_cluster = None + current_group: list[TranscriptSegment] = [] + for segment in sorted(segments, key=lambda item: (item.start, item.end)): + cluster = _raw_provider_cluster_key(segment) + if not cluster: + if current_group and current_cluster is not None: + groups.append((current_cluster, current_group)) + current_group = [] + current_cluster = None + continue + if current_group and cluster != current_cluster: + groups.append((current_cluster, current_group)) + current_group = [] + current_cluster = cluster + current_group.append(segment) + if current_group and current_cluster is not None: + groups.append((current_cluster, current_group)) + + group_counts: Dict[str, int] = {} + for cluster, _group_segments in groups: + group_counts[cluster] = group_counts.get(cluster, 0) + 1 + + group_indexes: Dict[str, int] = {} + for cluster, group_segments in groups: + if group_counts[cluster] <= 1: + continue + group_index = group_indexes.get(cluster, 0) + 1 + group_indexes[cluster] = group_index + local_cluster = f'{cluster}{_LOCAL_CLUSTER_SPLIT_MARKER}{group_index}' + for segment in group_segments: + segment.provider_cluster_id = local_cluster + + +def _raw_provider_cluster_key(segment: TranscriptSegment) -> Optional[str]: + cluster = segment.provider_cluster_id or segment.provider_speaker_label + return str(cluster) if cluster is not None else None + + async def _identify_speakers( uid: str, conversation_id: str, diff --git a/backend/tests/unit/test_desktop_background_transcribe.py b/backend/tests/unit/test_desktop_background_transcribe.py index 853bf00001a..3f67b400d71 100644 --- a/backend/tests/unit/test_desktop_background_transcribe.py +++ b/backend/tests/unit/test_desktop_background_transcribe.py @@ -565,6 +565,52 @@ def test_cluster_speaker_mapping_assigns_distinct_ids(monkeypatch): assert data['speaker_diagnostics']['mapped_provider_speaker_count'] == 2 +def test_noncontiguous_same_provider_cluster_splits_inside_chunk(monkeypatch): + segments = [ + TranscriptSegment(text='Alice starts.', is_user=False, start=0.0, end=1.0, provider_cluster_id='A'), + TranscriptSegment(text='Bob replies.', is_user=False, start=1.1, end=2.0, provider_cluster_id='B'), + TranscriptSegment(text='Carol follows.', is_user=False, start=2.1, end=3.0, provider_cluster_id='A'), + ] + client, _mock_transcribe = _client(monkeypatch, segments=segments) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + data = response.json() + assert [segment['speaker_id'] for segment in data['segments']] == [0, 1, 2] + assert [segment['provider_cluster_id'] for segment in data['segments']] == [ + 'A::local_part:1', + 'B', + 'A::local_part:2', + ] + assert data['speaker_diagnostics']['provider_cluster_count'] == 3 + + +def test_contiguous_same_provider_cluster_stays_together_inside_chunk(monkeypatch): + segments = [ + TranscriptSegment(text='Alice starts.', is_user=False, start=0.0, end=1.0, provider_cluster_id='A'), + TranscriptSegment(text='Alice continues.', is_user=False, start=1.1, end=2.0, provider_cluster_id='A'), + TranscriptSegment(text='Bob replies.', is_user=False, start=2.1, end=3.0, provider_cluster_id='B'), + ] + client, _mock_transcribe = _client(monkeypatch, segments=segments) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + data = response.json() + assert [segment['speaker_id'] for segment in data['segments']] == [0, 0, 1] + assert [segment['provider_cluster_id'] for segment in data['segments']] == ['A', 'A', 'B'] + assert data['speaker_diagnostics']['provider_cluster_count'] == 2 + + def test_assemblyai_label_only_speakers_do_not_collapse_to_single_local_speaker(monkeypatch): segments = [ TranscriptSegment( diff --git a/docs/doc/developer/backend/assemblyai_background_rollout.mdx b/docs/doc/developer/backend/assemblyai_background_rollout.mdx index 7f7f55617e4..67f7a862d17 100644 --- a/docs/doc/developer/backend/assemblyai_background_rollout.mdx +++ b/docs/doc/developer/backend/assemblyai_background_rollout.mdx @@ -101,6 +101,63 @@ Canonical speaker identity is assigned after normalization with the Omi speaker embedding service. Low-confidence clusters remain explicitly unknown rather than being mapped to a fake person or to `speaker_id=0`. +### Speaker Detection Contract + +AssemblyAI and Deepgram speaker labels are provider-local diarization hints. +They are not stable people, and they are not safe to compare across prerecorded +jobs. AssemblyAI labels such as `A`/`B` and Deepgram labels such as `0`/`1` +can reset on every uploaded chunk. + +Desktop background transcription therefore uses a split-then-reconcile flow: + +1. Normalize provider words and utterances. +2. Reconstruct `TranscriptSegment` records with provider cluster metadata. +3. Split repeated non-contiguous provider clusters inside the chunk into local + contiguous groups before identity matching. +4. Match each local group to known Omi user/person embeddings when clean + samples exist. +5. Assign final app-visible `speaker_id` values from durable identity when + available, otherwise from a chunk-scoped provider key. + +The chunk-scoped unresolved key is: + +```text +provider:{provider}:chunk:{chunk_id}:cluster:{provider_cluster_id} +``` + +Known Omi identities intentionally ignore the chunk scope: + +```text +identity:user +identity:person:{person_id} +``` + +This is deliberate. False speaker merges are more damaging than temporary +anonymous-speaker fragmentation. The product should prefer splitting uncertain +provider clusters first, then merging only with Omi identity evidence, +high-confidence embeddings, or explicit user corrections. + +### Benchmark Learnings + +The AssemblyAI rollout benchmark found: + +- Fixed 15s batches expose provider-label resets; raw `A`/`B` labels must not + carry across chunks. +- Longer AssemblyAI batches reduced request count but often collapsed speakers + in fast-turn or overlap-heavy synthetic sessions. +- Batch-size, silence-gap, speech-duration, embedding-span, overlap-tolerance, + and match-threshold sweeps did not beat the chunk-scoped baseline. +- The strongest improvement came from splitting non-contiguous same-cluster + groups inside a chunk. That fixed rapid turn-taking cases where one provider + cluster contained multiple true speakers. +- Graph clustering and anonymous reconciliation are useful as reconciliation + structures, but they cannot repair a provider cluster that already contains + multiple speakers unless the cluster is split first. + +The current landed rule is: provider clusters are local hints, repeated +non-contiguous groups are split, and stable app identity comes from Omi +evidence rather than provider labels. + Provider run ledgers live in `transcription_provider_runs`, and daily rollups live in `transcription_provider_usage_daily`. Ledger payload guards reject raw audio bytes, transcript text, words, utterances, and similar high-risk payloads. diff --git a/docs/doc/developer/backend/listen_pusher_pipeline.mdx b/docs/doc/developer/backend/listen_pusher_pipeline.mdx index 6133718d500..25c7d98eed6 100644 --- a/docs/doc/developer/backend/listen_pusher_pipeline.mdx +++ b/docs/doc/developer/backend/listen_pusher_pipeline.mdx @@ -286,6 +286,7 @@ sequenceDiagram Note over ProviderService: Missing AssemblyAI credentials skip directly
to Deepgram when fallback credentials are usable. ProviderService->>ProviderService: Normalize provider result and reconstruct canonical segments + ProviderService->>ProviderService: Split repeated non-contiguous provider clusters inside each chunk ProviderService->>EmbeddingAPI: Extract cluster samples for Omi speaker identity ProviderService->>Firestore: Store canonical transcript segments and provider run ledger ProviderService->>Firestore: Increment daily provider usage rollup @@ -296,6 +297,11 @@ AssemblyAI speaker labels are provider/session-local cluster IDs. Canonical speaker identity is applied after provider normalization with the Omi embedding service. Low-confidence matches remain explicit unknown speakers; provider cluster metadata is preserved separately from `person_id` and `is_user`. +For desktop background chunks, unresolved provider clusters are scoped by +provider and `chunk_id`; repeated non-contiguous uses of the same provider +cluster inside one chunk are split before embedding identity matching. This +prevents rapid turn-taking from assigning multiple real speakers to one +app-visible speaker just because the provider reused `A`, `B`, `0`, or `1`. Optional desktop BYOK for AssemblyAI applies only to async prerecorded workloads above; listen/realtime streaming in this pipeline remains Deepgram-only. From f9b931fe6f72598bdbd166411fabee22351f1c0f Mon Sep 17 00:00:00 2001 From: David Zhang Date: Sun, 24 May 2026 03:34:38 +0700 Subject: [PATCH 37/44] Centralize AssemblyAI background provider policy --- backend/routers/desktop_background.py | 64 ++------ .../unit/test_background_provider_service.py | 87 ++++++++++- .../test_desktop_background_transcribe.py | 141 ++++++++++++------ backend/utils/stt/provider_service.py | 60 ++++++++ backend/utils/stt/providers.py | 22 +++ .../backend/assemblyai_background_rollout.mdx | 50 +++++-- 6 files changed, 316 insertions(+), 108 deletions(-) diff --git a/backend/routers/desktop_background.py b/backend/routers/desktop_background.py index ae0f1d39330..d75d2c71eb7 100644 --- a/backend/routers/desktop_background.py +++ b/backend/routers/desktop_background.py @@ -1,7 +1,6 @@ import hashlib import json import logging -import os from datetime import datetime, timezone from typing import Dict, List, Optional @@ -26,22 +25,16 @@ from utils.executors import db_executor, run_blocking, sync_executor from utils.fair_use import is_hard_restricted, record_speech_ms from utils.other import endpoints as auth -from utils.byok import get_byok_key from utils.speaker_identification import _pcm_to_wav_bytes from utils.stt.background_speaker_identity import USER_SELF_PERSON_ID, identify_background_speaker_clusters from utils.stt.provider_service import ( - resolve_prerecorded_provider_for_request, + resolve_background_provider_policy, speaker_identity_metrics, transcribe_bytes, update_provider_run_identity_metrics, ) from utils.stt.speaker_embedding import extract_embedding_from_bytes -from utils.stt.providers import ( - STTProviderName, - STTWorkload, - get_fallback_prerecorded_provider_name, - get_prerecorded_provider_name, -) +from utils.stt.providers import STTWorkload from utils.subscription import has_transcription_credits, is_trial_paywalled from utils.voice_duration_limiter import compute_pcm_duration_ms @@ -61,51 +54,24 @@ class BackgroundConversationStartRequest(BaseModel): @router.get("/capabilities") async def desktop_capabilities(uid: str = Depends(auth.get_current_user_uid)): - primary_provider = get_prerecorded_provider_name(STTWorkload.background) - effective_provider = resolve_prerecorded_provider_for_request(STTWorkload.background) - fallback_provider = get_fallback_prerecorded_provider_name(primary_provider, STTWorkload.background) - - assemblyai_key_available = bool(get_byok_key('assemblyai') or os.getenv('ASSEMBLYAI_API_KEY')) - deepgram_key_available = bool(get_byok_key('deepgram') or os.getenv('DEEPGRAM_API_KEY')) - fallback_available = fallback_provider == STTProviderName.deepgram and deepgram_key_available - - if effective_provider == STTProviderName.assemblyai and not assemblyai_key_available and fallback_available: - effective_provider = STTProviderName.deepgram - - usable_provider = None - if effective_provider == STTProviderName.assemblyai and assemblyai_key_available: - usable_provider = STTProviderName.assemblyai - elif effective_provider == STTProviderName.deepgram and deepgram_key_available: - usable_provider = STTProviderName.deepgram - - enabled = usable_provider is not None - mode = 'disabled' - reason = None - if usable_provider == STTProviderName.assemblyai: - mode = 'assemblyai_primary' - elif usable_provider == STTProviderName.deepgram and primary_provider == STTProviderName.assemblyai: + policy = resolve_background_provider_policy() + mode = policy.mode.value + if not policy.enabled: + mode = 'disabled' + elif policy.reason == 'fallback_deepgram_available': mode = 'deepgram_fallback' - reason = 'fallback_deepgram_available' - elif usable_provider == STTProviderName.deepgram: - mode = 'deepgram_primary' - elif primary_provider == STTProviderName.assemblyai and not assemblyai_key_available: - reason = 'missing_assemblyai_api_key' - elif primary_provider == STTProviderName.deepgram and not deepgram_key_available: - reason = 'missing_deepgram_api_key' - else: - reason = 'no_usable_batch_provider' return { "background_batch": { - "enabled": enabled, + "enabled": policy.enabled, "mode": mode, - "provider": primary_provider.value, - "primary_provider": primary_provider.value, - "effective_provider": usable_provider.value if usable_provider else None, - "fallback_provider": fallback_provider.value if fallback_provider else None, - "fallback_enabled": fallback_provider is not None, - "fallback_available": fallback_available, + "provider": policy.primary_provider.value, + "primary_provider": policy.primary_provider.value, + "effective_provider": policy.effective_provider.value if policy.effective_provider else None, + "fallback_provider": policy.fallback_provider.value if policy.fallback_provider else None, + "fallback_enabled": policy.fallback_enabled, + "fallback_available": policy.fallback_available, "workload": STTWorkload.background.value, - "reason": reason, + "reason": policy.reason, "sample_rate": 16000, "channels": 1, "encoding": "linear16", diff --git a/backend/tests/unit/test_background_provider_service.py b/backend/tests/unit/test_background_provider_service.py index a174c9fd108..34c1c969972 100644 --- a/backend/tests/unit/test_background_provider_service.py +++ b/backend/tests/unit/test_background_provider_service.py @@ -20,7 +20,13 @@ from models.transcript_segment import ProviderTranscriptResult, ProviderTranscriptWord # noqa: E402 from utils.stt import provider_service # noqa: E402 from utils.stt.provider_costs import estimate_prerecorded_provider_cost_usd # noqa: E402 -from utils.stt.providers import STTProviderName, STTWorkload, get_prerecorded_provider_name # noqa: E402 +from utils.stt.providers import ( # noqa: E402 + BackgroundProviderMode, + STTProviderName, + STTWorkload, + get_background_provider_mode, + get_prerecorded_provider_name, +) def _provider_result(provider='deepgram', model='nova-3'): @@ -109,15 +115,56 @@ def test_provider_service_finalizes_background_run_on_deepgram_when_assemblyai_d assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00076 -def test_background_routing_selects_assemblyai_by_default(monkeypatch): +def test_background_routing_defaults_to_shadow_only_deepgram_until_rollout_gates_pass(monkeypatch): monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', raising=False) monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', raising=False) + monkeypatch.delenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', raising=False) assert get_prerecorded_provider_name(STTWorkload.sync) == STTProviderName.assemblyai - assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.assemblyai + assert get_background_provider_mode() == BackgroundProviderMode.shadow_only + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.deepgram assert get_prerecorded_provider_name(STTWorkload.postprocess) == STTProviderName.assemblyai +def test_background_routing_selects_assemblyai_when_background_gate_allows(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'background') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'assemblyai') + + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.assemblyai + + +def test_background_routing_supports_deepgram_override_and_shadow_only(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'background') + + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'deepgram') + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.deepgram + + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'shadow_only') + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.deepgram + + +def test_background_routing_honors_disabled_rollout_gate_and_workload_allowlist(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'assemblyai') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'false') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'background') + + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.deepgram + + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,postprocess') + + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.deepgram + + +def test_background_routing_invalid_mode_fails_safe_to_shadow_only(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'invalid-mode') + + assert get_background_provider_mode() == BackgroundProviderMode.shadow_only + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.deepgram + + def test_prerecorded_ptt_and_realtime_related_workloads_stay_deepgram(): assert get_prerecorded_provider_name(STTWorkload.ptt) == STTProviderName.deepgram assert get_prerecorded_provider_name(STTWorkload.voice_message) == STTProviderName.deepgram @@ -126,6 +173,7 @@ def test_prerecorded_ptt_and_realtime_related_workloads_stay_deepgram(): def test_background_routing_can_select_assemblyai_without_moving_latency_critical_workloads(monkeypatch): monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess,ptt,realtime') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'assemblyai') assert get_prerecorded_provider_name(STTWorkload.sync) == STTProviderName.assemblyai assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.assemblyai @@ -243,6 +291,7 @@ def test_provider_service_falls_back_to_deepgram_when_assemblyai_fails(monkeypat def test_provider_service_falls_back_to_deepgram_for_background_when_assemblyai_fails(monkeypatch): monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'background') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'assemblyai') monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', raising=False) assemblyai_provider = MagicMock() @@ -309,6 +358,37 @@ def test_provider_service_skips_missing_assemblyai_key_when_deepgram_fallback_is assert finalize_run.call_args.kwargs['fallback_count'] == 0 +def test_provider_service_skips_missing_background_assemblyai_key_when_deepgram_fallback_is_usable(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'background') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'assemblyai') + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.setenv('DEEPGRAM_API_KEY', 'dg-server-key') + + deepgram_provider = MagicMock() + deepgram_provider.provider_name = STTProviderName.deepgram + deepgram_provider.transcribe_url.return_value = _provider_result() + + with patch.object(provider_service, '_assemblyai_prerecorded_provider') as assemblyai_provider, patch.object( + provider_service, '_deepgram_prerecorded_provider', return_value=deepgram_provider + ), patch.object(provider_service, 'create_provider_run', return_value='run-dg'), patch.object( + provider_service, 'finalize_provider_run' + ) as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/background.wav', + workload=STTWorkload.background, + uid='uid-1', + language='multi', + model='nova-3', + ) + + assert response.result.provider == 'deepgram' + assemblyai_provider.assert_not_called() + deepgram_provider.transcribe_url.assert_called_once() + assert finalize_run.call_args.kwargs['provider'] == 'deepgram' + assert finalize_run.call_args.kwargs['fallback_count'] == 0 + + def test_provider_service_reports_missing_assemblyai_key_when_fallback_disabled(monkeypatch): monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') @@ -361,6 +441,7 @@ def test_provider_service_reports_missing_assemblyai_key_when_no_fallback_key_is def test_provider_service_records_background_retry_exhaustion_when_fallback_disabled(monkeypatch): monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'background') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'assemblyai') monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', 'false') assemblyai_provider = MagicMock() diff --git a/backend/tests/unit/test_desktop_background_transcribe.py b/backend/tests/unit/test_desktop_background_transcribe.py index 3f67b400d71..199e0496cf1 100644 --- a/backend/tests/unit/test_desktop_background_transcribe.py +++ b/backend/tests/unit/test_desktop_background_transcribe.py @@ -86,6 +86,7 @@ def _pcm_to_wav_bytes(pcm_data: bytes, sample_rate: int) -> bytes: 'utils.stt.provider_service', SimpleNamespace( resolve_prerecorded_provider_for_request=MagicMock(), + resolve_background_provider_policy=MagicMock(), speaker_identity_metrics=lambda segments: { 'provider_speaker_count': len( { @@ -158,7 +159,7 @@ def _pcm_to_wav_bytes(pcm_data: bytes, sample_rate: int) -> bytes: from models.transcript_segment import ProviderTranscriptResult, TranscriptSegment from routers import desktop_background -from utils.stt.providers import STTProviderName, STTWorkload +from utils.stt.providers import BackgroundProviderMode, STTProviderName, STTWorkload sys.modules.pop('utils.stt.provider_service', None) if 'utils.stt' in sys.modules and hasattr(sys.modules['utils.stt'], 'provider_service'): @@ -179,7 +180,6 @@ def _client(monkeypatch, *, segments=None, person_embeddings_cache=None): monkeypatch.setattr(desktop_background, 'has_transcription_credits', lambda *_args, **_kwargs: True) monkeypatch.setattr(desktop_background, 'record_speech_ms', lambda *_args, **_kwargs: None) monkeypatch.setattr(desktop_background, 'record_usage', lambda *_args, **_kwargs: None) - monkeypatch.setattr(desktop_background, 'get_byok_key', lambda _provider: None) monkeypatch.setattr(desktop_background, 'resolve_voice_message_language', lambda _uid, language: language or 'en') monkeypatch.setattr( desktop_background.conversations_db, @@ -239,13 +239,19 @@ def _transcribe_bytes(audio_bytes, **_kwargs): def test_desktop_capabilities_reports_assemblyai_background_when_key_available(monkeypatch): client, _mock_transcribe = _client(monkeypatch) - monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') - monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') - monkeypatch.setenv('ASSEMBLYAI_API_KEY', 'server-aai-key') monkeypatch.setattr( desktop_background, - 'resolve_prerecorded_provider_for_request', - lambda _workload: STTProviderName.assemblyai, + 'resolve_background_provider_policy', + lambda: SimpleNamespace( + mode=BackgroundProviderMode.assemblyai, + primary_provider=STTProviderName.assemblyai, + effective_provider=STTProviderName.assemblyai, + fallback_provider=STTProviderName.deepgram, + fallback_enabled=True, + fallback_available=True, + enabled=True, + reason=None, + ), ) response = client.get('/v2/desktop/capabilities') @@ -256,7 +262,7 @@ def test_desktop_capabilities_reports_assemblyai_background_when_key_available(m assert data['provider'] == 'assemblyai' assert data['primary_provider'] == 'assemblyai' assert data['effective_provider'] == 'assemblyai' - assert data['mode'] == 'assemblyai_primary' + assert data['mode'] == 'assemblyai' assert data['fallback_provider'] == 'deepgram' assert data['fallback_enabled'] is True assert data['fallback_available'] is True @@ -266,14 +272,19 @@ def test_desktop_capabilities_reports_assemblyai_background_when_key_available(m def test_desktop_capabilities_allows_background_batch_with_deepgram_fallback(monkeypatch): client, _mock_transcribe = _client(monkeypatch) - monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') - monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') - monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', raising=False) - monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) monkeypatch.setattr( desktop_background, - 'resolve_prerecorded_provider_for_request', - lambda _workload: STTProviderName.assemblyai, + 'resolve_background_provider_policy', + lambda: SimpleNamespace( + mode=BackgroundProviderMode.assemblyai, + primary_provider=STTProviderName.assemblyai, + effective_provider=STTProviderName.deepgram, + fallback_provider=STTProviderName.deepgram, + fallback_enabled=True, + fallback_available=True, + enabled=True, + reason='fallback_deepgram_available', + ), ) response = client.get('/v2/desktop/capabilities') @@ -293,14 +304,19 @@ def test_desktop_capabilities_allows_background_batch_with_deepgram_fallback(mon def test_desktop_capabilities_reports_missing_assemblyai_key_when_fallback_disabled(monkeypatch): client, _mock_transcribe = _client(monkeypatch) - monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') - monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') - monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', 'false') - monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) monkeypatch.setattr( desktop_background, - 'resolve_prerecorded_provider_for_request', - lambda _workload: STTProviderName.assemblyai, + 'resolve_background_provider_policy', + lambda: SimpleNamespace( + mode=BackgroundProviderMode.assemblyai, + primary_provider=STTProviderName.assemblyai, + effective_provider=None, + fallback_provider=None, + fallback_enabled=False, + fallback_available=False, + enabled=False, + reason='missing_assemblyai_api_key', + ), ) response = client.get('/v2/desktop/capabilities') @@ -318,14 +334,19 @@ def test_desktop_capabilities_reports_missing_assemblyai_key_when_fallback_disab def test_desktop_capabilities_reports_no_usable_batch_provider_without_any_key(monkeypatch): client, _mock_transcribe = _client(monkeypatch) - monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') - monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') - monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) - monkeypatch.delenv('DEEPGRAM_API_KEY', raising=False) monkeypatch.setattr( desktop_background, - 'resolve_prerecorded_provider_for_request', - lambda _workload: STTProviderName.assemblyai, + 'resolve_background_provider_policy', + lambda: SimpleNamespace( + mode=BackgroundProviderMode.assemblyai, + primary_provider=STTProviderName.assemblyai, + effective_provider=None, + fallback_provider=STTProviderName.deepgram, + fallback_enabled=True, + fallback_available=False, + enabled=False, + reason='missing_assemblyai_api_key', + ), ) response = client.get('/v2/desktop/capabilities') @@ -342,16 +363,19 @@ def test_desktop_capabilities_reports_no_usable_batch_provider_without_any_key(m def test_desktop_capabilities_uses_byok_assemblyai(monkeypatch): client, _mock_transcribe = _client(monkeypatch) - monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') - monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') - monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) - monkeypatch.setattr( - desktop_background, 'get_byok_key', lambda provider: 'aai-user-key' if provider == 'assemblyai' else None - ) monkeypatch.setattr( desktop_background, - 'resolve_prerecorded_provider_for_request', - lambda _workload: STTProviderName.assemblyai, + 'resolve_background_provider_policy', + lambda: SimpleNamespace( + mode=BackgroundProviderMode.assemblyai, + primary_provider=STTProviderName.assemblyai, + effective_provider=STTProviderName.assemblyai, + fallback_provider=STTProviderName.deepgram, + fallback_enabled=True, + fallback_available=False, + enabled=True, + reason=None, + ), ) response = client.get('/v2/desktop/capabilities') @@ -360,22 +384,50 @@ def test_desktop_capabilities_uses_byok_assemblyai(monkeypatch): data = response.json()['background_batch'] assert data['enabled'] is True assert data['effective_provider'] == 'assemblyai' - assert data['mode'] == 'assemblyai_primary' + assert data['mode'] == 'assemblyai' def test_desktop_capabilities_uses_byok_deepgram_fallback(monkeypatch): client, _mock_transcribe = _client(monkeypatch) - monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') - monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') - monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) - monkeypatch.delenv('DEEPGRAM_API_KEY', raising=False) monkeypatch.setattr( - desktop_background, 'get_byok_key', lambda provider: 'dg-user-key' if provider == 'deepgram' else None + desktop_background, + 'resolve_background_provider_policy', + lambda: SimpleNamespace( + mode=BackgroundProviderMode.assemblyai, + primary_provider=STTProviderName.assemblyai, + effective_provider=STTProviderName.deepgram, + fallback_provider=STTProviderName.deepgram, + fallback_enabled=True, + fallback_available=True, + enabled=True, + reason='fallback_deepgram_available', + ), ) + + response = client.get('/v2/desktop/capabilities') + + assert response.status_code == 200 + data = response.json()['background_batch'] + assert data['enabled'] is True + assert data['effective_provider'] == 'deepgram' + assert data['mode'] == 'deepgram_fallback' + + +def test_desktop_capabilities_reports_shadow_only_mode(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) monkeypatch.setattr( desktop_background, - 'resolve_prerecorded_provider_for_request', - lambda _workload: STTProviderName.deepgram, + 'resolve_background_provider_policy', + lambda: SimpleNamespace( + mode=BackgroundProviderMode.shadow_only, + primary_provider=STTProviderName.deepgram, + effective_provider=STTProviderName.deepgram, + fallback_provider=None, + fallback_enabled=False, + fallback_available=False, + enabled=True, + reason='shadow_only', + ), ) response = client.get('/v2/desktop/capabilities') @@ -383,8 +435,10 @@ def test_desktop_capabilities_uses_byok_deepgram_fallback(monkeypatch): assert response.status_code == 200 data = response.json()['background_batch'] assert data['enabled'] is True + assert data['provider'] == 'deepgram' assert data['effective_provider'] == 'deepgram' - assert data['mode'] == 'deepgram_fallback' + assert data['mode'] == 'shadow_only' + assert data['reason'] == 'shadow_only' def test_background_transcribe_returns_segments_with_offset(monkeypatch): @@ -749,6 +803,7 @@ def test_byok_background_routing_uses_deepgram_when_only_deepgram_key(monkeypatc monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'assemblyai') monkeypatch.setattr(provider_service, 'get_byok_key', lambda provider: {'deepgram': 'dg-user-key'}.get(provider)) provider = provider_service.resolve_prerecorded_provider_for_request(STTWorkload.background) diff --git a/backend/utils/stt/provider_service.py b/backend/utils/stt/provider_service.py index 27156847fe9..21c58014baf 100644 --- a/backend/utils/stt/provider_service.py +++ b/backend/utils/stt/provider_service.py @@ -14,9 +14,11 @@ from utils.stt.deepgram_adapter import provider_result_to_legacy_words from utils.stt.provider_costs import estimate_prerecorded_provider_cost_usd from utils.stt.providers import ( + BackgroundProviderMode, STTProviderName, STTWorkload, assemblyai_prerecorded_fallback_enabled, + get_background_provider_mode, get_fallback_prerecorded_provider_name, get_prerecorded_provider_name, ) @@ -117,6 +119,10 @@ def _has_deepgram_key_for_request() -> bool: return bool((get_byok_key('deepgram') if get_byok_key else None) or os.getenv('DEEPGRAM_API_KEY')) +def _has_assemblyai_key_for_request() -> bool: + return bool((get_byok_key('assemblyai') if get_byok_key else None) or os.getenv('ASSEMBLYAI_API_KEY')) + + def _deepgram_client_for_request() -> DeepgramClient: byok = get_byok_key('deepgram') if get_byok_key else None if byok: @@ -133,6 +139,18 @@ class PrerecordedTranscriptionResponse: run_id: Optional[str] +@dataclass(frozen=True) +class BackgroundProviderPolicy: + mode: BackgroundProviderMode + primary_provider: STTProviderName + effective_provider: Optional[STTProviderName] + fallback_provider: Optional[STTProviderName] + fallback_enabled: bool + fallback_available: bool + enabled: bool + reason: Optional[str] + + class ProviderTranscriptionRetriesExhausted(RuntimeError): def __init__(self, provider_error: Exception, retry_count: int): super().__init__(str(provider_error)) @@ -144,6 +162,48 @@ def resolve_prerecorded_language_model(language: Optional[str]) -> Tuple[str, st return get_deepgram_model_for_language(language or 'multi') +def resolve_background_provider_policy() -> BackgroundProviderPolicy: + mode = get_background_provider_mode() + primary_provider = get_prerecorded_provider_name(STTWorkload.background) + effective_provider = resolve_prerecorded_provider_for_request(STTWorkload.background) + fallback_provider = get_fallback_prerecorded_provider_name(primary_provider, STTWorkload.background) + + assemblyai_key_available = _has_assemblyai_key_for_request() + deepgram_key_available = _has_deepgram_key_for_request() + fallback_available = fallback_provider == STTProviderName.deepgram and deepgram_key_available + + usable_provider = None + reason = None + if effective_provider == STTProviderName.assemblyai: + if assemblyai_key_available: + usable_provider = STTProviderName.assemblyai + elif fallback_available: + usable_provider = STTProviderName.deepgram + reason = 'fallback_deepgram_available' + else: + reason = 'missing_assemblyai_api_key' + elif effective_provider == STTProviderName.deepgram: + if deepgram_key_available: + usable_provider = STTProviderName.deepgram + if mode == BackgroundProviderMode.shadow_only: + reason = 'shadow_only' + else: + reason = 'missing_deepgram_api_key' + else: + reason = 'no_usable_batch_provider' + + return BackgroundProviderPolicy( + mode=mode, + primary_provider=primary_provider, + effective_provider=usable_provider, + fallback_provider=fallback_provider, + fallback_enabled=fallback_provider is not None, + fallback_available=fallback_available, + enabled=usable_provider is not None, + reason=reason, + ) + + def transcribe_url( audio_url: str, workload: STTWorkload, diff --git a/backend/utils/stt/providers.py b/backend/utils/stt/providers.py index ee68555fa60..609939ac9ea 100644 --- a/backend/utils/stt/providers.py +++ b/backend/utils/stt/providers.py @@ -19,6 +19,12 @@ class STTWorkload(str, Enum): voice_message = 'voice_message' +class BackgroundProviderMode(str, Enum): + assemblyai = 'assemblyai' + deepgram = 'deepgram' + shadow_only = 'shadow_only' + + class PrerecordedSTTProvider(Protocol): provider_name: STTProviderName @@ -93,6 +99,14 @@ class SpeakerIdentityProvider(Protocol): def get_prerecorded_provider_name(workload: STTWorkload) -> STTProviderName: workload = STTWorkload(workload) + if workload == STTWorkload.background: + if ( + get_background_provider_mode() == BackgroundProviderMode.assemblyai + and _assemblyai_prerecorded_enabled() + and workload in _assemblyai_enabled_workloads() + ): + return STTProviderName.assemblyai + return STTProviderName.deepgram if _assemblyai_prerecorded_enabled() and workload in _assemblyai_enabled_workloads(): return STTProviderName.assemblyai return _DEFAULT_PRERECORDED_WORKLOAD_PROVIDERS[workload] @@ -127,6 +141,14 @@ def assemblyai_prerecorded_fallback_enabled() -> bool: return os.getenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', 'true').lower() == 'true' +def get_background_provider_mode() -> BackgroundProviderMode: + configured = os.getenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', BackgroundProviderMode.shadow_only.value) + try: + return BackgroundProviderMode(configured.strip().lower()) + except ValueError: + return BackgroundProviderMode.shadow_only + + def _assemblyai_enabled_workloads() -> set[STTWorkload]: configured = os.getenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') workloads = set() diff --git a/docs/doc/developer/backend/assemblyai_background_rollout.mdx b/docs/doc/developer/backend/assemblyai_background_rollout.mdx index 67f7a862d17..264a02dac75 100644 --- a/docs/doc/developer/backend/assemblyai_background_rollout.mdx +++ b/docs/doc/developer/backend/assemblyai_background_rollout.mdx @@ -8,9 +8,16 @@ description: "Rollout gates, feature flags, instrumentation, and rollback for th ## Scope -AssemblyAI is the MVP async prerecorded provider. Deepgram remains -the provider for mobile/BLE `/v4/listen`, realtime assistant streaming, -Hold-to-Talk streaming, and voice-message finalize semantics. +AssemblyAI is the intended async background provider, but broad background +defaulting is still gated. Current saved-provider benchmark evidence favors +Deepgram on cost-adjusted speaker quality: Deepgram reached 99.79% speaker +purity at about $0.264/hour, while AssemblyAI reached 95.91% at about +$0.734/hour. The rollout path must close or explicitly mitigate that gap before +AssemblyAI becomes the broad default. + +Deepgram remains the provider for mobile/BLE `/v4/listen`, realtime assistant +streaming, Hold-to-Talk streaming, voice-message finalize semantics, and +background fallback. Eligible AssemblyAI workloads are: @@ -26,12 +33,14 @@ Ineligible latency-sensitive workloads stay on Deepgram: ## Feature Flags And Environment -AssemblyAI is enabled by default for eligible prerecorded workloads when -credentials are configured. +AssemblyAI remains enabled by default for eligible non-background prerecorded +workloads when credentials are configured. Desktop background has its own +provider mode so canary rollout can be controlled independently. | Variable | Default | Purpose | | --- | --- | --- | | `ASSEMBLYAI_API_KEY` | unset | Required before any AssemblyAI request can run. | +| `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE` | `shadow_only` | Desktop background policy mode: `assemblyai`, `deepgram`, or `shadow_only`. `shadow_only` keeps production background requests on Deepgram while the AssemblyAI path is prepared for canary evidence. | | `ASSEMBLYAI_PRERECORDED_STT_ENABLED` | `true` | Main rollout switch for passive prerecorded workloads. | | `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS` | `sync,background,postprocess` | Comma-separated eligible prerecorded workloads. Unknown or ineligible values are ignored. | | `ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED` | `true` | Use Deepgram when AssemblyAI is unavailable or fails and Deepgram credentials are usable. | @@ -43,9 +52,20 @@ credentials are configured. ## Routing And Fallback -`backend/utils/stt/providers.py` owns provider selection. Eligible workloads -use AssemblyAI only when the main flag is enabled and the workload is in -`ASSEMBLYAI_PRERECORDED_STT_WORKLOADS`; otherwise they use Deepgram. +`backend/utils/stt/providers.py` owns provider selection, and +`backend/utils/stt/provider_service.py` resolves request-level credentials and +fallbacks. Eligible non-background workloads use AssemblyAI only when the main +flag is enabled and the workload is in `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS`; +otherwise they use Deepgram. + +Desktop background routing is controlled by +`ASSEMBLYAI_BACKGROUND_PROVIDER_MODE`: + +| Mode | Runtime provider | Use | +| --- | --- | --- | +| `shadow_only` | Deepgram | Default until TICKET-026 evals and TICKET-028 canary gates are rollout-ready. | +| `assemblyai` | AssemblyAI when credentials and workload gates allow it; Deepgram fallback when configured and usable. | Canary and eventual default mode. | +| `deepgram` | Deepgram | Kill switch or explicit opt-out without a code deploy. | Desktop Audio Recording uses the background workload through: @@ -63,11 +83,15 @@ or exhausts retries for an eligible prerecorded workload, the failed AssemblyAI run is finalized in the provider ledger and the request retries through Deepgram with fallback metadata. If AssemblyAI credentials are missing before the request starts, provider resolution skips directly to Deepgram when -fallback is enabled and Deepgram credentials are usable. - -Rollback is to set `ASSEMBLYAI_PRERECORDED_STT_ENABLED=false` or remove the -affected workload from `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS`. No client deploy -is required for rollback. +fallback is enabled and Deepgram credentials are usable. If fallback is disabled +or no Deepgram key is usable, the AssemblyAI request fails closed and records +the failed provider run. + +Rollback for background is to set `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=deepgram` +or `shadow_only`. The broader kill switch is +`ASSEMBLYAI_PRERECORDED_STT_ENABLED=false` or removing the affected workload +from `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS`. No client deploy is required for +rollback. ## BYOK (Bring Your Own Keys) From 2473a3e2ada4dbc91fbd50ec6eb5aae1a8696588 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Sun, 24 May 2026 12:52:41 +0700 Subject: [PATCH 38/44] Verify silence-aware background chunking --- .../BackgroundAudioChunker.swift | 19 +++- ...BackgroundTranscriptionConfiguration.swift | 14 +++ .../Tests/BackgroundTranscriptionTests.swift | 89 +++++++++++++++++++ .../backend/assemblyai_background_rollout.mdx | 9 ++ .../backend/listen_pusher_pipeline.mdx | 15 ++-- 5 files changed, 136 insertions(+), 10 deletions(-) diff --git a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift index 0a881fc305f..2fea1bfca65 100644 --- a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift +++ b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift @@ -110,7 +110,9 @@ struct BackgroundAudioChunker { var offset = minBytes.alignedToSample while offset + windowBytes <= maxBytes { - if isSilentWindow(start: offset, byteCount: windowBytes), hasSpeech(before: offset) { + if isSilentWindow(start: offset, byteCount: windowBytes), + hasMinimumSpeech(before: offset) + { return offset } offset += configuration.bytesPerSample @@ -130,11 +132,16 @@ struct BackgroundAudioChunker { return true } - private func hasSpeech(before endOffset: Int) -> Bool { + private func hasMinimumSpeech(before endOffset: Int) -> Bool { guard endOffset > 0 else { return false } + let minimumSpeechBytes = max( + configuration.bytesPerSample, + configuration.alignedByteCount(for: configuration.speechActivityDetection.minimumSpeechDuration) + ) var peak = 0 var sumSquares = 0.0 var count = 0 + var speechLikeBytes = 0 for offset in stride( from: 0, to: min(endOffset, buffer.count), by: configuration.bytesPerSample) @@ -142,13 +149,17 @@ struct BackgroundAudioChunker { let amplitude = sampleAmplitude(at: offset) peak = max(peak, amplitude) sumSquares += Double(amplitude * amplitude) + if amplitude >= configuration.speechPeakAmplitudeThreshold { + speechLikeBytes += configuration.bytesPerSample + } count += 1 } guard count > 0 else { return false } let rms = sqrt(sumSquares / Double(count)) - return peak >= configuration.speechPeakAmplitudeThreshold - || rms >= Double(configuration.speechRMSAmplitudeThreshold) + return speechLikeBytes >= minimumSpeechBytes + || (rms >= Double(configuration.speechRMSAmplitudeThreshold) + && endOffset >= minimumSpeechBytes) } private func sampleAmplitude(at offset: Int) -> Int { diff --git a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift index b7eb8279db2..626e1646e89 100644 --- a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift +++ b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift @@ -13,6 +13,9 @@ struct BackgroundTranscriptionConfiguration: Equatable { var maxChunkTranscriptionAttempts: Int = 3 var requiresSpeechBeforeUpload: Bool = false var speechActivityDetection = SpeechActivityDetectionConfiguration() + var usesSilenceAwareChunking: Bool { + minChunkDuration < maxChunkDuration + } var bytesPerSample: Int { 2 } @@ -29,6 +32,10 @@ struct BackgroundTranscriptionConfiguration: Equatable { } static var cloudBatch: BackgroundTranscriptionConfiguration { + fixedFifteenSecondCloudBatch + } + + static var fixedFifteenSecondCloudBatch: BackgroundTranscriptionConfiguration { BackgroundTranscriptionConfiguration( maxChunkDuration: 15.0, minChunkDuration: 15.0, @@ -45,6 +52,13 @@ struct BackgroundTranscriptionConfiguration: Equatable { ) ) } + + static var silenceAwareCloudBatchCandidate: BackgroundTranscriptionConfiguration { + var configuration = fixedFifteenSecondCloudBatch + configuration.minChunkDuration = 6.0 + configuration.silenceWindowDuration = 0.35 + return configuration + } } extension Int { diff --git a/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift index 95ad0d9a59e..bc07eaa9109 100644 --- a/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift +++ b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift @@ -156,6 +156,65 @@ final class BackgroundTranscriptionTests: XCTestCase { XCTAssertEqual(session.snapshot().segments.count, 1) } + func testCloudBatchDefaultKeepsFixedFifteenSecondFallback() { + let configuration = BackgroundTranscriptionConfiguration.cloudBatch + var chunker = BackgroundAudioChunker(configuration: configuration) + + XCTAssertFalse(configuration.usesSilenceAwareChunking) + XCTAssertEqual(configuration.minChunkDuration, 15.0) + XCTAssertEqual(configuration.maxChunkDuration, 15.0) + + let samplesPerFrame = configuration.sampleRate / 10 + let speechFrame = pcm(samples: Array(repeating: 1_000, count: samplesPerFrame)) + let silenceFrame = pcm(samples: Array(repeating: 0, count: samplesPerFrame)) + var chunks: [BackgroundAudioChunk] = [] + + for frameIndex in 0..<70 { + let frame = frameIndex < 65 ? speechFrame : silenceFrame + chunks.append(contentsOf: chunker.append(pcmData: frame, startTime: Double(frameIndex) / 10.0)) + } + XCTAssertTrue(chunks.isEmpty) + + for frameIndex in 70..<160 { + chunks.append(contentsOf: chunker.append(pcmData: speechFrame, startTime: Double(frameIndex) / 10.0)) + } + + XCTAssertEqual(chunks.count, 1) + XCTAssertEqual(chunks[0].startTime, 0, accuracy: 0.001) + XCTAssertEqual(sampleCount(chunks[0].pcmData), configuration.sampleRate * 15) + } + + func testSilenceAwareCloudBatchCandidateFlushesOnSilenceBeforeHardCap() { + let configuration = BackgroundTranscriptionConfiguration.silenceAwareCloudBatchCandidate + var chunker = BackgroundAudioChunker(configuration: configuration) + + XCTAssertTrue(configuration.usesSilenceAwareChunking) + XCTAssertEqual(configuration.maxChunkDuration, 15.0) + XCTAssertLessThan(configuration.minChunkDuration, configuration.maxChunkDuration) + + let speechSamples = Array(repeating: Int16(1_000), count: Int(Double(configuration.sampleRate) * 6.5)) + let silenceSamples = Array(repeating: Int16(0), count: Int(Double(configuration.sampleRate) * 0.5)) + let chunks = chunker.append(pcmData: pcm(samples: speechSamples + silenceSamples), startTime: 30) + + XCTAssertEqual(chunks.count, 1) + XCTAssertEqual(chunks[0].startTime, 30, accuracy: 0.001) + XCTAssertLessThan(sampleCount(chunks[0].pcmData), configuration.sampleRate * 15) + XCTAssertGreaterThanOrEqual(sampleCount(chunks[0].pcmData), Int(Double(configuration.sampleRate) * 6.5)) + XCTAssertFalse(chunks[0].isFinal) + } + + func testSilenceAwareCloudBatchCandidateStillHardCapsContinuousSpeech() { + let configuration = BackgroundTranscriptionConfiguration.silenceAwareCloudBatchCandidate + var chunker = BackgroundAudioChunker(configuration: configuration) + + let samples = Array(repeating: Int16(1_000), count: configuration.sampleRate * 16) + let chunks = chunker.append(pcmData: pcm(samples: samples), startTime: 12) + + XCTAssertEqual(chunks.count, 1) + XCTAssertEqual(chunks[0].startTime, 12, accuracy: 0.001) + XCTAssertEqual(sampleCount(chunks[0].pcmData), configuration.sampleRate * 15) + } + func testCloudBatchDropsSilentChunksBeforeUpload() async throws { let configuration = BackgroundTranscriptionConfiguration.cloudBatch var uploadCount = 0 @@ -257,6 +316,36 @@ final class BackgroundTranscriptionTests: XCTestCase { ) } + func testSilenceAwareCloudBatchCandidateDropsShortSpeechBeforeUpload() async throws { + let configuration = BackgroundTranscriptionConfiguration.silenceAwareCloudBatchCandidate + var uploadCount = 0 + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + uploadCount += 1 + return [ + Self.backendSegment( + id: "short-speech", + text: "short", + start: chunk.startTime, + end: chunk.startTime + 1 + ) + ] + } + + let shortSpeech = Array( + repeating: Int16(1_000), count: Int(Double(configuration.sampleRate) * 0.2)) + let silence = Array( + repeating: Int16(0), count: Int(Double(configuration.sampleRate) * 6.5)) + let result = session.append(pcmData: pcm(samples: shortSpeech + silence), startTime: 0) + + XCTAssertEqual(result.enqueuedChunks, 0) + XCTAssertEqual(session.pendingChunkCount, 0) + let shortSpeechUpload = try await session.transcribeNext() + XCTAssertNil(shortSpeechUpload) + XCTAssertEqual(uploadCount, 0) + XCTAssertEqual(session.snapshot().droppedChunkCount, 1) + XCTAssertEqual(session.snapshot().lastSpeechActivityDecision?.reason, .insufficientSpeech) + } + func testCloudBatchFinishSplitsLongTailAtFifteenSecondWindows() { let configuration = BackgroundTranscriptionConfiguration.cloudBatch var chunker = BackgroundAudioChunker(configuration: configuration) diff --git a/docs/doc/developer/backend/assemblyai_background_rollout.mdx b/docs/doc/developer/backend/assemblyai_background_rollout.mdx index 264a02dac75..30b3afa8d63 100644 --- a/docs/doc/developer/backend/assemblyai_background_rollout.mdx +++ b/docs/doc/developer/backend/assemblyai_background_rollout.mdx @@ -78,6 +78,15 @@ upload, applies the chunk timeline offset, maps provider speaker clusters to stable numeric `speaker_id` values per conversation, and appends segments to the in-progress conversation when `persist=true`. +Desktop chunking is currently config-gated to the fixed 15s safety profile: +`BackgroundTranscriptionConfiguration.cloudBatch` aliases the fixed-15s +configuration with 0.5s overlap and client-side speech-activity suppression. +The Swift chunker also has a named silence-aware candidate configuration that +can flush after the minimum speech/audio window and a 350ms silence gap, while +still enforcing the same 15s hard cap. Do not switch the runtime background path +to that candidate until Deepgram-vs-AssemblyAI replay evals show no speaker +quality regression on fast-turn, overlap, sparse, and variable-density fixtures. + Deepgram is the prerecorded fallback provider. If AssemblyAI fails, times out, or exhausts retries for an eligible prerecorded workload, the failed AssemblyAI run is finalized in the provider ledger and the request retries through diff --git a/docs/doc/developer/backend/listen_pusher_pipeline.mdx b/docs/doc/developer/backend/listen_pusher_pipeline.mdx index 25c7d98eed6..2d35c9fb5be 100644 --- a/docs/doc/developer/backend/listen_pusher_pipeline.mdx +++ b/docs/doc/developer/backend/listen_pusher_pipeline.mdx @@ -244,12 +244,15 @@ explicit conversation with `POST /v2/desktop/background-conversation/{conversation_id}/finish`. Mobile listen, BLE listen, and PTT continue to use their existing Deepgram paths. -For cloud batch, the desktop client buffers roughly 15s PCM chunks with 0.5s -overlap, drops chunks that do not contain sustained speech activity before -uploading, and uses backpressure instead of stopping recording when AssemblyAI -is slower than realtime. AssemblyAI batch requests use the selected single -language rather than `multi` to avoid provider-side language-detection failures -on low-speech audio. +For cloud batch, the desktop client currently uses the fixed 15s safety profile: +15s PCM chunks with 0.5s overlap, client-side sustained-speech checks before +uploading, and backpressure instead of stopping recording when AssemblyAI is +slower than realtime. A silence-aware client configuration exists for eval and +can flush earlier after sustained speech plus a silence gap, but it remains off +the runtime path until replay gates show no speaker-quality regression versus +fixed 15s. AssemblyAI batch requests use the selected single language rather +than `multi` to avoid provider-side language-detection failures on low-speech +audio. Eligible AssemblyAI workloads are `sync`, `background`, and `postprocess`. Latency-critical workloads (`realtime`, streaming `ptt`, and From af356502234634a97d1f7770cd07918bfff3f33a Mon Sep 17 00:00:00 2001 From: David Zhang Date: Sun, 24 May 2026 12:56:40 +0700 Subject: [PATCH 39/44] Harden background speaker reconciliation --- backend/routers/desktop_background.py | 116 +++++++++++++++++- .../unit/test_background_speaker_identity.py | 22 ++++ .../test_desktop_background_transcribe.py | 62 +++++++++- .../utils/stt/background_speaker_identity.py | 10 ++ 4 files changed, 201 insertions(+), 9 deletions(-) diff --git a/backend/routers/desktop_background.py b/backend/routers/desktop_background.py index d75d2c71eb7..13c64135c9d 100644 --- a/backend/routers/desktop_background.py +++ b/backend/routers/desktop_background.py @@ -45,6 +45,15 @@ _MAX_PCM_BODY_BYTES = 200_000_000 _SPEAKER_MAP_TTL_SECONDS = 60 * 60 * 24 _LOCAL_CLUSTER_SPLIT_MARKER = "::local_part:" +_ANONYMOUS_CHUNK_NAMESPACE = "chunk" +_IDENTIFIED_SPEAKER_NAMESPACE = "identity" + +_SPEAKER_RECONCILIATION_BUDGETS = { + "max_app_visible_speaker_inflation_ratio": 4.0, + "max_provider_only_false_merge_rate": 0.0, + "max_unresolved_anonymous_speaker_duration_seconds": 60 * 60, + "max_split_reconcile_ratio": 10.0, +} class BackgroundConversationStartRequest(BaseModel): @@ -260,8 +269,9 @@ async def background_transcribe( del audio_bytes segments = response.segments - _split_noncontiguous_provider_clusters(segments) + split_diagnostics = _split_noncontiguous_provider_clusters(segments) speaker_diagnostics = _speaker_diagnostics(segments) + speaker_diagnostics.update(split_diagnostics) if conversation_id and segments: await _identify_speakers( uid=uid, @@ -274,8 +284,11 @@ async def background_transcribe( ) _apply_chunk_offset(segments, chunk_start_ms / 1000.0) if conversation_id: - _apply_speaker_ids(conversation_id, segments) + reconciliation_diagnostics = _apply_speaker_ids(conversation_id, chunk_id, segments) + else: + reconciliation_diagnostics = _speaker_reconciliation_diagnostics(segments, {}, {}, {}) speaker_diagnostics.update(_speaker_diagnostics(segments, prefix="mapped_")) + speaker_diagnostics.update(reconciliation_diagnostics) finished_at = datetime.now(timezone.utc) if persist and conversation_id: @@ -393,7 +406,7 @@ def _apply_chunk_offset(segments: List[TranscriptSegment], offset_sec: float) -> segment.end += offset_sec -def _split_noncontiguous_provider_clusters(segments: List[TranscriptSegment]) -> None: +def _split_noncontiguous_provider_clusters(segments: List[TranscriptSegment]) -> Dict[str, object]: """Split reused provider clusters into local contiguous groups. Batched providers can reuse the same cluster label for rapid, non-contiguous turns @@ -423,16 +436,28 @@ def _split_noncontiguous_provider_clusters(segments: List[TranscriptSegment]) -> for cluster, _group_segments in groups: group_counts[cluster] = group_counts.get(cluster, 0) + 1 + split_cluster_count = 0 + split_segment_count = 0 + cannot_link_pairs_prevented = 0 group_indexes: Dict[str, int] = {} for cluster, group_segments in groups: if group_counts[cluster] <= 1: continue + split_cluster_count += 1 + split_segment_count += len(group_segments) + cannot_link_pairs_prevented += group_indexes.get(cluster, 0) * len(group_segments) group_index = group_indexes.get(cluster, 0) + 1 group_indexes[cluster] = group_index local_cluster = f'{cluster}{_LOCAL_CLUSTER_SPLIT_MARKER}{group_index}' for segment in group_segments: segment.provider_cluster_id = local_cluster + return { + "speaker_split_cluster_count": split_cluster_count, + "speaker_split_segment_count": split_segment_count, + "cannot_link_violations_prevented": cannot_link_pairs_prevented, + } + def _raw_provider_cluster_key(segment: TranscriptSegment) -> Optional[str]: cluster = segment.provider_cluster_id or segment.provider_speaker_label @@ -516,13 +541,25 @@ def _build_person_embeddings_cache(uid: str) -> Dict[str, dict]: return cache -def _apply_speaker_ids(conversation_id: str, segments: List[TranscriptSegment]) -> None: +def _apply_speaker_ids( + conversation_id: str, chunk_id: Optional[str], segments: List[TranscriptSegment] +) -> Dict[str, object]: speaker_map = _load_speaker_map(conversation_id) changed = False + attempted_reconciliation: Dict[str, str] = {} + accepted_reconciliation: Dict[str, str] = {} + rejected_reconciliation: Dict[str, str] = {} for segment in segments: - cluster = segment.provider_cluster_id or segment.provider_speaker_label or segment.speaker - if not cluster: + raw_cluster = segment.provider_cluster_id or segment.provider_speaker_label or segment.speaker + if not raw_cluster: continue + cluster = _speaker_reconciliation_key(chunk_id, segment, str(raw_cluster)) + attempted_reconciliation[str(raw_cluster)] = cluster + if cluster.startswith(f"{_IDENTIFIED_SPEAKER_NAMESPACE}:"): + accepted_reconciliation[str(raw_cluster)] = cluster + else: + rejected_reconciliation[str(raw_cluster)] = "anonymous_provider_cluster_is_chunk_local" + if cluster not in speaker_map: speaker_map[cluster] = len(speaker_map) changed = True @@ -534,6 +571,22 @@ def _apply_speaker_ids(conversation_id: str, segments: List[TranscriptSegment]) if changed: _store_speaker_map(conversation_id, speaker_map) + return _speaker_reconciliation_diagnostics( + segments, + attempted_reconciliation, + accepted_reconciliation, + rejected_reconciliation, + ) + + +def _speaker_reconciliation_key(chunk_id: Optional[str], segment: TranscriptSegment, raw_cluster: str) -> str: + if segment.is_user or segment.speaker_identity_state == "user": + return f"{_IDENTIFIED_SPEAKER_NAMESPACE}:user" + if segment.person_id and segment.speaker_identity_state == "identified": + return f"{_IDENTIFIED_SPEAKER_NAMESPACE}:person:{segment.person_id}" + + namespace = chunk_id or "unspecified" + return f"{_ANONYMOUS_CHUNK_NAMESPACE}:{namespace}:provider:{raw_cluster}" def _speaker_map_key(conversation_id: str) -> str: @@ -555,6 +608,57 @@ def _store_speaker_map(conversation_id: str, speaker_map: Dict[str, int]) -> Non redis_db.r.set(_speaker_map_key(conversation_id), json.dumps(speaker_map), ex=_SPEAKER_MAP_TTL_SECONDS) +def _speaker_reconciliation_diagnostics( + segments: List[TranscriptSegment], + attempted_reconciliation: Dict[str, str], + accepted_reconciliation: Dict[str, str], + rejected_reconciliation: Dict[str, str], +) -> Dict[str, object]: + clean_speech_duration = 0.0 + unknown_speaker_duration = 0.0 + provider_clusters = set() + app_speakers = set() + for segment in segments: + duration = max(0.0, segment.end - segment.start) + if segment.text and len(segment.text.split()) >= 2: + clean_speech_duration += duration + if segment.speaker_identity_state in ("unknown", "unassigned", "legacy_ambiguous"): + unknown_speaker_duration += duration + cluster = segment.provider_cluster_id or segment.provider_speaker_label + if cluster: + provider_clusters.add(str(cluster)) + if segment.speaker_id is not None: + app_speakers.add(int(segment.speaker_id)) + + provider_count = max(len(provider_clusters), 1) + speaker_inflation_ratio = round(len(app_speakers) / provider_count, 6) + split_count = sum(1 for cluster in provider_clusters if _LOCAL_CLUSTER_SPLIT_MARKER in cluster) + accepted_count = len(set(accepted_reconciliation.values())) + split_reconcile_ratio = round(split_count / max(accepted_count, 1), 6) if split_count else 0.0 + + budget_violations = [] + if speaker_inflation_ratio > _SPEAKER_RECONCILIATION_BUDGETS["max_app_visible_speaker_inflation_ratio"]: + budget_violations.append("max_app_visible_speaker_inflation_ratio") + if unknown_speaker_duration > _SPEAKER_RECONCILIATION_BUDGETS["max_unresolved_anonymous_speaker_duration_seconds"]: + budget_violations.append("max_unresolved_anonymous_speaker_duration_seconds") + if split_reconcile_ratio > _SPEAKER_RECONCILIATION_BUDGETS["max_split_reconcile_ratio"]: + budget_violations.append("max_split_reconcile_ratio") + + return { + "speaker_reconciliation_attempted_count": len(attempted_reconciliation), + "speaker_reconciliation_accepted_count": len(accepted_reconciliation), + "speaker_reconciliation_rejected_count": len(rejected_reconciliation), + "speaker_reconciliation_rejected_reasons": sorted(set(rejected_reconciliation.values()))[:20], + "provider_only_false_merge_rate": 0.0, + "clean_speech_duration_seconds": round(clean_speech_duration, 3), + "unknown_speaker_duration_seconds": round(unknown_speaker_duration, 3), + "app_visible_speaker_inflation_ratio": speaker_inflation_ratio, + "split_reconcile_ratio": split_reconcile_ratio, + "speaker_reconciliation_budgets": _SPEAKER_RECONCILIATION_BUDGETS, + "speaker_reconciliation_budget_violations": budget_violations, + } + + def _speaker_diagnostics(segments: List[TranscriptSegment], prefix: str = "") -> Dict[str, object]: provider_clusters = sorted( {str(segment.provider_cluster_id) for segment in segments if segment.provider_cluster_id is not None} diff --git a/backend/tests/unit/test_background_speaker_identity.py b/backend/tests/unit/test_background_speaker_identity.py index d7bfb8151ab..4d0a152610a 100644 --- a/backend/tests/unit/test_background_speaker_identity.py +++ b/backend/tests/unit/test_background_speaker_identity.py @@ -136,6 +136,28 @@ def test_text_self_introduction_is_hint_only_without_voice_assignment(): assert segments[0].speaker_identity_text_hints[0]['source'] == 'text_self_introduction' +def test_text_hint_negative_controls_do_not_create_identity_hints(): + negative_texts = [ + 'Bonjour, je suis Alice et je parle francais.', + 'Alice said I am Bob during the interview.', + '"I am Alice," she read from the script.', + 'Alice is my manager and she mentioned the roadmap.', + 'Hello Alice, hello Alice, thanks for joining.', + 'This monologue has no actual speaker change.', + 'My name as Alice was mistranscribed by the ASR.', + ] + + segments = [ + _segment(f'negative-{index}', index * 3.0, index * 3.0 + 2.0, cluster=f'cluster-{index}', text=text) + for index, text in enumerate(negative_texts) + ] + + assignments = identify_background_speaker_clusters(segments, audio_bytes=None, person_embeddings_cache={}) + + assert all(not assignment.text_hints for assignment in assignments.values()) + assert all(not segment.speaker_identity_text_hints for segment in segments) + + def test_user_sentinel_is_not_persisted_as_durable_person_identity(): user_embedding = np.array([[1.0, 0.0]], dtype=np.float32) segments = [_segment('me', 0.0, 4.0)] diff --git a/backend/tests/unit/test_desktop_background_transcribe.py b/backend/tests/unit/test_desktop_background_transcribe.py index 199e0496cf1..abebcaeb619 100644 --- a/backend/tests/unit/test_desktop_background_transcribe.py +++ b/backend/tests/unit/test_desktop_background_transcribe.py @@ -642,6 +642,9 @@ def test_noncontiguous_same_provider_cluster_splits_inside_chunk(monkeypatch): 'A::local_part:2', ] assert data['speaker_diagnostics']['provider_cluster_count'] == 3 + assert data['speaker_diagnostics']['speaker_split_cluster_count'] == 2 + assert data['speaker_diagnostics']['speaker_split_segment_count'] == 2 + assert data['speaker_diagnostics']['cannot_link_violations_prevented'] == 1 def test_contiguous_same_provider_cluster_stays_together_inside_chunk(monkeypatch): @@ -708,7 +711,7 @@ def test_assemblyai_label_only_speakers_do_not_collapse_to_single_local_speaker( assert desktop_background.update_provider_run_identity_metrics.call_args.args[6] == 'missing_candidate_embeddings' -def test_background_transcribe_multi_chunk_offsets_persist_and_keep_speaker_map(monkeypatch): +def test_background_transcribe_multi_chunk_offsets_persist_and_keep_anonymous_speakers_chunk_local(monkeypatch): client, mock_transcribe = _client(monkeypatch) speaker_map_store = {} @@ -752,8 +755,61 @@ def _transcribe_bytes(_audio_bytes, **_kwargs): ] assert [segment.start for segment in appended_segments] == [0.1, 14.1, 28.1] assert [segment.end for segment in appended_segments] == [1.0, 15.0, 29.0] - assert [segment.speaker_id for segment in appended_segments] == [0, 1, 0] - assert [segment.speaker for segment in appended_segments] == ['SPEAKER_00', 'SPEAKER_01', 'SPEAKER_00'] + assert [segment.speaker_id for segment in appended_segments] == [0, 1, 2] + assert [segment.speaker for segment in appended_segments] == ['SPEAKER_00', 'SPEAKER_01', 'SPEAKER_02'] + + +def test_background_transcribe_multi_chunk_known_omi_identity_can_reconcile(monkeypatch): + alice_embedding = np.array([[1.0, 0.0]], dtype=np.float32) + client, mock_transcribe = _client( + monkeypatch, + person_embeddings_cache={'person-alice': {'embedding': alice_embedding, 'name': 'Alice'}}, + ) + monkeypatch.setattr(desktop_background, 'extract_embedding_from_bytes', lambda _audio, _filename: alice_embedding) + speaker_map_store = {} + + def _redis_get(key): + return speaker_map_store.get(key) + + def _redis_set(key, value, **_kwargs): + speaker_map_store[key] = value + + monkeypatch.setattr(desktop_background.redis_db.r, 'get', _redis_get) + monkeypatch.setattr(desktop_background.redis_db.r, 'set', _redis_set) + + responses = [ + [TranscriptSegment(text='Alice first chunk.', is_user=False, start=0.1, end=3.0, provider_cluster_id='A')], + [TranscriptSegment(text='Alice second chunk.', is_user=False, start=0.1, end=3.0, provider_cluster_id='B')], + ] + + def _transcribe_bytes(_audio_bytes, **_kwargs): + return SimpleNamespace( + result=ProviderTranscriptResult(provider='assemblyai', model='universal-2', words=[], utterances=[]), + detected_language='en', + segments=[segment.model_copy(deep=True) for segment in responses.pop(0)], + run_id='run-known', + ) + + mock_transcribe.side_effect = _transcribe_bytes + + response_payloads = [] + for chunk_start_ms in (0, 14000): + response = client.post( + f'/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-id-{chunk_start_ms}&chunk_start_ms={chunk_start_ms}', + content=b'\x01\x00' * 16000 * 4, + headers={'Content-Type': 'application/octet-stream'}, + ) + assert response.status_code == 200 + response_payloads.append(response.json()) + + appended_segments = [ + call.args[4][0] + for call in desktop_background.append_background_chunk_to_in_progress_conversation.call_args_list + ] + assert [segment.person_id for segment in appended_segments] == ['person-alice', 'person-alice'] + assert [segment.speaker_id for segment in appended_segments] == [0, 0] + assert response_payloads[-1]['speaker_diagnostics']['speaker_reconciliation_accepted_count'] == 1 + assert response_payloads[-1]['speaker_diagnostics']['speaker_reconciliation_rejected_count'] == 0 def test_background_transcribe_identifies_assemblyai_speaker_with_omi_user_embedding(monkeypatch): diff --git a/backend/utils/stt/background_speaker_identity.py b/backend/utils/stt/background_speaker_identity.py index 3b6210fc753..8db7d4c4a1c 100644 --- a/backend/utils/stt/background_speaker_identity.py +++ b/backend/utils/stt/background_speaker_identity.py @@ -338,12 +338,22 @@ def _detect_self_introduction_hint(text: str) -> Optional[str]: match = re.search(pattern, text) if not match: continue + if _looks_like_quoted_or_reported_speech(text, match.start()): + continue name = match.groups()[-1] if name and len(name) >= 2: return name.capitalize() return None +def _looks_like_quoted_or_reported_speech(text: str, match_start: int) -> bool: + prefix = text[:match_start].strip().lower() + if prefix and prefix[-1:] in {'"', "'", '“', '‘'}: + return True + recent_prefix = prefix[-40:] + return bool(re.search(r"\b(said|says|asked|told|quoted|read|wrote|writes)\b", recent_prefix)) + + def _rank_candidates( query_embedding: np.ndarray, person_embeddings_cache: Dict[str, dict], From 2ed8dee7be489f9b82a6467344f1574fe3291228 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Sun, 24 May 2026 13:08:27 +0700 Subject: [PATCH 40/44] Add offline STT provider readiness gate --- .../scripts/stt/provider_comparison_gate.md | 27 + .../scripts/stt/provider_comparison_gate.py | 8 +- .../fixture_good_meeting.assemblyai.json | 2 + ...ixture_good_meeting.assemblyai.rollup.json | 7 +- .../fixture_good_meeting.deepgram.json | 2 + .../fixture_good_meeting.deepgram.rollup.json | 7 +- .../fixtures/stt_provider_eval/manifest.json | 567 ++++++++++++++++++ .../tests/unit/test_provider_evaluation.py | 47 +- backend/utils/stt/provider_evaluation.py | 424 ++++++++++++- 9 files changed, 1057 insertions(+), 34 deletions(-) create mode 100644 backend/scripts/stt/provider_comparison_gate.md diff --git a/backend/scripts/stt/provider_comparison_gate.md b/backend/scripts/stt/provider_comparison_gate.md new file mode 100644 index 00000000000..b300d3fd586 --- /dev/null +++ b/backend/scripts/stt/provider_comparison_gate.md @@ -0,0 +1,27 @@ +# STT Provider Comparison Gate + +Run the offline production-readiness eval from the backend directory: + +```bash +python3 scripts/stt/provider_comparison_gate.py \ + --manifest tests/fixtures/stt_provider_eval/manifest.json \ + --output-md /tmp/stt-provider-eval.md \ + --output-json /tmp/stt-provider-eval.json +``` + +The default command replays synthetic and saved provider outputs only. It does not require `ASSEMBLYAI_API_KEY` or `DEEPGRAM_API_KEY`. + +The report compares `always_deepgram`, `always_assemblyai`, `current_policy`, and `shadow_only`. It includes speaker safety, default viability, and canary readiness gates plus an AssemblyAI gap report that names the limiting scenario, likely cause, and rollout mitigation. + +The fixture manifest covers clean turns, fast turns, overlap, sparse speech, low-signal/no-speech, multilingual turns, duplicate chunk replay, provider failure/fallback, saved real-provider E2E output, and saved policy-router output. + +Live-provider smoke tests are optional: + +```bash +ASSEMBLYAI_API_KEY=... DEEPGRAM_API_KEY=... \ +python3 scripts/stt/provider_comparison_gate.py \ + --manifest tests/fixtures/stt_provider_eval/manifest.json \ + --live +``` + +Synthetic and saved-output gates are necessary but insufficient for broad defaulting. TICKET-028 must turn the latest gap-closing report into an AssemblyAI canary/default rollout plan backed by privacy-safe real-session metrics. diff --git a/backend/scripts/stt/provider_comparison_gate.py b/backend/scripts/stt/provider_comparison_gate.py index 63b72a70543..322b6cad59f 100644 --- a/backend/scripts/stt/provider_comparison_gate.py +++ b/backend/scripts/stt/provider_comparison_gate.py @@ -22,7 +22,7 @@ from utils.stt.providers import STTProviderName, STTWorkload # noqa: E402 LIVE_PROVIDER_IMPORT_ERROR = None -except ModuleNotFoundError as e: +except Exception as e: _transcribe_bytes_with_provider = None _transcribe_url_with_provider = None STTProviderName = None @@ -120,7 +120,11 @@ def main() -> int: def _prepare_case(case: dict[str, Any], base_path: Path) -> dict[str, Any]: - prepared = {'id': case.get('id') or case.get('case_id')} + prepared = { + 'id': case.get('id') or case.get('case_id'), + 'scenario': case.get('scenario') or case.get('type'), + 'current_policy_provider': case.get('current_policy_provider'), + } prepared['deepgram'] = _load_provider_payload(case, base_path, 'deepgram') prepared['assemblyai'] = _load_provider_payload(case, base_path, 'assemblyai') return prepared diff --git a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.json b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.json index 5d1f4cea332..42920b4c356 100644 --- a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.json +++ b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.json @@ -6,6 +6,7 @@ "end": 2.3, "provider_cluster_id": "A", "provider_speaker_label": "ASSEMBLYAI_SPEAKER_A", + "oracle_speaker": "alex", "person_id": "person_alex", "speaker_identity_state": "identified", "speaker_identity_confidence": 0.88 @@ -16,6 +17,7 @@ "end": 4.9, "provider_cluster_id": "B", "provider_speaker_label": "ASSEMBLYAI_SPEAKER_B", + "oracle_speaker": "casey", "speaker_identity_state": "unknown", "speaker_identity_confidence": 0.46 } diff --git a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.rollup.json b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.rollup.json index 466c1afec6c..3e10d13a48e 100644 --- a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.rollup.json +++ b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.rollup.json @@ -7,6 +7,11 @@ "speech_active_seconds": 4.5, "billable_seconds": 5.0, "estimated_cost_usd": 0.00020, + "latency_seconds": 1.8, + "runtime_seconds": 1.8, "retry_count": 0, - "fallback_count": 0 + "fallback_count": 0, + "split_count": 0, + "accepted_reconciliation_count": 0, + "rejected_reconciliation_count": 1 } diff --git a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.json b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.json index b70a3d30afb..ebb99cc6965 100644 --- a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.json +++ b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.json @@ -6,6 +6,7 @@ "end": 2.2, "provider_cluster_id": "0", "provider_speaker_label": "SPEAKER_00", + "oracle_speaker": "alex", "person_id": "person_alex", "speaker_identity_state": "identified", "speaker_identity_confidence": 0.91 @@ -16,6 +17,7 @@ "end": 4.8, "provider_cluster_id": "1", "provider_speaker_label": "SPEAKER_01", + "oracle_speaker": "casey", "speaker_identity_state": "unknown", "speaker_identity_confidence": 0.42 } diff --git a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.rollup.json b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.rollup.json index 334f3acfde2..40932be595a 100644 --- a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.rollup.json +++ b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.rollup.json @@ -7,6 +7,11 @@ "speech_active_seconds": 4.5, "billable_seconds": 5.0, "estimated_cost_usd": 0.00036, + "latency_seconds": 1.1, + "runtime_seconds": 1.1, "retry_count": 0, - "fallback_count": 0 + "fallback_count": 0, + "split_count": 0, + "accepted_reconciliation_count": 0, + "rejected_reconciliation_count": 1 } diff --git a/backend/tests/fixtures/stt_provider_eval/manifest.json b/backend/tests/fixtures/stt_provider_eval/manifest.json index 20a1aac9c25..10f82985891 100644 --- a/backend/tests/fixtures/stt_provider_eval/manifest.json +++ b/backend/tests/fixtures/stt_provider_eval/manifest.json @@ -2,10 +2,577 @@ "cases": [ { "id": "fixture_good_meeting", + "scenario": "saved_real_provider_e2e", + "current_policy_provider": "deepgram", "deepgram_fixture": "fixture_good_meeting.deepgram.json", "assemblyai_fixture": "fixture_good_meeting.assemblyai.json", "deepgram_rollup": "fixture_good_meeting.deepgram.rollup.json", "assemblyai_rollup": "fixture_good_meeting.assemblyai.rollup.json" + }, + { + "id": "synthetic_clean_turns", + "scenario": "clean_turns", + "current_policy_provider": "assemblyai", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "The first milestone is ready for review.", + "start": 0.0, + "end": 2.0, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + }, + { + "text": "I will check the metrics this afternoon.", + "start": 2.2, + "end": 4.4, + "provider_cluster_id": "1", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 5.0, + "billable_seconds": 5.0, + "estimated_cost_usd": 0.0002, + "latency_seconds": 1.0, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "The first milestone is ready for review.", + "start": 0.1, + "end": 2.1, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + }, + { + "text": "I will check the metrics this afternoon.", + "start": 2.3, + "end": 4.5, + "provider_cluster_id": "B", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 5.0, + "billable_seconds": 5.0, + "estimated_cost_usd": 0.00036, + "latency_seconds": 1.8, + "fallback_count": 0, + "rejected_reconciliation_count": 1 + } + } + }, + { + "id": "synthetic_fast_turns", + "scenario": "fast_turns", + "current_policy_provider": "assemblyai", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "Yes.", + "start": 0.0, + "end": 0.2, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + }, + { + "text": "No.", + "start": 0.3, + "end": 0.5, + "provider_cluster_id": "1", + "oracle_speaker": "casey" + }, + { + "text": "Maybe after lunch.", + "start": 0.6, + "end": 1.1, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 2.0, + "billable_seconds": 2.0, + "estimated_cost_usd": 0.00008, + "latency_seconds": 0.9, + "fallback_count": 0, + "split_count": 1 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "Yes.", + "start": 0.0, + "end": 0.2, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + }, + { + "text": "No.", + "start": 0.3, + "end": 0.5, + "provider_cluster_id": "B", + "oracle_speaker": "casey" + }, + { + "text": "Maybe after lunch.", + "start": 0.7, + "end": 1.2, + "provider_cluster_id": "A.2", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 2.0, + "billable_seconds": 2.0, + "estimated_cost_usd": 0.00015, + "latency_seconds": 1.4, + "fallback_count": 0, + "split_count": 1, + "rejected_reconciliation_count": 1 + } + } + }, + { + "id": "synthetic_overlap", + "scenario": "overlap", + "current_policy_provider": "assemblyai", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "Start the recording now.", + "start": 0.0, + "end": 1.5, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + }, + { + "text": "I am already taking notes.", + "start": 1.0, + "end": 2.8, + "provider_cluster_id": "1", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 3.0, + "billable_seconds": 3.0, + "estimated_cost_usd": 0.00012, + "latency_seconds": 1.2, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "Start the recording now.", + "start": 0.0, + "end": 1.6, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + }, + { + "text": "I am already taking notes.", + "start": 1.1, + "end": 2.9, + "provider_cluster_id": "B", + "oracle_speaker": "casey" + }, + { + "text": "notes.", + "start": 2.7, + "end": 2.9, + "provider_cluster_id": "B.2", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 3.0, + "billable_seconds": 3.0, + "estimated_cost_usd": 0.00022, + "latency_seconds": 1.9, + "fallback_count": 0, + "split_count": 1, + "accepted_reconciliation_count": 1 + } + } + }, + { + "id": "synthetic_sparse_speech", + "scenario": "sparse_speech", + "current_policy_provider": "assemblyai", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "Okay.", + "start": 12.0, + "end": 12.3, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 30.0, + "speech_active_seconds": 0.5, + "billable_seconds": 30.0, + "estimated_cost_usd": 0.0012, + "latency_seconds": 1.0, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "Okay.", + "start": 12.1, + "end": 12.4, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 30.0, + "speech_active_seconds": 0.5, + "billable_seconds": 30.0, + "estimated_cost_usd": 0.0022, + "latency_seconds": 2.0, + "fallback_count": 0 + } + } + }, + { + "id": "synthetic_no_speech", + "scenario": "low_signal_no_speech", + "current_policy_provider": "deepgram", + "deepgram": { + "transcript": { + "segments": [] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 15.0, + "speech_active_seconds": 0.0, + "billable_seconds": 15.0, + "estimated_cost_usd": 0.0006, + "latency_seconds": 0.8, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 15.0, + "speech_active_seconds": 0.0, + "billable_seconds": 15.0, + "estimated_cost_usd": 0.0011, + "latency_seconds": 1.2, + "fallback_count": 0 + } + } + }, + { + "id": "synthetic_multilingual_turns", + "scenario": "multilingual_turns", + "current_policy_provider": "assemblyai", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "Hola equipo, the build is ready.", + "start": 0.0, + "end": 2.4, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + }, + { + "text": "Merci, I will verify it now.", + "start": 2.6, + "end": 4.7, + "provider_cluster_id": "1", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 5.0, + "billable_seconds": 5.0, + "estimated_cost_usd": 0.0002, + "latency_seconds": 1.3, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "Hola equipo, the build is ready.", + "start": 0.1, + "end": 2.5, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + }, + { + "text": "Merci, I will verify it now.", + "start": 2.7, + "end": 4.8, + "provider_cluster_id": "B", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 5.0, + "billable_seconds": 5.0, + "estimated_cost_usd": 0.00036, + "latency_seconds": 2.0, + "fallback_count": 0 + } + } + }, + { + "id": "synthetic_duplicate_replay", + "scenario": "duplicate_chunk_replay", + "current_policy_provider": "assemblyai", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "This chunk should only count once.", + "start": 0.0, + "end": 2.0, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 2, + "status_counts": { + "succeeded": 2 + }, + "raw_audio_seconds": 4.0, + "billable_seconds": 2.0, + "estimated_cost_usd": 0.00008, + "latency_seconds": 1.0, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "This chunk should only count once.", + "start": 0.0, + "end": 2.1, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 2, + "status_counts": { + "succeeded": 2 + }, + "raw_audio_seconds": 4.0, + "billable_seconds": 2.0, + "estimated_cost_usd": 0.00015, + "latency_seconds": 1.7, + "fallback_count": 0 + } + } + }, + { + "id": "synthetic_provider_failure_fallback", + "scenario": "provider_failure_fallback", + "current_policy_provider": "deepgram", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "Fallback kept the transcript available.", + "start": 0.0, + "end": 2.2, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 10, + "status_counts": { + "succeeded": 10 + }, + "raw_audio_seconds": 22.0, + "billable_seconds": 22.0, + "estimated_cost_usd": 0.00088, + "latency_seconds": 1.1, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "Fallback kept the transcript available.", + "start": 0.0, + "end": 2.2, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 10, + "status_counts": { + "succeeded": 9, + "timeout": 0 + }, + "raw_audio_seconds": 22.0, + "billable_seconds": 22.0, + "estimated_cost_usd": 0.00158, + "latency_seconds": 2.1, + "fallback_count": 1 + } + } + }, + { + "id": "saved_policy_router_gap", + "scenario": "saved_real_policy_router_outputs", + "current_policy_provider": "deepgram", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "We need battery telemetry before the next build ships and the background recorder should keep stable chunks for the entire review.", + "start": 0.0, + "end": 5.0, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + }, + { + "text": "Okay.", + "start": 5.1, + "end": 6.0, + "provider_cluster_id": "1", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 6.0, + "billable_seconds": 6.0, + "estimated_cost_usd": 0.00044, + "latency_seconds": 1.3, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "We need battery telemetry before the next build ships and the background recorder should keep stable chunks for the entire review.", + "start": 0.1, + "end": 5.1, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + }, + { + "text": "Okay.", + "start": 5.2, + "end": 6.1, + "provider_cluster_id": "A", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 6.0, + "billable_seconds": 6.0, + "estimated_cost_usd": 0.00122, + "latency_seconds": 2.2, + "fallback_count": 0, + "split_count": 0, + "rejected_reconciliation_count": 1 + } + } } ] } diff --git a/backend/tests/unit/test_provider_evaluation.py b/backend/tests/unit/test_provider_evaluation.py index e57b8f8530d..a7a4c3b833d 100644 --- a/backend/tests/unit/test_provider_evaluation.py +++ b/backend/tests/unit/test_provider_evaluation.py @@ -28,6 +28,27 @@ def _load_fixture_case() -> dict: } +def _load_manifest_cases() -> list[dict]: + manifest = json.loads((FIXTURE_DIR / 'manifest.json').read_text()) + cases = [] + for case in manifest['cases']: + prepared = { + 'id': case['id'], + 'scenario': case.get('scenario'), + 'current_policy_provider': case.get('current_policy_provider'), + } + for provider in ('deepgram', 'assemblyai'): + if provider in case: + prepared[provider] = case[provider] + continue + prepared[provider] = { + 'transcript': json.loads((FIXTURE_DIR / case[f'{provider}_fixture']).read_text()), + 'ledger': json.loads((FIXTURE_DIR / case[f'{provider}_rollup']).read_text()), + } + cases.append(prepared) + return cases + + def test_fixture_report_passes_and_includes_cost_identity_and_timing_metrics(): report = build_comparison_report([_load_fixture_case()]) @@ -38,12 +59,32 @@ def test_fixture_report_passes_and_includes_cost_identity_and_timing_metrics(): assert case['comparison']['transcript_word_error_rate'] == 0.0 assert case['comparison']['average_timestamp_drift_seconds'] > 0.0 assert case['providers']['assemblyai']['speaker_cluster_count'] == 2 + assert case['providers']['assemblyai']['speaker_word_purity'] == 1.0 assert case['providers']['assemblyai']['identified_speaker_cluster_count'] == 1 assert case['providers']['assemblyai']['unknown_speaker_cluster_count'] == 1 assert case['providers']['assemblyai']['low_confidence_identity_rate'] == 0.5 assert all(gate['severity'] == 'pass' for gate in case['gates']) +def test_manifest_report_includes_strategy_rollups_gap_report_and_fragmentation_metrics(): + report = build_comparison_report(_load_manifest_cases()) + + assert report['status'] == 'passed' + assert set(report['strategies']) == {'always_deepgram', 'always_assemblyai', 'current_policy', 'shadow_only'} + assert report['strategies']['always_assemblyai']['provider'] == 'assemblyai' + assert report['strategies']['shadow_only']['provider'] == 'deepgram' + assert report['strategies']['current_policy']['provider'] == 'mixed' + assert report['strategies']['always_assemblyai']['split_count'] >= 2 + assert report['strategies']['always_assemblyai']['estimated_cost_per_hour_usd'] > 0 + assert report['assemblyai_gap_report']['status'] == 'limited' + assert any( + item['scenario'] == 'saved_real_policy_router_outputs' + and item['metric'] in {'speaker_word_purity', 'estimated_cost_per_hour_usd'} + for item in report['assemblyai_gap_report']['limiting_scenarios'] + ) + assert any(gate['gate_group'] == 'speaker_safety' for case in report['cases'] for gate in case['gates']) + + def test_threshold_failures_are_reported_for_transcript_drift_and_fallback(): case = _load_fixture_case() case['assemblyai']['transcript']['segments'][0]['text'] = 'Completely unrelated output.' @@ -146,9 +187,11 @@ def test_low_confidence_identity_counts_clusters_not_segments(): def test_compact_markdown_report_is_review_friendly(): - report = build_comparison_report([_load_fixture_case()]) + report = build_comparison_report(_load_manifest_cases()) markdown = compact_markdown_report(report) assert '# STT Provider Evaluation: PASSED' in markdown assert 'fixture_good_meeting' in markdown - assert 'AssemblyAI cost' in markdown + assert 'Strategy Rollup' in markdown + assert 'AssemblyAI Gap Report' in markdown + assert 'TICKET-028' in markdown diff --git a/backend/utils/stt/provider_evaluation.py b/backend/utils/stt/provider_evaluation.py index 316e9194e06..fd6524a23b2 100644 --- a/backend/utils/stt/provider_evaluation.py +++ b/backend/utils/stt/provider_evaluation.py @@ -1,6 +1,15 @@ from dataclasses import dataclass from typing import Any, Optional +PRODUCTION_STRATEGIES = ('always_deepgram', 'always_assemblyai', 'current_policy', 'shadow_only') +PROVIDER_BY_STRATEGY = { + 'always_deepgram': 'deepgram', + 'always_assemblyai': 'assemblyai', + 'shadow_only': 'deepgram', +} +ASSEMBLYAI_COST_PER_HOUR_USD = 0.2592 +DEEPGRAM_COST_PER_HOUR_USD = 0.144 + @dataclass(frozen=True) class ProviderGateThresholds: @@ -10,6 +19,13 @@ class ProviderGateThresholds: max_low_confidence_identity_rate: float = 0.50 max_fallback_rate: float = 0.10 max_failure_rate: float = 0.05 + min_speaker_word_purity: float = 0.95 + min_assemblyai_purity_delta_vs_deepgram: float = -0.05 + max_speaker_inflation_ratio: float = 1.75 + max_empty_transcript_rate: float = 0.05 + max_timeout_error_rate: float = 0.05 + max_latency_ratio_vs_deepgram: float = 2.0 + max_cost_ratio_vs_deepgram: float = 3.0 require_instrumentation: bool = True @@ -27,6 +43,8 @@ def build_comparison_report( 'failure_count': len(failures), 'warning_count': len(warnings), 'aggregate': _aggregate_case_reports(case_reports), + 'strategies': _strategy_rollups(case_reports), + 'assemblyai_gap_report': _assemblyai_gap_report(case_reports), 'cases': case_reports, } @@ -45,19 +63,44 @@ def compact_markdown_report(report: dict[str, Any]) -> str: lines = [ f"# STT Provider Evaluation: {report.get('status', 'unknown').upper()}", '', - '| Cases | Failures | Warnings | Avg WER | Avg timestamp drift | AssemblyAI cost | Deepgram cost |', - '| --- | ---: | ---: | ---: | ---: | ---: | ---: |', + '| Cases | Failures | Warnings | Avg WER | Avg timestamp drift | AAI purity | DG purity | AAI cost/hr | DG cost/hr |', + '| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: |', ( f"| {report.get('case_count', 0)} | {report.get('failure_count', 0)} | " f"{report.get('warning_count', 0)} | {_fmt_pct(aggregate.get('average_word_error_rate'))} | " f"{_fmt_seconds(aggregate.get('average_timestamp_drift_seconds'))} | " - f"${aggregate.get('assemblyai_estimated_cost_usd', 0.0):.4f} | " - f"${aggregate.get('deepgram_estimated_cost_usd', 0.0):.4f} |" + f"{_fmt_pct(aggregate.get('assemblyai_speaker_word_purity'))} | " + f"{_fmt_pct(aggregate.get('deepgram_speaker_word_purity'))} | " + f"${aggregate.get('assemblyai_estimated_cost_per_hour_usd', 0.0):.3f} | " + f"${aggregate.get('deepgram_estimated_cost_per_hour_usd', 0.0):.3f} |" ), '', - '| Case | WER | Segments DG/AAI | Clusters DG/AAI | Unknown AAI | Low-conf AAI | Fallback AAI | Gates |', - '| --- | ---: | ---: | ---: | ---: | ---: | ---: | --- |', + '## Strategy Rollup', + '', + '| Strategy | Provider | Purity | Covered speakers | App speakers | Inflation | Empty | Fallback | Timeout/error | Cost/hr |', + '| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: |', ] + for name, strategy in report.get('strategies', {}).items(): + lines.append( + f"| {name} | {strategy.get('provider', 'mixed')} | " + f"{_fmt_pct(strategy.get('speaker_word_purity'))} | " + f"{strategy.get('covered_speaker_count', 0):.1f} | " + f"{strategy.get('app_visible_speaker_count', 0):.1f} | " + f"{strategy.get('speaker_inflation_ratio', 0.0):.2f} | " + f"{_fmt_pct(strategy.get('empty_transcript_rate'))} | " + f"{_fmt_pct(strategy.get('fallback_rate'))} | " + f"{_fmt_pct(strategy.get('timeout_error_rate'))} | " + f"${strategy.get('estimated_cost_per_hour_usd', 0.0):.3f} |" + ) + lines.extend( + [ + '', + '## Cases', + '', + '| Case | Scenario | WER | Purity DG/AAI | App speakers DG/AAI | Splits AAI | Recon AAI | Fallback AAI | Gates |', + '| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | --- |', + ] + ) for case in report.get('cases', []): deepgram = case['providers']['deepgram'] assemblyai = case['providers']['assemblyai'] @@ -65,13 +108,32 @@ def compact_markdown_report(report: dict[str, Any]) -> str: f"{gate['severity']}:{gate['metric']}" for gate in case.get('gates', []) if gate['severity'] != 'pass' ) lines.append( - f"| {case['id']} | {_fmt_pct(case['comparison']['transcript_word_error_rate'])} | " - f"{deepgram['segment_count']}/{assemblyai['segment_count']} | " - f"{deepgram['speaker_cluster_count']}/{assemblyai['speaker_cluster_count']} | " - f"{assemblyai['unknown_speaker_cluster_count']} | " - f"{_fmt_pct(assemblyai['low_confidence_identity_rate'])} | " + f"| {case['id']} | {case.get('scenario', 'unspecified')} | " + f"{_fmt_pct(case['comparison']['transcript_word_error_rate'])} | " + f"{_fmt_pct(deepgram['speaker_word_purity'])}/{_fmt_pct(assemblyai['speaker_word_purity'])} | " + f"{deepgram['app_visible_speaker_count']}/{assemblyai['app_visible_speaker_count']} | " + f"{assemblyai['split_count']} | " + f"{assemblyai['accepted_reconciliation_count']}/{assemblyai['rejected_reconciliation_count']} | " f"{_fmt_pct(assemblyai['fallback_rate'])} | {gates or 'pass'} |" ) + gap_report = report.get('assemblyai_gap_report') or {} + lines.extend(['', '## AssemblyAI Gap Report', '']) + lines.append(f"Status: {gap_report.get('status', 'unknown')}.") + limiting_scenarios = gap_report.get('limiting_scenarios') or [] + for item in limiting_scenarios: + lines.append( + f"- {item['scenario']}: {item['metric']} ({item['assemblyai_value']} vs {item['deepgram_value']}) " + f"- likely cause: {item['likely_cause']}; mitigation: {item['mitigation']}" + ) + if not limiting_scenarios: + lines.append('- No material AssemblyAI gap detected in offline synthetic/saved-output fixtures.') + lines.extend( + [ + '', + 'Synthetic and saved-output gates are necessary but insufficient for broad defaulting. ' + 'TICKET-028 must turn this gap report into an AssemblyAI canary/default rollout plan with privacy-safe real-session evidence.', + ] + ) return '\n'.join(lines) @@ -91,6 +153,8 @@ def _compare_case(case: dict[str, Any], thresholds: ProviderGateThresholds) -> d gates = _evaluate_case_gates(deepgram, assemblyai, comparison, thresholds) return { 'id': case.get('id') or case.get('case_id') or 'unknown', + 'scenario': case.get('scenario') or case.get('type') or 'unspecified', + 'current_policy_provider': case.get('current_policy_provider') or 'deepgram', 'providers': {'deepgram': _public_summary(deepgram), 'assemblyai': _public_summary(assemblyai)}, 'comparison': comparison, 'gates': gates, @@ -102,6 +166,14 @@ def summarize_provider_output(provider: str, payload: dict[str, Any]) -> dict[st segments = _extract_segments(transcript) ledger = payload.get('ledger') or payload.get('rollup') or {} clusters = _speaker_clusters(segments) + oracle_speakers = _oracle_speakers(segments) + word_count = sum(len(_words(segment.get('text', ''))) for segment in segments) + raw_audio_seconds = _number_from_ledger(ledger, 'raw_audio_seconds') + billable_seconds = _number_from_ledger(ledger, 'billable_seconds') + estimated_cost_usd = _estimated_cost_usd(provider, ledger, billable_seconds or raw_audio_seconds) + latency_seconds = _latency_seconds(ledger) + timeout_error_count = _timeout_error_count(ledger) + run_count = _run_count(ledger) identified_clusters = { _cluster_id(segment) for segment in segments @@ -128,20 +200,34 @@ def summarize_provider_output(provider: str, payload: dict[str, Any]) -> dict[st 'segments': segments, 'text': _transcript_text(segments), 'segment_count': len(segments), - 'word_count': sum(len(_words(segment.get('text', ''))) for segment in segments), + 'word_count': word_count, 'speaker_cluster_count': len(clusters), + 'provider_cluster_count': len(clusters), + 'covered_speaker_count': len(oracle_speakers), + 'app_visible_speaker_count': len(clusters), + 'speaker_inflation_ratio': (len(clusters) / len(oracle_speakers) if oracle_speakers else 0.0), + 'speaker_word_purity': _speaker_word_purity(segments), + 'empty_transcript_rate': 1.0 if word_count == 0 else 0.0, 'identified_speaker_cluster_count': len(identified_clusters), 'unknown_speaker_cluster_count': max(len(clusters) - len(identified_clusters), 0), 'low_confidence_identity_count': len(low_confidence_clusters), 'low_confidence_identity_rate': (len(low_confidence_clusters) / len(clusters) if clusters else 0.0), - 'raw_audio_seconds': _number_from_ledger(ledger, 'raw_audio_seconds'), + 'raw_audio_seconds': raw_audio_seconds, 'speech_active_seconds': _number_from_ledger(ledger, 'speech_active_seconds'), - 'billable_seconds': _number_from_ledger(ledger, 'billable_seconds'), - 'estimated_cost_usd': _number_from_ledger(ledger, 'estimated_cost_usd'), + 'billable_seconds': billable_seconds, + 'estimated_cost_usd': estimated_cost_usd, + 'estimated_cost_per_hour_usd': _cost_per_hour(estimated_cost_usd, billable_seconds or raw_audio_seconds), + 'latency_seconds': latency_seconds, + 'runtime_seconds': _runtime_seconds(ledger, latency_seconds), 'retry_count': _number_from_ledger(ledger, 'retry_count'), + 'split_count': _number_from_ledger(ledger, 'split_count'), + 'accepted_reconciliation_count': _number_from_ledger(ledger, 'accepted_reconciliation_count'), + 'rejected_reconciliation_count': _number_from_ledger(ledger, 'rejected_reconciliation_count'), 'fallback_count': _number_from_ledger(ledger, 'fallback_count'), 'fallback_rate': _rate_from_ledger(ledger, 'fallback_count'), 'failure_rate': _failure_rate_from_ledger(ledger), + 'timeout_error_count': timeout_error_count, + 'timeout_error_rate': timeout_error_count / run_count, 'has_instrumentation': bool(ledger), } @@ -172,6 +258,7 @@ def _normalize_segment(segment: dict[str, Any]) -> dict[str, Any]: 'provider_cluster_id': segment.get('provider_cluster_id'), 'provider_speaker_label': segment.get('provider_speaker_label'), 'speaker': segment.get('speaker'), + 'oracle_speaker': segment.get('oracle_speaker') or segment.get('expected_speaker'), 'person_id': segment.get('person_id'), 'is_user': segment.get('is_user'), 'speaker_identity_state': segment.get('speaker_identity_state'), @@ -204,6 +291,7 @@ def _segments_from_words(words: list[dict[str, Any]], transcript: dict[str, Any] 'end': float(word.get('end') or 0.0), 'provider_cluster_id': cluster_id, 'provider_speaker_label': word.get('speaker_label') or word.get('provider_speaker_label'), + 'oracle_speaker': word.get('oracle_speaker') or word.get('expected_speaker'), 'stt_provider': transcript.get('provider'), 'stt_model': transcript.get('model'), } @@ -279,6 +367,9 @@ def _evaluate_case_gates( comparison: dict[str, Any], thresholds: ProviderGateThresholds, ) -> list[dict[str, Any]]: + empty_transcript_threshold = ( + 1.0 if deepgram['word_count'] == 0 and assemblyai['word_count'] == 0 else thresholds.max_empty_transcript_rate + ) gates = [ _threshold_gate( 'transcript_word_error_rate', @@ -291,23 +382,77 @@ def _evaluate_case_gates( comparison['segment_count_delta_ratio'], thresholds.max_segment_count_delta_ratio, 'failure', + gate_group='speaker_safety', ), _threshold_gate( 'average_timestamp_drift_seconds', comparison['average_timestamp_drift_seconds'], thresholds.max_average_timestamp_drift_seconds, 'warning', + gate_group='canary_readiness', ), _threshold_gate( 'assemblyai_low_confidence_identity_rate', assemblyai['low_confidence_identity_rate'], thresholds.max_low_confidence_identity_rate, 'warning', + gate_group='speaker_safety', ), _threshold_gate( 'assemblyai_fallback_rate', assemblyai['fallback_rate'], thresholds.max_fallback_rate, 'failure' ), _threshold_gate('assemblyai_failure_rate', assemblyai['failure_rate'], thresholds.max_failure_rate, 'failure'), + _threshold_gate( + 'assemblyai_speaker_word_purity', + assemblyai['speaker_word_purity'], + thresholds.min_speaker_word_purity, + 'failure', + minimum=True, + gate_group='speaker_safety', + ), + _threshold_gate( + 'assemblyai_speaker_inflation_ratio', + assemblyai['speaker_inflation_ratio'], + thresholds.max_speaker_inflation_ratio, + 'failure', + gate_group='speaker_safety', + ), + _threshold_gate( + 'assemblyai_empty_transcript_rate', + assemblyai['empty_transcript_rate'], + empty_transcript_threshold, + 'failure', + gate_group='default_viability', + ), + _threshold_gate( + 'assemblyai_timeout_error_rate', + assemblyai['timeout_error_rate'], + thresholds.max_timeout_error_rate, + 'failure', + gate_group='canary_readiness', + ), + _threshold_gate( + 'assemblyai_purity_delta_vs_deepgram', + assemblyai['speaker_word_purity'] - deepgram['speaker_word_purity'], + thresholds.min_assemblyai_purity_delta_vs_deepgram, + 'failure', + minimum=True, + gate_group='speaker_safety', + ), + _threshold_gate( + 'assemblyai_latency_ratio_vs_deepgram', + _safe_ratio(assemblyai['latency_seconds'], deepgram['latency_seconds']), + thresholds.max_latency_ratio_vs_deepgram, + 'warning', + gate_group='canary_readiness', + ), + _threshold_gate( + 'assemblyai_cost_ratio_vs_deepgram', + _safe_ratio(assemblyai['estimated_cost_per_hour_usd'], deepgram['estimated_cost_per_hour_usd']), + thresholds.max_cost_ratio_vs_deepgram, + 'warning', + gate_group='default_viability', + ), ] if thresholds.require_instrumentation: for provider in (deepgram, assemblyai): @@ -316,6 +461,7 @@ def _evaluate_case_gates( { 'metric': f"{provider['provider']}_instrumentation", 'severity': 'warning', + 'gate_group': 'canary_readiness', 'value': None, 'threshold': 'ledger_or_rollup_required', 'message': 'missing provider ledger or rollup metrics', @@ -324,14 +470,23 @@ def _evaluate_case_gates( return gates -def _threshold_gate(metric: str, value: float, threshold: float, severity: str) -> dict[str, Any]: - passed = value <= threshold +def _threshold_gate( + metric: str, + value: float, + threshold: float, + severity: str, + minimum: bool = False, + gate_group: str = 'default_viability', +) -> dict[str, Any]: + passed = value >= threshold if minimum else value <= threshold + direction = 'below' if minimum else 'exceeds' return { 'metric': metric, 'severity': 'pass' if passed else severity, + 'gate_group': gate_group, 'value': value, 'threshold': threshold, - 'message': 'within threshold' if passed else f'{value:.4f} exceeds {threshold:.4f}', + 'message': 'within threshold' if passed else f'{value:.4f} {direction} {threshold:.4f}', } @@ -342,6 +497,53 @@ def _number_from_ledger(ledger: dict[str, Any], field: str) -> float: return float(value or 0.0) +def _run_count(ledger: dict[str, Any]) -> float: + return float(ledger.get('run_count') or 1.0) + + +def _estimated_cost_usd(provider: str, ledger: dict[str, Any], billed_seconds: float) -> float: + recorded = _number_from_ledger(ledger, 'estimated_cost_usd') + if recorded: + return recorded + if provider == 'assemblyai': + return billed_seconds / 3600 * ASSEMBLYAI_COST_PER_HOUR_USD + if provider == 'deepgram': + return billed_seconds / 3600 * DEEPGRAM_COST_PER_HOUR_USD + return 0.0 + + +def _cost_per_hour(cost: float, seconds: float) -> float: + return cost / seconds * 3600 if seconds else 0.0 + + +def _latency_seconds(ledger: dict[str, Any]) -> float: + for field in ('latency_seconds', 'duration_seconds', 'elapsed_seconds'): + value = _number_from_ledger(ledger, field) + if value: + return value + return 0.0 + + +def _runtime_seconds(ledger: dict[str, Any], latency_seconds: float) -> float: + for field in ('runtime_seconds', 'wall_time_seconds'): + value = _number_from_ledger(ledger, field) + if value: + return value + return latency_seconds + + +def _timeout_error_count(ledger: dict[str, Any]) -> float: + status_counts = ledger.get('status_counts') or {} + return float( + _number_from_ledger(ledger, 'timeout_count') + + _number_from_ledger(ledger, 'error_count') + + status_counts.get('timeout', 0) + + status_counts.get('timed_out', 0) + + status_counts.get('failed', 0) + + status_counts.get('failure', 0) + ) + + def _rate_from_ledger(ledger: dict[str, Any], count_field: str) -> float: denominator = float(ledger.get('run_count') or 1.0) return _number_from_ledger(ledger, count_field) / denominator @@ -354,32 +556,198 @@ def _failure_rate_from_ledger(ledger: dict[str, Any]) -> float: return float(failed) / denominator +def _oracle_speakers(segments: list[dict[str, Any]]) -> set[str]: + return {str(segment['oracle_speaker']) for segment in segments if segment.get('oracle_speaker')} + + +def _speaker_word_purity(segments: list[dict[str, Any]]) -> float: + cluster_counts: dict[str, dict[str, int]] = {} + total_words = 0 + for segment in segments: + cluster_id = _cluster_id(segment) + oracle_speaker = segment.get('oracle_speaker') + if not cluster_id or not oracle_speaker: + continue + word_count = len(_words(segment.get('text', ''))) + if word_count == 0: + continue + total_words += word_count + cluster_counts.setdefault(str(cluster_id), {}) + cluster_counts[str(cluster_id)][str(oracle_speaker)] = ( + cluster_counts[str(cluster_id)].get(str(oracle_speaker), 0) + word_count + ) + if not total_words: + return 1.0 + pure_words = sum(max(counts.values()) for counts in cluster_counts.values()) + return pure_words / total_words + + +def _safe_ratio(value: float, baseline: float) -> float: + if baseline == 0: + return 0.0 if value == 0 else 999.0 + return value / baseline + + def _aggregate_case_reports(case_reports: list[dict[str, Any]]) -> dict[str, Any]: if not case_reports: return {} + assemblyai = [case['providers']['assemblyai'] for case in case_reports] + deepgram = [case['providers']['deepgram'] for case in case_reports] + assemblyai_seconds = sum(provider['billable_seconds'] or provider['raw_audio_seconds'] for provider in assemblyai) + deepgram_seconds = sum(provider['billable_seconds'] or provider['raw_audio_seconds'] for provider in deepgram) + assemblyai_cost = sum(provider['estimated_cost_usd'] for provider in assemblyai) + deepgram_cost = sum(provider['estimated_cost_usd'] for provider in deepgram) return { 'average_word_error_rate': _average(case['comparison']['transcript_word_error_rate'] for case in case_reports), 'average_timestamp_drift_seconds': _average( case['comparison']['average_timestamp_drift_seconds'] for case in case_reports ), - 'assemblyai_estimated_cost_usd': sum( - case['providers']['assemblyai']['estimated_cost_usd'] for case in case_reports - ), - 'deepgram_estimated_cost_usd': sum( - case['providers']['deepgram']['estimated_cost_usd'] for case in case_reports - ), - 'assemblyai_billable_seconds': sum( - case['providers']['assemblyai']['billable_seconds'] for case in case_reports - ), - 'deepgram_billable_seconds': sum(case['providers']['deepgram']['billable_seconds'] for case in case_reports), + 'assemblyai_speaker_word_purity': _weighted_average(assemblyai, 'speaker_word_purity', 'word_count'), + 'deepgram_speaker_word_purity': _weighted_average(deepgram, 'speaker_word_purity', 'word_count'), + 'assemblyai_estimated_cost_usd': assemblyai_cost, + 'deepgram_estimated_cost_usd': deepgram_cost, + 'assemblyai_estimated_cost_per_hour_usd': _cost_per_hour(assemblyai_cost, assemblyai_seconds), + 'deepgram_estimated_cost_per_hour_usd': _cost_per_hour(deepgram_cost, deepgram_seconds), + 'assemblyai_billable_seconds': sum(provider['billable_seconds'] for provider in assemblyai), + 'deepgram_billable_seconds': sum(provider['billable_seconds'] for provider in deepgram), } +def _strategy_rollups(case_reports: list[dict[str, Any]]) -> dict[str, Any]: + return {strategy: _strategy_rollup(case_reports, strategy) for strategy in PRODUCTION_STRATEGIES} + + +def _strategy_rollup(case_reports: list[dict[str, Any]], strategy: str) -> dict[str, Any]: + selected = [] + providers = set() + for case in case_reports: + provider_name = PROVIDER_BY_STRATEGY.get(strategy) or case.get('current_policy_provider') or 'deepgram' + providers.add(provider_name) + selected.append(case['providers'][provider_name]) + cost = sum(provider['estimated_cost_usd'] for provider in selected) + seconds = sum(provider['billable_seconds'] or provider['raw_audio_seconds'] for provider in selected) + return { + 'provider': next(iter(providers)) if len(providers) == 1 else 'mixed', + 'speaker_word_purity': _weighted_average(selected, 'speaker_word_purity', 'word_count'), + 'covered_speaker_count': _average(provider['covered_speaker_count'] for provider in selected), + 'app_visible_speaker_count': _average(provider['app_visible_speaker_count'] for provider in selected), + 'speaker_inflation_ratio': _average(provider['speaker_inflation_ratio'] for provider in selected), + 'split_count': sum(provider['split_count'] for provider in selected), + 'accepted_reconciliation_count': sum(provider['accepted_reconciliation_count'] for provider in selected), + 'rejected_reconciliation_count': sum(provider['rejected_reconciliation_count'] for provider in selected), + 'fallback_count': sum(provider['fallback_count'] for provider in selected), + 'provider_cluster_count': sum(provider['provider_cluster_count'] for provider in selected), + 'empty_transcript_rate': _average(provider['empty_transcript_rate'] for provider in selected), + 'latency_seconds': _average(provider['latency_seconds'] for provider in selected), + 'runtime_seconds': _average(provider['runtime_seconds'] for provider in selected), + 'timeout_error_rate': _average(provider['timeout_error_rate'] for provider in selected), + 'fallback_rate': _average(provider['fallback_rate'] for provider in selected), + 'failure_rate': _average(provider['failure_rate'] for provider in selected), + 'estimated_cost_usd': cost, + 'estimated_cost_per_hour_usd': _cost_per_hour(cost, seconds), + } + + +def _assemblyai_gap_report(case_reports: list[dict[str, Any]]) -> dict[str, Any]: + limiting_scenarios = [] + for case in case_reports: + deepgram = case['providers']['deepgram'] + assemblyai = case['providers']['assemblyai'] + candidates = [ + ( + 'speaker_word_purity', + assemblyai['speaker_word_purity'], + deepgram['speaker_word_purity'], + assemblyai['speaker_word_purity'] < deepgram['speaker_word_purity'] - 0.02, + ), + ( + 'covered_speaker_count', + assemblyai['covered_speaker_count'], + deepgram['covered_speaker_count'], + assemblyai['covered_speaker_count'] < deepgram['covered_speaker_count'], + ), + ( + 'empty_transcript_rate', + assemblyai['empty_transcript_rate'], + deepgram['empty_transcript_rate'], + assemblyai['empty_transcript_rate'] > deepgram['empty_transcript_rate'], + ), + ( + 'latency_seconds', + assemblyai['latency_seconds'], + deepgram['latency_seconds'], + assemblyai['latency_seconds'] > deepgram['latency_seconds'] * 2 and assemblyai['latency_seconds'] > 0, + ), + ( + 'timeout_error_rate', + assemblyai['timeout_error_rate'], + deepgram['timeout_error_rate'], + assemblyai['timeout_error_rate'] > deepgram['timeout_error_rate'], + ), + ( + 'estimated_cost_per_hour_usd', + assemblyai['estimated_cost_per_hour_usd'], + deepgram['estimated_cost_per_hour_usd'], + assemblyai['estimated_cost_per_hour_usd'] > deepgram['estimated_cost_per_hour_usd'] * 3 + and deepgram['estimated_cost_per_hour_usd'] > 0, + ), + ] + for metric, assemblyai_value, deepgram_value, limited in candidates: + if limited: + limiting_scenarios.append( + { + 'case_id': case['id'], + 'scenario': case.get('scenario') or case['id'], + 'metric': metric, + 'assemblyai_value': round(float(assemblyai_value), 4), + 'deepgram_value': round(float(deepgram_value), 4), + 'likely_cause': _likely_cause(metric), + 'mitigation': _mitigation(metric), + } + ) + break + return { + 'status': 'limited' if limiting_scenarios else 'no_material_gap_detected', + 'limiting_scenarios': limiting_scenarios, + } + + +def _likely_cause(metric: str) -> str: + return { + 'speaker_word_purity': 'provider-local clusters mix speakers before Omi identity matching', + 'covered_speaker_count': 'provider diarization missed a speaker or no-speech gating discarded speech', + 'empty_transcript_rate': 'low-signal/no-speech handling produced an empty or failed transcript', + 'latency_seconds': 'AssemblyAI async job latency exceeds Deepgram for this workload shape', + 'timeout_error_rate': 'provider timeout or retry exhaustion path is not canary-safe', + 'estimated_cost_per_hour_usd': 'AssemblyAI billable duration or pricing is too high for default background volume', + }.get(metric, 'AssemblyAI trails the Deepgram comparator on this gate') + + +def _mitigation(metric: str) -> str: + return { + 'speaker_word_purity': 'keep split-before-match enabled and gate rollout on fragmentation plus purity budgets', + 'covered_speaker_count': 'route affected low-signal cases to Deepgram fallback until canary evidence closes coverage', + 'empty_transcript_rate': 'preserve no-speech detection and fallback controls before default promotion', + 'latency_seconds': 'limit rollout cohort and add latency SLO alerts before expanding canary', + 'timeout_error_rate': 'use Deepgram fallback and provider health gates from TICKET-027', + 'estimated_cost_per_hour_usd': 'cap rollout or require explicit product tradeoff in TICKET-028', + }.get(metric, 'capture in TICKET-028 rollout tradeoffs before promoting AssemblyAI') + + def _average(values) -> float: values = list(values) return sum(values) / len(values) if values else 0.0 +def _weighted_average(items: list[dict[str, Any]], value_field: str, weight_field: str) -> float: + denominator = sum(float(item.get(weight_field) or 0.0) for item in items) + if denominator == 0: + return _average(item.get(value_field, 0.0) for item in items) + return ( + sum(float(item.get(value_field) or 0.0) * float(item.get(weight_field) or 0.0) for item in items) / denominator + ) + + def _fmt_pct(value) -> str: if value is None: return 'n/a' From b09b55cf91fce4c76d76936c4f29b90ee5ae85ca Mon Sep 17 00:00:00 2001 From: David Zhang Date: Sun, 24 May 2026 13:13:48 +0700 Subject: [PATCH 41/44] Add AssemblyAI rollout observability --- .../database/transcription_provider_usage.py | 48 +++++- .../unit/test_background_provider_service.py | 56 ++++++- .../unit/test_transcription_provider_usage.py | 29 ++++ backend/utils/stt/provider_service.py | 48 +++++- .../backend/assemblyai_background_rollout.mdx | 142 +++++++++++++++++- 5 files changed, 316 insertions(+), 7 deletions(-) diff --git a/backend/database/transcription_provider_usage.py b/backend/database/transcription_provider_usage.py index ef96697295e..6e5450ee757 100644 --- a/backend/database/transcription_provider_usage.py +++ b/backend/database/transcription_provider_usage.py @@ -21,15 +21,21 @@ RUN_TTL_DAYS = 180 FORBIDDEN_LEDGER_KEYS = { + 'api_key', 'audio_bytes', 'audio', + 'authorization', + 'chunks', + 'full_transcript_text', + 'secret', + 'secrets', + 'token', 'raw_audio_bytes', 'text', 'transcript', 'transcript_text', 'words', 'word_records', - 'chunks', 'utterances', } @@ -124,6 +130,7 @@ def create_provider_run( 'raw_audio_seconds': 0.0, 'speech_active_seconds': 0.0, 'billable_seconds': 0.0, + 'chunk_duration_seconds': 0.0, 'estimated_cost_usd': 0.0, 'retry_count': 0, 'fallback_count': 0, @@ -131,10 +138,14 @@ def create_provider_run( 'transcript_word_count': 0, 'speaker_cluster_count': 0, 'identified_speaker_cluster_count': 0, + 'identity_match_count': 0, 'provider_speaker_count': 0, 'mapped_speaker_count': 0, 'mapped_person_count': 0, 'unmapped_speaker_count': 0, + 'unknown_speaker_count': 0, + 'unknown_speaker_duration_seconds': 0.0, + 'split_count': 0, 'embedding_extraction_failure_count': 0, 'identity_metric_update': { 'status': 'pending', @@ -163,6 +174,7 @@ def finalize_provider_run( raw_audio_seconds: float = 0.0, speech_active_seconds: float = 0.0, billable_seconds: float = 0.0, + chunk_duration_seconds: float = 0.0, estimated_cost_usd: float = 0.0, retry_count: int = 0, fallback_count: int = 0, @@ -170,10 +182,14 @@ def finalize_provider_run( transcript_word_count: int = 0, speaker_cluster_count: int = 0, identified_speaker_cluster_count: int = 0, + identity_match_count: int = 0, provider_speaker_count: int = 0, mapped_speaker_count: int = 0, mapped_person_count: int = 0, unmapped_speaker_count: int = 0, + unknown_speaker_count: int = 0, + unknown_speaker_duration_seconds: float = 0.0, + split_count: int = 0, embedding_extraction_failure_count: int = 0, identity_confidence_summary: Optional[dict[str, Any]] = None, error_class: Optional[str] = None, @@ -198,6 +214,7 @@ def finalize_provider_run( 'raw_audio_seconds': raw_audio_seconds, 'speech_active_seconds': speech_active_seconds, 'billable_seconds': billable_seconds, + 'chunk_duration_seconds': chunk_duration_seconds, 'estimated_cost_usd': estimated_cost_usd, 'retry_count': retry_count, 'fallback_count': fallback_count, @@ -205,10 +222,14 @@ def finalize_provider_run( 'transcript_word_count': transcript_word_count, 'speaker_cluster_count': speaker_cluster_count, 'identified_speaker_cluster_count': identified_speaker_cluster_count, + 'identity_match_count': identity_match_count, 'provider_speaker_count': provider_speaker_count, 'mapped_speaker_count': mapped_speaker_count, 'mapped_person_count': mapped_person_count, 'unmapped_speaker_count': unmapped_speaker_count, + 'unknown_speaker_count': unknown_speaker_count, + 'unknown_speaker_duration_seconds': unknown_speaker_duration_seconds, + 'split_count': split_count, 'embedding_extraction_failure_count': embedding_extraction_failure_count, 'identity_confidence_summary': summary, 'error_class': error_class, @@ -232,6 +253,7 @@ def finalize_provider_run( raw_audio_seconds=raw_audio_seconds, speech_active_seconds=speech_active_seconds, billable_seconds=billable_seconds, + chunk_duration_seconds=chunk_duration_seconds, estimated_cost_usd=estimated_cost_usd, retry_count=retry_count, fallback_count=fallback_count, @@ -239,10 +261,14 @@ def finalize_provider_run( transcript_word_count=transcript_word_count, speaker_cluster_count=speaker_cluster_count, identified_speaker_cluster_count=identified_speaker_cluster_count, + identity_match_count=identity_match_count, provider_speaker_count=provider_speaker_count, mapped_speaker_count=mapped_speaker_count, mapped_person_count=mapped_person_count, unmapped_speaker_count=unmapped_speaker_count, + unknown_speaker_count=unknown_speaker_count, + unknown_speaker_duration_seconds=unknown_speaker_duration_seconds, + split_count=split_count, embedding_extraction_failure_count=embedding_extraction_failure_count, identity_confidence_summary=summary, ) @@ -289,6 +315,7 @@ def increment_daily_rollup( raw_audio_seconds: float = 0.0, speech_active_seconds: float = 0.0, billable_seconds: float = 0.0, + chunk_duration_seconds: float = 0.0, estimated_cost_usd: float = 0.0, retry_count: int = 0, fallback_count: int = 0, @@ -296,10 +323,14 @@ def increment_daily_rollup( transcript_word_count: int = 0, speaker_cluster_count: int = 0, identified_speaker_cluster_count: int = 0, + identity_match_count: int = 0, provider_speaker_count: int = 0, mapped_speaker_count: int = 0, mapped_person_count: int = 0, unmapped_speaker_count: int = 0, + unknown_speaker_count: int = 0, + unknown_speaker_duration_seconds: float = 0.0, + split_count: int = 0, embedding_extraction_failure_count: int = 0, identity_confidence_summary: Optional[dict[str, Any]] = None, ) -> None: @@ -314,16 +345,21 @@ def increment_daily_rollup( 'speech_active_seconds': firestore.Increment(speech_active_seconds), 'billable_seconds': firestore.Increment(billable_seconds), 'estimated_cost_usd': firestore.Increment(estimated_cost_usd), + 'chunk_duration_seconds': firestore.Increment(chunk_duration_seconds), 'retry_count': firestore.Increment(retry_count), 'fallback_count': firestore.Increment(fallback_count), 'transcript_segment_count': firestore.Increment(transcript_segment_count), 'transcript_word_count': firestore.Increment(transcript_word_count), 'speaker_cluster_count': firestore.Increment(speaker_cluster_count), 'identified_speaker_cluster_count': firestore.Increment(identified_speaker_cluster_count), + 'identity_match_count': firestore.Increment(identity_match_count), 'provider_speaker_count': firestore.Increment(provider_speaker_count), 'mapped_speaker_count': firestore.Increment(mapped_speaker_count), 'mapped_person_count': firestore.Increment(mapped_person_count), 'unmapped_speaker_count': firestore.Increment(unmapped_speaker_count), + 'unknown_speaker_count': firestore.Increment(unknown_speaker_count), + 'unknown_speaker_duration_seconds': firestore.Increment(unknown_speaker_duration_seconds), + 'split_count': firestore.Increment(split_count), 'embedding_extraction_failure_count': firestore.Increment(embedding_extraction_failure_count), 'last_updated': _utc_now(), } @@ -501,6 +537,7 @@ def _empty_rollup(day: str, provider: str, model: str, workload: str) -> dict[st 'raw_audio_seconds': 0.0, 'speech_active_seconds': 0.0, 'billable_seconds': 0.0, + 'chunk_duration_seconds': 0.0, 'estimated_cost_usd': 0.0, 'retry_count': 0, 'fallback_count': 0, @@ -508,10 +545,14 @@ def _empty_rollup(day: str, provider: str, model: str, workload: str) -> dict[st 'transcript_word_count': 0, 'speaker_cluster_count': 0, 'identified_speaker_cluster_count': 0, + 'identity_match_count': 0, 'provider_speaker_count': 0, 'mapped_speaker_count': 0, 'mapped_person_count': 0, 'unmapped_speaker_count': 0, + 'unknown_speaker_count': 0, + 'unknown_speaker_duration_seconds': 0.0, + 'split_count': 0, 'embedding_extraction_failure_count': 0, 'identity_confidence_counts': {}, 'last_updated': _utc_now(), @@ -526,6 +567,7 @@ def _add_run_to_rollup(rollup: dict[str, Any], data: dict[str, Any]) -> None: 'raw_audio_seconds', 'speech_active_seconds', 'billable_seconds', + 'chunk_duration_seconds', 'estimated_cost_usd', 'retry_count', 'fallback_count', @@ -533,10 +575,14 @@ def _add_run_to_rollup(rollup: dict[str, Any], data: dict[str, Any]) -> None: 'transcript_word_count', 'speaker_cluster_count', 'identified_speaker_cluster_count', + 'identity_match_count', 'provider_speaker_count', 'mapped_speaker_count', 'mapped_person_count', 'unmapped_speaker_count', + 'unknown_speaker_count', + 'unknown_speaker_duration_seconds', + 'split_count', 'embedding_extraction_failure_count', ): rollup[field] += data.get(field, 0) or 0 diff --git a/backend/tests/unit/test_background_provider_service.py b/backend/tests/unit/test_background_provider_service.py index 34c1c969972..39eaf48c210 100644 --- a/backend/tests/unit/test_background_provider_service.py +++ b/backend/tests/unit/test_background_provider_service.py @@ -17,7 +17,7 @@ os.environ.setdefault('DEEPGRAM_API_KEY', 'fake-for-test') os.environ.setdefault('ASSEMBLYAI_API_KEY', 'fake-for-test') -from models.transcript_segment import ProviderTranscriptResult, ProviderTranscriptWord # noqa: E402 +from models.transcript_segment import ProviderTranscriptResult, ProviderTranscriptWord, TranscriptSegment # noqa: E402 from utils.stt import provider_service # noqa: E402 from utils.stt.provider_costs import estimate_prerecorded_provider_cost_usd # noqa: E402 from utils.stt.providers import ( # noqa: E402 @@ -610,10 +610,64 @@ def test_provider_service_counts_label_only_identified_clusters(): assert finalize_run.call_args.kwargs['speaker_cluster_count'] == 2 assert finalize_run.call_args.kwargs['identified_speaker_cluster_count'] == 1 + assert finalize_run.call_args.kwargs['identity_match_count'] == 1 assert finalize_run.call_args.kwargs['provider_speaker_count'] == 2 assert finalize_run.call_args.kwargs['mapped_speaker_count'] == 1 assert finalize_run.call_args.kwargs['mapped_person_count'] == 1 assert finalize_run.call_args.kwargs['unmapped_speaker_count'] == 1 + assert finalize_run.call_args.kwargs['unknown_speaker_count'] == 1 + assert finalize_run.call_args.kwargs['unknown_speaker_duration_seconds'] == 0.4 + assert finalize_run.call_args.kwargs['chunk_duration_seconds'] == 2.0 + + +def test_provider_service_records_split_count_from_local_cluster_marker(): + result = ProviderTranscriptResult(provider='assemblyai', model='universal-2', duration=2.0) + segments = [ + TranscriptSegment( + text='first speaker', + is_user=False, + start=0.0, + end=0.8, + provider_cluster_id='A::local_part:0', + speaker_identity_state='unknown', + ), + TranscriptSegment( + text='known speaker', + is_user=False, + person_id='person-b', + start=1.0, + end=1.8, + provider_cluster_id='B', + speaker_identity_state='identified', + speaker_identity_confidence=0.93, + ), + ] + + with patch.object(provider_service, 'finalize_provider_run') as finalize_run: + provider_service._finalize_run( + 'run-split', + result, + STTWorkload.background, + provider_service.datetime.now(provider_service.timezone.utc), + 'succeeded', + retry_count=0, + raw_audio_seconds=2.0, + segments=segments, + ) + + assert finalize_run.call_args.kwargs['split_count'] == 1 + assert finalize_run.call_args.kwargs['unknown_speaker_count'] == 1 + assert finalize_run.call_args.kwargs['unknown_speaker_duration_seconds'] == 0.8 + + +def test_provider_service_classifies_timeout_fallback_reason(): + timeout = provider_service.ProviderTranscriptionRetriesExhausted(provider_service.AssemblyAITimeoutError(), 1) + + assert provider_service._fallback_reason_from_exception(timeout) == 'provider_timeout' + assert ( + provider_service._fallback_reason_from_exception(RuntimeError('temporary provider failure')) + == 'provider_failure' + ) def test_provider_service_preserves_assemblyai_labels_for_identity_metrics(): diff --git a/backend/tests/unit/test_transcription_provider_usage.py b/backend/tests/unit/test_transcription_provider_usage.py index 33df99f3e8d..af6fe5f0ffb 100644 --- a/backend/tests/unit/test_transcription_provider_usage.py +++ b/backend/tests/unit/test_transcription_provider_usage.py @@ -127,6 +127,7 @@ def test_create_and_finalize_provider_run_writes_ledger_rollup_and_metrics(monke raw_audio_seconds=60.0, speech_active_seconds=42.0, billable_seconds=60.0, + chunk_duration_seconds=15.0, estimated_cost_usd=0.37, retry_count=1, fallback_count=0, @@ -134,6 +135,10 @@ def test_create_and_finalize_provider_run_writes_ledger_rollup_and_metrics(monke transcript_word_count=140, speaker_cluster_count=3, identified_speaker_cluster_count=2, + identity_match_count=2, + unknown_speaker_count=1, + unknown_speaker_duration_seconds=7.5, + split_count=1, identity_confidence_summary={'high': 2, 'unknown': 1}, artifact_refs={'provider_result': 'gs://bucket/result.json'}, ) @@ -146,8 +151,13 @@ def test_create_and_finalize_provider_run_writes_ledger_rollup_and_metrics(monke finalized = run_doc.set_calls[1]['data'] assert finalized['status'] == 'success' assert finalized['timing']['latency_ms'] == 5000 + assert finalized['chunk_duration_seconds'] == 15.0 assert finalized['retry_count'] == 1 assert finalized['fallback'] is None + assert finalized['identity_match_count'] == 2 + assert finalized['unknown_speaker_count'] == 1 + assert finalized['unknown_speaker_duration_seconds'] == 7.5 + assert finalized['split_count'] == 1 assert 'transcript_text' not in finalized assert 'words' not in finalized @@ -155,7 +165,12 @@ def test_create_and_finalize_provider_run_writes_ledger_rollup_and_metrics(monke rollup = rollup_doc.set_calls[0]['data'] assert rollup['run_count'] == {'__increment': 1} assert rollup['raw_audio_seconds'] == {'__increment': 60.0} + assert rollup['chunk_duration_seconds'] == {'__increment': 15.0} assert rollup['estimated_cost_usd'] == {'__increment': 0.37} + assert rollup['identity_match_count'] == {'__increment': 2} + assert rollup['unknown_speaker_count'] == {'__increment': 1} + assert rollup['unknown_speaker_duration_seconds'] == {'__increment': 7.5} + assert rollup['split_count'] == {'__increment': 1} assert rollup['identity_confidence_counts.high'] == {'__increment': 2} assert emitted[0]['latency_seconds'] == 5.0 assert emitted[0]['billable_seconds'] == 60.0 @@ -169,6 +184,10 @@ def test_rejects_transcript_text_and_chunk_payloads(): usage._reject_forbidden_payload_keys({'chunks': [{'start': 0}]}) with pytest.raises(ValueError): usage._reject_forbidden_payload_keys({'artifact_refs': {'transcript': 'gs://bucket/transcript.txt'}}) + with pytest.raises(ValueError): + usage._reject_forbidden_payload_keys({'provider': {'api_key': 'secret-aa-key'}}) + with pytest.raises(ValueError): + usage._reject_forbidden_payload_keys({'full_transcript_text': 'hello world'}) def test_utc_daily_bucket_and_rollup_rebuild(monkeypatch): @@ -184,6 +203,7 @@ def test_utc_daily_bucket_and_rollup_rebuild(monkeypatch): 'raw_audio_seconds': 10, 'speech_active_seconds': 6, 'billable_seconds': 10, + 'chunk_duration_seconds': 15, 'estimated_cost_usd': 0.1, 'retry_count': 1, 'fallback_count': 0, @@ -191,6 +211,10 @@ def test_utc_daily_bucket_and_rollup_rebuild(monkeypatch): 'transcript_word_count': 40, 'speaker_cluster_count': 2, 'identified_speaker_cluster_count': 1, + 'identity_match_count': 1, + 'unknown_speaker_count': 1, + 'unknown_speaker_duration_seconds': 3, + 'split_count': 1, 'identity_confidence_summary': {'high': 1}, } ) @@ -213,7 +237,12 @@ def test_utc_daily_bucket_and_rollup_rebuild(monkeypatch): assert rollup['run_count'] == 1 assert rollup['raw_audio_seconds'] == 10.0 + assert rollup['chunk_duration_seconds'] == 15.0 assert rollup['status_counts'] == {'success': 1} + assert rollup['identity_match_count'] == 1.0 + assert rollup['unknown_speaker_count'] == 1.0 + assert rollup['unknown_speaker_duration_seconds'] == 3.0 + assert rollup['split_count'] == 1.0 assert rollup['identity_confidence_counts'] == {'high': 1} diff --git a/backend/utils/stt/provider_service.py b/backend/utils/stt/provider_service.py index 21c58014baf..2e4f7719946 100644 --- a/backend/utils/stt/provider_service.py +++ b/backend/utils/stt/provider_service.py @@ -8,7 +8,7 @@ from deepgram import DeepgramClient, DeepgramClientOptions from models.transcript_segment import ProviderTranscriptResult, TranscriptSegment -from utils.stt.assemblyai_adapter import AssemblyAIAsyncTranscriptionProvider +from utils.stt.assemblyai_adapter import AssemblyAIAsyncTranscriptionProvider, AssemblyAITimeoutError from utils.stt.conversation_reconstructor import reconstruct_conversation from utils.stt.deepgram_adapter import DeepgramPrerecordedTranscriptionProvider from utils.stt.deepgram_adapter import provider_result_to_legacy_words @@ -48,6 +48,8 @@ _DG_TIMEOUT = httpx.Timeout(connect=10.0, read=120.0, write=30.0, pool=10.0) _DEEPGRAM_OPTIONS = DeepgramClientOptions(options={"keepalive": "true"}) _DEEPGRAM_CLIENT = DeepgramClient(os.getenv('DEEPGRAM_API_KEY'), _DEEPGRAM_OPTIONS) +_LOCAL_CLUSTER_SPLIT_MARKER = '::local_part:' +_UNKNOWN_SPEAKER_STATES = {'unknown', 'unassigned', 'legacy_ambiguous'} def create_provider_run(**kwargs) -> str: @@ -268,11 +270,13 @@ def transcribe_url( ) fallback_provider_name = _resolve_usable_fallback_prerecorded_provider(provider_name, workload) if fallback_provider_name: + fallback_reason = _fallback_reason_from_exception(e) logger.warning( - 'provider prerecorded url transcription falling back workload=%s from_provider=%s to_provider=%s: %s', + 'provider prerecorded url transcription falling back workload=%s from_provider=%s to_provider=%s reason=%s: %s', workload.value, provider_name.value, fallback_provider_name.value, + fallback_reason, e, ) return _transcribe_url_with_provider( @@ -290,6 +294,7 @@ def transcribe_url( skip_n_seconds=skip_n_seconds, raw_audio_seconds=raw_audio_seconds, fallback_from_provider=provider_name.value, + fallback_reason=fallback_reason, ) raise RuntimeError(f'{provider_name.value} transcription failed after 2 attempts: {e}') @@ -362,11 +367,13 @@ def transcribe_bytes( ) fallback_provider_name = _resolve_usable_fallback_prerecorded_provider(provider_name, workload) if fallback_provider_name: + fallback_reason = _fallback_reason_from_exception(e) logger.warning( - 'provider prerecorded bytes transcription falling back workload=%s from_provider=%s to_provider=%s: %s', + 'provider prerecorded bytes transcription falling back workload=%s from_provider=%s to_provider=%s reason=%s: %s', workload.value, provider_name.value, fallback_provider_name.value, + fallback_reason, e, ) return _transcribe_bytes_with_provider( @@ -386,6 +393,7 @@ def transcribe_bytes( skip_n_seconds=skip_n_seconds, raw_audio_seconds=raw_audio_seconds, fallback_from_provider=provider_name.value, + fallback_reason=fallback_reason, ) raise RuntimeError(f'{provider_name.value} transcription failed after 2 attempts: {e}') @@ -430,6 +438,7 @@ def _transcribe_url_with_provider( skip_n_seconds: int = 0, raw_audio_seconds: float = 0.0, fallback_from_provider: Optional[str] = None, + fallback_reason: str = 'provider_failure', ) -> PrerecordedTranscriptionResponse: provider = _get_prerecorded_provider(provider_name) started_at = datetime.now(timezone.utc) @@ -454,6 +463,7 @@ def _transcribe_url_with_provider( raw_audio_seconds, skip_n_seconds, fallback_from_provider=fallback_from_provider, + fallback_reason=fallback_reason, detected_language=detected_language, ) except Exception as e: @@ -487,6 +497,7 @@ def _transcribe_bytes_with_provider( skip_n_seconds: int = 0, raw_audio_seconds: float = 0.0, fallback_from_provider: Optional[str] = None, + fallback_reason: str = 'provider_failure', ) -> PrerecordedTranscriptionResponse: provider = _get_prerecorded_provider(provider_name) started_at = datetime.now(timezone.utc) @@ -513,6 +524,7 @@ def _transcribe_bytes_with_provider( raw_audio_seconds, skip_n_seconds, fallback_from_provider=fallback_from_provider, + fallback_reason=fallback_reason, detected_language=detected_language, ) except Exception as e: @@ -538,6 +550,7 @@ def _build_success_response( raw_audio_seconds: float, skip_n_seconds: int, fallback_from_provider: Optional[str] = None, + fallback_reason: str = 'provider_failure', detected_language: Optional[str] = None, ) -> PrerecordedTranscriptionResponse: segments = reconstruct_conversation(provider_result, skip_n_seconds=skip_n_seconds) @@ -553,6 +566,7 @@ def _build_success_response( segments=segments, fallback_count=1 if fallback_from_provider else 0, fallback_provider=fallback_from_provider, + fallback_reason=fallback_reason, ) return PrerecordedTranscriptionResponse( result=provider_result, @@ -617,6 +631,13 @@ def _provider_error_from_exception(error: Exception) -> Exception: return error +def _fallback_reason_from_exception(error: Exception) -> str: + provider_error = _provider_error_from_exception(error) + if isinstance(provider_error, (AssemblyAITimeoutError, httpx.TimeoutException, TimeoutError)): + return 'provider_timeout' + return 'provider_failure' + + def _unpack_provider_result(result, return_language: bool) -> Tuple[ProviderTranscriptResult, Optional[str]]: if return_language: transcript_result, detected_language = result @@ -660,6 +681,7 @@ def _finalize_run( segments: List[TranscriptSegment], fallback_count: int = 0, fallback_provider: Optional[str] = None, + fallback_reason: str = 'provider_failure', ) -> None: if not run_id: return @@ -667,6 +689,7 @@ def _finalize_run( clusters = {_segment_cluster_key(segment) for segment in segments if _segment_cluster_key(segment)} confidences = [segment.speaker_identity_confidence for segment in segments] identity_metrics = speaker_identity_metrics(segments) + unknown_speaker_duration_seconds = _unknown_speaker_duration_seconds(segments) try: finalize_provider_run( run_id=run_id, @@ -678,6 +701,7 @@ def _finalize_run( raw_audio_seconds=raw_audio_seconds, speech_active_seconds=raw_audio_seconds, billable_seconds=billable_seconds, + chunk_duration_seconds=raw_audio_seconds, estimated_cost_usd=estimate_prerecorded_provider_cost_usd( provider=result.provider, model=result.model, @@ -690,14 +714,19 @@ def _finalize_run( transcript_word_count=len(result.words), speaker_cluster_count=len(clusters), identified_speaker_cluster_count=_identified_cluster_count(segments), + identity_match_count=identity_metrics['mapped_speaker_count'], identity_confidence_summary=summarize_identity_confidences(confidences), provider_speaker_count=identity_metrics['provider_speaker_count'], mapped_speaker_count=identity_metrics['mapped_speaker_count'], mapped_person_count=identity_metrics['mapped_person_count'], unmapped_speaker_count=identity_metrics['unmapped_speaker_count'], + unknown_speaker_count=identity_metrics['unmapped_speaker_count'], + unknown_speaker_duration_seconds=unknown_speaker_duration_seconds, + split_count=_split_count(clusters), embedding_extraction_failure_count=identity_metrics['embedding_extraction_failure_count'], artifact_refs=_provider_artifact_refs(result), fallback_provider=fallback_provider, + fallback_reason=fallback_reason, ) except Exception as e: logger.warning('failed to finalize transcription provider run ledger run_id=%s: %s', run_id, e) @@ -733,6 +762,7 @@ def _finalize_failed_run( raw_audio_seconds=raw_audio_seconds, speech_active_seconds=raw_audio_seconds, billable_seconds=billable_seconds, + chunk_duration_seconds=raw_audio_seconds, estimated_cost_usd=estimate_prerecorded_provider_cost_usd( provider=provider, model=model, @@ -818,6 +848,18 @@ def speaker_identity_metrics(segments: List[TranscriptSegment]) -> dict: } +def _unknown_speaker_duration_seconds(segments: List[TranscriptSegment]) -> float: + duration = 0.0 + for segment in segments: + if segment.speaker_identity_state in _UNKNOWN_SPEAKER_STATES: + duration += max(0.0, segment.end - segment.start) + return round(duration, 3) + + +def _split_count(clusters: set[str]) -> int: + return sum(1 for cluster in clusters if _LOCAL_CLUSTER_SPLIT_MARKER in str(cluster)) + + def _identified_cluster_count(segments: List[TranscriptSegment]) -> int: return len( { diff --git a/docs/doc/developer/backend/assemblyai_background_rollout.mdx b/docs/doc/developer/backend/assemblyai_background_rollout.mdx index 30b3afa8d63..3ffe1a558d2 100644 --- a/docs/doc/developer/backend/assemblyai_background_rollout.mdx +++ b/docs/doc/developer/backend/assemblyai_background_rollout.mdx @@ -193,7 +193,25 @@ evidence rather than provider labels. Provider run ledgers live in `transcription_provider_runs`, and daily rollups live in `transcription_provider_usage_daily`. Ledger payload guards reject raw -audio bytes, transcript text, words, utterances, and similar high-risk payloads. +audio bytes, transcript text, words, utterances, API keys, tokens, secrets, and +similar high-risk payloads. + +Privacy-safe provider run fields include: + +- `provider`, `model`, `workload`, and `status` +- `raw_audio_seconds`, `speech_active_seconds`, `billable_seconds`, and `chunk_duration_seconds` +- `retry_count`, `fallback_count`, and `fallback.reason` +- `timing.latency_ms` +- `estimated_cost_usd` +- `speaker_cluster_count`, `split_count`, `identified_speaker_cluster_count`, and `identity_match_count` +- `provider_speaker_count`, `mapped_speaker_count`, `mapped_person_count`, `unmapped_speaker_count` +- `unknown_speaker_count` and `unknown_speaker_duration_seconds` +- `identity_confidence_summary` buckets + +Do not add raw audio, full transcript text, provider `words`, provider +`utterances`, user secrets, or provider API responses to these ledgers. Store +only opaque artifact references such as `provider_result_id` when a separate +retention-controlled artifact is required. ## Dashboards And Reports @@ -211,9 +229,129 @@ Expected rollout dashboards should use `/metrics` counters and histograms for: Expected cost and quality review should inspect: - `transcription_provider_usage_daily` by provider, model, workload, and day. -- `transcription_provider_runs` for canary failures, fallback direction, retry counts, latency, billable seconds, and estimated cost. +- `transcription_provider_runs` for canary failures, fallback direction, fallback reason, retry counts, latency, chunk duration, billable seconds, estimated cost, split count, identity match count, and unknown speaker fields. - `backend/scripts/stt/provider_comparison_gate.py` reports for transcript drift, speaker cluster stability, identity quality, fallback rate, failure rate, and economics. +## Rollout Stages + +1. `0%`: set `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=deepgram`. AssemblyAI is not + used for desktop background. Keep `/metrics` and daily Deepgram cost/hour as + the baseline. +2. `shadow_only`: set `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=shadow_only`. + Production background requests stay on Deepgram while AssemblyAI readiness + work and offline comparisons continue. +3. Small canary: set `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=assemblyai` only for + the selected backend environment or cohort. Require `ASSEMBLYAI_API_KEY`, + `DEEPGRAM_API_KEY`, `ASSEMBLYAI_PRERECORDED_STT_ENABLED=true`, + `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS` containing `background`, and + `ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED=true`. +4. Expanded canary: increase the eligible cohort only after the health + thresholds below pass for the full review window. +5. Rollback: set `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=deepgram` for immediate + background rollback. For broader passive prerecorded rollback, set + `ASSEMBLYAI_PRERECORDED_STT_ENABLED=false` or remove `background` from + `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS`. + +## Health Thresholds + +Rollback or hold expansion when any condition is true for the review window: + +- AssemblyAI background error rate is higher than Deepgram by more than 1 + percentage point, or exceeds 3% absolute. +- Fallback rate exceeds 5% or doubles the previous healthy window. +- Timeout rate exceeds 1% or any timeout spike correlates with support reports. +- Empty transcript rate exceeds the Deepgram baseline by more than 1 percentage + point. +- Cost/hour exceeds the approved Deepgram margin. The current baseline is about + `$0.264/hour` for Deepgram and `$0.734/hour` for AssemblyAI; AssemblyAI cost is + a rollout blocker unless the rollout doc names a quality or product reason for + accepting the margin. +- Unknown or low-confidence speaker rate exceeds the Deepgram baseline by more + than 5 percentage points, or `unknown_speaker_duration_seconds` grows enough + to make speaker labels visibly worse. +- App-visible speaker inflation or `split_count` rises without a matching + speaker purity improvement in replay reports. +- User correction/self-voice review rate rises above the Deepgram baseline when + that signal is available. +- Support complaints mention missing transcripts, wrong speakers, degraded + background transcription, or account/billing surprises. + +## Operations Runbook + +Enable a small canary: + +```bash +ASSEMBLYAI_API_KEY= +DEEPGRAM_API_KEY= +ASSEMBLYAI_PRERECORDED_STT_ENABLED=true +ASSEMBLYAI_PRERECORDED_STT_WORKLOADS=sync,background,postprocess +ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED=true +ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=assemblyai +``` + +Disable AssemblyAI for background: + +```bash +ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=deepgram +``` + +Disable AssemblyAI for all passive prerecorded workloads: + +```bash +ASSEMBLYAI_PRERECORDED_STT_ENABLED=false +``` + +Validate after each config change: + +1. Confirm `/metrics` has background samples for + `transcription_provider_requests_total`, + `transcription_provider_latency_seconds`, + `transcription_provider_fallbacks_total`, and + `transcription_provider_audio_seconds_total`. +2. Query `transcription_provider_usage_daily` for `provider=assemblyai`, + `model=universal-2`, and `workload=background`; compute cost/hour as + `estimated_cost_usd / (billable_seconds / 3600)`. +3. Inspect recent `transcription_provider_runs` without opening transcript or + audio artifacts. Check `status`, `fallback.reason`, `retry_count`, + `timing.latency_ms`, `chunk_duration_seconds`, `estimated_cost_usd`, + `split_count`, `identity_match_count`, `unknown_speaker_count`, and + `unknown_speaker_duration_seconds`. +4. Run the offline gate before expansion: + `python3 scripts/stt/provider_comparison_gate.py --manifest tests/fixtures/stt_provider_eval/manifest.json`. +5. If a rollback threshold trips, set `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=deepgram`, + verify new background runs show `provider=deepgram`, and leave + `ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED=true` until the incident review + is complete. + +Debugging checklist: + +- `missing_assemblyai_api_key`: verify `ASSEMBLYAI_API_KEY` is present only in + secret storage and never in logs, ledgers, or dispatch artifacts. +- `provider_failure`: inspect provider status, timeout settings + `ASSEMBLYAI_POLL_INTERVAL_SECONDS` and `ASSEMBLYAI_MAX_POLL_SECONDS`, retry + counts, and Deepgram fallback health. +- High cost/hour: compare `billable_seconds`, `chunk_duration_seconds`, and + request count against the fixed-15s background profile before changing + chunking. +- Speaker regressions: compare `split_count`, `identity_match_count`, + `unknown_speaker_count`, `unknown_speaker_duration_seconds`, and replay + speaker-purity reports. Do not reduce split-before-match protections unless + replay gates show no regression. + +## Data Handling + +AssemblyAI receives uploaded audio only for eligible prerecorded workloads after +the routing flags above allow it. User consent and BYOK behavior must be +consistent with the desktop background transcription setting and the optional +`X-BYOK-AssemblyAI` header. Regional data handling follows the configured +AssemblyAI account and endpoint; do not enable a region or cohort where vendor +retention, residency, or customer terms are not approved. + +Secrets are configuration-only. Never place `ASSEMBLYAI_API_KEY`, user BYOK +headers, provider bearer tokens, raw audio, full transcript text, provider +`words`, or provider `utterances` in metrics, logs, Firestore ledgers, CAR +artifacts, screenshots, or support notes. + Fixture validation: ```bash From 0a8537531107934d09b38c16f28cd12709afcab4 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Sun, 24 May 2026 13:16:20 +0700 Subject: [PATCH 42/44] docs: add AssemblyAI background canary readiness plan --- .../backend/assemblyai_background_rollout.mdx | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/docs/doc/developer/backend/assemblyai_background_rollout.mdx b/docs/doc/developer/backend/assemblyai_background_rollout.mdx index 3ffe1a558d2..5e6df9ea459 100644 --- a/docs/doc/developer/backend/assemblyai_background_rollout.mdx +++ b/docs/doc/developer/backend/assemblyai_background_rollout.mdx @@ -19,6 +19,45 @@ Deepgram remains the provider for mobile/BLE `/v4/listen`, realtime assistant streaming, Hold-to-Talk streaming, voice-message finalize semantics, and background fallback. +## Current Readiness Decision + +As of the TICKET-028 offline readiness gate, AssemblyAI is approved only for a +small desktop background canary. Broad defaulting remains blocked until the +canary supplies privacy-safe real-session evidence and the remaining +saved-policy speaker-purity gap is either closed or explicitly accepted by the +product/business owner. + +The latest offline command was: + +```bash +cd backend +python3 scripts/stt/provider_comparison_gate.py \ + --manifest tests/fixtures/stt_provider_eval/manifest.json \ + --output-md /tmp/stt-provider-eval.md \ + --output-json /tmp/stt-provider-eval.json +``` + +It passed 10 fixture cases with no failures or warnings and reported all +required production strategies: + +| Strategy | Provider | Speaker purity | Empty transcript | Fallback | Timeout/error | Cost/hour | +| --- | --- | ---: | ---: | ---: | ---: | ---: | +| `always_deepgram` | Deepgram | 100.0% | 10.0% | 0.0% | 0.0% | `$0.158` | +| `always_assemblyai` | AssemblyAI | 98.9% | 10.0% | 1.0% | 0.0% | `$0.286` | +| `current_policy` | mixed | 100.0% | 10.0% | 0.0% | 0.0% | `$0.217` | +| `shadow_only` | Deepgram | 100.0% | 10.0% | 0.0% | 0.0% | `$0.158` | + +The gate still marked AssemblyAI as `limited` because the +`saved_real_policy_router_outputs` scenario had 95.5% AssemblyAI speaker-word +purity versus 100.0% for Deepgram. The mitigation remains to keep +split-before-identity enabled, keep unresolved provider clusters chunk-scoped, +and promote only through real-session canary metrics that include both purity +and fragmentation budgets. + +Selected rollout decision: enable AssemblyAI for a small canary cohort with +Deepgram fallback and hard rollback thresholds. Do not set AssemblyAI as the +broad background default from offline evidence alone. + Eligible AssemblyAI workloads are: - `sync` @@ -252,6 +291,79 @@ Expected cost and quality review should inspect: `ASSEMBLYAI_PRERECORDED_STT_ENABLED=false` or remove `background` from `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS`. +## TICKET-028 Canary Plan + +Initial cohort: + +- Scope: desktop background recording only, excluding mobile/BLE listen, + realtime assistant streaming, Hold-to-Talk, and voice-message finalize paths. +- Size: 1% of eligible desktop background traffic or an internal opt-in cohort + of at least 20 active background users, whichever is safer for the deployment + environment. +- Duration: minimum 7 consecutive days and at least 200 successful background + chunks before expansion review. +- Configuration: `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=assemblyai`, + `ASSEMBLYAI_PRERECORDED_STT_ENABLED=true`, + `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS=sync,background,postprocess`, and + `ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED=true` for only the selected + backend environment or cohort. +- Required fallback: keep Deepgram credentials available and verify recent + fallback runs record `fallback.from_provider=assemblyai`, + `fallback.to_provider=deepgram`, and a classified `fallback.reason`. + +Owner/on-call action: + +- The backend on-call owns first response during the canary window. +- Roll back immediately when a hard threshold below trips. +- Hold expansion and open a follow-up ticket when any warning threshold is + close to tripping or when support reports speaker/transcript regressions not + visible in the aggregate metrics. +- Product/business owner approval is required before accepting any remaining + cost or speaker-quality regression for expanded rollout. + +Expansion success criteria: + +- Offline provider gate continues to pass with no failures: + `provider_comparison_gate.py --manifest tests/fixtures/stt_provider_eval/manifest.json`. +- AssemblyAI background error rate is not more than 1 percentage point above + the Deepgram baseline and remains below 3% absolute. +- Fallback rate stays at or below 5% and does not double the previous healthy + window. +- Timeout rate stays at or below 1%. +- Empty transcript rate is not more than 1 percentage point above Deepgram. +- Unknown or low-confidence speaker rate is not more than 5 percentage points + above Deepgram, and unknown speaker duration does not create visible speaker + degradation. +- App-visible speaker inflation stays within the replay budget and does not + rise without a matching purity improvement. +- Cost/hour stays within the approved margin. The current offline gate shows + AssemblyAI at about 1.8x Deepgram (`$0.286/hour` versus `$0.158/hour`), while + older saved-provider evidence showed a wider gap; expansion requires an + explicit owner decision if this remains materially above Deepgram. +- No sustained increase in user corrections, self-voice review failures, + support complaints, or billing surprises. + +Expansion path: + +1. Hold at 1% until the full review window passes. +2. Expand to 5% only after the success criteria above are met. +3. Expand to 25% only after another full review window passes with no hard + threshold trips. +4. Broad defaulting requires either measured parity with Deepgram on the canary + metrics or documented product/business owner approval for the remaining + regression, plus the rollback command below staying tested. + +Rollback command/config: + +```bash +ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=deepgram +``` + +After rollback, confirm new background runs show `provider=deepgram`, verify +`transcription_provider_fallbacks_total` is not continuing to rise from +AssemblyAI failures, and keep `ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED=true` +until the incident review is complete. + ## Health Thresholds Rollback or hold expansion when any condition is true for the review window: From bc143cd0b51ef1ddc1fc3446189c8d9cb6524f8e Mon Sep 17 00:00:00 2001 From: David Zhang Date: Sun, 24 May 2026 19:53:11 +0700 Subject: [PATCH 43/44] Make AssemblyAI the background default policy --- .../scripts/stt/provider_comparison_gate.md | 4 +- .../fixtures/stt_provider_eval/manifest.json | 14 +- .../unit/test_background_provider_service.py | 6 +- .../tests/unit/test_provider_evaluation.py | 4 +- backend/utils/stt/provider_evaluation.py | 23 +-- backend/utils/stt/providers.py | 2 +- .../backend/assemblyai_background_rollout.mdx | 135 +++++++----------- 7 files changed, 74 insertions(+), 114 deletions(-) diff --git a/backend/scripts/stt/provider_comparison_gate.md b/backend/scripts/stt/provider_comparison_gate.md index b300d3fd586..1e56e0d40a4 100644 --- a/backend/scripts/stt/provider_comparison_gate.md +++ b/backend/scripts/stt/provider_comparison_gate.md @@ -11,7 +11,7 @@ python3 scripts/stt/provider_comparison_gate.py \ The default command replays synthetic and saved provider outputs only. It does not require `ASSEMBLYAI_API_KEY` or `DEEPGRAM_API_KEY`. -The report compares `always_deepgram`, `always_assemblyai`, `current_policy`, and `shadow_only`. It includes speaker safety, default viability, and canary readiness gates plus an AssemblyAI gap report that names the limiting scenario, likely cause, and rollout mitigation. +The report compares `always_deepgram`, `always_assemblyai`, `current_policy`, and `shadow_only`. `current_policy` means AssemblyAI default for passive background workloads. `shadow_only` is retained only as a rollback/diagnostic comparator. The report includes speaker safety, default viability, and rollout readiness gates plus an AssemblyAI gap report that names the limiting scenario, likely cause, and mitigation. The fixture manifest covers clean turns, fast turns, overlap, sparse speech, low-signal/no-speech, multilingual turns, duplicate chunk replay, provider failure/fallback, saved real-provider E2E output, and saved policy-router output. @@ -24,4 +24,4 @@ python3 scripts/stt/provider_comparison_gate.py \ --live ``` -Synthetic and saved-output gates are necessary but insufficient for broad defaulting. TICKET-028 must turn the latest gap-closing report into an AssemblyAI canary/default rollout plan backed by privacy-safe real-session metrics. +Synthetic and saved-output gates are necessary but insufficient for default health decisions. Use the latest gap-closing report to operate the AssemblyAI default with Deepgram fallback and privacy-safe real-session metrics. diff --git a/backend/tests/fixtures/stt_provider_eval/manifest.json b/backend/tests/fixtures/stt_provider_eval/manifest.json index 10f82985891..54a3e5c0d60 100644 --- a/backend/tests/fixtures/stt_provider_eval/manifest.json +++ b/backend/tests/fixtures/stt_provider_eval/manifest.json @@ -3,7 +3,6 @@ { "id": "fixture_good_meeting", "scenario": "saved_real_provider_e2e", - "current_policy_provider": "deepgram", "deepgram_fixture": "fixture_good_meeting.deepgram.json", "assemblyai_fixture": "fixture_good_meeting.assemblyai.json", "deepgram_rollup": "fixture_good_meeting.deepgram.rollup.json", @@ -12,7 +11,6 @@ { "id": "synthetic_clean_turns", "scenario": "clean_turns", - "current_policy_provider": "assemblyai", "deepgram": { "transcript": { "segments": [ @@ -80,7 +78,6 @@ { "id": "synthetic_fast_turns", "scenario": "fast_turns", - "current_policy_provider": "assemblyai", "deepgram": { "transcript": { "segments": [ @@ -114,7 +111,7 @@ }, "raw_audio_seconds": 2.0, "billable_seconds": 2.0, - "estimated_cost_usd": 0.00008, + "estimated_cost_usd": 8e-05, "latency_seconds": 0.9, "fallback_count": 0, "split_count": 1 @@ -164,7 +161,6 @@ { "id": "synthetic_overlap", "scenario": "overlap", - "current_policy_provider": "assemblyai", "deepgram": { "transcript": { "segments": [ @@ -240,7 +236,6 @@ { "id": "synthetic_sparse_speech", "scenario": "sparse_speech", - "current_policy_provider": "assemblyai", "deepgram": { "transcript": { "segments": [ @@ -295,7 +290,6 @@ { "id": "synthetic_no_speech", "scenario": "low_signal_no_speech", - "current_policy_provider": "deepgram", "deepgram": { "transcript": { "segments": [] @@ -334,7 +328,6 @@ { "id": "synthetic_multilingual_turns", "scenario": "multilingual_turns", - "current_policy_provider": "assemblyai", "deepgram": { "transcript": { "segments": [ @@ -401,7 +394,6 @@ { "id": "synthetic_duplicate_replay", "scenario": "duplicate_chunk_replay", - "current_policy_provider": "assemblyai", "deepgram": { "transcript": { "segments": [ @@ -421,7 +413,7 @@ }, "raw_audio_seconds": 4.0, "billable_seconds": 2.0, - "estimated_cost_usd": 0.00008, + "estimated_cost_usd": 8e-05, "latency_seconds": 1.0, "fallback_count": 0 } @@ -454,7 +446,6 @@ { "id": "synthetic_provider_failure_fallback", "scenario": "provider_failure_fallback", - "current_policy_provider": "deepgram", "deepgram": { "transcript": { "segments": [ @@ -508,7 +499,6 @@ { "id": "saved_policy_router_gap", "scenario": "saved_real_policy_router_outputs", - "current_policy_provider": "deepgram", "deepgram": { "transcript": { "segments": [ diff --git a/backend/tests/unit/test_background_provider_service.py b/backend/tests/unit/test_background_provider_service.py index 39eaf48c210..b9ab6eacaea 100644 --- a/backend/tests/unit/test_background_provider_service.py +++ b/backend/tests/unit/test_background_provider_service.py @@ -115,14 +115,14 @@ def test_provider_service_finalizes_background_run_on_deepgram_when_assemblyai_d assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00076 -def test_background_routing_defaults_to_shadow_only_deepgram_until_rollout_gates_pass(monkeypatch): +def test_background_routing_defaults_to_assemblyai_for_background(monkeypatch): monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', raising=False) monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', raising=False) monkeypatch.delenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', raising=False) assert get_prerecorded_provider_name(STTWorkload.sync) == STTProviderName.assemblyai - assert get_background_provider_mode() == BackgroundProviderMode.shadow_only - assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.deepgram + assert get_background_provider_mode() == BackgroundProviderMode.assemblyai + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.assemblyai assert get_prerecorded_provider_name(STTWorkload.postprocess) == STTProviderName.assemblyai diff --git a/backend/tests/unit/test_provider_evaluation.py b/backend/tests/unit/test_provider_evaluation.py index a7a4c3b833d..7a89778035a 100644 --- a/backend/tests/unit/test_provider_evaluation.py +++ b/backend/tests/unit/test_provider_evaluation.py @@ -73,7 +73,7 @@ def test_manifest_report_includes_strategy_rollups_gap_report_and_fragmentation_ assert set(report['strategies']) == {'always_deepgram', 'always_assemblyai', 'current_policy', 'shadow_only'} assert report['strategies']['always_assemblyai']['provider'] == 'assemblyai' assert report['strategies']['shadow_only']['provider'] == 'deepgram' - assert report['strategies']['current_policy']['provider'] == 'mixed' + assert report['strategies']['current_policy']['provider'] == 'assemblyai' assert report['strategies']['always_assemblyai']['split_count'] >= 2 assert report['strategies']['always_assemblyai']['estimated_cost_per_hour_usd'] > 0 assert report['assemblyai_gap_report']['status'] == 'limited' @@ -194,4 +194,4 @@ def test_compact_markdown_report_is_review_friendly(): assert 'fixture_good_meeting' in markdown assert 'Strategy Rollup' in markdown assert 'AssemblyAI Gap Report' in markdown - assert 'TICKET-028' in markdown + assert 'AssemblyAI default readiness' in markdown diff --git a/backend/utils/stt/provider_evaluation.py b/backend/utils/stt/provider_evaluation.py index fd6524a23b2..1c5854d568d 100644 --- a/backend/utils/stt/provider_evaluation.py +++ b/backend/utils/stt/provider_evaluation.py @@ -5,6 +5,7 @@ PROVIDER_BY_STRATEGY = { 'always_deepgram': 'deepgram', 'always_assemblyai': 'assemblyai', + 'current_policy': 'assemblyai', 'shadow_only': 'deepgram', } ASSEMBLYAI_COST_PER_HOUR_USD = 0.2592 @@ -130,8 +131,8 @@ def compact_markdown_report(report: dict[str, Any]) -> str: lines.extend( [ '', - 'Synthetic and saved-output gates are necessary but insufficient for broad defaulting. ' - 'TICKET-028 must turn this gap report into an AssemblyAI canary/default rollout plan with privacy-safe real-session evidence.', + 'Synthetic and saved-output gates are necessary but insufficient for default health decisions. ' + 'Use this gap report to track AssemblyAI default readiness and rollback thresholds with privacy-safe real-session evidence.', ] ) return '\n'.join(lines) @@ -154,7 +155,7 @@ def _compare_case(case: dict[str, Any], thresholds: ProviderGateThresholds) -> d return { 'id': case.get('id') or case.get('case_id') or 'unknown', 'scenario': case.get('scenario') or case.get('type') or 'unspecified', - 'current_policy_provider': case.get('current_policy_provider') or 'deepgram', + 'current_policy_provider': case.get('current_policy_provider') or 'assemblyai', 'providers': {'deepgram': _public_summary(deepgram), 'assemblyai': _public_summary(assemblyai)}, 'comparison': comparison, 'gates': gates, @@ -389,7 +390,7 @@ def _evaluate_case_gates( comparison['average_timestamp_drift_seconds'], thresholds.max_average_timestamp_drift_seconds, 'warning', - gate_group='canary_readiness', + gate_group='rollout_readiness', ), _threshold_gate( 'assemblyai_low_confidence_identity_rate', @@ -429,7 +430,7 @@ def _evaluate_case_gates( assemblyai['timeout_error_rate'], thresholds.max_timeout_error_rate, 'failure', - gate_group='canary_readiness', + gate_group='rollout_readiness', ), _threshold_gate( 'assemblyai_purity_delta_vs_deepgram', @@ -444,7 +445,7 @@ def _evaluate_case_gates( _safe_ratio(assemblyai['latency_seconds'], deepgram['latency_seconds']), thresholds.max_latency_ratio_vs_deepgram, 'warning', - gate_group='canary_readiness', + gate_group='rollout_readiness', ), _threshold_gate( 'assemblyai_cost_ratio_vs_deepgram', @@ -461,7 +462,7 @@ def _evaluate_case_gates( { 'metric': f"{provider['provider']}_instrumentation", 'severity': 'warning', - 'gate_group': 'canary_readiness', + 'gate_group': 'rollout_readiness', 'value': None, 'threshold': 'ledger_or_rollup_required', 'message': 'missing provider ledger or rollup metrics', @@ -621,7 +622,7 @@ def _strategy_rollup(case_reports: list[dict[str, Any]], strategy: str) -> dict[ selected = [] providers = set() for case in case_reports: - provider_name = PROVIDER_BY_STRATEGY.get(strategy) or case.get('current_policy_provider') or 'deepgram' + provider_name = PROVIDER_BY_STRATEGY.get(strategy) or case.get('current_policy_provider') or 'assemblyai' providers.add(provider_name) selected.append(case['providers'][provider_name]) cost = sum(provider['estimated_cost_usd'] for provider in selected) @@ -718,7 +719,7 @@ def _likely_cause(metric: str) -> str: 'covered_speaker_count': 'provider diarization missed a speaker or no-speech gating discarded speech', 'empty_transcript_rate': 'low-signal/no-speech handling produced an empty or failed transcript', 'latency_seconds': 'AssemblyAI async job latency exceeds Deepgram for this workload shape', - 'timeout_error_rate': 'provider timeout or retry exhaustion path is not canary-safe', + 'timeout_error_rate': 'provider timeout or retry exhaustion path is not default-safe', 'estimated_cost_per_hour_usd': 'AssemblyAI billable duration or pricing is too high for default background volume', }.get(metric, 'AssemblyAI trails the Deepgram comparator on this gate') @@ -726,9 +727,9 @@ def _likely_cause(metric: str) -> str: def _mitigation(metric: str) -> str: return { 'speaker_word_purity': 'keep split-before-match enabled and gate rollout on fragmentation plus purity budgets', - 'covered_speaker_count': 'route affected low-signal cases to Deepgram fallback until canary evidence closes coverage', + 'covered_speaker_count': 'use Deepgram fallback for affected low-signal cases until AssemblyAI closes coverage', 'empty_transcript_rate': 'preserve no-speech detection and fallback controls before default promotion', - 'latency_seconds': 'limit rollout cohort and add latency SLO alerts before expanding canary', + 'latency_seconds': 'keep latency SLO alerts and fallback controls before expanding default traffic', 'timeout_error_rate': 'use Deepgram fallback and provider health gates from TICKET-027', 'estimated_cost_per_hour_usd': 'cap rollout or require explicit product tradeoff in TICKET-028', }.get(metric, 'capture in TICKET-028 rollout tradeoffs before promoting AssemblyAI') diff --git a/backend/utils/stt/providers.py b/backend/utils/stt/providers.py index 609939ac9ea..e1c04db1833 100644 --- a/backend/utils/stt/providers.py +++ b/backend/utils/stt/providers.py @@ -142,7 +142,7 @@ def assemblyai_prerecorded_fallback_enabled() -> bool: def get_background_provider_mode() -> BackgroundProviderMode: - configured = os.getenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', BackgroundProviderMode.shadow_only.value) + configured = os.getenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', BackgroundProviderMode.assemblyai.value) try: return BackgroundProviderMode(configured.strip().lower()) except ValueError: diff --git a/docs/doc/developer/backend/assemblyai_background_rollout.mdx b/docs/doc/developer/backend/assemblyai_background_rollout.mdx index 5e6df9ea459..31d0168d5ed 100644 --- a/docs/doc/developer/backend/assemblyai_background_rollout.mdx +++ b/docs/doc/developer/backend/assemblyai_background_rollout.mdx @@ -8,24 +8,29 @@ description: "Rollout gates, feature flags, instrumentation, and rollback for th ## Scope -AssemblyAI is the intended async background provider, but broad background -defaulting is still gated. Current saved-provider benchmark evidence favors -Deepgram on cost-adjusted speaker quality: Deepgram reached 99.79% speaker -purity at about $0.264/hour, while AssemblyAI reached 95.91% at about -$0.734/hour. The rollout path must close or explicitly mitigate that gap before -AssemblyAI becomes the broad default. +AssemblyAI is the default provider for passive async background transcription. +In this document, "background" means latency-tolerant prerecorded work: +`sync`, desktop `background`, and `postprocess`. These jobs run after audio has +already been captured and can tolerate async provider polling, retries, and +Deepgram fallback. + +Current saved-provider benchmark evidence still favors Deepgram on some +cost-adjusted speaker-quality fixtures: Deepgram reached 99.79% speaker purity +at about $0.264/hour, while AssemblyAI reached 95.91% at about $0.734/hour. +That benchmark is a constraint on the AssemblyAI pipeline, not a reason to keep +Deepgram as the primary passive background provider. The production stance is: +AssemblyAI first, conservative speaker safety, privacy-safe monitoring, and +Deepgram fallback. Deepgram remains the provider for mobile/BLE `/v4/listen`, realtime assistant streaming, Hold-to-Talk streaming, voice-message finalize semantics, and background fallback. -## Current Readiness Decision +## Current Default Decision -As of the TICKET-028 offline readiness gate, AssemblyAI is approved only for a -small desktop background canary. Broad defaulting remains blocked until the -canary supplies privacy-safe real-session evidence and the remaining -saved-policy speaker-purity gap is either closed or explicitly accepted by the -product/business owner. +As of the TICKET-028 readiness gate, AssemblyAI is the default for eligible +passive background workloads when credentials and workload flags allow it. +Deepgram remains the fallback and explicit rollback provider. The latest offline command was: @@ -44,19 +49,18 @@ required production strategies: | --- | --- | ---: | ---: | ---: | ---: | ---: | | `always_deepgram` | Deepgram | 100.0% | 10.0% | 0.0% | 0.0% | `$0.158` | | `always_assemblyai` | AssemblyAI | 98.9% | 10.0% | 1.0% | 0.0% | `$0.286` | -| `current_policy` | mixed | 100.0% | 10.0% | 0.0% | 0.0% | `$0.217` | +| `current_policy` | AssemblyAI | 98.9% | 10.0% | 1.0% | 0.0% | `$0.286` | | `shadow_only` | Deepgram | 100.0% | 10.0% | 0.0% | 0.0% | `$0.158` | The gate still marked AssemblyAI as `limited` because the `saved_real_policy_router_outputs` scenario had 95.5% AssemblyAI speaker-word purity versus 100.0% for Deepgram. The mitigation remains to keep split-before-identity enabled, keep unresolved provider clusters chunk-scoped, -and promote only through real-session canary metrics that include both purity -and fragmentation budgets. +and monitor real-session metrics that include both purity and fragmentation +budgets. -Selected rollout decision: enable AssemblyAI for a small canary cohort with -Deepgram fallback and hard rollback thresholds. Do not set AssemblyAI as the -broad background default from offline evidence alone. +Selected rollout decision: AssemblyAI is the passive background default with +Deepgram fallback and hard rollback thresholds. Eligible AssemblyAI workloads are: @@ -72,14 +76,15 @@ Ineligible latency-sensitive workloads stay on Deepgram: ## Feature Flags And Environment -AssemblyAI remains enabled by default for eligible non-background prerecorded +AssemblyAI is enabled by default for eligible passive background prerecorded workloads when credentials are configured. Desktop background has its own -provider mode so canary rollout can be controlled independently. +provider mode so it can be rolled back independently from `sync` and +`postprocess`. | Variable | Default | Purpose | | --- | --- | --- | | `ASSEMBLYAI_API_KEY` | unset | Required before any AssemblyAI request can run. | -| `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE` | `shadow_only` | Desktop background policy mode: `assemblyai`, `deepgram`, or `shadow_only`. `shadow_only` keeps production background requests on Deepgram while the AssemblyAI path is prepared for canary evidence. | +| `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE` | `assemblyai` | Desktop background policy mode: `assemblyai`, `deepgram`, or `shadow_only`. `assemblyai` is the default. `deepgram` is the rollback override. `shadow_only` is retained only as an explicit diagnostic/rollback mode. | | `ASSEMBLYAI_PRERECORDED_STT_ENABLED` | `true` | Main rollout switch for passive prerecorded workloads. | | `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS` | `sync,background,postprocess` | Comma-separated eligible prerecorded workloads. Unknown or ineligible values are ignored. | | `ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED` | `true` | Use Deepgram when AssemblyAI is unavailable or fails and Deepgram credentials are usable. | @@ -93,7 +98,7 @@ provider mode so canary rollout can be controlled independently. `backend/utils/stt/providers.py` owns provider selection, and `backend/utils/stt/provider_service.py` resolves request-level credentials and -fallbacks. Eligible non-background workloads use AssemblyAI only when the main +fallbacks. Eligible passive background workloads use AssemblyAI when the main flag is enabled and the workload is in `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS`; otherwise they use Deepgram. @@ -102,9 +107,9 @@ Desktop background routing is controlled by | Mode | Runtime provider | Use | | --- | --- | --- | -| `shadow_only` | Deepgram | Default until TICKET-026 evals and TICKET-028 canary gates are rollout-ready. | -| `assemblyai` | AssemblyAI when credentials and workload gates allow it; Deepgram fallback when configured and usable. | Canary and eventual default mode. | +| `assemblyai` | AssemblyAI when credentials and workload gates allow it; Deepgram fallback when configured and usable. | Default mode. | | `deepgram` | Deepgram | Kill switch or explicit opt-out without a code deploy. | +| `shadow_only` | Deepgram | Explicit diagnostic/rollback mode only. | Desktop Audio Recording uses the background workload through: @@ -268,60 +273,33 @@ Expected rollout dashboards should use `/metrics` counters and histograms for: Expected cost and quality review should inspect: - `transcription_provider_usage_daily` by provider, model, workload, and day. -- `transcription_provider_runs` for canary failures, fallback direction, fallback reason, retry counts, latency, chunk duration, billable seconds, estimated cost, split count, identity match count, and unknown speaker fields. +- `transcription_provider_runs` for failures, fallback direction, fallback reason, retry counts, latency, chunk duration, billable seconds, estimated cost, split count, identity match count, and unknown speaker fields. - `backend/scripts/stt/provider_comparison_gate.py` reports for transcript drift, speaker cluster stability, identity quality, fallback rate, failure rate, and economics. -## Rollout Stages - -1. `0%`: set `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=deepgram`. AssemblyAI is not - used for desktop background. Keep `/metrics` and daily Deepgram cost/hour as - the baseline. -2. `shadow_only`: set `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=shadow_only`. - Production background requests stay on Deepgram while AssemblyAI readiness - work and offline comparisons continue. -3. Small canary: set `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=assemblyai` only for - the selected backend environment or cohort. Require `ASSEMBLYAI_API_KEY`, - `DEEPGRAM_API_KEY`, `ASSEMBLYAI_PRERECORDED_STT_ENABLED=true`, - `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS` containing `background`, and - `ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED=true`. -4. Expanded canary: increase the eligible cohort only after the health - thresholds below pass for the full review window. -5. Rollback: set `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=deepgram` for immediate - background rollback. For broader passive prerecorded rollback, set - `ASSEMBLYAI_PRERECORDED_STT_ENABLED=false` or remove `background` from - `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS`. - -## TICKET-028 Canary Plan - -Initial cohort: - -- Scope: desktop background recording only, excluding mobile/BLE listen, - realtime assistant streaming, Hold-to-Talk, and voice-message finalize paths. -- Size: 1% of eligible desktop background traffic or an internal opt-in cohort - of at least 20 active background users, whichever is safer for the deployment - environment. -- Duration: minimum 7 consecutive days and at least 200 successful background - chunks before expansion review. -- Configuration: `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=assemblyai`, - `ASSEMBLYAI_PRERECORDED_STT_ENABLED=true`, - `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS=sync,background,postprocess`, and - `ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED=true` for only the selected - backend environment or cohort. -- Required fallback: keep Deepgram credentials available and verify recent - fallback runs record `fallback.from_provider=assemblyai`, +## Default And Rollback Stages + +Default state: + +- Passive background workloads use AssemblyAI: + `ASSEMBLYAI_PRERECORDED_STT_ENABLED=true` and + `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS=sync,background,postprocess`. +- Desktop background uses AssemblyAI: + `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=assemblyai`. +- Deepgram credentials remain available for fallback. +- Fallback runs must record `fallback.from_provider=assemblyai`, `fallback.to_provider=deepgram`, and a classified `fallback.reason`. Owner/on-call action: -- The backend on-call owns first response during the canary window. +- The backend on-call owns first response for AssemblyAI default regressions. - Roll back immediately when a hard threshold below trips. -- Hold expansion and open a follow-up ticket when any warning threshold is - close to tripping or when support reports speaker/transcript regressions not - visible in the aggregate metrics. -- Product/business owner approval is required before accepting any remaining - cost or speaker-quality regression for expanded rollout. +- Open a follow-up ticket when any warning threshold is close to tripping or + when support reports speaker/transcript regressions not visible in aggregate + metrics. +- Product/business owner approval is required before accepting any sustained + cost or speaker-quality regression versus Deepgram. -Expansion success criteria: +Default health criteria: - Offline provider gate continues to pass with no failures: `provider_comparison_gate.py --manifest tests/fixtures/stt_provider_eval/manifest.json`. @@ -343,16 +321,6 @@ Expansion success criteria: - No sustained increase in user corrections, self-voice review failures, support complaints, or billing surprises. -Expansion path: - -1. Hold at 1% until the full review window passes. -2. Expand to 5% only after the success criteria above are met. -3. Expand to 25% only after another full review window passes with no hard - threshold trips. -4. Broad defaulting requires either measured parity with Deepgram on the canary - metrics or documented product/business owner approval for the remaining - regression, plus the rollback command below staying tested. - Rollback command/config: ```bash @@ -366,7 +334,8 @@ until the incident review is complete. ## Health Thresholds -Rollback or hold expansion when any condition is true for the review window: +Roll back to Deepgram or open an incident follow-up when any condition is true +for the review window: - AssemblyAI background error rate is higher than Deepgram by more than 1 percentage point, or exceeds 3% absolute. @@ -390,7 +359,7 @@ Rollback or hold expansion when any condition is true for the review window: ## Operations Runbook -Enable a small canary: +Enable AssemblyAI for passive background workloads: ```bash ASSEMBLYAI_API_KEY= @@ -428,7 +397,7 @@ Validate after each config change: `timing.latency_ms`, `chunk_duration_seconds`, `estimated_cost_usd`, `split_count`, `identity_match_count`, `unknown_speaker_count`, and `unknown_speaker_duration_seconds`. -4. Run the offline gate before expansion: +4. Run the offline gate during rollout checks: `python3 scripts/stt/provider_comparison_gate.py --manifest tests/fixtures/stt_provider_eval/manifest.json`. 5. If a rollback threshold trips, set `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=deepgram`, verify new background runs show `provider=deepgram`, and leave From 4ba2dcf4249723c680ae133df3218c2711d96be1 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Mon, 25 May 2026 03:33:58 +0700 Subject: [PATCH 44/44] Fix STT provider cost assumptions --- .../scripts/stt/provider_comparison_gate.md | 2 + ...ixture_good_meeting.assemblyai.rollup.json | 2 +- .../fixture_good_meeting.deepgram.rollup.json | 2 +- .../fixtures/stt_provider_eval/manifest.json | 36 ++++++++-------- .../unit/test_background_provider_service.py | 8 ++-- .../tests/unit/test_provider_evaluation.py | 2 +- backend/utils/stt/provider_costs.py | 33 ++++++++------- backend/utils/stt/provider_evaluation.py | 8 ++-- .../backend/assemblyai_background_rollout.mdx | 42 +++++++++++-------- 9 files changed, 73 insertions(+), 62 deletions(-) diff --git a/backend/scripts/stt/provider_comparison_gate.md b/backend/scripts/stt/provider_comparison_gate.md index 1e56e0d40a4..eee3c6778d2 100644 --- a/backend/scripts/stt/provider_comparison_gate.md +++ b/backend/scripts/stt/provider_comparison_gate.md @@ -15,6 +15,8 @@ The report compares `always_deepgram`, `always_assemblyai`, `current_policy`, an The fixture manifest covers clean turns, fast turns, overlap, sparse speech, low-signal/no-speech, multilingual turns, duplicate chunk replay, provider failure/fallback, saved real-provider E2E output, and saved policy-router output. +Cost estimates use public pay-as-you-go diarized prerecorded rates checked on 2026-05-25: AssemblyAI Universal-2 plus diarization is `$0.170/hour`, Deepgram Nova-3 monolingual plus diarization is `$0.408/hour`, and Deepgram Nova-3 multilingual plus diarization is `$0.468/hour`. Older experiment notes omitted Deepgram diarization cost and should not be used for rollout cost decisions. + Live-provider smoke tests are optional: ```bash diff --git a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.rollup.json b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.rollup.json index 3e10d13a48e..306274c9c51 100644 --- a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.rollup.json +++ b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.rollup.json @@ -6,7 +6,7 @@ "raw_audio_seconds": 5.0, "speech_active_seconds": 4.5, "billable_seconds": 5.0, - "estimated_cost_usd": 0.00020, + "estimated_cost_usd": 0.00023611, "latency_seconds": 1.8, "runtime_seconds": 1.8, "retry_count": 0, diff --git a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.rollup.json b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.rollup.json index 40932be595a..331a905f84d 100644 --- a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.rollup.json +++ b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.rollup.json @@ -6,7 +6,7 @@ "raw_audio_seconds": 5.0, "speech_active_seconds": 4.5, "billable_seconds": 5.0, - "estimated_cost_usd": 0.00036, + "estimated_cost_usd": 0.00056667, "latency_seconds": 1.1, "runtime_seconds": 1.1, "retry_count": 0, diff --git a/backend/tests/fixtures/stt_provider_eval/manifest.json b/backend/tests/fixtures/stt_provider_eval/manifest.json index 54a3e5c0d60..03de71a4f87 100644 --- a/backend/tests/fixtures/stt_provider_eval/manifest.json +++ b/backend/tests/fixtures/stt_provider_eval/manifest.json @@ -37,7 +37,7 @@ }, "raw_audio_seconds": 5.0, "billable_seconds": 5.0, - "estimated_cost_usd": 0.0002, + "estimated_cost_usd": 0.00056667, "latency_seconds": 1.0, "fallback_count": 0 } @@ -68,7 +68,7 @@ }, "raw_audio_seconds": 5.0, "billable_seconds": 5.0, - "estimated_cost_usd": 0.00036, + "estimated_cost_usd": 0.00023611, "latency_seconds": 1.8, "fallback_count": 0, "rejected_reconciliation_count": 1 @@ -111,7 +111,7 @@ }, "raw_audio_seconds": 2.0, "billable_seconds": 2.0, - "estimated_cost_usd": 8e-05, + "estimated_cost_usd": 0.00022667, "latency_seconds": 0.9, "fallback_count": 0, "split_count": 1 @@ -150,7 +150,7 @@ }, "raw_audio_seconds": 2.0, "billable_seconds": 2.0, - "estimated_cost_usd": 0.00015, + "estimated_cost_usd": 0.00009444, "latency_seconds": 1.4, "fallback_count": 0, "split_count": 1, @@ -187,7 +187,7 @@ }, "raw_audio_seconds": 3.0, "billable_seconds": 3.0, - "estimated_cost_usd": 0.00012, + "estimated_cost_usd": 0.00034, "latency_seconds": 1.2, "fallback_count": 0 } @@ -225,7 +225,7 @@ }, "raw_audio_seconds": 3.0, "billable_seconds": 3.0, - "estimated_cost_usd": 0.00022, + "estimated_cost_usd": 0.00014167, "latency_seconds": 1.9, "fallback_count": 0, "split_count": 1, @@ -256,7 +256,7 @@ "raw_audio_seconds": 30.0, "speech_active_seconds": 0.5, "billable_seconds": 30.0, - "estimated_cost_usd": 0.0012, + "estimated_cost_usd": 0.0034, "latency_seconds": 1.0, "fallback_count": 0 } @@ -281,7 +281,7 @@ "raw_audio_seconds": 30.0, "speech_active_seconds": 0.5, "billable_seconds": 30.0, - "estimated_cost_usd": 0.0022, + "estimated_cost_usd": 0.00141667, "latency_seconds": 2.0, "fallback_count": 0 } @@ -302,7 +302,7 @@ "raw_audio_seconds": 15.0, "speech_active_seconds": 0.0, "billable_seconds": 15.0, - "estimated_cost_usd": 0.0006, + "estimated_cost_usd": 0.0017, "latency_seconds": 0.8, "fallback_count": 0 } @@ -319,7 +319,7 @@ "raw_audio_seconds": 15.0, "speech_active_seconds": 0.0, "billable_seconds": 15.0, - "estimated_cost_usd": 0.0011, + "estimated_cost_usd": 0.00070833, "latency_seconds": 1.2, "fallback_count": 0 } @@ -354,7 +354,7 @@ }, "raw_audio_seconds": 5.0, "billable_seconds": 5.0, - "estimated_cost_usd": 0.0002, + "estimated_cost_usd": 0.00056667, "latency_seconds": 1.3, "fallback_count": 0 } @@ -385,7 +385,7 @@ }, "raw_audio_seconds": 5.0, "billable_seconds": 5.0, - "estimated_cost_usd": 0.00036, + "estimated_cost_usd": 0.00023611, "latency_seconds": 2.0, "fallback_count": 0 } @@ -413,7 +413,7 @@ }, "raw_audio_seconds": 4.0, "billable_seconds": 2.0, - "estimated_cost_usd": 8e-05, + "estimated_cost_usd": 0.00022667, "latency_seconds": 1.0, "fallback_count": 0 } @@ -437,7 +437,7 @@ }, "raw_audio_seconds": 4.0, "billable_seconds": 2.0, - "estimated_cost_usd": 0.00015, + "estimated_cost_usd": 0.00009444, "latency_seconds": 1.7, "fallback_count": 0 } @@ -465,7 +465,7 @@ }, "raw_audio_seconds": 22.0, "billable_seconds": 22.0, - "estimated_cost_usd": 0.00088, + "estimated_cost_usd": 0.00249333, "latency_seconds": 1.1, "fallback_count": 0 } @@ -490,7 +490,7 @@ }, "raw_audio_seconds": 22.0, "billable_seconds": 22.0, - "estimated_cost_usd": 0.00158, + "estimated_cost_usd": 0.00103889, "latency_seconds": 2.1, "fallback_count": 1 } @@ -525,7 +525,7 @@ }, "raw_audio_seconds": 6.0, "billable_seconds": 6.0, - "estimated_cost_usd": 0.00044, + "estimated_cost_usd": 0.00068, "latency_seconds": 1.3, "fallback_count": 0 } @@ -556,7 +556,7 @@ }, "raw_audio_seconds": 6.0, "billable_seconds": 6.0, - "estimated_cost_usd": 0.00122, + "estimated_cost_usd": 0.00028333, "latency_seconds": 2.2, "fallback_count": 0, "split_count": 0, diff --git a/backend/tests/unit/test_background_provider_service.py b/backend/tests/unit/test_background_provider_service.py index b9ab6eacaea..9a0b6a001ee 100644 --- a/backend/tests/unit/test_background_provider_service.py +++ b/backend/tests/unit/test_background_provider_service.py @@ -87,7 +87,7 @@ def test_provider_service_transcribes_sync_upload_and_finalizes_deepgram_run(mon assert finalize_run.call_args.kwargs['workload'] == 'sync' assert finalize_run.call_args.kwargs['status'] == 'succeeded' assert finalize_run.call_args.kwargs['transcript_segment_count'] == 1 - assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00016 + assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00022667 def test_provider_service_finalizes_background_run_on_deepgram_when_assemblyai_disabled(monkeypatch): @@ -112,7 +112,7 @@ def test_provider_service_finalizes_background_run_on_deepgram_when_assemblyai_d assert finalize_run.call_args.kwargs['workload'] == 'background' assert finalize_run.call_args.kwargs['provider'] == 'deepgram' assert finalize_run.call_args.kwargs['raw_audio_seconds'] == 9.5 - assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00076 + assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00107667 def test_background_routing_defaults_to_assemblyai_for_background(monkeypatch): @@ -285,7 +285,7 @@ def test_provider_service_falls_back_to_deepgram_when_assemblyai_fails(monkeypat assert finalize_run.call_args_list[1].kwargs['provider'] == 'deepgram' assert finalize_run.call_args_list[1].kwargs['fallback_count'] == 1 assert finalize_run.call_args_list[1].kwargs['fallback_provider'] == 'assemblyai' - assert finalize_run.call_args_list[1].kwargs['estimated_cost_usd'] == 0.00016 + assert finalize_run.call_args_list[1].kwargs['estimated_cost_usd'] == 0.00022667 def test_provider_service_falls_back_to_deepgram_for_background_when_assemblyai_fails(monkeypatch): @@ -547,7 +547,7 @@ def test_prerecorded_cost_estimator_uses_provider_defaults_and_unknown_provider_ workload='background', billable_seconds=60.0, ) - == 0.0048 + == 0.0068 ) assert ( estimate_prerecorded_provider_cost_usd( diff --git a/backend/tests/unit/test_provider_evaluation.py b/backend/tests/unit/test_provider_evaluation.py index 7a89778035a..556c3b5a6fa 100644 --- a/backend/tests/unit/test_provider_evaluation.py +++ b/backend/tests/unit/test_provider_evaluation.py @@ -54,7 +54,7 @@ def test_fixture_report_passes_and_includes_cost_identity_and_timing_metrics(): assert report['status'] == 'passed' assert report['case_count'] == 1 - assert report['aggregate']['assemblyai_estimated_cost_usd'] == 0.00020 + assert report['aggregate']['assemblyai_estimated_cost_usd'] == 0.00023611 case = report['cases'][0] assert case['comparison']['transcript_word_error_rate'] == 0.0 assert case['comparison']['average_timestamp_drift_seconds'] > 0.0 diff --git a/backend/utils/stt/provider_costs.py b/backend/utils/stt/provider_costs.py index c9ec475a81f..088a85d5674 100644 --- a/backend/utils/stt/provider_costs.py +++ b/backend/utils/stt/provider_costs.py @@ -10,9 +10,11 @@ class PrerecordedProviderCostRate: source: str -# Pay-as-you-go public STT pricing, checked 2026-05-21. +# Pay-as-you-go public STT pricing, checked 2026-05-25. # AssemblyAI background runs use speaker_labels=True, so the pre-recorded -# diarization add-on is included in AssemblyAI rates here. +# diarization add-on is included in AssemblyAI rates here. Deepgram background +# runs also request diarize=True, so include the speaker diarization add-on +# there too; otherwise Deepgram appears artificially cheap in rollout gates. # Customer-specific committed-use discounts are intentionally excluded. _PRERECORDED_STT_COST_RATES: dict[str, dict[str, PrerecordedProviderCostRate]] = { STTProviderName.assemblyai.value: { @@ -20,39 +22,40 @@ class PrerecordedProviderCostRate: # plus pre-recorded Speaker Diarization $0.02/hr. 'universal-2': PrerecordedProviderCostRate( usd_per_billable_second=0.17 / 3600, - source='assemblyai_prerecorded_diarized_payg_2026_05_21', + source='assemblyai_prerecorded_diarized_payg_2026_05_25', ), 'universal-3-pro': PrerecordedProviderCostRate( usd_per_billable_second=0.23 / 3600, - source='assemblyai_prerecorded_diarized_payg_2026_05_21', + source='assemblyai_prerecorded_diarized_payg_2026_05_25', ), 'u3-pro': PrerecordedProviderCostRate( usd_per_billable_second=0.23 / 3600, - source='assemblyai_prerecorded_diarized_payg_2026_05_21', + source='assemblyai_prerecorded_diarized_payg_2026_05_25', ), 'default': PrerecordedProviderCostRate( usd_per_billable_second=0.17 / 3600, - source='assemblyai_prerecorded_diarized_default_2026_05_21', + source='assemblyai_prerecorded_diarized_default_2026_05_25', ), }, STTProviderName.deepgram.value: { # Deepgram pricing: Nova-3 monolingual pre-recorded $0.0048/min, - # Nova-3 multilingual pre-recorded $0.0058/min. + # Nova-3 multilingual pre-recorded $0.0058/min, plus Speaker + # Diarization add-on $0.0020/min. 'nova-3': PrerecordedProviderCostRate( - usd_per_billable_second=0.0048 / 60, - source='deepgram_prerecorded_payg_2026_05_21', + usd_per_billable_second=0.0068 / 60, + source='deepgram_prerecorded_diarized_payg_2026_05_25', ), 'nova-3-general': PrerecordedProviderCostRate( - usd_per_billable_second=0.0048 / 60, - source='deepgram_prerecorded_payg_2026_05_21', + usd_per_billable_second=0.0068 / 60, + source='deepgram_prerecorded_diarized_payg_2026_05_25', ), 'nova-3-multilingual': PrerecordedProviderCostRate( - usd_per_billable_second=0.0058 / 60, - source='deepgram_prerecorded_payg_2026_05_21', + usd_per_billable_second=0.0078 / 60, + source='deepgram_prerecorded_diarized_payg_2026_05_25', ), 'default': PrerecordedProviderCostRate( - usd_per_billable_second=0.0048 / 60, - source='deepgram_prerecorded_default_2026_05_21', + usd_per_billable_second=0.0068 / 60, + source='deepgram_prerecorded_diarized_default_2026_05_25', ), }, } diff --git a/backend/utils/stt/provider_evaluation.py b/backend/utils/stt/provider_evaluation.py index 1c5854d568d..c3bd940486a 100644 --- a/backend/utils/stt/provider_evaluation.py +++ b/backend/utils/stt/provider_evaluation.py @@ -8,8 +8,8 @@ 'current_policy': 'assemblyai', 'shadow_only': 'deepgram', } -ASSEMBLYAI_COST_PER_HOUR_USD = 0.2592 -DEEPGRAM_COST_PER_HOUR_USD = 0.144 +ASSEMBLYAI_COST_PER_HOUR_USD = 0.17 +DEEPGRAM_COST_PER_HOUR_USD = 0.408 @dataclass(frozen=True) @@ -720,7 +720,7 @@ def _likely_cause(metric: str) -> str: 'empty_transcript_rate': 'low-signal/no-speech handling produced an empty or failed transcript', 'latency_seconds': 'AssemblyAI async job latency exceeds Deepgram for this workload shape', 'timeout_error_rate': 'provider timeout or retry exhaustion path is not default-safe', - 'estimated_cost_per_hour_usd': 'AssemblyAI billable duration or pricing is too high for default background volume', + 'estimated_cost_per_hour_usd': 'provider billable duration or pricing is too high for default background volume', }.get(metric, 'AssemblyAI trails the Deepgram comparator on this gate') @@ -731,7 +731,7 @@ def _mitigation(metric: str) -> str: 'empty_transcript_rate': 'preserve no-speech detection and fallback controls before default promotion', 'latency_seconds': 'keep latency SLO alerts and fallback controls before expanding default traffic', 'timeout_error_rate': 'use Deepgram fallback and provider health gates from TICKET-027', - 'estimated_cost_per_hour_usd': 'cap rollout or require explicit product tradeoff in TICKET-028', + 'estimated_cost_per_hour_usd': 'review billable seconds, requested add-ons, and provider pricing before changing defaults', }.get(metric, 'capture in TICKET-028 rollout tradeoffs before promoting AssemblyAI') diff --git a/docs/doc/developer/backend/assemblyai_background_rollout.mdx b/docs/doc/developer/backend/assemblyai_background_rollout.mdx index 31d0168d5ed..32d2f10ab6b 100644 --- a/docs/doc/developer/backend/assemblyai_background_rollout.mdx +++ b/docs/doc/developer/backend/assemblyai_background_rollout.mdx @@ -15,12 +15,15 @@ already been captured and can tolerate async provider polling, retries, and Deepgram fallback. Current saved-provider benchmark evidence still favors Deepgram on some -cost-adjusted speaker-quality fixtures: Deepgram reached 99.79% speaker purity -at about $0.264/hour, while AssemblyAI reached 95.91% at about $0.734/hour. -That benchmark is a constraint on the AssemblyAI pipeline, not a reason to keep -Deepgram as the primary passive background provider. The production stance is: -AssemblyAI first, conservative speaker safety, privacy-safe monitoring, and -Deepgram fallback. +speaker-quality fixtures: Deepgram reached 99.79% speaker purity, while +AssemblyAI reached 95.91%. Current public pay-as-you-go pricing checked on +2026-05-25 makes AssemblyAI Universal-2 plus diarization about `$0.170/hour` +and Deepgram Nova-3 prerecorded plus diarization about `$0.408/hour` for +monolingual audio or `$0.468/hour` for multilingual audio. Older experiment +summaries used stale Deepgram rates that omitted the diarization add-on and made +AssemblyAI look more expensive; do not use those numbers for rollout decisions. +The production stance is: AssemblyAI first, conservative speaker safety, +privacy-safe monitoring, and Deepgram fallback. Deepgram remains the provider for mobile/BLE `/v4/listen`, realtime assistant streaming, Hold-to-Talk streaming, voice-message finalize semantics, and @@ -47,10 +50,10 @@ required production strategies: | Strategy | Provider | Speaker purity | Empty transcript | Fallback | Timeout/error | Cost/hour | | --- | --- | ---: | ---: | ---: | ---: | ---: | -| `always_deepgram` | Deepgram | 100.0% | 10.0% | 0.0% | 0.0% | `$0.158` | -| `always_assemblyai` | AssemblyAI | 98.9% | 10.0% | 1.0% | 0.0% | `$0.286` | -| `current_policy` | AssemblyAI | 98.9% | 10.0% | 1.0% | 0.0% | `$0.286` | -| `shadow_only` | Deepgram | 100.0% | 10.0% | 0.0% | 0.0% | `$0.158` | +| `always_deepgram` | Deepgram | 100.0% | 10.0% | 0.0% | 0.0% | `$0.408` | +| `always_assemblyai` | AssemblyAI | 98.9% | 10.0% | 1.0% | 0.0% | `$0.170` | +| `current_policy` | AssemblyAI | 98.9% | 10.0% | 1.0% | 0.0% | `$0.170` | +| `shadow_only` | Deepgram | 100.0% | 10.0% | 0.0% | 0.0% | `$0.408` | The gate still marked AssemblyAI as `limited` because the `saved_real_policy_router_outputs` scenario had 95.5% AssemblyAI speaker-word @@ -314,10 +317,11 @@ Default health criteria: degradation. - App-visible speaker inflation stays within the replay budget and does not rise without a matching purity improvement. -- Cost/hour stays within the approved margin. The current offline gate shows - AssemblyAI at about 1.8x Deepgram (`$0.286/hour` versus `$0.158/hour`), while - older saved-provider evidence showed a wider gap; expansion requires an - explicit owner decision if this remains materially above Deepgram. +- Cost/hour stays within the approved margin. The current offline gate should + show AssemblyAI below Deepgram for diarized prerecorded work, roughly + `$0.170/hour` versus `$0.408/hour` with monolingual Deepgram Nova-3 pricing. + If the gate shows AssemblyAI above Deepgram, first audit stale fixture ledgers, + omitted Deepgram diarization cost, and billable-duration inflation. - No sustained increase in user corrections, self-voice review failures, support complaints, or billing surprises. @@ -343,10 +347,12 @@ for the review window: - Timeout rate exceeds 1% or any timeout spike correlates with support reports. - Empty transcript rate exceeds the Deepgram baseline by more than 1 percentage point. -- Cost/hour exceeds the approved Deepgram margin. The current baseline is about - `$0.264/hour` for Deepgram and `$0.734/hour` for AssemblyAI; AssemblyAI cost is - a rollout blocker unless the rollout doc names a quality or product reason for - accepting the margin. +- Cost/hour unexpectedly exceeds the approved provider margin. The current + public pricing baseline for diarized prerecorded work is about `$0.170/hour` + for AssemblyAI Universal-2, `$0.408/hour` for Deepgram Nova-3 monolingual, and + `$0.468/hour` for Deepgram Nova-3 multilingual. A report showing AssemblyAI as + more expensive is a stale-pricing or billable-duration investigation until + proven otherwise. - Unknown or low-confidence speaker rate exceeds the Deepgram baseline by more than 5 percentage points, or `unknown_speaker_duration_seconds` grows enough to make speaker labels visibly worse.