diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 58c3fe2ab01..224249a1cb2 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -108,7 +108,7 @@ jobs: if: steps.changed.outputs.has_frontend == 'true' || steps.changed.outputs.has_personas == 'true' uses: actions/setup-node@v4 with: - node-version: '18' + node-version: '22' cache: 'npm' cache-dependency-path: | web/frontend/package-lock.json diff --git a/.gitignore b/.gitignore index 13d9ed256b2..4ea395c4b38 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,9 @@ myenv/ venv/ .DS_Store dump/ -/scripts/ +/scripts/* +!/scripts/desktop_assemblyai_e2e.py +!/scripts/ASSEMBLYAI_BACKGROUND_E2E_AGENT_PROMPT.md *.zip *.wav node_modules @@ -27,10 +29,15 @@ yarn.lock .packages .pub-cache/ .pub/ +.swiftpm/ + +# Build / compile artifacts (cross-ecosystem; prefer these over per-tool *.o/*.d rules) build/ dist/ .build/ -.swiftpm/ +target/ +**/.build-agent-*/ +**/.build-hybrid-*/ # VS Code .vscode/* diff --git a/AGENTS.md b/AGENTS.md index 471e536fca3..b8ce9f4d339 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -49,7 +49,8 @@ backend (main.py) ├── ws ──► pusher (pusher/) ├── ──────► diarizer (diarizer/) ├── ──────► vad (modal/) - └── ──────► deepgram (self-hosted or cloud) + ├── ──────► deepgram (self-hosted or cloud) + └── ──────► assemblyai (cloud, background async when enabled) pusher ├── ──────► diarizer (diarizer/) @@ -63,12 +64,13 @@ notifications-job (modal/job.py) [cron] Helm charts: `backend/charts/{backend-listen,pusher,diarizer,vad,deepgram-self-hosted,agent-proxy}/` -- **backend** (`main.py`) — REST API. Streams audio to pusher via WebSocket (`utils/pusher.py`). Calls diarizer for speaker embeddings (`utils/stt/speaker_embedding.py`). Calls vad for voice activity detection and speaker identification (`utils/stt/vad.py`, `utils/stt/speech_profile.py`). Calls deepgram for STT (`utils/stt/streaming.py`). +- **backend** (`main.py`) — REST API. Streams realtime/listen audio to pusher via WebSocket (`utils/pusher.py`). Calls diarizer for speaker embeddings (`utils/stt/speaker_embedding.py`). Calls vad for voice activity detection and speaker identification (`utils/stt/vad.py`, `utils/stt/speech_profile.py`). Calls deepgram for realtime and Hold-to-Talk STT (`utils/stt/streaming.py`) and for prerecorded fallback. Calls AssemblyAI for explicitly enabled async/background prerecorded workloads through `utils/stt/provider_service.py`. - **pusher** (`pusher/main.py`) — Receives audio via binary WebSocket protocol. Calls diarizer and deepgram for speaker sample extraction (`utils/speaker_identification.py` → `utils/speaker_sample.py`). - **agent-proxy** (`agent-proxy/main.py`) — GKE. WebSocket proxy at `wss://agent.omi.me/v1/agent/ws`. Validates Firebase ID token, looks up `agentVm` in Firestore, proxies bidirectionally to VM's `ws://:8080/ws`. VM credentials never leave the server. - **diarizer** (`diarizer/main.py`) — GPU. Speaker embeddings at `/v2/embedding`. Called by backend and pusher (`HOSTED_SPEAKER_EMBEDDING_API_URL`). - **vad** (`modal/main.py`) — GPU. `/v1/vad` (voice activity detection) and `/v1/speaker-identification` (speaker matching). Called by backend only (`HOSTED_VAD_API_URL`, `HOSTED_SPEECH_PROFILE_API_URL`). -- **deepgram** — STT. Streaming uses self-hosted (`DEEPGRAM_SELF_HOSTED_URL`) or cloud based on `DEEPGRAM_SELF_HOSTED_ENABLED` (`utils/stt/streaming.py`). Pre-recorded always uses Deepgram cloud (`utils/stt/pre_recorded.py`). Called by backend and pusher. +- **deepgram** — STT. Streaming uses self-hosted (`DEEPGRAM_SELF_HOSTED_URL`) or cloud based on `DEEPGRAM_SELF_HOSTED_ENABLED` (`utils/stt/streaming.py`). Prerecorded Deepgram cloud remains the default and fallback provider through `utils/stt/provider_service.py`/`utils/stt/pre_recorded.py`. Called by backend and pusher. +- **assemblyai** — Async/background STT. Used only by backend for feature-flagged prerecorded workloads (`sync`, `background`, `postprocess`) through `utils/stt/assemblyai_adapter.py`; provider speaker labels remain session-local metadata. - **notifications-job** (`modal/job.py`) — Cron job, reads Firestore/Redis, sends push notifications. Keep this map up to date. When adding, removing, or changing inter-service calls, update this section and the matching section in `CLAUDE.md`. diff --git a/CLAUDE.md b/CLAUDE.md index 96bbcb9fc00..71571e05b67 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -52,7 +52,8 @@ backend (main.py) ├── ws ──► pusher (pusher/) ├── ──────► diarizer (diarizer/) ├── ──────► vad (modal/) - └── ──────► deepgram (self-hosted or cloud) + ├── ──────► deepgram (self-hosted or cloud) + └── ──────► assemblyai (cloud, background async when enabled) pusher ├── ──────► diarizer (diarizer/) diff --git a/app/lib/backend/schema/transcript_segment.dart b/app/lib/backend/schema/transcript_segment.dart index 56091c48a68..605dc09b292 100644 --- a/app/lib/backend/schema/transcript_segment.dart +++ b/app/lib/backend/schema/transcript_segment.dart @@ -33,6 +33,13 @@ class TranscriptSegment { List translations = []; bool speechProfileProcessed; String? sttProvider; + String? sttModel; + String? providerClusterId; + String? providerSpeakerLabel; + String speakerIdentityState; + double? speakerIdentityConfidence; + String? speakerIdentitySource; + String? speakerIdentityVersion; TranscriptSegment({ required this.id, @@ -45,11 +52,22 @@ class TranscriptSegment { required this.translations, this.speechProfileProcessed = true, this.sttProvider, + this.sttModel, + this.providerClusterId, + this.providerSpeakerLabel, + this.speakerIdentityState = 'legacy_ambiguous', + this.speakerIdentityConfidence, + this.speakerIdentitySource, + this.speakerIdentityVersion, }) { final parts = speaker?.split('_') ?? []; speakerId = parts.length > 1 ? (int.tryParse(parts[1]) ?? 0) : 0; } + bool get hasExplicitUnknownSpeakerIdentity => speakerIdentityState == 'unknown'; + + bool get hasLegacySpeakerLabel => speaker != null && speaker!.isNotEmpty; + @override String toString() { return 'TranscriptSegment: {id: $id text: $text, speaker: $speakerId, isUser: $isUser, start: $start, end: $end}'; @@ -66,7 +84,7 @@ class TranscriptSegment { return TranscriptSegment( id: (json['id'] ?? '') as String, text: json['text'] as String, - speaker: (json['speaker'] ?? 'SPEAKER_00') as String, + speaker: json['speaker'] as String?, isUser: (json['is_user'] ?? false) as bool, personId: json['person_id'], start: double.tryParse(json['start'].toString()) ?? 0.0, @@ -74,6 +92,15 @@ class TranscriptSegment { translations: json['translations'] != null ? Translation.fromJsonList(json['translations'] as List) : [], speechProfileProcessed: (json['speech_profile_processed'] ?? true) as bool, sttProvider: json['stt_provider'] as String?, + sttModel: json['stt_model'] as String?, + providerClusterId: json['provider_cluster_id'] as String?, + providerSpeakerLabel: json['provider_speaker_label'] as String?, + speakerIdentityState: (json['speaker_identity_state'] ?? 'legacy_ambiguous') as String, + speakerIdentityConfidence: json['speaker_identity_confidence'] != null + ? double.tryParse(json['speaker_identity_confidence'].toString()) + : null, + speakerIdentitySource: json['speaker_identity_source'] as String?, + speakerIdentityVersion: json['speaker_identity_version'] as String?, ); } @@ -88,6 +115,13 @@ class TranscriptSegment { 'end': end, 'translations': translations.map((t) => t.toJson()).toList(), if (sttProvider != null) 'stt_provider': sttProvider, + if (sttModel != null) 'stt_model': sttModel, + if (providerClusterId != null) 'provider_cluster_id': providerClusterId, + if (providerSpeakerLabel != null) 'provider_speaker_label': providerSpeakerLabel, + 'speaker_identity_state': speakerIdentityState, + if (speakerIdentityConfidence != null) 'speaker_identity_confidence': speakerIdentityConfidence, + if (speakerIdentitySource != null) 'speaker_identity_source': speakerIdentitySource, + if (speakerIdentityVersion != null) 'speaker_identity_version': speakerIdentityVersion, }; } @@ -195,7 +229,7 @@ class TranscriptSegment { if (segment.personId != null && peopleMap.containsKey(segment.personId)) { speakerName = peopleMap[segment.personId]!; } else { - var displayId = '${getDisplaySpeakerId(segment.speakerId, segments)}'; + var displayId = getDisplaySpeakerIdForSegment(segment, segments); speakerName = speakerLabelBuilder != null ? speakerLabelBuilder(displayId) : 'Speaker $displayId'; } transcript += '$timestampStr $speakerName: $segmentText '; @@ -242,4 +276,25 @@ class TranscriptSegment { // Normalize: subtract minimum and add 1 to make it 1-indexed return speakerId - minSpeakerId + 1; } + + static String getDisplaySpeakerIdForSegment(TranscriptSegment segment, List segments) { + if (segment.hasLegacySpeakerLabel) { + return '${getDisplaySpeakerId(segment.speakerId, segments)}'; + } + + final providerClusterId = segment.providerClusterId; + if (providerClusterId != null && providerClusterId.isNotEmpty) { + final clusterIds = []; + for (final item in segments) { + final clusterId = item.providerClusterId; + if (!item.isUser && clusterId != null && clusterId.isNotEmpty && !clusterIds.contains(clusterId)) { + clusterIds.add(clusterId); + } + } + final index = clusterIds.indexOf(providerClusterId); + if (index >= 0) return '${index + 1}'; + } + + return '?'; + } } diff --git a/app/lib/pages/conversation_detail/page.dart b/app/lib/pages/conversation_detail/page.dart index 094606e4b77..cb2ba4b5284 100644 --- a/app/lib/pages/conversation_detail/page.dart +++ b/app/lib/pages/conversation_detail/page.dart @@ -1456,7 +1456,7 @@ class _TranscriptWidgetsState extends State with AutomaticKee final person = segment.personId != null ? SharedPreferencesUtil().getPersonById(segment.personId!) : null; final speakerName = person?.name ?? - context.l10n.speakerWithId('${TranscriptSegment.getDisplaySpeakerId(segment.speakerId, segments)}'); + context.l10n.speakerWithId(TranscriptSegment.getDisplaySpeakerIdForSegment(segment, segments)); PlatformManager.instance.analytics.editSegmentTextStarted(); bool saved = false; showEditSegmentBottomSheet( diff --git a/app/lib/widgets/transcript.dart b/app/lib/widgets/transcript.dart index cea056f70da..120aad8cab4 100644 --- a/app/lib/widgets/transcript.dart +++ b/app/lib/widgets/transcript.dart @@ -508,7 +508,7 @@ class _TranscriptWidgetState extends State { ? 'omi' : (person?.name ?? context.l10n.speakerWithId( - '${TranscriptSegment.getDisplaySpeakerId(data.speakerId, widget.segments)}', + TranscriptSegment.getDisplaySpeakerIdForSegment(data, widget.segments), )), style: TextStyle( color: data.speakerId == omiSpeakerId || person != null diff --git a/app/test/widgets/transcript_test.dart b/app/test/widgets/transcript_test.dart index 9ed35e9fc09..ded47d5821a 100644 --- a/app/test/widgets/transcript_test.dart +++ b/app/test/widgets/transcript_test.dart @@ -42,6 +42,63 @@ void main() { ); } + group('TranscriptSegment compatibility metadata', () { + test('decodes provider speaker metadata without changing legacy fields', () { + final segment = TranscriptSegment.fromJson({ + 'id': 'seg-provider', + 'text': 'Hello', + 'speaker': null, + 'speaker_id': 0, + 'is_user': false, + 'person_id': null, + 'start': 0.0, + 'end': 1.0, + 'stt_provider': 'provider-a', + 'stt_model': 'async-large', + 'provider_cluster_id': 'speaker-alpha', + 'provider_speaker_label': null, + 'speaker_identity_state': 'unknown', + 'speaker_identity_confidence': null, + 'speaker_identity_source': null, + 'speaker_identity_version': 'v1', + }); + + expect(segment.speaker, isNull); + expect(segment.speakerId, 0); + expect(segment.sttProvider, 'provider-a'); + expect(segment.sttModel, 'async-large'); + expect(segment.providerClusterId, 'speaker-alpha'); + expect(segment.speakerIdentityState, 'unknown'); + expect(segment.speakerIdentityVersion, 'v1'); + }); + + test('uses provider cluster labels when legacy speaker label is absent', () { + final first = TranscriptSegment.fromJson({ + 'id': 'seg-provider-a', + 'text': 'Hello', + 'speaker': null, + 'is_user': false, + 'start': 0.0, + 'end': 1.0, + 'provider_cluster_id': 'speaker-alpha', + 'speaker_identity_state': 'unknown', + }); + final second = TranscriptSegment.fromJson({ + 'id': 'seg-provider-b', + 'text': 'Hi', + 'speaker': null, + 'is_user': false, + 'start': 1.0, + 'end': 2.0, + 'provider_cluster_id': 'speaker-beta', + 'speaker_identity_state': 'unknown', + }); + + expect(TranscriptSegment.getDisplaySpeakerIdForSegment(first, [first, second]), '1'); + expect(TranscriptSegment.getDisplaySpeakerIdForSegment(second, [first, second]), '2'); + }); + }); + group('Speaker label display', () { testWidgets('shows person name when personId is set and in cache', (tester) async { final now = DateTime.now(); diff --git a/backend/.env.template b/backend/.env.template index b4397a8caf6..fefc456fab5 100644 --- a/backend/.env.template +++ b/backend/.env.template @@ -14,6 +14,17 @@ REDIS_DB_PASSWORD= DEEPGRAM_API_KEY= +# AssemblyAI async prerecorded STT. Enabled by default for eligible workloads: sync, background, postprocess. +ASSEMBLYAI_API_KEY= +ASSEMBLYAI_PRERECORDED_STT_ENABLED=true +ASSEMBLYAI_PRERECORDED_STT_WORKLOADS=sync,background,postprocess +ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED=true +ASSEMBLYAI_STT_MODEL=universal-2 +ASSEMBLYAI_BASE_URL=https://api.assemblyai.com +ASSEMBLYAI_POLL_INTERVAL_SECONDS=3 +ASSEMBLYAI_MAX_POLL_SECONDS=900 +ASSEMBLYAI_SMOKE_AUDIO_URL= + ADMIN_KEY= OPENAI_API_KEY= diff --git a/backend/database/conversations.py b/backend/database/conversations.py index 3a58052a58e..47996c5b3e5 100644 --- a/backend/database/conversations.py +++ b/backend/database/conversations.py @@ -874,6 +874,36 @@ def update_conversation_segments( return +def update_conversation_segments_and_background_chunks( + uid: str, + conversation_id: str, + segments: List[dict], + background_processed_chunks: Dict[str, dict], + finished_at: datetime = None, + data_protection_level: str = None, +): + doc_ref = db.collection('users').document(uid).collection(conversations_collection).document(conversation_id) + if data_protection_level is not None: + doc_level = data_protection_level + else: + doc_snapshot = doc_ref.get(field_paths=['data_protection_level']) + if not doc_snapshot.exists: + return + doc_level = doc_snapshot.to_dict().get('data_protection_level', 'standard') + update_payload = { + 'transcript_segments': segments, + 'background_processed_chunks': background_processed_chunks, + } + if finished_at: + update_payload['finished_at'] = finished_at + prepared_payload = _prepare_conversation_for_write(update_payload, uid, doc_level) + try: + doc_ref.update(prepared_payload) + except NotFound: + # Document was deleted between cache read and write — safe to skip + return + + # *********************************** # ********** VISIBILITY ************* # *********************************** diff --git a/backend/database/self_voice_review.py b/backend/database/self_voice_review.py new file mode 100644 index 00000000000..cd4f7696737 --- /dev/null +++ b/backend/database/self_voice_review.py @@ -0,0 +1,224 @@ +from datetime import datetime, timedelta, timezone +from typing import Any, Optional + +from google.cloud import firestore +from google.cloud.firestore_v1 import FieldFilter + +from ._client import db, document_id_from_seed + +CANDIDATES_COLLECTION = 'self_voice_review_candidates' +NEGATIVE_MARKERS_COLLECTION = 'self_voice_review_negative_markers' +CONFIRMED_SAMPLES_COLLECTION = 'self_voice_review_confirmed_samples' +DEFAULT_CANDIDATE_TTL_DAYS = 30 +FORBIDDEN_CANDIDATE_KEYS = {'text', 'transcript', 'transcript_text', 'words', 'utterances', 'audio_bytes', 'raw_audio'} + + +def utc_now() -> datetime: + return datetime.now(timezone.utc) + + +def candidate_id_from_source(conversation_id: str, provider_cluster_id: str, segment_ids: list[str]) -> str: + seed = ':'.join(['self_voice_review', conversation_id, provider_cluster_id, ','.join(sorted(segment_ids))]) + return document_id_from_seed(seed) + + +def marker_id_from_source(conversation_id: str, provider_cluster_id: str) -> str: + return document_id_from_seed(':'.join(['self_voice_negative_marker', conversation_id, provider_cluster_id])) + + +def _user_ref(uid: str): + return db.collection('users').document(uid) + + +def _candidate_ref(uid: str, candidate_id: str): + return _user_ref(uid).collection(CANDIDATES_COLLECTION).document(candidate_id) + + +def _negative_marker_ref(uid: str, marker_id: str): + return _user_ref(uid).collection(NEGATIVE_MARKERS_COLLECTION).document(marker_id) + + +def _confirmed_sample_ref(uid: str, sample_id: str): + return _user_ref(uid).collection(CONFIRMED_SAMPLES_COLLECTION).document(sample_id) + + +def get_candidate(uid: str, candidate_id: str) -> Optional[dict[str, Any]]: + snapshot = _candidate_ref(uid, candidate_id).get() + if not snapshot.exists: + return None + data = snapshot.to_dict() or {} + data.setdefault('candidate_id', snapshot.id) + return data + + +def has_negative_marker(uid: str, marker_id: str) -> bool: + return _negative_marker_ref(uid, marker_id).get().exists + + +def recently_shown_source_exists( + uid: str, + conversation_id: str, + provider_cluster_id: str, + now: Optional[datetime] = None, +) -> bool: + now = now or utc_now() + query = ( + _user_ref(uid) + .collection(CANDIDATES_COLLECTION) + .where(filter=FieldFilter('source.conversation_id', '==', conversation_id)) + .where(filter=FieldFilter('source.provider_cluster_id', '==', provider_cluster_id)) + .where(filter=FieldFilter('cooldown_until', '>', now)) + .limit(1) + ) + return bool(list(query.stream())) + + +def upsert_candidate(uid: str, candidate: dict[str, Any]) -> bool: + candidate_id = candidate['candidate_id'] + ref = _candidate_ref(uid, candidate_id) + if ref.get().exists: + return False + + now = candidate.get('created_at') or utc_now() + payload = { + **candidate, + 'uid': uid, + 'review_status': candidate.get('review_status', 'pending'), + 'created_at': now, + 'updated_at': now, + 'expires_at': candidate.get('expires_at') or now + timedelta(days=DEFAULT_CANDIDATE_TTL_DAYS), + } + _reject_forbidden_candidate_keys(payload) + ref.set(payload, merge=False) + return True + + +def list_pending_candidates(uid: str, limit: int = 20, confidence_bucket: Optional[str] = None) -> list[dict[str, Any]]: + query = _user_ref(uid).collection(CANDIDATES_COLLECTION).where(filter=FieldFilter('review_status', '==', 'pending')) + if confidence_bucket: + query = query.where(filter=FieldFilter('confidence_bucket', '==', confidence_bucket)) + query = query.order_by('created_at', direction=firestore.Query.DESCENDING).limit(limit) + + candidates = [] + for snapshot in query.stream(): + data = snapshot.to_dict() or {} + data.setdefault('candidate_id', snapshot.id) + candidates.append(data) + return candidates + + +def mark_candidate_confirmed( + uid: str, + candidate_id: str, + embedding_version: str, + reviewed_at: Optional[datetime] = None, +) -> None: + reviewed_at = reviewed_at or utc_now() + _candidate_ref(uid, candidate_id).update( + { + 'review_status': 'confirmed', + 'reviewed_at': reviewed_at, + 'negative_review_marker': None, + 'updated_at': reviewed_at, + 'confirmed_sample': { + 'candidate_id': candidate_id, + 'embedding_version': embedding_version, + 'confirmed_at': reviewed_at, + 'revisable': True, + }, + } + ) + _confirmed_sample_ref(uid, candidate_id).set( + { + 'candidate_id': candidate_id, + 'source': 'self_voice_review', + 'embedding_version': embedding_version, + 'confirmed_at': reviewed_at, + 'deleted_at': None, + }, + merge=True, + ) + + +def mark_candidate_rejected( + uid: str, + candidate: dict[str, Any], + reviewed_at: Optional[datetime] = None, +) -> str: + reviewed_at = reviewed_at or utc_now() + source = candidate.get('source') or {} + marker_id = marker_id_from_source(source.get('conversation_id', ''), source.get('provider_cluster_id', '')) + marker = { + 'marker_id': marker_id, + 'candidate_id': candidate['candidate_id'], + 'conversation_id': source.get('conversation_id'), + 'provider_cluster_id': source.get('provider_cluster_id'), + 'segment_ids': source.get('segment_ids', []), + 'negative_review': True, + 'reviewed_at': reviewed_at, + } + _negative_marker_ref(uid, marker_id).set(marker, merge=True) + _candidate_ref(uid, candidate['candidate_id']).update( + { + 'review_status': 'rejected', + 'reviewed_at': reviewed_at, + 'negative_review_marker': marker, + 'updated_at': reviewed_at, + } + ) + return marker_id + + +def mark_candidate_skipped( + uid: str, + candidate_id: str, + cooldown_until: datetime, + reviewed_at: Optional[datetime] = None, +) -> None: + reviewed_at = reviewed_at or utc_now() + _candidate_ref(uid, candidate_id).update( + { + 'review_status': 'pending', + 'last_review_action': 'skipped', + 'reviewed_at': reviewed_at, + 'cooldown_until': cooldown_until, + 'updated_at': reviewed_at, + } + ) + + +def delete_confirmed_sample(uid: str, candidate_id: str, deleted_at: Optional[datetime] = None) -> bool: + deleted_at = deleted_at or utc_now() + candidate = get_candidate(uid, candidate_id) + if not candidate or candidate.get('review_status') != 'confirmed': + return False + _candidate_ref(uid, candidate_id).update( + { + 'review_status': 'deleted', + 'reviewed_at': deleted_at, + 'updated_at': deleted_at, + 'confirmed_sample.deleted_at': deleted_at, + } + ) + _confirmed_sample_ref(uid, candidate_id).set({'deleted_at': deleted_at}, merge=True) + return True + + +def _reject_forbidden_candidate_keys(payload: dict[str, Any]) -> None: + forbidden = _find_forbidden_candidate_keys(payload) + if forbidden: + raise ValueError(f'self voice review candidate contains forbidden keys: {sorted(forbidden)}') + + +def _find_forbidden_candidate_keys(value: Any) -> set[str]: + if isinstance(value, dict): + forbidden = FORBIDDEN_CANDIDATE_KEYS & set(value) + for nested in value.values(): + forbidden.update(_find_forbidden_candidate_keys(nested)) + return forbidden + if isinstance(value, list): + forbidden = set() + for nested in value: + forbidden.update(_find_forbidden_candidate_keys(nested)) + return forbidden + return set() diff --git a/backend/database/transcription_provider_usage.py b/backend/database/transcription_provider_usage.py new file mode 100644 index 00000000000..6e5450ee757 --- /dev/null +++ b/backend/database/transcription_provider_usage.py @@ -0,0 +1,590 @@ +from datetime import datetime, timedelta, timezone +from typing import Any, List, Optional +from uuid import uuid4 + +from google.cloud import firestore +from google.cloud.firestore_v1 import FieldFilter + +from ._client import db +from utils.metrics import ( + identity_confidence_bucket, + observe_transcription_provider_audio_seconds, + observe_transcription_provider_fallback, + observe_transcription_provider_identity_confidence, + observe_transcription_provider_request, + observe_transcription_provider_retry, + observe_transcription_provider_speaker_clusters, +) + +RUNS_COLLECTION = 'transcription_provider_runs' +DAILY_USAGE_COLLECTION = 'transcription_provider_usage_daily' +RUN_TTL_DAYS = 180 + +FORBIDDEN_LEDGER_KEYS = { + 'api_key', + 'audio_bytes', + 'audio', + 'authorization', + 'chunks', + 'full_transcript_text', + 'secret', + 'secrets', + 'token', + 'raw_audio_bytes', + 'text', + 'transcript', + 'transcript_text', + 'words', + 'word_records', + 'utterances', +} + + +def utc_day_bucket(value: Optional[datetime] = None) -> str: + value = value or datetime.now(timezone.utc) + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc).strftime('%Y-%m-%d') + + +def _safe_doc_id_part(value: str) -> str: + return str(value or 'unknown').replace('/', '_') + + +def daily_rollup_doc_id(day: str, provider: str, model: str, workload: str) -> str: + return ':'.join( + [ + _safe_doc_id_part(day), + _safe_doc_id_part(provider), + _safe_doc_id_part(model), + _safe_doc_id_part(workload), + ] + ) + + +def _run_ref(run_id: str): + return db.collection(RUNS_COLLECTION).document(run_id) + + +def _rollup_ref(day: str, provider: str, model: str, workload: str): + return db.collection(DAILY_USAGE_COLLECTION).document(daily_rollup_doc_id(day, provider, model, workload)) + + +def _utc_now() -> datetime: + return datetime.now(timezone.utc) + + +def _ttl_expires_at(now: datetime) -> datetime: + return now + timedelta(days=RUN_TTL_DAYS) + + +def _reject_forbidden_payload_keys(payload: dict[str, Any]) -> None: + forbidden = _find_forbidden_payload_keys(payload) + if forbidden: + raise ValueError(f'transcription provider ledger payload contains forbidden keys: {sorted(forbidden)}') + + +def _find_forbidden_payload_keys(value: Any) -> set[str]: + if isinstance(value, dict): + forbidden = FORBIDDEN_LEDGER_KEYS & set(value) + for nested in value.values(): + forbidden.update(_find_forbidden_payload_keys(nested)) + return forbidden + if isinstance(value, list): + forbidden = set() + for nested in value: + forbidden.update(_find_forbidden_payload_keys(nested)) + return forbidden + return set() + + +def create_provider_run( + uid: str, + provider: str, + model: str, + workload: str, + run_id: Optional[str] = None, + conversation_id: Optional[str] = None, + provider_job_ref: Optional[str] = None, + artifact_refs: Optional[dict[str, str]] = None, + started_at: Optional[datetime] = None, +) -> str: + run_id = run_id or str(uuid4()) + now = _utc_now() + started_at = started_at or now + payload = { + 'run_id': run_id, + 'uid': uid, + 'conversation_id': conversation_id, + 'provider': provider, + 'model': model, + 'workload': workload, + 'status': 'started', + 'provider_job_ref': provider_job_ref, + 'artifact_refs': artifact_refs or {}, + 'timing': { + 'started_at': started_at, + 'completed_at': None, + 'latency_ms': None, + }, + 'raw_audio_seconds': 0.0, + 'speech_active_seconds': 0.0, + 'billable_seconds': 0.0, + 'chunk_duration_seconds': 0.0, + 'estimated_cost_usd': 0.0, + 'retry_count': 0, + 'fallback_count': 0, + 'transcript_segment_count': 0, + 'transcript_word_count': 0, + 'speaker_cluster_count': 0, + 'identified_speaker_cluster_count': 0, + 'identity_match_count': 0, + 'provider_speaker_count': 0, + 'mapped_speaker_count': 0, + 'mapped_person_count': 0, + 'unmapped_speaker_count': 0, + 'unknown_speaker_count': 0, + 'unknown_speaker_duration_seconds': 0.0, + 'split_count': 0, + 'embedding_extraction_failure_count': 0, + 'identity_metric_update': { + 'status': 'pending', + 'skipped_reason': None, + 'updated_at': None, + }, + 'identity_confidence_summary': {}, + 'error_class': None, + 'created_at': now, + 'updated_at': now, + 'expires_at': _ttl_expires_at(now), + } + _reject_forbidden_payload_keys(payload) + _run_ref(run_id).set(payload, merge=False) + return run_id + + +def finalize_provider_run( + run_id: str, + provider: str, + model: str, + workload: str, + status: str, + started_at: datetime, + completed_at: Optional[datetime] = None, + raw_audio_seconds: float = 0.0, + speech_active_seconds: float = 0.0, + billable_seconds: float = 0.0, + chunk_duration_seconds: float = 0.0, + estimated_cost_usd: float = 0.0, + retry_count: int = 0, + fallback_count: int = 0, + transcript_segment_count: int = 0, + transcript_word_count: int = 0, + speaker_cluster_count: int = 0, + identified_speaker_cluster_count: int = 0, + identity_match_count: int = 0, + provider_speaker_count: int = 0, + mapped_speaker_count: int = 0, + mapped_person_count: int = 0, + unmapped_speaker_count: int = 0, + unknown_speaker_count: int = 0, + unknown_speaker_duration_seconds: float = 0.0, + split_count: int = 0, + embedding_extraction_failure_count: int = 0, + identity_confidence_summary: Optional[dict[str, Any]] = None, + error_class: Optional[str] = None, + artifact_refs: Optional[dict[str, str]] = None, + fallback_provider: Optional[str] = None, + fallback_reason: str = 'provider_failure', +) -> None: + completed_at = completed_at or _utc_now() + if started_at.tzinfo is None: + started_at = started_at.replace(tzinfo=timezone.utc) + if completed_at.tzinfo is None: + completed_at = completed_at.replace(tzinfo=timezone.utc) + latency_seconds = max((completed_at - started_at).total_seconds(), 0.0) + summary = identity_confidence_summary or {} + payload = { + 'status': status, + 'timing': { + 'started_at': started_at, + 'completed_at': completed_at, + 'latency_ms': int(latency_seconds * 1000), + }, + 'raw_audio_seconds': raw_audio_seconds, + 'speech_active_seconds': speech_active_seconds, + 'billable_seconds': billable_seconds, + 'chunk_duration_seconds': chunk_duration_seconds, + 'estimated_cost_usd': estimated_cost_usd, + 'retry_count': retry_count, + 'fallback_count': fallback_count, + 'transcript_segment_count': transcript_segment_count, + 'transcript_word_count': transcript_word_count, + 'speaker_cluster_count': speaker_cluster_count, + 'identified_speaker_cluster_count': identified_speaker_cluster_count, + 'identity_match_count': identity_match_count, + 'provider_speaker_count': provider_speaker_count, + 'mapped_speaker_count': mapped_speaker_count, + 'mapped_person_count': mapped_person_count, + 'unmapped_speaker_count': unmapped_speaker_count, + 'unknown_speaker_count': unknown_speaker_count, + 'unknown_speaker_duration_seconds': unknown_speaker_duration_seconds, + 'split_count': split_count, + 'embedding_extraction_failure_count': embedding_extraction_failure_count, + 'identity_confidence_summary': summary, + 'error_class': error_class, + 'artifact_refs': artifact_refs or {}, + 'fallback': _fallback_details( + fallback_count=fallback_count, + from_provider=fallback_provider, + to_provider=provider, + reason=fallback_reason, + ), + 'updated_at': completed_at, + } + _reject_forbidden_payload_keys(payload) + _run_ref(run_id).set(payload, merge=True) + increment_daily_rollup( + day=utc_day_bucket(completed_at), + provider=provider, + model=model, + workload=workload, + status=status, + raw_audio_seconds=raw_audio_seconds, + speech_active_seconds=speech_active_seconds, + billable_seconds=billable_seconds, + chunk_duration_seconds=chunk_duration_seconds, + estimated_cost_usd=estimated_cost_usd, + retry_count=retry_count, + fallback_count=fallback_count, + transcript_segment_count=transcript_segment_count, + transcript_word_count=transcript_word_count, + speaker_cluster_count=speaker_cluster_count, + identified_speaker_cluster_count=identified_speaker_cluster_count, + identity_match_count=identity_match_count, + provider_speaker_count=provider_speaker_count, + mapped_speaker_count=mapped_speaker_count, + mapped_person_count=mapped_person_count, + unmapped_speaker_count=unmapped_speaker_count, + unknown_speaker_count=unknown_speaker_count, + unknown_speaker_duration_seconds=unknown_speaker_duration_seconds, + split_count=split_count, + embedding_extraction_failure_count=embedding_extraction_failure_count, + identity_confidence_summary=summary, + ) + emit_provider_run_metrics( + provider=provider, + model=model, + workload=workload, + status=status, + latency_seconds=latency_seconds, + raw_audio_seconds=raw_audio_seconds, + speech_active_seconds=speech_active_seconds, + billable_seconds=billable_seconds, + retry_count=retry_count, + fallback_count=fallback_count, + fallback_provider=fallback_provider, + fallback_reason=fallback_reason, + speaker_cluster_count=speaker_cluster_count, + identified_speaker_cluster_count=identified_speaker_cluster_count, + identity_confidence_summary=summary, + ) + + +def _fallback_details( + fallback_count: int, + from_provider: Optional[str], + to_provider: str, + reason: str, +) -> Optional[dict[str, Any]]: + if fallback_count <= 0: + return None + return { + 'from_provider': from_provider or 'unknown', + 'to_provider': to_provider, + 'reason': reason, + } + + +def increment_daily_rollup( + day: str, + provider: str, + model: str, + workload: str, + status: str, + raw_audio_seconds: float = 0.0, + speech_active_seconds: float = 0.0, + billable_seconds: float = 0.0, + chunk_duration_seconds: float = 0.0, + estimated_cost_usd: float = 0.0, + retry_count: int = 0, + fallback_count: int = 0, + transcript_segment_count: int = 0, + transcript_word_count: int = 0, + speaker_cluster_count: int = 0, + identified_speaker_cluster_count: int = 0, + identity_match_count: int = 0, + provider_speaker_count: int = 0, + mapped_speaker_count: int = 0, + mapped_person_count: int = 0, + unmapped_speaker_count: int = 0, + unknown_speaker_count: int = 0, + unknown_speaker_duration_seconds: float = 0.0, + split_count: int = 0, + embedding_extraction_failure_count: int = 0, + identity_confidence_summary: Optional[dict[str, Any]] = None, +) -> None: + update = { + 'day': day, + 'provider': provider, + 'model': model, + 'workload': workload, + 'run_count': firestore.Increment(1), + f'status_counts.{status}': firestore.Increment(1), + 'raw_audio_seconds': firestore.Increment(raw_audio_seconds), + 'speech_active_seconds': firestore.Increment(speech_active_seconds), + 'billable_seconds': firestore.Increment(billable_seconds), + 'estimated_cost_usd': firestore.Increment(estimated_cost_usd), + 'chunk_duration_seconds': firestore.Increment(chunk_duration_seconds), + 'retry_count': firestore.Increment(retry_count), + 'fallback_count': firestore.Increment(fallback_count), + 'transcript_segment_count': firestore.Increment(transcript_segment_count), + 'transcript_word_count': firestore.Increment(transcript_word_count), + 'speaker_cluster_count': firestore.Increment(speaker_cluster_count), + 'identified_speaker_cluster_count': firestore.Increment(identified_speaker_cluster_count), + 'identity_match_count': firestore.Increment(identity_match_count), + 'provider_speaker_count': firestore.Increment(provider_speaker_count), + 'mapped_speaker_count': firestore.Increment(mapped_speaker_count), + 'mapped_person_count': firestore.Increment(mapped_person_count), + 'unmapped_speaker_count': firestore.Increment(unmapped_speaker_count), + 'unknown_speaker_count': firestore.Increment(unknown_speaker_count), + 'unknown_speaker_duration_seconds': firestore.Increment(unknown_speaker_duration_seconds), + 'split_count': firestore.Increment(split_count), + 'embedding_extraction_failure_count': firestore.Increment(embedding_extraction_failure_count), + 'last_updated': _utc_now(), + } + for bucket, count in (identity_confidence_summary or {}).items(): + if isinstance(count, (int, float)) and count > 0: + update[f'identity_confidence_counts.{bucket}'] = firestore.Increment(count) + _rollup_ref(day, provider, model, workload).set(update, merge=True) + + +def rebuild_daily_rollup_from_runs(day: str, provider: str, model: str, workload: str) -> dict[str, Any]: + query = ( + db.collection(RUNS_COLLECTION) + .where(filter=FieldFilter('provider', '==', provider)) + .where(filter=FieldFilter('model', '==', model)) + .where(filter=FieldFilter('workload', '==', workload)) + ) + rollup = _empty_rollup(day, provider, model, workload) + for doc in query.stream(): + data = doc.to_dict() or {} + completed_at = (data.get('timing') or {}).get('completed_at') + if not completed_at or utc_day_bucket(completed_at) != day: + continue + _add_run_to_rollup(rollup, data) + _rollup_ref(day, provider, model, workload).set(rollup, merge=False) + return rollup + + +def purge_provider_runs_for_user(uid: str, batch_size: int = 400) -> int: + deleted = 0 + query = db.collection(RUNS_COLLECTION).where(filter=FieldFilter('uid', '==', uid)) + batch = db.batch() + pending = 0 + for doc in query.stream(): + batch.delete(doc.reference) + deleted += 1 + pending += 1 + if pending >= batch_size: + batch.commit() + batch = db.batch() + pending = 0 + if pending: + batch.commit() + return deleted + + +def update_provider_run_identity_metrics( + run_id: str, + provider: str, + model: str, + workload: str, + identified_speaker_cluster_count: int, + identity_confidence_summary: Optional[dict[str, Any]] = None, + provider_speaker_count: int = 0, + mapped_speaker_count: int = 0, + mapped_person_count: int = 0, + unmapped_speaker_count: int = 0, + embedding_extraction_failure_count: int = 0, + identity_metric_update_status: str = 'succeeded', + identity_metric_update_skipped_reason: Optional[str] = None, +) -> None: + ref = _run_ref(run_id) + snapshot = ref.get() + if not snapshot.exists: + return + data = snapshot.to_dict() or {} + previous_identified = data.get('identified_speaker_cluster_count', 0) or 0 + previous_provider_speakers = data.get('provider_speaker_count', data.get('speaker_cluster_count', 0)) or 0 + previous_mapped_speakers = data.get('mapped_speaker_count', previous_identified) or 0 + previous_mapped_people = data.get('mapped_person_count', 0) or 0 + previous_unmapped_speakers = data.get('unmapped_speaker_count', 0) or 0 + previous_embedding_failures = data.get('embedding_extraction_failure_count', 0) or 0 + previous_summary = data.get('identity_confidence_summary') or {} + summary = identity_confidence_summary or {} + completed_at = (data.get('timing') or {}).get('completed_at') or _utc_now() + + ref.set( + { + 'identified_speaker_cluster_count': identified_speaker_cluster_count, + 'provider_speaker_count': provider_speaker_count, + 'mapped_speaker_count': mapped_speaker_count, + 'mapped_person_count': mapped_person_count, + 'unmapped_speaker_count': unmapped_speaker_count, + 'embedding_extraction_failure_count': embedding_extraction_failure_count, + 'identity_metric_update': { + 'status': identity_metric_update_status, + 'skipped_reason': identity_metric_update_skipped_reason, + 'updated_at': _utc_now(), + }, + 'identity_confidence_summary': summary, + 'updated_at': _utc_now(), + }, + merge=True, + ) + + rollup_update = { + 'identified_speaker_cluster_count': firestore.Increment(identified_speaker_cluster_count - previous_identified), + 'provider_speaker_count': firestore.Increment(provider_speaker_count - previous_provider_speakers), + 'mapped_speaker_count': firestore.Increment(mapped_speaker_count - previous_mapped_speakers), + 'mapped_person_count': firestore.Increment(mapped_person_count - previous_mapped_people), + 'unmapped_speaker_count': firestore.Increment(unmapped_speaker_count - previous_unmapped_speakers), + 'embedding_extraction_failure_count': firestore.Increment( + embedding_extraction_failure_count - previous_embedding_failures + ), + 'last_updated': _utc_now(), + } + for bucket in set(previous_summary) | set(summary): + delta = (summary.get(bucket, 0) or 0) - (previous_summary.get(bucket, 0) or 0) + if delta: + rollup_update[f'identity_confidence_counts.{bucket}'] = firestore.Increment(delta) + _rollup_ref(utc_day_bucket(completed_at), provider, model, workload).set(rollup_update, merge=True) + + +def emit_provider_run_metrics( + provider: str, + model: str, + workload: str, + status: str, + latency_seconds: float, + raw_audio_seconds: float, + speech_active_seconds: float, + billable_seconds: float, + retry_count: int, + fallback_count: int, + speaker_cluster_count: int, + identified_speaker_cluster_count: int, + identity_confidence_summary: Optional[dict[str, Any]] = None, + fallback_provider: Optional[str] = None, + fallback_reason: str = 'provider_failure', +) -> None: + observe_transcription_provider_request(provider, model, workload, status, latency_seconds) + observe_transcription_provider_audio_seconds( + provider, + model, + workload, + raw_audio_seconds=raw_audio_seconds, + speech_active_seconds=speech_active_seconds, + billable_seconds=billable_seconds, + ) + observe_transcription_provider_retry(provider, model, workload, 'provider_retry', retry_count) + if fallback_count > 0: + observe_transcription_provider_fallback( + fallback_provider or 'unknown', + provider, + workload, + fallback_reason, + fallback_count, + ) + observe_transcription_provider_speaker_clusters( + provider, + model, + workload, + speaker_cluster_count=speaker_cluster_count, + identified_speaker_cluster_count=identified_speaker_cluster_count, + ) + for bucket, count in (identity_confidence_summary or {}).items(): + observe_transcription_provider_identity_confidence(provider, model, workload, bucket, count) + + +def summarize_identity_confidences(confidences: List[Optional[float]]) -> dict[str, int]: + summary: dict[str, int] = {} + for confidence in confidences: + bucket = identity_confidence_bucket(confidence) + summary[bucket] = summary.get(bucket, 0) + 1 + return summary + + +def _empty_rollup(day: str, provider: str, model: str, workload: str) -> dict[str, Any]: + return { + 'day': day, + 'provider': provider, + 'model': model, + 'workload': workload, + 'run_count': 0, + 'status_counts': {}, + 'raw_audio_seconds': 0.0, + 'speech_active_seconds': 0.0, + 'billable_seconds': 0.0, + 'chunk_duration_seconds': 0.0, + 'estimated_cost_usd': 0.0, + 'retry_count': 0, + 'fallback_count': 0, + 'transcript_segment_count': 0, + 'transcript_word_count': 0, + 'speaker_cluster_count': 0, + 'identified_speaker_cluster_count': 0, + 'identity_match_count': 0, + 'provider_speaker_count': 0, + 'mapped_speaker_count': 0, + 'mapped_person_count': 0, + 'unmapped_speaker_count': 0, + 'unknown_speaker_count': 0, + 'unknown_speaker_duration_seconds': 0.0, + 'split_count': 0, + 'embedding_extraction_failure_count': 0, + 'identity_confidence_counts': {}, + 'last_updated': _utc_now(), + } + + +def _add_run_to_rollup(rollup: dict[str, Any], data: dict[str, Any]) -> None: + rollup['run_count'] += 1 + status = data.get('status') or 'unknown' + rollup['status_counts'][status] = rollup['status_counts'].get(status, 0) + 1 + for field in ( + 'raw_audio_seconds', + 'speech_active_seconds', + 'billable_seconds', + 'chunk_duration_seconds', + 'estimated_cost_usd', + 'retry_count', + 'fallback_count', + 'transcript_segment_count', + 'transcript_word_count', + 'speaker_cluster_count', + 'identified_speaker_cluster_count', + 'identity_match_count', + 'provider_speaker_count', + 'mapped_speaker_count', + 'mapped_person_count', + 'unmapped_speaker_count', + 'unknown_speaker_count', + 'unknown_speaker_duration_seconds', + 'split_count', + 'embedding_extraction_failure_count', + ): + rollup[field] += data.get(field, 0) or 0 + for bucket, count in (data.get('identity_confidence_summary') or {}).items(): + rollup['identity_confidence_counts'][bucket] = rollup['identity_confidence_counts'].get(bucket, 0) + count diff --git a/backend/main.py b/backend/main.py index 3ecd2546125..2ef585b5bf6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -14,6 +14,7 @@ from routers import ( chat, + desktop_background, firmware, transcribe, notifications, @@ -86,6 +87,7 @@ app = FastAPI() app.include_router(transcribe.router) +app.include_router(desktop_background.router) app.include_router(conversations.router) app.include_router(action_items.router) app.include_router(task_integrations.router) diff --git a/backend/models/conversation.py b/backend/models/conversation.py index 7bcce5517f3..214107946e8 100644 --- a/backend/models/conversation.py +++ b/backend/models/conversation.py @@ -92,6 +92,7 @@ class Conversation(BaseModel): plugins_results: List[PluginResult] = [] external_data: Optional[Dict] = None + background_processed_chunks: Dict[str, Dict] = Field(default_factory=dict) app_id: Optional[str] = None discarded: bool = False diff --git a/backend/models/message_event.py b/backend/models/message_event.py index 0bd386689f5..cca36fd3146 100644 --- a/backend/models/message_event.py +++ b/backend/models/message_event.py @@ -154,6 +154,13 @@ class SpeakerLabelSuggestionEvent(MessageEvent): person_id: str person_name: str segment_id: str + version: int = 1 + provider_cluster_id: Optional[str] = None + speaker_identity_state: Optional[str] = None + confidence: Optional[float] = None + source: Optional[str] = None + provenance: Optional[dict[str, Any]] = None + candidates: Optional[List[dict[str, Any]]] = None def to_json(self): j = self.model_dump(mode="json") diff --git a/backend/models/transcript_segment.py b/backend/models/transcript_segment.py index 44647ce7aca..7bf045d7f87 100644 --- a/backend/models/transcript_segment.py +++ b/backend/models/transcript_segment.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Optional, List, Tuple +from typing import Any, Literal, Optional, List, Tuple import uuid import re from pydantic import BaseModel, Field @@ -23,6 +23,38 @@ class Translation(BaseModel): text: str +SpeakerIdentityState = Literal['unknown', 'legacy_ambiguous', 'unassigned', 'identified', 'user'] + + +class ProviderTranscriptWord(BaseModel): + text: str + start: float + end: float + provider_cluster_id: Optional[str] = None + speaker_label: Optional[str] = None + confidence: Optional[float] = None + + +class ProviderTranscriptUtterance(BaseModel): + text: str + start: float + end: float + provider_cluster_id: Optional[str] = None + speaker_label: Optional[str] = None + confidence: Optional[float] = None + words: Optional[List[ProviderTranscriptWord]] = None + + +class ProviderTranscriptResult(BaseModel): + provider: str + model: Optional[str] = None + language: Optional[str] = None + duration: Optional[float] = None + words: List[ProviderTranscriptWord] = [] + utterances: List[ProviderTranscriptUtterance] = [] + raw_provider_result_id: Optional[str] = None + + class TranscriptSegment(BaseModel): id: Optional[str] = None text: str @@ -35,6 +67,16 @@ class TranscriptSegment(BaseModel): translations: Optional[List[Translation]] = [] speech_profile_processed: bool = True stt_provider: Optional[str] = None + stt_model: Optional[str] = None + provider_cluster_id: Optional[str] = None + provider_speaker_label: Optional[str] = None + speaker_identity_state: SpeakerIdentityState = 'legacy_ambiguous' + speaker_identity_confidence: Optional[float] = None + speaker_identity_source: Optional[str] = None + speaker_identity_version: Optional[str] = None + speaker_identity_provenance: Optional[dict[str, Any]] = None + speaker_identity_candidates: Optional[List[dict[str, Any]]] = None + speaker_identity_text_hints: Optional[List[dict[str, Any]]] = None def __init__(self, **data): super().__init__(**data) @@ -45,10 +87,16 @@ def __init__(self, **data): try: self.speaker_id = int(self.speaker.split('_', 1)[1]) except (ValueError, IndexError): - self.speaker_id = 0 - else: + if self.speaker_id is None: + self.speaker_id = 0 + elif self.speaker_id is None: self.speaker_id = 0 + if self.person_id and self.speaker_identity_state in ('legacy_ambiguous', 'unassigned', 'unknown'): + self.speaker_identity_state = 'user' if self.is_user else 'identified' + elif self.is_user and self.speaker_identity_state in ('legacy_ambiguous', 'unassigned', 'unknown'): + self.speaker_identity_state = 'user' + def get_timestamp_string(self): start_duration = timedelta(seconds=int(self.start)) end_duration = timedelta(seconds=int(self.end)) @@ -60,6 +108,15 @@ def segments_as_string(segments, include_timestamps=False, user_name: str = None user_name = 'User' transcript = '' people_map = {person.id: person.name for person in people} if people else {} + provider_cluster_display_ids = {} + next_provider_display_id = 1 + for segment in segments: + provider_cluster_key = _provider_display_cluster_key(segment) + if segment.is_user or not provider_cluster_key: + continue + if provider_cluster_key not in provider_cluster_display_ids: + provider_cluster_display_ids[provider_cluster_key] = next_provider_display_id + next_provider_display_id += 1 include_timestamps = include_timestamps and TranscriptSegment.can_display_seconds(segments) for segment in segments: segment_text = segment.text.strip() @@ -69,7 +126,11 @@ def segments_as_string(segments, include_timestamps=False, user_name: str = None if segment.person_id and segment.person_id in people_map: speaker_name = people_map[segment.person_id] else: - speaker_name = f'Speaker {segment.speaker_id}' + provider_cluster_key = _provider_display_cluster_key(segment) + if provider_cluster_key in provider_cluster_display_ids: + speaker_name = f'Speaker {provider_cluster_display_ids[provider_cluster_key]}' + else: + speaker_name = f'Speaker {segment.speaker_id}' transcript += f'{timestamp_str}{speaker_name}: {segment_text}\n\n' return transcript.strip() @@ -228,6 +289,14 @@ def _merge(a, b: TranscriptSegment): return segments, joined_similar_segments, removed_ids +def _provider_display_cluster_key(segment: TranscriptSegment) -> Optional[str]: + if not segment.provider_cluster_id and not segment.provider_speaker_label: + return None + if segment.speaker_id not in (None, 0): + return None + return segment.provider_cluster_id or segment.provider_speaker_label + + class ImprovedTranscriptSegment(BaseModel): speaker_id: int = Field(..., description='The correctly assigned speaker id') text: str = Field(..., description='The corrected text of the segment') diff --git a/backend/routers/desktop_background.py b/backend/routers/desktop_background.py new file mode 100644 index 00000000000..13c64135c9d --- /dev/null +++ b/backend/routers/desktop_background.py @@ -0,0 +1,683 @@ +import hashlib +import json +import logging +from datetime import datetime, timezone +from typing import Dict, List, Optional + +import numpy as np +from fastapi import APIRouter, Depends, Header, HTTPException, Request +from pydantic import BaseModel + +import database.conversations as conversations_db +import database.users as users_db +from database import redis_db +from models.conversation_enums import ConversationSource, ConversationStatus +from models.transcript_segment import TranscriptSegment +from utils.analytics import record_usage +from utils.chat import resolve_voice_message_language +from utils.conversations.desktop_background import ( + DesktopBackgroundConversationError, + append_background_chunk_to_in_progress_conversation, + create_in_progress_desktop_conversation, + finish_desktop_background_conversation, + get_background_chunk_record, +) +from utils.executors import db_executor, run_blocking, sync_executor +from utils.fair_use import is_hard_restricted, record_speech_ms +from utils.other import endpoints as auth +from utils.speaker_identification import _pcm_to_wav_bytes +from utils.stt.background_speaker_identity import USER_SELF_PERSON_ID, identify_background_speaker_clusters +from utils.stt.provider_service import ( + resolve_background_provider_policy, + speaker_identity_metrics, + transcribe_bytes, + update_provider_run_identity_metrics, +) +from utils.stt.speaker_embedding import extract_embedding_from_bytes +from utils.stt.providers import STTWorkload +from utils.subscription import has_transcription_credits, is_trial_paywalled +from utils.voice_duration_limiter import compute_pcm_duration_ms + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/v2/desktop", tags=["desktop-background"]) + +_MAX_PCM_BODY_BYTES = 200_000_000 +_SPEAKER_MAP_TTL_SECONDS = 60 * 60 * 24 +_LOCAL_CLUSTER_SPLIT_MARKER = "::local_part:" +_ANONYMOUS_CHUNK_NAMESPACE = "chunk" +_IDENTIFIED_SPEAKER_NAMESPACE = "identity" + +_SPEAKER_RECONCILIATION_BUDGETS = { + "max_app_visible_speaker_inflation_ratio": 4.0, + "max_provider_only_false_merge_rate": 0.0, + "max_unresolved_anonymous_speaker_duration_seconds": 60 * 60, + "max_split_reconcile_ratio": 10.0, +} + + +class BackgroundConversationStartRequest(BaseModel): + language: Optional[str] = None + source: Optional[str] = "desktop" + + +@router.get("/capabilities") +async def desktop_capabilities(uid: str = Depends(auth.get_current_user_uid)): + policy = resolve_background_provider_policy() + mode = policy.mode.value + if not policy.enabled: + mode = 'disabled' + elif policy.reason == 'fallback_deepgram_available': + mode = 'deepgram_fallback' + return { + "background_batch": { + "enabled": policy.enabled, + "mode": mode, + "provider": policy.primary_provider.value, + "primary_provider": policy.primary_provider.value, + "effective_provider": policy.effective_provider.value if policy.effective_provider else None, + "fallback_provider": policy.fallback_provider.value if policy.fallback_provider else None, + "fallback_enabled": policy.fallback_enabled, + "fallback_available": policy.fallback_available, + "workload": STTWorkload.background.value, + "reason": policy.reason, + "sample_rate": 16000, + "channels": 1, + "encoding": "linear16", + "max_chunk_seconds": 15, + } + } + + +@router.post("/background-conversation/start") +async def start_background_conversation( + body: BackgroundConversationStartRequest, + uid: str = Depends(auth.get_current_user_uid), + x_app_platform: Optional[str] = Header(None, alias='X-App-Platform'), +): + if is_trial_paywalled(uid, x_app_platform or body.source or "desktop"): + raise HTTPException(status_code=402, detail={'error': 'quota_exceeded', 'plan_type': 'basic'}) + + language = resolve_voice_message_language(uid, body.language) + try: + source = ConversationSource(body.source or "desktop") + except ValueError: + raise HTTPException(status_code=422, detail='Invalid source') + + conversation_id = await run_blocking( + db_executor, + create_in_progress_desktop_conversation, + uid, + language, + source, + ) + return {"conversation_id": conversation_id} + + +@router.post("/background-conversation/{conversation_id}/finish") +async def finish_background_conversation( + conversation_id: str, + uid: str = Depends(auth.with_rate_limit(auth.get_current_user_uid, "desktop:background_conversation_finish")), +): + try: + return await run_blocking( + db_executor, + finish_desktop_background_conversation, + uid, + conversation_id, + ) + except DesktopBackgroundConversationError as e: + raise HTTPException(status_code=e.status_code, detail=str(e)) + + +@router.post("/background-transcribe") +async def background_transcribe( + request: Request, + uid: str = Depends(auth.with_rate_limit(auth.get_current_user_uid, "desktop:background_transcribe")), + x_app_platform: Optional[str] = Header(None, alias='X-App-Platform'), +): + if is_trial_paywalled(uid, x_app_platform or "desktop"): + raise HTTPException(status_code=402, detail={'error': 'quota_exceeded', 'plan_type': 'basic'}) + if await run_blocking(db_executor, is_hard_restricted, uid): + raise HTTPException(status_code=429, detail='Transcription temporarily restricted') + if not await run_blocking(db_executor, has_transcription_credits, uid, source="desktop"): + raise HTTPException(status_code=429, detail='Transcription credits exhausted') + + content_length = request.headers.get("content-length") + if content_length: + try: + content_length_value = int(content_length) + except ValueError: + raise HTTPException(status_code=422, detail='content-length must be an integer') + if content_length_value > _MAX_PCM_BODY_BYTES: + raise HTTPException(status_code=413, detail=f'Body too large (max {_MAX_PCM_BODY_BYTES} bytes)') + + audio_bytes = await request.body() + if not audio_bytes: + raise HTTPException(status_code=400, detail='No audio data provided') + if len(audio_bytes) > _MAX_PCM_BODY_BYTES: + del audio_bytes + raise HTTPException(status_code=413, detail=f'Body too large (max {_MAX_PCM_BODY_BYTES} bytes)') + + try: + sample_rate = int(request.query_params.get("sample_rate", "16000")) + channels = int(request.query_params.get("channels", "1")) + except ValueError: + del audio_bytes + raise HTTPException(status_code=422, detail='sample_rate and channels must be integers') + + if sample_rate < 8000 or sample_rate > 48000: + del audio_bytes + raise HTTPException(status_code=422, detail='sample_rate must be between 8000 and 48000') + if channels != 1: + del audio_bytes + raise HTTPException(status_code=422, detail='channels must be 1') + + chunk_start_ms_raw = request.query_params.get("chunk_start_ms") + if chunk_start_ms_raw is None: + del audio_bytes + raise HTTPException(status_code=400, detail='chunk_start_ms is required') + try: + chunk_start_ms = int(chunk_start_ms_raw) + except ValueError: + del audio_bytes + raise HTTPException(status_code=422, detail='chunk_start_ms must be an integer') + if chunk_start_ms < 0: + del audio_bytes + raise HTTPException(status_code=422, detail='chunk_start_ms must be non-negative') + + conversation_id = request.query_params.get("conversation_id") + persist = _parse_persist(request.query_params.get("persist"), default=bool(conversation_id)) + chunk_id = request.query_params.get("chunk_id") + encoding = request.query_params.get("encoding", "linear16") + if persist: + if not conversation_id: + del audio_bytes + raise HTTPException(status_code=400, detail='conversation_id is required when persist=true') + if not chunk_id: + del audio_bytes + raise HTTPException(status_code=400, detail='chunk_id is required when persist=true') + if not _is_valid_chunk_id(chunk_id): + del audio_bytes + raise HTTPException(status_code=422, detail='chunk_id is invalid') + await _validate_in_progress_conversation(uid, conversation_id) + + payload_hash = _background_chunk_payload_hash(audio_bytes, sample_rate, channels, encoding, chunk_start_ms) + if persist and conversation_id and chunk_id: + try: + existing_chunk = await run_blocking( + db_executor, get_background_chunk_record, uid, conversation_id, chunk_id + ) + except DesktopBackgroundConversationError as e: + del audio_bytes + raise HTTPException(status_code=e.status_code, detail=str(e)) + if existing_chunk: + del audio_bytes + if existing_chunk.get('payload_hash') != payload_hash: + logger.warning( + "desktop_background_transcribe chunk payload conflict uid=%s conversation_id=%s chunk_id=%s", + uid, + conversation_id, + chunk_id, + ) + raise HTTPException(status_code=409, detail='chunk_id payload mismatch') + logger.info( + "desktop_background_transcribe duplicate chunk uid=%s conversation_id=%s chunk_id=%s provider=%s", + uid, + conversation_id, + chunk_id, + existing_chunk.get('provider'), + ) + return { + "segments": [], + "language": None, + "provider": existing_chunk.get('provider'), + "run_id": existing_chunk.get('run_id'), + "chunk_duration_ms": existing_chunk.get('chunk_duration_ms'), + "chunk_id": chunk_id, + "duplicate": True, + "speaker_diagnostics": _speaker_diagnostics([]), + } + + language = resolve_voice_message_language(uid, request.query_params.get("language")) + keywords = _parse_context_keywords(request.query_params.get("keywords")) + duration_ms = compute_pcm_duration_ms(len(audio_bytes), sample_rate, channels) + duration_sec = duration_ms / 1000.0 + + try: + audio_for_stt = _pcm_to_wav_bytes(audio_bytes, sample_rate) if encoding == "linear16" else audio_bytes + response = await run_blocking( + sync_executor, + transcribe_bytes, + audio_for_stt, + workload=STTWorkload.background, + uid=uid, + conversation_id=conversation_id, + sample_rate=sample_rate, + diarize=True, + encoding=None if encoding == "linear16" else encoding, + channels=channels, + language=language, + return_language=language == "multi", + keywords=keywords, + raw_audio_seconds=duration_sec, + ) + except RuntimeError as e: + logger.error("Desktop background transcription failed: %s", e) + raise HTTPException(status_code=500, detail=f'Transcription failed: {str(e)}') + finally: + del audio_bytes + + segments = response.segments + split_diagnostics = _split_noncontiguous_provider_clusters(segments) + speaker_diagnostics = _speaker_diagnostics(segments) + speaker_diagnostics.update(split_diagnostics) + if conversation_id and segments: + await _identify_speakers( + uid=uid, + conversation_id=conversation_id, + audio_bytes=audio_for_stt, + segments=segments, + provider=response.result.provider if response.result else None, + model=response.result.model if response.result else None, + run_id=response.run_id, + ) + _apply_chunk_offset(segments, chunk_start_ms / 1000.0) + if conversation_id: + reconciliation_diagnostics = _apply_speaker_ids(conversation_id, chunk_id, segments) + else: + reconciliation_diagnostics = _speaker_reconciliation_diagnostics(segments, {}, {}, {}) + speaker_diagnostics.update(_speaker_diagnostics(segments, prefix="mapped_")) + speaker_diagnostics.update(reconciliation_diagnostics) + + finished_at = datetime.now(timezone.utc) + if persist and conversation_id: + try: + append_result = await run_blocking( + db_executor, + append_background_chunk_to_in_progress_conversation, + uid, + conversation_id, + chunk_id, + payload_hash, + segments, + finished_at, + response.result.provider if response.result else None, + response.run_id, + chunk_start_ms, + duration_ms, + ) + except DesktopBackgroundConversationError as e: + raise HTTPException(status_code=e.status_code, detail=str(e)) + if append_result.duplicate: + segments = append_result.segments + + await run_blocking(db_executor, record_speech_ms, uid, duration_ms, source='background') + await run_blocking(db_executor, record_usage, uid, transcription_seconds=duration_sec, speech_seconds=duration_sec) + provider = response.result.provider if response.result else None + logger.info( + "desktop_background_transcribe completed uid=%s conversation_id=%s workload=background provider=%s run_id=%s " + "chunk_id=%s chunk_start_ms=%s chunk_duration_ms=%s segments=%s persisted=%s", + uid, + conversation_id, + provider, + response.run_id, + chunk_id, + chunk_start_ms, + duration_ms, + len(segments), + bool(persist and conversation_id), + ) + + return { + "segments": [segment.model_dump() for segment in segments], + "language": response.detected_language or language, + "provider": provider, + "run_id": response.run_id, + "chunk_duration_ms": duration_ms, + "chunk_id": chunk_id, + "duplicate": False, + "speaker_diagnostics": speaker_diagnostics, + } + + +def _parse_persist(raw: Optional[str], default: bool) -> bool: + if raw is None: + return default + return raw.strip().lower() not in ("0", "false", "no") + + +def _parse_context_keywords(raw: Optional[str]) -> List[str]: + if not raw: + return [] + keywords: List[str] = [] + seen = set() + for item in raw.split(','): + keyword = item.strip() + if len(keyword) < 2 or len(keyword) > 80: + continue + dedupe_key = keyword.lower() + if dedupe_key in seen: + continue + seen.add(dedupe_key) + keywords.append(keyword) + if len(keywords) >= 100: + break + return keywords + + +def _is_valid_chunk_id(chunk_id: str) -> bool: + if len(chunk_id) < 8 or len(chunk_id) > 160: + return False + return all(char.isalnum() or char in ('-', '_') for char in chunk_id) + + +def _background_chunk_payload_hash( + audio_bytes: bytes, + sample_rate: int, + channels: int, + encoding: str, + chunk_start_ms: int, +) -> str: + hasher = hashlib.sha256() + hasher.update(str(sample_rate).encode('ascii')) + hasher.update(b':') + hasher.update(str(channels).encode('ascii')) + hasher.update(b':') + hasher.update(encoding.encode('utf-8')) + hasher.update(b':') + hasher.update(str(chunk_start_ms).encode('ascii')) + hasher.update(b':') + hasher.update(audio_bytes) + return hasher.hexdigest() + + +async def _validate_in_progress_conversation(uid: str, conversation_id: str) -> None: + conversation = await run_blocking(db_executor, conversations_db.get_conversation, uid, conversation_id) + if not conversation: + raise HTTPException(status_code=404, detail='conversation_id not found') + if conversation.get('status') != ConversationStatus.in_progress: + raise HTTPException(status_code=404, detail='conversation is not in_progress') + + +def _apply_chunk_offset(segments: List[TranscriptSegment], offset_sec: float) -> None: + for segment in segments: + segment.start += offset_sec + segment.end += offset_sec + + +def _split_noncontiguous_provider_clusters(segments: List[TranscriptSegment]) -> Dict[str, object]: + """Split reused provider clusters into local contiguous groups. + + Batched providers can reuse the same cluster label for rapid, non-contiguous turns + inside a chunk. Omi treats provider labels as hints, so split those groups before + identity matching and final speaker_id assignment. + """ + groups: list[tuple[str, list[TranscriptSegment]]] = [] + current_cluster = None + current_group: list[TranscriptSegment] = [] + for segment in sorted(segments, key=lambda item: (item.start, item.end)): + cluster = _raw_provider_cluster_key(segment) + if not cluster: + if current_group and current_cluster is not None: + groups.append((current_cluster, current_group)) + current_group = [] + current_cluster = None + continue + if current_group and cluster != current_cluster: + groups.append((current_cluster, current_group)) + current_group = [] + current_cluster = cluster + current_group.append(segment) + if current_group and current_cluster is not None: + groups.append((current_cluster, current_group)) + + group_counts: Dict[str, int] = {} + for cluster, _group_segments in groups: + group_counts[cluster] = group_counts.get(cluster, 0) + 1 + + split_cluster_count = 0 + split_segment_count = 0 + cannot_link_pairs_prevented = 0 + group_indexes: Dict[str, int] = {} + for cluster, group_segments in groups: + if group_counts[cluster] <= 1: + continue + split_cluster_count += 1 + split_segment_count += len(group_segments) + cannot_link_pairs_prevented += group_indexes.get(cluster, 0) * len(group_segments) + group_index = group_indexes.get(cluster, 0) + 1 + group_indexes[cluster] = group_index + local_cluster = f'{cluster}{_LOCAL_CLUSTER_SPLIT_MARKER}{group_index}' + for segment in group_segments: + segment.provider_cluster_id = local_cluster + + return { + "speaker_split_cluster_count": split_cluster_count, + "speaker_split_segment_count": split_segment_count, + "cannot_link_violations_prevented": cannot_link_pairs_prevented, + } + + +def _raw_provider_cluster_key(segment: TranscriptSegment) -> Optional[str]: + cluster = segment.provider_cluster_id or segment.provider_speaker_label + return str(cluster) if cluster is not None else None + + +async def _identify_speakers( + uid: str, + conversation_id: str, + audio_bytes: bytes, + segments: List[TranscriptSegment], + provider: Optional[str], + model: Optional[str], + run_id: Optional[str], +) -> None: + """Apply Omi speaker identity to AssemblyAI background chunk-local segments.""" + identity_metric_update_status = 'succeeded' + identity_metric_update_skipped_reason = None + try: + person_embeddings_cache = await run_blocking(db_executor, _build_person_embeddings_cache, uid) + if not person_embeddings_cache: + identity_metric_update_status = 'skipped' + identity_metric_update_skipped_reason = 'missing_candidate_embeddings' + else: + assignments = await run_blocking( + sync_executor, + identify_background_speaker_clusters, + segments, + audio_bytes, + person_embeddings_cache, + extract_embedding_from_bytes, + ) + identified_count = sum( + 1 for assignment in assignments.values() if assignment.state in ('identified', 'user') + ) + logger.info( + "Speaker ID (desktop background): cluster assignments=%s identified=%s uid=%s conversation_id=%s", + len(assignments), + identified_count, + uid, + conversation_id, + ) + except Exception as e: + identity_metric_update_status = 'failed' + logger.warning( + "Speaker ID (desktop background): identification failed uid=%s conversation_id=%s: %s", + uid, + conversation_id, + e, + ) + await run_blocking( + db_executor, + update_provider_run_identity_metrics, + run_id, + provider or 'unknown', + model or 'unknown', + STTWorkload.background, + segments, + identity_metric_update_status, + identity_metric_update_skipped_reason, + ) + + +def _build_person_embeddings_cache(uid: str) -> Dict[str, dict]: + cache: Dict[str, dict] = {} + + embedding_list = users_db.get_user_speaker_embedding(uid) + if embedding_list: + user_embedding = np.array(embedding_list, dtype=np.float32).reshape(1, -1) + cache[USER_SELF_PERSON_ID] = {'embedding': user_embedding, 'name': 'User'} + + people = users_db.get_people(uid) + for person in people or []: + embedding = person.get('speaker_embedding') + if embedding and person.get('speech_samples'): + cache[person['id']] = { + 'embedding': np.array(embedding, dtype=np.float32).reshape(1, -1), + 'name': person['name'], + } + + return cache + + +def _apply_speaker_ids( + conversation_id: str, chunk_id: Optional[str], segments: List[TranscriptSegment] +) -> Dict[str, object]: + speaker_map = _load_speaker_map(conversation_id) + changed = False + attempted_reconciliation: Dict[str, str] = {} + accepted_reconciliation: Dict[str, str] = {} + rejected_reconciliation: Dict[str, str] = {} + for segment in segments: + raw_cluster = segment.provider_cluster_id or segment.provider_speaker_label or segment.speaker + if not raw_cluster: + continue + cluster = _speaker_reconciliation_key(chunk_id, segment, str(raw_cluster)) + attempted_reconciliation[str(raw_cluster)] = cluster + if cluster.startswith(f"{_IDENTIFIED_SPEAKER_NAMESPACE}:"): + accepted_reconciliation[str(raw_cluster)] = cluster + else: + rejected_reconciliation[str(raw_cluster)] = "anonymous_provider_cluster_is_chunk_local" + + if cluster not in speaker_map: + speaker_map[cluster] = len(speaker_map) + changed = True + speaker_id = speaker_map[cluster] + segment.speaker_id = speaker_id + segment.speaker = f"SPEAKER_{speaker_id:02d}" + if segment.speaker_identity_state == "legacy_ambiguous": + segment.speaker_identity_state = "unassigned" + + if changed: + _store_speaker_map(conversation_id, speaker_map) + return _speaker_reconciliation_diagnostics( + segments, + attempted_reconciliation, + accepted_reconciliation, + rejected_reconciliation, + ) + + +def _speaker_reconciliation_key(chunk_id: Optional[str], segment: TranscriptSegment, raw_cluster: str) -> str: + if segment.is_user or segment.speaker_identity_state == "user": + return f"{_IDENTIFIED_SPEAKER_NAMESPACE}:user" + if segment.person_id and segment.speaker_identity_state == "identified": + return f"{_IDENTIFIED_SPEAKER_NAMESPACE}:person:{segment.person_id}" + + namespace = chunk_id or "unspecified" + return f"{_ANONYMOUS_CHUNK_NAMESPACE}:{namespace}:provider:{raw_cluster}" + + +def _speaker_map_key(conversation_id: str) -> str: + return f"desktop_batch_speaker_map:{conversation_id}" + + +def _load_speaker_map(conversation_id: str) -> Dict[str, int]: + raw = redis_db.r.get(_speaker_map_key(conversation_id)) + if not raw: + return {} + try: + return {str(key): int(value) for key, value in json.loads(raw).items()} + except (TypeError, ValueError, json.JSONDecodeError): + logger.warning("Invalid desktop batch speaker map for conversation_id=%s", conversation_id) + return {} + + +def _store_speaker_map(conversation_id: str, speaker_map: Dict[str, int]) -> None: + redis_db.r.set(_speaker_map_key(conversation_id), json.dumps(speaker_map), ex=_SPEAKER_MAP_TTL_SECONDS) + + +def _speaker_reconciliation_diagnostics( + segments: List[TranscriptSegment], + attempted_reconciliation: Dict[str, str], + accepted_reconciliation: Dict[str, str], + rejected_reconciliation: Dict[str, str], +) -> Dict[str, object]: + clean_speech_duration = 0.0 + unknown_speaker_duration = 0.0 + provider_clusters = set() + app_speakers = set() + for segment in segments: + duration = max(0.0, segment.end - segment.start) + if segment.text and len(segment.text.split()) >= 2: + clean_speech_duration += duration + if segment.speaker_identity_state in ("unknown", "unassigned", "legacy_ambiguous"): + unknown_speaker_duration += duration + cluster = segment.provider_cluster_id or segment.provider_speaker_label + if cluster: + provider_clusters.add(str(cluster)) + if segment.speaker_id is not None: + app_speakers.add(int(segment.speaker_id)) + + provider_count = max(len(provider_clusters), 1) + speaker_inflation_ratio = round(len(app_speakers) / provider_count, 6) + split_count = sum(1 for cluster in provider_clusters if _LOCAL_CLUSTER_SPLIT_MARKER in cluster) + accepted_count = len(set(accepted_reconciliation.values())) + split_reconcile_ratio = round(split_count / max(accepted_count, 1), 6) if split_count else 0.0 + + budget_violations = [] + if speaker_inflation_ratio > _SPEAKER_RECONCILIATION_BUDGETS["max_app_visible_speaker_inflation_ratio"]: + budget_violations.append("max_app_visible_speaker_inflation_ratio") + if unknown_speaker_duration > _SPEAKER_RECONCILIATION_BUDGETS["max_unresolved_anonymous_speaker_duration_seconds"]: + budget_violations.append("max_unresolved_anonymous_speaker_duration_seconds") + if split_reconcile_ratio > _SPEAKER_RECONCILIATION_BUDGETS["max_split_reconcile_ratio"]: + budget_violations.append("max_split_reconcile_ratio") + + return { + "speaker_reconciliation_attempted_count": len(attempted_reconciliation), + "speaker_reconciliation_accepted_count": len(accepted_reconciliation), + "speaker_reconciliation_rejected_count": len(rejected_reconciliation), + "speaker_reconciliation_rejected_reasons": sorted(set(rejected_reconciliation.values()))[:20], + "provider_only_false_merge_rate": 0.0, + "clean_speech_duration_seconds": round(clean_speech_duration, 3), + "unknown_speaker_duration_seconds": round(unknown_speaker_duration, 3), + "app_visible_speaker_inflation_ratio": speaker_inflation_ratio, + "split_reconcile_ratio": split_reconcile_ratio, + "speaker_reconciliation_budgets": _SPEAKER_RECONCILIATION_BUDGETS, + "speaker_reconciliation_budget_violations": budget_violations, + } + + +def _speaker_diagnostics(segments: List[TranscriptSegment], prefix: str = "") -> Dict[str, object]: + provider_clusters = sorted( + {str(segment.provider_cluster_id) for segment in segments if segment.provider_cluster_id is not None} + ) + provider_labels = sorted( + {str(segment.provider_speaker_label) for segment in segments if segment.provider_speaker_label is not None} + ) + mapped_speakers = sorted({int(segment.speaker_id) for segment in segments if segment.speaker_id is not None}) + identity_metrics = speaker_identity_metrics(segments) + return { + f"{prefix}provider_cluster_count": len(provider_clusters), + f"{prefix}provider_clusters": provider_clusters[:20], + f"{prefix}provider_speaker_label_count": len(provider_labels), + f"{prefix}provider_speaker_labels": provider_labels[:20], + f"{prefix}provider_speaker_count": identity_metrics['provider_speaker_count'], + f"{prefix}mapped_speaker_count": identity_metrics['mapped_speaker_count'], + f"{prefix}mapped_person_count": identity_metrics['mapped_person_count'], + f"{prefix}unmapped_speaker_count": identity_metrics['unmapped_speaker_count'], + f"{prefix}embedding_extraction_failure_count": identity_metrics['embedding_extraction_failure_count'], + f"{prefix}speaker_id_count": len(mapped_speakers), + f"{prefix}speaker_ids": mapped_speakers[:20], + } diff --git a/backend/routers/sync.py b/backend/routers/sync.py index 4cd3e72fe51..9d2bce016af 100644 --- a/backend/routers/sync.py +++ b/backend/routers/sync.py @@ -61,7 +61,8 @@ from utils.byok import get_byok_keys, set_byok_keys from utils.http_client import _get_semaphore from utils.log_sanitizer import sanitize -from utils.stt.pre_recorded import deepgram_prerecorded, get_deepgram_model_for_language, postprocess_words +from utils.stt import provider_service as stt_provider_service +from utils.stt.providers import STTWorkload from utils.stt.vad import vad_is_empty from utils.fair_use import ( record_speech_ms, @@ -75,13 +76,8 @@ FAIR_USE_ENABLED, FAIR_USE_RESTRICT_DAILY_DG_MS, ) -from utils.speaker_assignment import process_speaker_assigned_segments -from utils.speaker_identification import detect_speaker_from_text -from utils.stt.speaker_embedding import ( - extract_embedding_from_bytes, - compare_embeddings, - SPEAKER_MATCH_THRESHOLD, -) +from utils.stt.background_speaker_identity import identify_background_speaker_clusters +from utils.stt.speaker_embedding import extract_embedding_from_bytes from utils.subscription import has_transcription_credits logger = logging.getLogger(__name__) @@ -821,111 +817,21 @@ def identify_speakers_for_segments( person_embeddings_cache: Dict[str, dict], uid: str, ) -> None: - """Identify speakers in transcript segments using voice embeddings and text detection. - - Modifies segments in-place by assigning person_id and is_user fields. + """Identify background speakers once per provider cluster. - Steps: - 1. Voice embedding matching (requires audio_bytes and non-empty cache): - For each unique speaker_id, find the longest segment (>=1s), extract audio clip, - get embedding, match against person_embeddings_cache. - 2. Text-based detection ("I am X") runs independently for all unmatched speakers. - 3. Apply assignments via process_speaker_assigned_segments. + Text self-introduction is retained as hint metadata only. It must not create + or apply a durable speaker identity without voice evidence. """ - speaker_to_person_map: Dict[int, Tuple[str, str]] = {} - segment_person_assignment_map: Dict[str, str] = {} - - # Group segments by speaker_id, find best (longest) segment per speaker for embedding - speaker_segments: Dict[int, List[TranscriptSegment]] = {} - for seg in transcript_segments: - sid = seg.speaker_id if seg.speaker_id is not None else 0 - speaker_segments.setdefault(sid, []).append(seg) - - # Voice embedding matching (only when audio and cached embeddings are available) - # Track matched person_ids so each person is only assigned to one speaker - # (diarization tells us speakers are distinct — no person can be two speakers). - matched_person_ids: set = set() - - if audio_bytes and person_embeddings_cache: - # Sort speakers by best single segment duration (longest first) — this is the clip - # actually used for embedding, so it determines match quality. - # Note: matched_person_ids assumes diarization is correct (one person = one speaker). - # If diarization fragments one person across speaker IDs, only the best match wins. - sorted_speakers = sorted( - speaker_segments.items(), - key=lambda kv: max(s.end - s.start for s in kv[1]), - reverse=True, - ) - - for speaker_id, segments in sorted_speakers: - best_seg = max(segments, key=lambda s: s.end - s.start) - seg_duration = best_seg.end - best_seg.start - - if seg_duration < SPEAKER_ID_MIN_AUDIO: - continue - - clip_wav = _extract_speaker_clip_wav(audio_bytes, best_seg.start, best_seg.end) - if not clip_wav: - continue - - try: - query_embedding = extract_embedding_from_bytes(clip_wav, "sync_speaker.wav") - except (ValueError, Exception) as e: - logger.info(f'Speaker ID: embedding extraction failed for speaker {speaker_id}: {e} uid={uid}') - continue - - # Compare only against unmatched candidates (each person can be one speaker) - best_match = None - best_distance = float('inf') - for person_id, data in person_embeddings_cache.items(): - if person_id in matched_person_ids: - continue - distance = compare_embeddings(query_embedding, data['embedding']) - if distance < best_distance: - best_distance = distance - best_match = (person_id, data['name']) - - if best_match and best_distance < SPEAKER_MATCH_THRESHOLD: - person_id, person_name = best_match - speaker_to_person_map[speaker_id] = (person_id, person_name) - segment_person_assignment_map[best_seg.id] = person_id - matched_person_ids.add(person_id) - logger.info( - f'Speaker ID (sync): speaker {speaker_id} -> {person_id} ' - f'(distance={best_distance:.3f}) uid={uid}' - ) - - # Text-based detection runs independently for all unmatched speakers. - # For speaker_id > 0 (diarized): update both speaker_to_person_map and per-segment map. - # For speaker_id <= 0 (undiarized): only assign per-segment (avoid mapping all speaker_id=0 - # segments to one person when diarization is inactive). - for speaker_id, segments in speaker_segments.items(): - if speaker_id in speaker_to_person_map: - continue - for seg in segments: - detected_name = detect_speaker_from_text(seg.text) - if detected_name: - person = users_db.get_person_by_name(uid, detected_name) - if person: - # Per-segment assignment always applies - segment_person_assignment_map[seg.id] = person['id'] - # Update speaker map only when diarization is active - if speaker_id > 0: - speaker_to_person_map[speaker_id] = (person['id'], person['name']) - logger.info( - f'Speaker ID (sync): text detection speaker {speaker_id} -> ' - f'{person["id"]} via "{detected_name}" uid={uid}' - ) - if speaker_id > 0: - break # One match per diarized speaker is enough - - # Apply all assignments to segments - if speaker_to_person_map or segment_person_assignment_map: - process_speaker_assigned_segments( - transcript_segments, - segment_person_assignment_map, - speaker_to_person_map, - ) + assignments = identify_background_speaker_clusters( + transcript_segments, + audio_bytes, + person_embeddings_cache or {}, + embedding_extractor=extract_embedding_from_bytes, + ) + identified_count = sum(1 for assignment in assignments.values() if assignment.state in ('identified', 'user')) + logger.info( + f'Speaker ID (sync): cluster identity assignments={len(assignments)} identified={identified_count} uid={uid}' + ) def process_segment( @@ -957,29 +863,34 @@ def delete_file(): single_language_mode = prefs.get('single_language_mode', False) if single_language_mode and user_language: - dg_language, dg_model = get_deepgram_model_for_language(user_language) + stt_language, stt_model = stt_provider_service.resolve_prerecorded_language_model(user_language) else: - dg_language, dg_model = get_deepgram_model_for_language('multi') + stt_language, stt_model = stt_provider_service.resolve_prerecorded_language_model('multi') # When single-language mode is active, trust the user's language choice # rather than Deepgram's detection (avoids overriding explicit selection). use_return_language = not (single_language_mode and user_language) - words, detected_language = deepgram_prerecorded( + transcription = stt_provider_service.transcribe_url( url, + workload=STTWorkload.sync, + uid=uid, + conversation_id=target_conversation_id, speakers_count=3, - attempts=0, - return_language=True, - language=dg_language, - model=dg_model, + return_language=use_return_language, + language=stt_language, + model=stt_model, keywords=vocabulary if vocabulary else None, + raw_audio_seconds=get_wav_duration(path), ) + words = transcription.words + detected_language = transcription.detected_language or stt_language language = user_language if (single_language_mode and user_language) else detected_language if not words: # DG processed audio successfully but found no speech (silence/noise). # Real DG failures now raise RuntimeError and are caught by the except block. logger.info(f'No transcript words for segment {path} (silence or noise-only audio)') return - transcript_segments: List[TranscriptSegment] = postprocess_words(words, 0) + transcript_segments: List[TranscriptSegment] = transcription.segments if not transcript_segments: logger.warning(f'Postprocessing returned empty for segment {path} (words present but no segments)') return @@ -993,6 +904,18 @@ def delete_file(): finally: if audio_bytes: del audio_bytes + try: + stt_provider_service.update_provider_run_identity_metrics( + transcription.run_id, + transcription.result.provider, + transcription.result.model or 'unknown', + STTWorkload.sync, + transcript_segments, + 'skipped' if not person_embeddings_cache else 'succeeded', + 'missing_candidate_embeddings' if not person_embeddings_cache else None, + ) + except Exception as e: + logger.warning(f'Speaker ID (sync): identity metric update failed for {path}: {e}') timestamp = get_timestamp_from_path(path) segment_end_timestamp = timestamp + transcript_segments[-1].end diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index 532a83fc5b0..520ed8cae39 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -34,7 +34,6 @@ should_update_speaker_to_person_map, ) import database.conversations as conversations_db -import database.calendar_meetings as calendar_db import database.users as user_db from utils.byok import get_byok_keys, extract_byok_from_websocket, set_byok_keys from database.users import get_user_transcription_preferences @@ -43,8 +42,8 @@ from models.conversation import Conversation from models.conversation_enums import ConversationSource, ConversationStatus from utils.conversations.factory import deserialize_conversation +from utils.conversations.desktop_background import create_in_progress_desktop_conversation from models.conversation_photo import ConversationPhoto -from models.structured import Structured from models.transcript_segment import TranscriptSegment from models.message_event import ( ConversationEvent, @@ -798,61 +797,14 @@ async def _create_new_in_progress_conversation(): logger.error(f"Invalid conversation source '{source}', defaulting to 'omi' {uid} {session_id}") conversation_source = ConversationSource.omi - new_conversation_id = str(uuid.uuid4()) - stub_conversation = Conversation( - id=new_conversation_id, - created_at=datetime.now(timezone.utc), - started_at=datetime.now(timezone.utc), - finished_at=datetime.now(timezone.utc), - structured=Structured(), - language=language, - transcript_segments=[], - photos=[], - status=ConversationStatus.in_progress, + new_conversation_id = create_in_progress_desktop_conversation( + uid, + language, source=conversation_source, private_cloud_sync_enabled=private_cloud_sync_enabled, call_id=call_id if is_multi_channel else None, + session_id=session_id, ) - conversations_db.upsert_conversation(uid, conversation_data=stub_conversation.dict()) - redis_db.set_in_progress_conversation_id(uid, new_conversation_id) - - detected_meeting_id = None - - # Only check for meetings if source is desktop - if conversation_source == ConversationSource.desktop: - now = datetime.now(timezone.utc) - # Check ±2 minute window - time_window = timedelta(minutes=2) - start_range = now - time_window - end_range = now + time_window - - meetings = calendar_db.get_meetings_in_time_range(uid, start_range, end_range) - - if len(meetings) == 1: - # Exactly one meeting found - detected_meeting_id = meetings[0]['id'] - elif len(meetings) > 1: - closest_meeting = None - smallest_diff = None - - for meeting in meetings: - # Calculate absolute time difference between meeting start and now - time_diff = abs((meeting['start_time'] - now).total_seconds()) - - if smallest_diff is None or time_diff < smallest_diff: - smallest_diff = time_diff - closest_meeting = meeting - - if closest_meeting: - detected_meeting_id = closest_meeting['id'] - logger.info( - f"Selected closest meeting: {closest_meeting['title']} (diff: {smallest_diff}s) {uid} {session_id}" - ) - - # Store meeting association if auto-detected - if detected_meeting_id: - redis_db.set_conversation_meeting_id(new_conversation_id, detected_meeting_id) - current_conversation_id = new_conversation_id logger.info(f"Created new stub conversation: {new_conversation_id} {uid} {session_id}") diff --git a/backend/routers/users.py b/backend/routers/users.py index 119d372bdeb..f01ef23d9fb 100644 --- a/backend/routers/users.py +++ b/backend/routers/users.py @@ -775,6 +775,7 @@ def get_user_usage_stats_endpoint( _SHA256_HEX_RE = re.compile(r'^[a-f0-9]{64}$') _BYOK_REQUIRED_PROVIDERS = {'openai', 'anthropic', 'gemini', 'deepgram'} +_BYOK_ALLOWED_PROVIDERS = _BYOK_REQUIRED_PROVIDERS | {'assemblyai'} class BYOKActivateRequest(BaseModel): @@ -796,7 +797,7 @@ def activate_byok_endpoint(data: BYOKActivateRequest, uid: str = Depends(auth.ge detail=f"Missing fingerprints for providers: {sorted(missing)}", ) for provider, fp in data.fingerprints.items(): - if provider not in _BYOK_REQUIRED_PROVIDERS: + if provider not in _BYOK_ALLOWED_PROVIDERS: raise HTTPException(status_code=400, detail=f"Unknown provider: {provider}") if not _SHA256_HEX_RE.match(fp): raise HTTPException( diff --git a/backend/run-local.sh b/backend/run-local.sh new file mode 100755 index 00000000000..6080e983e30 --- /dev/null +++ b/backend/run-local.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -euo pipefail +cd "$(dirname "$0")" +export DYLD_FALLBACK_LIBRARY_PATH="/opt/homebrew/lib:${DYLD_FALLBACK_LIBRARY_PATH:-}" +exec ./venv/bin/uvicorn main:app --host 127.0.0.1 --port 8080 --env-file .env diff --git a/backend/scripts/stt/provider_comparison_gate.md b/backend/scripts/stt/provider_comparison_gate.md new file mode 100644 index 00000000000..eee3c6778d2 --- /dev/null +++ b/backend/scripts/stt/provider_comparison_gate.md @@ -0,0 +1,29 @@ +# STT Provider Comparison Gate + +Run the offline production-readiness eval from the backend directory: + +```bash +python3 scripts/stt/provider_comparison_gate.py \ + --manifest tests/fixtures/stt_provider_eval/manifest.json \ + --output-md /tmp/stt-provider-eval.md \ + --output-json /tmp/stt-provider-eval.json +``` + +The default command replays synthetic and saved provider outputs only. It does not require `ASSEMBLYAI_API_KEY` or `DEEPGRAM_API_KEY`. + +The report compares `always_deepgram`, `always_assemblyai`, `current_policy`, and `shadow_only`. `current_policy` means AssemblyAI default for passive background workloads. `shadow_only` is retained only as a rollback/diagnostic comparator. The report includes speaker safety, default viability, and rollout readiness gates plus an AssemblyAI gap report that names the limiting scenario, likely cause, and mitigation. + +The fixture manifest covers clean turns, fast turns, overlap, sparse speech, low-signal/no-speech, multilingual turns, duplicate chunk replay, provider failure/fallback, saved real-provider E2E output, and saved policy-router output. + +Cost estimates use public pay-as-you-go diarized prerecorded rates checked on 2026-05-25: AssemblyAI Universal-2 plus diarization is `$0.170/hour`, Deepgram Nova-3 monolingual plus diarization is `$0.408/hour`, and Deepgram Nova-3 multilingual plus diarization is `$0.468/hour`. Older experiment notes omitted Deepgram diarization cost and should not be used for rollout cost decisions. + +Live-provider smoke tests are optional: + +```bash +ASSEMBLYAI_API_KEY=... DEEPGRAM_API_KEY=... \ +python3 scripts/stt/provider_comparison_gate.py \ + --manifest tests/fixtures/stt_provider_eval/manifest.json \ + --live +``` + +Synthetic and saved-output gates are necessary but insufficient for default health decisions. Use the latest gap-closing report to operate the AssemblyAI default with Deepgram fallback and privacy-safe real-session metrics. diff --git a/backend/scripts/stt/provider_comparison_gate.py b/backend/scripts/stt/provider_comparison_gate.py new file mode 100644 index 00000000000..322b6cad59f --- /dev/null +++ b/backend/scripts/stt/provider_comparison_gate.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import sys +from pathlib import Path +from typing import Any, Optional + +BACKEND_ROOT = Path(__file__).resolve().parents[2] +if str(BACKEND_ROOT) not in sys.path: + sys.path.insert(0, str(BACKEND_ROOT)) + +from utils.stt.provider_evaluation import ( # noqa: E402 + ProviderGateThresholds, + build_comparison_report, + compact_markdown_report, + evaluate_report_gates, +) + +try: + from utils.stt.provider_service import _transcribe_bytes_with_provider, _transcribe_url_with_provider # noqa: E402 + from utils.stt.providers import STTProviderName, STTWorkload # noqa: E402 + + LIVE_PROVIDER_IMPORT_ERROR = None +except Exception as e: + _transcribe_bytes_with_provider = None + _transcribe_url_with_provider = None + STTProviderName = None + STTWorkload = None + LIVE_PROVIDER_IMPORT_ERROR = e + + +def main() -> int: + parser = argparse.ArgumentParser( + description='Compare Deepgram and AssemblyAI background transcription outputs and apply rollout gates.' + ) + parser.add_argument( + '--manifest', + action='append', + default=[], + help='JSON manifest with cases, fixtures, optional audio_url/audio_file, and ledger/rollup files.', + ) + parser.add_argument('--live', action='store_true', help='Run providers for manifest audio_url/audio_file cases.') + parser.add_argument('--uid', default=None, help='Optional uid used to write provider ledger rows during live runs.') + parser.add_argument( + '--conversation-prefix', default='stt-eval', help='Conversation id prefix for live ledger rows.' + ) + parser.add_argument('--output-json', default=None, help='Write full JSON report to this path.') + parser.add_argument('--output-md', default=None, help='Write compact Markdown report to this path.') + parser.add_argument('--fail-on-warning', action='store_true', help='Return non-zero for warning gates too.') + parser.add_argument('--max-wer', type=float, default=ProviderGateThresholds.max_transcript_word_error_rate) + parser.add_argument( + '--max-segment-delta-ratio', type=float, default=ProviderGateThresholds.max_segment_count_delta_ratio + ) + parser.add_argument( + '--max-timestamp-drift', type=float, default=ProviderGateThresholds.max_average_timestamp_drift_seconds + ) + parser.add_argument( + '--max-low-confidence-rate', type=float, default=ProviderGateThresholds.max_low_confidence_identity_rate + ) + parser.add_argument('--max-fallback-rate', type=float, default=ProviderGateThresholds.max_fallback_rate) + parser.add_argument('--max-failure-rate', type=float, default=ProviderGateThresholds.max_failure_rate) + parser.add_argument( + '--allow-missing-instrumentation', + action='store_true', + help='Do not warn when fixture cases omit provider ledger or rollup metrics.', + ) + args = parser.parse_args() + + if not args.manifest: + parser.error('at least one --manifest is required') + + thresholds = ProviderGateThresholds( + max_transcript_word_error_rate=args.max_wer, + max_segment_count_delta_ratio=args.max_segment_delta_ratio, + max_average_timestamp_drift_seconds=args.max_timestamp_drift, + max_low_confidence_identity_rate=args.max_low_confidence_rate, + max_fallback_rate=args.max_fallback_rate, + max_failure_rate=args.max_failure_rate, + require_instrumentation=not args.allow_missing_instrumentation, + ) + cases = [] + skipped_live_cases = [] + for manifest_path in args.manifest: + manifest = _load_json(Path(manifest_path)) + manifest_cases = manifest.get('cases') if isinstance(manifest, dict) else manifest + for index, case in enumerate(manifest_cases): + prepared = _prepare_case(case, Path(manifest_path).parent) + if args.live and _case_has_audio(case): + live_case = _run_live_case( + case, + base_path=Path(manifest_path).parent, + uid=args.uid, + conversation_id=f'{args.conversation_prefix}-{case.get("id", index)}', + ) + if live_case: + prepared = live_case + else: + skipped_live_cases.append(case.get('id') or str(index)) + cases.append(prepared) + + report = build_comparison_report(cases, thresholds) + if skipped_live_cases: + report['skipped_live_cases'] = skipped_live_cases + markdown = compact_markdown_report(report) + print(markdown) + + if args.output_json: + _write_text(Path(args.output_json), json.dumps(report, indent=2, sort_keys=True) + '\n') + if args.output_md: + _write_text(Path(args.output_md), markdown + '\n') + + passed, messages = evaluate_report_gates(report, fail_on_warning=args.fail_on_warning) + if not passed: + print('\nGate messages:', file=sys.stderr) + for message in messages: + print(f'- {message}', file=sys.stderr) + return 1 + return 0 + + +def _prepare_case(case: dict[str, Any], base_path: Path) -> dict[str, Any]: + prepared = { + 'id': case.get('id') or case.get('case_id'), + 'scenario': case.get('scenario') or case.get('type'), + 'current_policy_provider': case.get('current_policy_provider'), + } + prepared['deepgram'] = _load_provider_payload(case, base_path, 'deepgram') + prepared['assemblyai'] = _load_provider_payload(case, base_path, 'assemblyai') + return prepared + + +def _load_provider_payload(case: dict[str, Any], base_path: Path, provider: str) -> dict[str, Any]: + payload = {} + inline = case.get(provider) + if inline: + payload.update(inline) + fixture_path = case.get(f'{provider}_fixture') or case.get(f'{provider}_transcript') + if fixture_path: + payload['transcript'] = _load_json(_resolve_path(base_path, fixture_path)) + ledger_path = case.get(f'{provider}_ledger') or case.get(f'{provider}_rollup') + if ledger_path: + payload['ledger'] = _load_json(_resolve_path(base_path, ledger_path)) + return payload + + +def _run_live_case( + case: dict[str, Any], + base_path: Path, + uid: Optional[str], + conversation_id: str, +) -> Optional[dict[str, Any]]: + if LIVE_PROVIDER_IMPORT_ERROR: + print( + f"Skipping live case {case.get('id', 'unknown')}: provider dependencies are unavailable " + f"({LIVE_PROVIDER_IMPORT_ERROR}).", + file=sys.stderr, + ) + return None + if not _credentials_available(): + print( + f"Skipping live case {case.get('id', 'unknown')}: " 'DEEPGRAM_API_KEY and ASSEMBLYAI_API_KEY are required.', + file=sys.stderr, + ) + return None + workload = STTWorkload(case.get('workload') or STTWorkload.background.value) + language = case.get('language') + raw_audio_seconds = float(case.get('raw_audio_seconds') or 0.0) + common_kwargs = { + 'workload': workload, + 'uid': uid, + 'conversation_id': conversation_id, + 'language': language, + 'model': case.get('model') or 'nova-3', + 'raw_audio_seconds': raw_audio_seconds, + 'return_language': bool(case.get('return_language', False)), + 'diarize': bool(case.get('diarize', True)), + } + if case.get('audio_url'): + deepgram = _transcribe_url_with_provider(STTProviderName.deepgram, case['audio_url'], **common_kwargs) + assemblyai = _transcribe_url_with_provider( + STTProviderName.assemblyai, + case['audio_url'], + **{**common_kwargs, 'model': case.get('assemblyai_model') or 'universal-2'}, + ) + else: + audio_bytes = _resolve_path(base_path, case['audio_file']).read_bytes() + deepgram = _transcribe_bytes_with_provider(STTProviderName.deepgram, audio_bytes, **common_kwargs) + assemblyai = _transcribe_bytes_with_provider( + STTProviderName.assemblyai, + audio_bytes, + **{**common_kwargs, 'model': case.get('assemblyai_model') or 'universal-2'}, + ) + del audio_bytes + return { + 'id': case.get('id') or conversation_id, + 'deepgram': {'transcript': {'segments': [segment.dict() for segment in deepgram.segments]}}, + 'assemblyai': {'transcript': {'segments': [segment.dict() for segment in assemblyai.segments]}}, + } + + +def _case_has_audio(case: dict[str, Any]) -> bool: + return bool(case.get('audio_url') or case.get('audio_file')) + + +def _credentials_available() -> bool: + return bool(os.getenv('DEEPGRAM_API_KEY') and os.getenv('ASSEMBLYAI_API_KEY')) + + +def _load_json(path: Path) -> Any: + with path.open('r', encoding='utf-8') as handle: + return json.load(handle) + + +def _write_text(path: Path, text: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(text, encoding='utf-8') + + +def _resolve_path(base_path: Path, raw_path: str) -> Path: + path = Path(raw_path) + if path.is_absolute(): + return path + return base_path / path + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/backend/test.sh b/backend/test.sh index afcc8870be2..90e7cc551ca 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -14,6 +14,11 @@ pytest tests/unit/test_speaker_sample_migration.py -v pytest tests/unit/test_short_audio_embedding.py -v pytest tests/unit/test_users_add_sample_transaction.py -v pytest tests/unit/test_voice_message_language.py -v +pytest tests/unit/test_assemblyai_adapter.py -v +pytest tests/unit/test_background_provider_service.py -v +pytest tests/unit/test_conversation_reconstructor.py -v +pytest tests/unit/test_provider_evaluation.py -v +pytest tests/unit/test_transcription_provider_usage.py -v pytest tests/unit/test_speaker_assignment.py -v pytest tests/unit/test_speaker_id_pipeline.py -v pytest tests/unit/test_user_speaker_embedding.py -v diff --git a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.json b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.json new file mode 100644 index 00000000000..42920b4c356 --- /dev/null +++ b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.json @@ -0,0 +1,25 @@ +{ + "segments": [ + { + "text": "We should launch the beta next Monday.", + "start": 0.1, + "end": 2.3, + "provider_cluster_id": "A", + "provider_speaker_label": "ASSEMBLYAI_SPEAKER_A", + "oracle_speaker": "alex", + "person_id": "person_alex", + "speaker_identity_state": "identified", + "speaker_identity_confidence": 0.88 + }, + { + "text": "I will prepare the dashboard before then.", + "start": 2.6, + "end": 4.9, + "provider_cluster_id": "B", + "provider_speaker_label": "ASSEMBLYAI_SPEAKER_B", + "oracle_speaker": "casey", + "speaker_identity_state": "unknown", + "speaker_identity_confidence": 0.46 + } + ] +} diff --git a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.rollup.json b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.rollup.json new file mode 100644 index 00000000000..306274c9c51 --- /dev/null +++ b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.assemblyai.rollup.json @@ -0,0 +1,17 @@ +{ + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 5.0, + "speech_active_seconds": 4.5, + "billable_seconds": 5.0, + "estimated_cost_usd": 0.00023611, + "latency_seconds": 1.8, + "runtime_seconds": 1.8, + "retry_count": 0, + "fallback_count": 0, + "split_count": 0, + "accepted_reconciliation_count": 0, + "rejected_reconciliation_count": 1 +} diff --git a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.json b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.json new file mode 100644 index 00000000000..ebb99cc6965 --- /dev/null +++ b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.json @@ -0,0 +1,25 @@ +{ + "segments": [ + { + "text": "We should launch the beta next Monday.", + "start": 0.0, + "end": 2.2, + "provider_cluster_id": "0", + "provider_speaker_label": "SPEAKER_00", + "oracle_speaker": "alex", + "person_id": "person_alex", + "speaker_identity_state": "identified", + "speaker_identity_confidence": 0.91 + }, + { + "text": "I will prepare the dashboard before then.", + "start": 2.5, + "end": 4.8, + "provider_cluster_id": "1", + "provider_speaker_label": "SPEAKER_01", + "oracle_speaker": "casey", + "speaker_identity_state": "unknown", + "speaker_identity_confidence": 0.42 + } + ] +} diff --git a/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.rollup.json b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.rollup.json new file mode 100644 index 00000000000..331a905f84d --- /dev/null +++ b/backend/tests/fixtures/stt_provider_eval/fixture_good_meeting.deepgram.rollup.json @@ -0,0 +1,17 @@ +{ + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 5.0, + "speech_active_seconds": 4.5, + "billable_seconds": 5.0, + "estimated_cost_usd": 0.00056667, + "latency_seconds": 1.1, + "runtime_seconds": 1.1, + "retry_count": 0, + "fallback_count": 0, + "split_count": 0, + "accepted_reconciliation_count": 0, + "rejected_reconciliation_count": 1 +} diff --git a/backend/tests/fixtures/stt_provider_eval/manifest.json b/backend/tests/fixtures/stt_provider_eval/manifest.json new file mode 100644 index 00000000000..03de71a4f87 --- /dev/null +++ b/backend/tests/fixtures/stt_provider_eval/manifest.json @@ -0,0 +1,568 @@ +{ + "cases": [ + { + "id": "fixture_good_meeting", + "scenario": "saved_real_provider_e2e", + "deepgram_fixture": "fixture_good_meeting.deepgram.json", + "assemblyai_fixture": "fixture_good_meeting.assemblyai.json", + "deepgram_rollup": "fixture_good_meeting.deepgram.rollup.json", + "assemblyai_rollup": "fixture_good_meeting.assemblyai.rollup.json" + }, + { + "id": "synthetic_clean_turns", + "scenario": "clean_turns", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "The first milestone is ready for review.", + "start": 0.0, + "end": 2.0, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + }, + { + "text": "I will check the metrics this afternoon.", + "start": 2.2, + "end": 4.4, + "provider_cluster_id": "1", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 5.0, + "billable_seconds": 5.0, + "estimated_cost_usd": 0.00056667, + "latency_seconds": 1.0, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "The first milestone is ready for review.", + "start": 0.1, + "end": 2.1, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + }, + { + "text": "I will check the metrics this afternoon.", + "start": 2.3, + "end": 4.5, + "provider_cluster_id": "B", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 5.0, + "billable_seconds": 5.0, + "estimated_cost_usd": 0.00023611, + "latency_seconds": 1.8, + "fallback_count": 0, + "rejected_reconciliation_count": 1 + } + } + }, + { + "id": "synthetic_fast_turns", + "scenario": "fast_turns", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "Yes.", + "start": 0.0, + "end": 0.2, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + }, + { + "text": "No.", + "start": 0.3, + "end": 0.5, + "provider_cluster_id": "1", + "oracle_speaker": "casey" + }, + { + "text": "Maybe after lunch.", + "start": 0.6, + "end": 1.1, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 2.0, + "billable_seconds": 2.0, + "estimated_cost_usd": 0.00022667, + "latency_seconds": 0.9, + "fallback_count": 0, + "split_count": 1 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "Yes.", + "start": 0.0, + "end": 0.2, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + }, + { + "text": "No.", + "start": 0.3, + "end": 0.5, + "provider_cluster_id": "B", + "oracle_speaker": "casey" + }, + { + "text": "Maybe after lunch.", + "start": 0.7, + "end": 1.2, + "provider_cluster_id": "A.2", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 2.0, + "billable_seconds": 2.0, + "estimated_cost_usd": 0.00009444, + "latency_seconds": 1.4, + "fallback_count": 0, + "split_count": 1, + "rejected_reconciliation_count": 1 + } + } + }, + { + "id": "synthetic_overlap", + "scenario": "overlap", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "Start the recording now.", + "start": 0.0, + "end": 1.5, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + }, + { + "text": "I am already taking notes.", + "start": 1.0, + "end": 2.8, + "provider_cluster_id": "1", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 3.0, + "billable_seconds": 3.0, + "estimated_cost_usd": 0.00034, + "latency_seconds": 1.2, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "Start the recording now.", + "start": 0.0, + "end": 1.6, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + }, + { + "text": "I am already taking notes.", + "start": 1.1, + "end": 2.9, + "provider_cluster_id": "B", + "oracle_speaker": "casey" + }, + { + "text": "notes.", + "start": 2.7, + "end": 2.9, + "provider_cluster_id": "B.2", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 3.0, + "billable_seconds": 3.0, + "estimated_cost_usd": 0.00014167, + "latency_seconds": 1.9, + "fallback_count": 0, + "split_count": 1, + "accepted_reconciliation_count": 1 + } + } + }, + { + "id": "synthetic_sparse_speech", + "scenario": "sparse_speech", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "Okay.", + "start": 12.0, + "end": 12.3, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 30.0, + "speech_active_seconds": 0.5, + "billable_seconds": 30.0, + "estimated_cost_usd": 0.0034, + "latency_seconds": 1.0, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "Okay.", + "start": 12.1, + "end": 12.4, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 30.0, + "speech_active_seconds": 0.5, + "billable_seconds": 30.0, + "estimated_cost_usd": 0.00141667, + "latency_seconds": 2.0, + "fallback_count": 0 + } + } + }, + { + "id": "synthetic_no_speech", + "scenario": "low_signal_no_speech", + "deepgram": { + "transcript": { + "segments": [] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 15.0, + "speech_active_seconds": 0.0, + "billable_seconds": 15.0, + "estimated_cost_usd": 0.0017, + "latency_seconds": 0.8, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 15.0, + "speech_active_seconds": 0.0, + "billable_seconds": 15.0, + "estimated_cost_usd": 0.00070833, + "latency_seconds": 1.2, + "fallback_count": 0 + } + } + }, + { + "id": "synthetic_multilingual_turns", + "scenario": "multilingual_turns", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "Hola equipo, the build is ready.", + "start": 0.0, + "end": 2.4, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + }, + { + "text": "Merci, I will verify it now.", + "start": 2.6, + "end": 4.7, + "provider_cluster_id": "1", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 5.0, + "billable_seconds": 5.0, + "estimated_cost_usd": 0.00056667, + "latency_seconds": 1.3, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "Hola equipo, the build is ready.", + "start": 0.1, + "end": 2.5, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + }, + { + "text": "Merci, I will verify it now.", + "start": 2.7, + "end": 4.8, + "provider_cluster_id": "B", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 5.0, + "billable_seconds": 5.0, + "estimated_cost_usd": 0.00023611, + "latency_seconds": 2.0, + "fallback_count": 0 + } + } + }, + { + "id": "synthetic_duplicate_replay", + "scenario": "duplicate_chunk_replay", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "This chunk should only count once.", + "start": 0.0, + "end": 2.0, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 2, + "status_counts": { + "succeeded": 2 + }, + "raw_audio_seconds": 4.0, + "billable_seconds": 2.0, + "estimated_cost_usd": 0.00022667, + "latency_seconds": 1.0, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "This chunk should only count once.", + "start": 0.0, + "end": 2.1, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 2, + "status_counts": { + "succeeded": 2 + }, + "raw_audio_seconds": 4.0, + "billable_seconds": 2.0, + "estimated_cost_usd": 0.00009444, + "latency_seconds": 1.7, + "fallback_count": 0 + } + } + }, + { + "id": "synthetic_provider_failure_fallback", + "scenario": "provider_failure_fallback", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "Fallback kept the transcript available.", + "start": 0.0, + "end": 2.2, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 10, + "status_counts": { + "succeeded": 10 + }, + "raw_audio_seconds": 22.0, + "billable_seconds": 22.0, + "estimated_cost_usd": 0.00249333, + "latency_seconds": 1.1, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "Fallback kept the transcript available.", + "start": 0.0, + "end": 2.2, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + } + ] + }, + "ledger": { + "run_count": 10, + "status_counts": { + "succeeded": 9, + "timeout": 0 + }, + "raw_audio_seconds": 22.0, + "billable_seconds": 22.0, + "estimated_cost_usd": 0.00103889, + "latency_seconds": 2.1, + "fallback_count": 1 + } + } + }, + { + "id": "saved_policy_router_gap", + "scenario": "saved_real_policy_router_outputs", + "deepgram": { + "transcript": { + "segments": [ + { + "text": "We need battery telemetry before the next build ships and the background recorder should keep stable chunks for the entire review.", + "start": 0.0, + "end": 5.0, + "provider_cluster_id": "0", + "oracle_speaker": "alex" + }, + { + "text": "Okay.", + "start": 5.1, + "end": 6.0, + "provider_cluster_id": "1", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 6.0, + "billable_seconds": 6.0, + "estimated_cost_usd": 0.00068, + "latency_seconds": 1.3, + "fallback_count": 0 + } + }, + "assemblyai": { + "transcript": { + "segments": [ + { + "text": "We need battery telemetry before the next build ships and the background recorder should keep stable chunks for the entire review.", + "start": 0.1, + "end": 5.1, + "provider_cluster_id": "A", + "oracle_speaker": "alex" + }, + { + "text": "Okay.", + "start": 5.2, + "end": 6.1, + "provider_cluster_id": "A", + "oracle_speaker": "casey" + } + ] + }, + "ledger": { + "run_count": 1, + "status_counts": { + "succeeded": 1 + }, + "raw_audio_seconds": 6.0, + "billable_seconds": 6.0, + "estimated_cost_usd": 0.00028333, + "latency_seconds": 2.2, + "fallback_count": 0, + "split_count": 0, + "rejected_reconciliation_count": 1 + } + } + } + ] +} diff --git a/backend/tests/unit/test_action_item_date_validation.py b/backend/tests/unit/test_action_item_date_validation.py index 5a36a8455c2..4bce2d459f0 100644 --- a/backend/tests/unit/test_action_item_date_validation.py +++ b/backend/tests/unit/test_action_item_date_validation.py @@ -82,6 +82,8 @@ def _load_module_from_file(module_name, file_path): if mod_name not in sys.modules: _stub_module(mod_name) +sys.modules["database.auth"].get_user_name = MagicMock(return_value="Test User") + # Stub database.action_items action_items_db = _stub_module("database.action_items") action_items_db.create_action_item = MagicMock(return_value="test-item-id") diff --git a/backend/tests/unit/test_assemblyai_adapter.py b/backend/tests/unit/test_assemblyai_adapter.py new file mode 100644 index 00000000000..2c42bfde443 --- /dev/null +++ b/backend/tests/unit/test_assemblyai_adapter.py @@ -0,0 +1,234 @@ +import httpx +import os +import pytest + +from utils.stt.assemblyai_adapter import ( + AssemblyAIAsyncTranscriptionProvider, + AssemblyAIProviderError, + AssemblyAITimeoutError, + normalize_assemblyai_transcript_result, +) + + +class FakeResponse: + def __init__(self, payload, status_code=200): + self._payload = payload + self.status_code = status_code + self.request = httpx.Request('GET', 'https://api.assemblyai.com/test') + + def json(self): + return self._payload + + def raise_for_status(self): + if self.status_code >= 400: + raise httpx.HTTPStatusError('failed', request=self.request, response=self) + + +class FakeClient: + def __init__(self, responses): + self.responses = list(responses) + self.requests = [] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def request(self, method, url, **kwargs): + self.requests.append((method, url, kwargs)) + response = self.responses.pop(0) + if isinstance(response, Exception): + raise response + return response + + +def _completed_transcript(): + return { + 'id': 'aai-transcript-1', + 'status': 'completed', + 'language_code': 'en_us', + 'speech_model_used': 'universal-2', + 'audio_duration': 2.5, + 'utterances': [ + { + 'speaker': 'A', + 'text': 'Hello world.', + 'start': 0, + 'end': 1100, + 'confidence': 0.93, + 'words': [ + {'speaker': 'A', 'text': 'Hello', 'start': 0, 'end': 400, 'confidence': 0.94}, + {'speaker': 'A', 'text': 'world.', 'start': 500, 'end': 1100, 'confidence': 0.91}, + ], + } + ], + } + + +def test_assemblyai_result_normalizes_utterances_words_and_speaker_clusters(): + result = normalize_assemblyai_transcript_result(_completed_transcript(), model='universal-2') + + assert result.provider == 'assemblyai' + assert result.model == 'universal-2' + assert result.language == 'en' + assert result.duration == 2.5 + assert result.raw_provider_result_id == 'aai-transcript-1' + assert result.utterances[0].provider_cluster_id == 'A' + assert result.utterances[0].speaker_label == 'ASSEMBLYAI_SPEAKER_A' + assert result.words[1].text == 'world.' + assert result.words[1].start == 0.5 + assert result.words[1].provider_cluster_id == 'A' + + +def test_assemblyai_result_does_not_report_multi_when_detection_returns_no_language(): + transcript = _completed_transcript() + transcript.pop('language_code') + + result = normalize_assemblyai_transcript_result(transcript, model='universal-2', language='multi') + + assert result.language is None + + +def test_assemblyai_transcribe_url_submits_diarization_and_polls_to_completion(): + fake_client = FakeClient( + [ + FakeResponse({'id': 'aai-transcript-1', 'status': 'queued'}), + FakeResponse({'id': 'aai-transcript-1', 'status': 'processing'}), + FakeResponse(_completed_transcript()), + ] + ) + provider = AssemblyAIAsyncTranscriptionProvider( + api_key='test-key', + client_factory=lambda: fake_client, + poll_interval_seconds=0, + max_poll_seconds=5, + sleeper=lambda seconds: None, + ) + + result, detected_language = provider.transcribe_url( + 'https://example.test/audio.wav', + speakers_count=2, + return_language=True, + language='multi', + model='universal-2', + keywords=['Omi'], + ) + + assert result.provider == 'assemblyai' + assert detected_language == 'en' + submit_payload = fake_client.requests[0][2]['json'] + assert submit_payload['audio_url'] == 'https://example.test/audio.wav' + assert submit_payload['speaker_labels'] is True + assert submit_payload['speech_models'] == ['universal-2'] + assert submit_payload['speakers_expected'] == 2 + assert submit_payload['language_detection'] is True + assert submit_payload['keyterms_prompt'] == ['Omi'] + assert fake_client.requests[1][0] == 'GET' + + +def test_assemblyai_transcribe_bytes_uploads_then_transcribes_upload_url(): + fake_client = FakeClient( + [ + FakeResponse({'upload_url': 'https://cdn.assemblyai.test/uploaded.wav'}), + FakeResponse({'id': 'aai-transcript-1'}), + FakeResponse(_completed_transcript()), + ] + ) + provider = AssemblyAIAsyncTranscriptionProvider( + api_key='test-key', + client_factory=lambda: fake_client, + poll_interval_seconds=0, + max_poll_seconds=5, + sleeper=lambda seconds: None, + ) + + result = provider.transcribe_bytes(b'audio-bytes', diarize=True) + + assert result.provider == 'assemblyai' + assert fake_client.requests[0][0] == 'POST' + assert fake_client.requests[0][1].endswith('/v2/upload') + assert fake_client.requests[1][2]['json']['audio_url'] == 'https://cdn.assemblyai.test/uploaded.wav' + + +def test_assemblyai_failure_status_normalizes_to_provider_error(): + fake_client = FakeClient( + [ + FakeResponse({'id': 'aai-transcript-1'}), + FakeResponse({'id': 'aai-transcript-1', 'status': 'error', 'error': 'unsupported media'}), + ] + ) + provider = AssemblyAIAsyncTranscriptionProvider( + api_key='test-key', + client_factory=lambda: fake_client, + poll_interval_seconds=0, + max_poll_seconds=5, + sleeper=lambda seconds: None, + ) + + with pytest.raises(AssemblyAIProviderError, match='unsupported media'): + provider.transcribe_url('https://example.test/audio.wav') + + +def test_assemblyai_poll_timeout_raises_timeout_error(): + current_time = {'value': 0.0} + + def clock(): + current_time['value'] += 2.0 + return current_time['value'] + + fake_client = FakeClient( + [ + FakeResponse({'id': 'aai-transcript-1'}), + FakeResponse({'id': 'aai-transcript-1', 'status': 'processing'}), + FakeResponse({'id': 'aai-transcript-1', 'status': 'processing'}), + ] + ) + provider = AssemblyAIAsyncTranscriptionProvider( + api_key='test-key', + client_factory=lambda: fake_client, + poll_interval_seconds=0, + max_poll_seconds=1, + sleeper=lambda seconds: None, + clock=clock, + ) + + with pytest.raises(AssemblyAITimeoutError): + provider.transcribe_url('https://example.test/audio.wav') + + +def test_assemblyai_retries_retryable_http_once(): + fake_client = FakeClient( + [ + FakeResponse({'temporarily': 'busy'}, status_code=503), + FakeResponse({'id': 'aai-transcript-1'}), + FakeResponse(_completed_transcript()), + ] + ) + provider = AssemblyAIAsyncTranscriptionProvider( + api_key='test-key', + client_factory=lambda: fake_client, + poll_interval_seconds=0, + max_poll_seconds=5, + sleeper=lambda seconds: None, + ) + + result = provider.transcribe_url('https://example.test/audio.wav') + + assert result.provider == 'assemblyai' + assert len(fake_client.requests) == 3 + + +def test_assemblyai_live_smoke_with_gated_credentials(): + api_key = os.getenv('ASSEMBLYAI_API_KEY') + audio_url = os.getenv('ASSEMBLYAI_SMOKE_AUDIO_URL') + if not api_key or not audio_url: + pytest.skip('ASSEMBLYAI_API_KEY and ASSEMBLYAI_SMOKE_AUDIO_URL are required for live smoke') + + provider = AssemblyAIAsyncTranscriptionProvider(api_key=api_key, poll_interval_seconds=3, max_poll_seconds=180) + + result = provider.transcribe_url(audio_url, diarize=True, language='en', model='universal-2') + + assert result.provider == 'assemblyai' + assert result.raw_provider_result_id + assert result.words or result.utterances diff --git a/backend/tests/unit/test_async_app_integrations.py b/backend/tests/unit/test_async_app_integrations.py index 796b029f200..f4d805b890a 100644 --- a/backend/tests/unit/test_async_app_integrations.py +++ b/backend/tests/unit/test_async_app_integrations.py @@ -36,6 +36,8 @@ "calendar_meetings", "vector_db", "apps", + "announcements", + "user_usage", "llm_usage", "chat", "goals", @@ -55,6 +57,8 @@ sys.modules["database.redis_db"].get_daily_notification_count = MagicMock(return_value=0) sys.modules["database.vector_db"].query_vectors_by_metadata = MagicMock(return_value=[]) sys.modules["database.apps"].record_app_usage = MagicMock() +sys.modules["database.announcements"].compare_versions = MagicMock(return_value=0) +sys.modules["database.user_usage"].get_monthly_chat_usage = MagicMock() sys.modules["database.llm_usage"].record_llm_usage = MagicMock() sys.modules["database.chat"].add_app_message = MagicMock(return_value={"id": "msg-1"}) sys.modules["database.chat"].get_app_messages = MagicMock(return_value=[]) @@ -139,6 +143,14 @@ def _noop_track(uid, feature): sys.modules["utils.executors"] = types.ModuleType("utils.executors") sys.modules["utils.executors"].critical_executor = _TPE(max_workers=2, thread_name_prefix="test-critical") sys.modules["utils.executors"].storage_executor = _TPE(max_workers=2, thread_name_prefix="test-storage") +sys.modules["utils.executors"].db_executor = _TPE(max_workers=2, thread_name_prefix="test-db") + + +async def _run_blocking(_executor, fn, *args, **kwargs): + return fn(*args, **kwargs) + + +sys.modules["utils.executors"].run_blocking = _run_blocking import importlib diff --git a/backend/tests/unit/test_available_plans_resilience.py b/backend/tests/unit/test_available_plans_resilience.py index ef19495fdfc..49ff8b629c5 100644 --- a/backend/tests/unit/test_available_plans_resilience.py +++ b/backend/tests/unit/test_available_plans_resilience.py @@ -73,6 +73,7 @@ def _compare_versions(a, b): _announcements_mod._compare_versions = _compare_versions +_announcements_mod.compare_versions = _compare_versions # database.users needs the functions payment.py imports by name _users_mod = sys.modules["database.users"] @@ -126,6 +127,7 @@ def _compare_versions(a, b): _endpoints_mod = sys.modules["utils.other.endpoints"] _endpoints_mod.get_current_user_uid = lambda: "test-user" +_endpoints_mod.get_current_user_uid_no_byok_validation = lambda: "test-user" # Ensure utils.other has endpoints attr for `from utils.other import endpoints` sys.modules["utils.other"].endpoints = _endpoints_mod diff --git a/backend/tests/unit/test_background_provider_service.py b/backend/tests/unit/test_background_provider_service.py new file mode 100644 index 00000000000..9a0b6a001ee --- /dev/null +++ b/backend/tests/unit/test_background_provider_service.py @@ -0,0 +1,722 @@ +import os +import sys +import types +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +for mod_name in ['deepgram', 'deepgram.clients', 'deepgram.clients.live', 'deepgram.clients.live.v1']: + if mod_name not in sys.modules: + sys.modules[mod_name] = MagicMock() + +sys.modules['deepgram'].DeepgramClient = MagicMock +sys.modules['deepgram'].DeepgramClientOptions = MagicMock +sys.modules.setdefault('database._client', types.SimpleNamespace(db=MagicMock())) + +os.environ.setdefault('DEEPGRAM_API_KEY', 'fake-for-test') +os.environ.setdefault('ASSEMBLYAI_API_KEY', 'fake-for-test') + +from models.transcript_segment import ProviderTranscriptResult, ProviderTranscriptWord, TranscriptSegment # noqa: E402 +from utils.stt import provider_service # noqa: E402 +from utils.stt.provider_costs import estimate_prerecorded_provider_cost_usd # noqa: E402 +from utils.stt.providers import ( # noqa: E402 + BackgroundProviderMode, + STTProviderName, + STTWorkload, + get_background_provider_mode, + get_prerecorded_provider_name, +) + + +def _provider_result(provider='deepgram', model='nova-3'): + return ProviderTranscriptResult( + provider=provider, + model=model, + language='en', + duration=2.0, + words=[ + ProviderTranscriptWord( + text='hello', + start=0.0, + end=0.4, + provider_cluster_id='0', + speaker_label='SPEAKER_00', + ), + ProviderTranscriptWord( + text='world', + start=0.5, + end=1.0, + provider_cluster_id='0', + speaker_label='SPEAKER_00', + ), + ], + ) + + +def test_provider_service_transcribes_sync_upload_and_finalizes_deepgram_run(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'false') + fake_provider = MagicMock() + fake_provider.provider_name = STTProviderName.deepgram + fake_provider.transcribe_url.return_value = (_provider_result(), 'en') + + with patch.object(provider_service, '_deepgram_prerecorded_provider', return_value=fake_provider), patch.object( + provider_service, 'create_provider_run', return_value='run-sync' + ) as create_run, patch.object(provider_service, 'finalize_provider_run') as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.sync, + uid='uid-1', + conversation_id='conversation-1', + return_language=True, + language='multi', + model='nova-3', + keywords=['Omi'], + ) + + assert response.detected_language == 'en' + assert [segment.text for segment in response.segments] == ['Hello world'] + assert response.words[0]['stt_provider'] == 'deepgram' + fake_provider.transcribe_url.assert_called_once() + assert fake_provider.transcribe_url.call_args.kwargs['keywords'] == ['Omi'] + create_run.assert_called_once() + assert create_run.call_args.kwargs['workload'] == 'sync' + finalize_run.assert_called_once() + assert finalize_run.call_args.kwargs['run_id'] == 'run-sync' + assert finalize_run.call_args.kwargs['provider'] == 'deepgram' + assert finalize_run.call_args.kwargs['workload'] == 'sync' + assert finalize_run.call_args.kwargs['status'] == 'succeeded' + assert finalize_run.call_args.kwargs['transcript_segment_count'] == 1 + assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00022667 + + +def test_provider_service_finalizes_background_run_on_deepgram_when_assemblyai_disabled(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'false') + fake_provider = MagicMock() + fake_provider.provider_name = STTProviderName.deepgram + fake_provider.transcribe_url.return_value = _provider_result() + + with patch.object(provider_service, '_deepgram_prerecorded_provider', return_value=fake_provider), patch.object( + provider_service, 'create_provider_run', return_value='run-background' + ), patch.object(provider_service, 'finalize_provider_run') as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/background.wav', + workload=STTWorkload.background, + uid='uid-1', + model='nova-3', + raw_audio_seconds=9.5, + ) + + assert response.result.provider == 'deepgram' + finalize_run.assert_called_once() + assert finalize_run.call_args.kwargs['workload'] == 'background' + assert finalize_run.call_args.kwargs['provider'] == 'deepgram' + assert finalize_run.call_args.kwargs['raw_audio_seconds'] == 9.5 + assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00107667 + + +def test_background_routing_defaults_to_assemblyai_for_background(monkeypatch): + monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', raising=False) + monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', raising=False) + monkeypatch.delenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', raising=False) + + assert get_prerecorded_provider_name(STTWorkload.sync) == STTProviderName.assemblyai + assert get_background_provider_mode() == BackgroundProviderMode.assemblyai + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.assemblyai + assert get_prerecorded_provider_name(STTWorkload.postprocess) == STTProviderName.assemblyai + + +def test_background_routing_selects_assemblyai_when_background_gate_allows(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'background') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'assemblyai') + + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.assemblyai + + +def test_background_routing_supports_deepgram_override_and_shadow_only(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'background') + + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'deepgram') + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.deepgram + + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'shadow_only') + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.deepgram + + +def test_background_routing_honors_disabled_rollout_gate_and_workload_allowlist(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'assemblyai') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'false') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'background') + + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.deepgram + + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,postprocess') + + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.deepgram + + +def test_background_routing_invalid_mode_fails_safe_to_shadow_only(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'invalid-mode') + + assert get_background_provider_mode() == BackgroundProviderMode.shadow_only + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.deepgram + + +def test_prerecorded_ptt_and_realtime_related_workloads_stay_deepgram(): + assert get_prerecorded_provider_name(STTWorkload.ptt) == STTProviderName.deepgram + assert get_prerecorded_provider_name(STTWorkload.voice_message) == STTProviderName.deepgram + + +def test_background_routing_can_select_assemblyai_without_moving_latency_critical_workloads(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess,ptt,realtime') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'assemblyai') + + assert get_prerecorded_provider_name(STTWorkload.sync) == STTProviderName.assemblyai + assert get_prerecorded_provider_name(STTWorkload.background) == STTProviderName.assemblyai + assert get_prerecorded_provider_name(STTWorkload.postprocess) == STTProviderName.assemblyai + assert get_prerecorded_provider_name(STTWorkload.ptt) == STTProviderName.deepgram + assert get_prerecorded_provider_name(STTWorkload.voice_message) == STTProviderName.deepgram + + +def test_background_call_sites_use_provider_service_layer(): + backend_root = Path(__file__).resolve().parents[2] + with open(backend_root / 'routers/sync.py') as f: + sync_source = f.read() + with open(backend_root / 'utils/conversations/postprocess_conversation.py') as f: + postprocess_source = f.read() + with open(backend_root / 'utils/chat.py') as f: + chat_source = f.read() + + assert 'from utils.stt.pre_recorded import' not in sync_source + assert 'stt_provider_service.transcribe_url' in sync_source + assert 'workload=STTWorkload.sync' in sync_source + + assert 'from utils.stt.pre_recorded import' not in postprocess_source + assert 'stt_provider_service.transcribe_url' in postprocess_source + assert 'workload=STTWorkload.postprocess' in postprocess_source + + assert 'from utils.stt.pre_recorded import' not in chat_source + assert 'workload=STTWorkload.voice_message' in chat_source + assert 'workload=STTWorkload.ptt' in chat_source + + +def test_provider_service_uses_assemblyai_for_enabled_sync_workload(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') + + fake_provider = MagicMock() + fake_provider.provider_name = STTProviderName.assemblyai + fake_provider.transcribe_url.return_value = (_provider_result(provider='assemblyai', model='universal-2'), 'en') + + with patch.object(provider_service, '_assemblyai_prerecorded_provider', return_value=fake_provider), patch.object( + provider_service, 'create_provider_run', return_value='run-aai' + ) as create_run, patch.object(provider_service, 'finalize_provider_run') as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.sync, + uid='uid-1', + conversation_id='conversation-1', + return_language=True, + language='multi', + model='nova-3', + keywords=['Omi'], + ) + + assert response.result.provider == 'assemblyai' + assert response.result.model == 'universal-2' + assert response.words[0]['stt_provider'] == 'assemblyai' + fake_provider.transcribe_url.assert_called_once() + assert fake_provider.transcribe_url.call_args.kwargs['model'] == 'universal-2' + create_run.assert_called_once() + assert create_run.call_args.kwargs['provider'] == 'assemblyai' + assert create_run.call_args.kwargs['model'] == 'universal-2' + finalize_run.assert_called_once() + assert finalize_run.call_args.kwargs['provider'] == 'assemblyai' + assert finalize_run.call_args.kwargs['artifact_refs'] == {} + assert finalize_run.call_args.kwargs['billable_seconds'] == 2.0 + assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.00009444 + + +def test_provider_service_falls_back_to_deepgram_when_assemblyai_fails(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') + + assemblyai_provider = MagicMock() + assemblyai_provider.provider_name = STTProviderName.assemblyai + assemblyai_provider.transcribe_url.side_effect = RuntimeError('AssemblyAI failed') + + deepgram_provider = MagicMock() + deepgram_provider.provider_name = STTProviderName.deepgram + deepgram_provider.transcribe_url.return_value = _provider_result() + + with patch.object( + provider_service, '_assemblyai_prerecorded_provider', return_value=assemblyai_provider + ), patch.object(provider_service, '_deepgram_prerecorded_provider', return_value=deepgram_provider), patch.object( + provider_service, 'create_provider_run', side_effect=['run-aai', 'run-dg'] + ), patch.object( + provider_service, 'finalize_provider_run' + ) as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.sync, + uid='uid-1', + conversation_id='conversation-1', + language='multi', + model='nova-3', + raw_audio_seconds=2.0, + ) + + assert response.result.provider == 'deepgram' + assert assemblyai_provider.transcribe_url.call_count == 2 + deepgram_provider.transcribe_url.assert_called_once() + assert deepgram_provider.transcribe_url.call_args.kwargs['model'] == 'nova-3' + assert finalize_run.call_args_list[0].kwargs['run_id'] == 'run-aai' + assert finalize_run.call_args_list[0].kwargs['provider'] == 'assemblyai' + assert finalize_run.call_args_list[0].kwargs['status'] == 'failed' + assert finalize_run.call_args_list[0].kwargs['retry_count'] == 1 + assert finalize_run.call_args_list[0].kwargs['error_class'] == 'RuntimeError' + assert finalize_run.call_args_list[0].kwargs['billable_seconds'] == 2.0 + assert finalize_run.call_args_list[0].kwargs['estimated_cost_usd'] == 0.00009444 + assert finalize_run.call_args_list[1].kwargs['run_id'] == 'run-dg' + assert finalize_run.call_args_list[1].kwargs['provider'] == 'deepgram' + assert finalize_run.call_args_list[1].kwargs['fallback_count'] == 1 + assert finalize_run.call_args_list[1].kwargs['fallback_provider'] == 'assemblyai' + assert finalize_run.call_args_list[1].kwargs['estimated_cost_usd'] == 0.00022667 + + +def test_provider_service_falls_back_to_deepgram_for_background_when_assemblyai_fails(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'background') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'assemblyai') + monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', raising=False) + + assemblyai_provider = MagicMock() + assemblyai_provider.provider_name = STTProviderName.assemblyai + assemblyai_provider.transcribe_url.side_effect = RuntimeError('AssemblyAI failed') + + deepgram_provider = MagicMock() + deepgram_provider.provider_name = STTProviderName.deepgram + deepgram_provider.transcribe_url.return_value = _provider_result() + + with patch.object( + provider_service, '_assemblyai_prerecorded_provider', return_value=assemblyai_provider + ), patch.object(provider_service, '_deepgram_prerecorded_provider', return_value=deepgram_provider), patch.object( + provider_service, 'create_provider_run', side_effect=['run-aai', 'run-dg'] + ), patch.object( + provider_service, 'finalize_provider_run' + ) as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/background.wav', + workload=STTWorkload.background, + uid='uid-1', + language='multi', + model='nova-3', + raw_audio_seconds=2.0, + ) + + assert response.result.provider == 'deepgram' + assert assemblyai_provider.transcribe_url.call_count == 2 + deepgram_provider.transcribe_url.assert_called_once() + assert finalize_run.call_args_list[0].kwargs['provider'] == 'assemblyai' + assert finalize_run.call_args_list[0].kwargs['status'] == 'failed' + assert finalize_run.call_args_list[1].kwargs['provider'] == 'deepgram' + assert finalize_run.call_args_list[1].kwargs['fallback_count'] == 1 + assert finalize_run.call_args_list[1].kwargs['fallback_provider'] == 'assemblyai' + + +def test_provider_service_skips_missing_assemblyai_key_when_deepgram_fallback_is_usable(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.setenv('DEEPGRAM_API_KEY', 'dg-server-key') + + deepgram_provider = MagicMock() + deepgram_provider.provider_name = STTProviderName.deepgram + deepgram_provider.transcribe_url.return_value = _provider_result() + + with patch.object(provider_service, '_assemblyai_prerecorded_provider') as assemblyai_provider, patch.object( + provider_service, '_deepgram_prerecorded_provider', return_value=deepgram_provider + ), patch.object(provider_service, 'create_provider_run', return_value='run-dg'), patch.object( + provider_service, 'finalize_provider_run' + ) as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.sync, + uid='uid-1', + language='multi', + model='nova-3', + ) + + assert response.result.provider == 'deepgram' + assemblyai_provider.assert_not_called() + deepgram_provider.transcribe_url.assert_called_once() + assert finalize_run.call_args.kwargs['provider'] == 'deepgram' + assert finalize_run.call_args.kwargs['fallback_count'] == 0 + + +def test_provider_service_skips_missing_background_assemblyai_key_when_deepgram_fallback_is_usable(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'background') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'assemblyai') + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.setenv('DEEPGRAM_API_KEY', 'dg-server-key') + + deepgram_provider = MagicMock() + deepgram_provider.provider_name = STTProviderName.deepgram + deepgram_provider.transcribe_url.return_value = _provider_result() + + with patch.object(provider_service, '_assemblyai_prerecorded_provider') as assemblyai_provider, patch.object( + provider_service, '_deepgram_prerecorded_provider', return_value=deepgram_provider + ), patch.object(provider_service, 'create_provider_run', return_value='run-dg'), patch.object( + provider_service, 'finalize_provider_run' + ) as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/background.wav', + workload=STTWorkload.background, + uid='uid-1', + language='multi', + model='nova-3', + ) + + assert response.result.provider == 'deepgram' + assemblyai_provider.assert_not_called() + deepgram_provider.transcribe_url.assert_called_once() + assert finalize_run.call_args.kwargs['provider'] == 'deepgram' + assert finalize_run.call_args.kwargs['fallback_count'] == 0 + + +def test_provider_service_reports_missing_assemblyai_key_when_fallback_disabled(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', 'false') + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.setenv('DEEPGRAM_API_KEY', 'dg-server-key') + + with patch.object(provider_service, '_deepgram_prerecorded_provider') as deepgram_provider, patch.object( + provider_service, 'create_provider_run', return_value='run-aai' + ), patch.object(provider_service, 'finalize_provider_run') as finalize_run: + with pytest.raises(RuntimeError, match='ASSEMBLYAI_API_KEY is not configured'): + provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.sync, + uid='uid-1', + language='multi', + model='nova-3', + ) + + deepgram_provider.assert_not_called() + assert finalize_run.call_args.kwargs['provider'] == 'assemblyai' + assert finalize_run.call_args.kwargs['status'] == 'failed' + assert finalize_run.call_args.kwargs['error_class'] == 'AssemblyAIProviderError' + + +def test_provider_service_reports_missing_assemblyai_key_when_no_fallback_key_is_usable(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.delenv('DEEPGRAM_API_KEY', raising=False) + + with patch.object(provider_service, '_deepgram_prerecorded_provider') as deepgram_provider, patch.object( + provider_service, 'create_provider_run', return_value='run-aai' + ), patch.object(provider_service, 'finalize_provider_run') as finalize_run: + with pytest.raises(RuntimeError, match='ASSEMBLYAI_API_KEY is not configured'): + provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.sync, + uid='uid-1', + language='multi', + model='nova-3', + ) + + deepgram_provider.assert_not_called() + assert finalize_run.call_args.kwargs['provider'] == 'assemblyai' + assert finalize_run.call_args.kwargs['status'] == 'failed' + assert finalize_run.call_args.kwargs['error_class'] == 'AssemblyAIProviderError' + + +def test_provider_service_records_background_retry_exhaustion_when_fallback_disabled(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'background') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'assemblyai') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', 'false') + + assemblyai_provider = MagicMock() + assemblyai_provider.provider_name = STTProviderName.assemblyai + assemblyai_provider.transcribe_url.side_effect = RuntimeError('AssemblyAI failed') + + with patch.object( + provider_service, '_assemblyai_prerecorded_provider', return_value=assemblyai_provider + ), patch.object(provider_service, '_deepgram_prerecorded_provider') as deepgram_provider, patch.object( + provider_service, 'create_provider_run', return_value='run-aai' + ), patch.object( + provider_service, 'finalize_provider_run' + ) as finalize_run: + with pytest.raises(RuntimeError, match='assemblyai transcription failed after 2 attempts'): + provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.background, + uid='uid-1', + language='multi', + model='nova-3', + ) + + assert assemblyai_provider.transcribe_url.call_count == 2 + deepgram_provider.assert_not_called() + finalize_run.assert_called_once() + assert finalize_run.call_args.kwargs['run_id'] == 'run-aai' + assert finalize_run.call_args.kwargs['provider'] == 'assemblyai' + assert finalize_run.call_args.kwargs['status'] == 'failed' + assert finalize_run.call_args.kwargs['retry_count'] == 1 + assert finalize_run.call_args.kwargs['fallback_count'] == 0 + assert finalize_run.call_args.kwargs['error_class'] == 'RuntimeError' + + +def test_provider_service_records_successful_after_retry(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') + + fake_provider = MagicMock() + fake_provider.provider_name = STTProviderName.assemblyai + fake_provider.transcribe_url.side_effect = [ + RuntimeError('temporary AssemblyAI failure'), + (_provider_result(provider='assemblyai', model='universal-2'), 'en'), + ] + + with patch.object(provider_service, '_assemblyai_prerecorded_provider', return_value=fake_provider), patch.object( + provider_service, 'create_provider_run', return_value='run-aai' + ), patch.object(provider_service, 'finalize_provider_run') as finalize_run: + response = provider_service.transcribe_url( + 'https://example.test/audio.wav', + workload=STTWorkload.sync, + uid='uid-1', + return_language=True, + language='multi', + model='nova-3', + ) + + assert response.result.provider == 'assemblyai' + assert fake_provider.transcribe_url.call_count == 2 + finalize_run.assert_called_once() + assert finalize_run.call_args.kwargs['status'] == 'succeeded' + assert finalize_run.call_args.kwargs['retry_count'] == 1 + assert finalize_run.call_args.kwargs['fallback_count'] == 0 + + +def test_provider_service_records_zero_cost_for_zero_duration_success(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') + + fake_provider = MagicMock() + fake_provider.provider_name = STTProviderName.assemblyai + provider_result = _provider_result(provider='assemblyai', model='universal-2') + provider_result.duration = 0.0 + fake_provider.transcribe_url.return_value = provider_result + + with patch.object(provider_service, '_assemblyai_prerecorded_provider', return_value=fake_provider), patch.object( + provider_service, 'create_provider_run', return_value='run-zero' + ), patch.object(provider_service, 'finalize_provider_run') as finalize_run: + provider_service.transcribe_url( + 'https://example.test/zero.wav', + workload=STTWorkload.sync, + uid='uid-1', + raw_audio_seconds=0.0, + ) + + assert finalize_run.call_args.kwargs['billable_seconds'] == 0.0 + assert finalize_run.call_args.kwargs['estimated_cost_usd'] == 0.0 + + +def test_prerecorded_cost_estimator_uses_provider_defaults_and_unknown_provider_zero(): + assert ( + estimate_prerecorded_provider_cost_usd( + provider='assemblyai', + model='future-model', + workload='background', + billable_seconds=60.0, + ) + == 0.00283333 + ) + assert ( + estimate_prerecorded_provider_cost_usd( + provider='deepgram', + model='future-model', + workload='background', + billable_seconds=60.0, + ) + == 0.0068 + ) + assert ( + estimate_prerecorded_provider_cost_usd( + provider='unknown-provider', + model='future-model', + workload='background', + billable_seconds=60.0, + ) + == 0.0 + ) + + +def test_provider_service_counts_user_identity_as_identified_cluster(): + result = _provider_result(provider='assemblyai', model='universal-2') + segments = provider_service.reconstruct_conversation(result) + segments[0].is_user = True + segments[0].speaker_identity_state = 'user' + + with patch.object(provider_service, 'finalize_provider_run') as finalize_run: + provider_service._finalize_run( + 'run-user', + result, + STTWorkload.sync, + provider_service.datetime.now(provider_service.timezone.utc), + 'succeeded', + retry_count=0, + raw_audio_seconds=2.0, + segments=segments, + ) + + assert finalize_run.call_args.kwargs['identified_speaker_cluster_count'] == 1 + + +def test_provider_service_counts_label_only_identified_clusters(): + result = ProviderTranscriptResult( + provider='assemblyai', + model='universal-2', + duration=2.0, + words=[ + ProviderTranscriptWord(text='hello', start=0.0, end=0.4, speaker_label='A'), + ProviderTranscriptWord(text='again', start=0.5, end=0.8, speaker_label='A'), + ProviderTranscriptWord(text='there', start=1.0, end=1.4, speaker_label='B'), + ], + ) + segments = provider_service.reconstruct_conversation(result) + segments[0].person_id = 'person-a' + segments[0].speaker_identity_state = 'identified' + + with patch.object(provider_service, 'finalize_provider_run') as finalize_run: + provider_service._finalize_run( + 'run-labels', + result, + STTWorkload.sync, + provider_service.datetime.now(provider_service.timezone.utc), + 'succeeded', + retry_count=0, + raw_audio_seconds=2.0, + segments=segments, + ) + + assert finalize_run.call_args.kwargs['speaker_cluster_count'] == 2 + assert finalize_run.call_args.kwargs['identified_speaker_cluster_count'] == 1 + assert finalize_run.call_args.kwargs['identity_match_count'] == 1 + assert finalize_run.call_args.kwargs['provider_speaker_count'] == 2 + assert finalize_run.call_args.kwargs['mapped_speaker_count'] == 1 + assert finalize_run.call_args.kwargs['mapped_person_count'] == 1 + assert finalize_run.call_args.kwargs['unmapped_speaker_count'] == 1 + assert finalize_run.call_args.kwargs['unknown_speaker_count'] == 1 + assert finalize_run.call_args.kwargs['unknown_speaker_duration_seconds'] == 0.4 + assert finalize_run.call_args.kwargs['chunk_duration_seconds'] == 2.0 + + +def test_provider_service_records_split_count_from_local_cluster_marker(): + result = ProviderTranscriptResult(provider='assemblyai', model='universal-2', duration=2.0) + segments = [ + TranscriptSegment( + text='first speaker', + is_user=False, + start=0.0, + end=0.8, + provider_cluster_id='A::local_part:0', + speaker_identity_state='unknown', + ), + TranscriptSegment( + text='known speaker', + is_user=False, + person_id='person-b', + start=1.0, + end=1.8, + provider_cluster_id='B', + speaker_identity_state='identified', + speaker_identity_confidence=0.93, + ), + ] + + with patch.object(provider_service, 'finalize_provider_run') as finalize_run: + provider_service._finalize_run( + 'run-split', + result, + STTWorkload.background, + provider_service.datetime.now(provider_service.timezone.utc), + 'succeeded', + retry_count=0, + raw_audio_seconds=2.0, + segments=segments, + ) + + assert finalize_run.call_args.kwargs['split_count'] == 1 + assert finalize_run.call_args.kwargs['unknown_speaker_count'] == 1 + assert finalize_run.call_args.kwargs['unknown_speaker_duration_seconds'] == 0.8 + + +def test_provider_service_classifies_timeout_fallback_reason(): + timeout = provider_service.ProviderTranscriptionRetriesExhausted(provider_service.AssemblyAITimeoutError(), 1) + + assert provider_service._fallback_reason_from_exception(timeout) == 'provider_timeout' + assert ( + provider_service._fallback_reason_from_exception(RuntimeError('temporary provider failure')) + == 'provider_failure' + ) + + +def test_provider_service_preserves_assemblyai_labels_for_identity_metrics(): + result = ProviderTranscriptResult( + provider='assemblyai', + model='universal-2', + duration=4.0, + words=[ + ProviderTranscriptWord(text='alice', start=0.0, end=0.5, speaker_label='A'), + ProviderTranscriptWord(text='speaks', start=0.6, end=1.0, speaker_label='A'), + ProviderTranscriptWord(text='bob', start=2.0, end=2.4, speaker_label='B'), + ProviderTranscriptWord(text='replies', start=2.5, end=3.0, speaker_label='B'), + ], + ) + segments = provider_service.reconstruct_conversation(result) + + assert [segment.provider_speaker_label for segment in segments] == ['A', 'B'] + assert {segment.speaker_id for segment in segments} == {0} + assert provider_service.speaker_identity_metrics(segments)['provider_speaker_count'] == 2 + + +def test_provider_service_live_assemblyai_smoke_records_ledger_when_credentials_are_present(monkeypatch): + api_key = os.getenv('ASSEMBLYAI_API_KEY') + audio_url = os.getenv('ASSEMBLYAI_SMOKE_AUDIO_URL') + if not api_key or not audio_url: + pytest.skip('ASSEMBLYAI_API_KEY and ASSEMBLYAI_SMOKE_AUDIO_URL are required for live smoke') + + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync') + + with patch.object(provider_service, 'create_provider_run', return_value='run-aai-live') as create_run, patch.object( + provider_service, 'finalize_provider_run' + ) as finalize_run: + response = provider_service.transcribe_url( + audio_url, + workload=STTWorkload.sync, + uid='uid-live-smoke', + conversation_id='conversation-live-smoke', + language='en', + model='nova-3', + raw_audio_seconds=1.0, + ) + + assert response.result.provider == 'assemblyai' + assert response.result.raw_provider_result_id + create_run.assert_called_once() + finalize_run.assert_called_once() + assert finalize_run.call_args.kwargs['provider'] == 'assemblyai' + assert finalize_run.call_args.kwargs['status'] == 'succeeded' + assert ( + finalize_run.call_args.kwargs['artifact_refs']['provider_result_id'] == response.result.raw_provider_result_id + ) diff --git a/backend/tests/unit/test_background_speaker_identity.py b/backend/tests/unit/test_background_speaker_identity.py new file mode 100644 index 00000000000..4d0a152610a --- /dev/null +++ b/backend/tests/unit/test_background_speaker_identity.py @@ -0,0 +1,209 @@ +import io +import wave + +import numpy as np + +from models.message_event import SpeakerLabelSuggestionEvent +from models.transcript_segment import TranscriptSegment +from utils.stt.background_speaker_identity import ( + SPEAKER_IDENTITY_SOURCE, + SPEAKER_IDENTITY_VERSION, + identify_background_speaker_clusters, + select_representative_cluster_spans, +) + + +def _segment(segment_id, start, end, cluster='cluster-a', speaker='SPEAKER_01', text='hello from speaker'): + return TranscriptSegment( + id=segment_id, + text=text, + speaker=speaker, + is_user=False, + start=start, + end=end, + provider_cluster_id=cluster, + provider_speaker_label=speaker, + speaker_identity_state='unassigned', + ) + + +def _wav_bytes(duration_seconds=30, sample_rate=16000): + samples = np.zeros(int(duration_seconds * sample_rate), dtype=np.int16) + out = io.BytesIO() + with wave.open(out, 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(samples.tobytes()) + return out.getvalue() + + +def test_cluster_sampling_prefers_clean_spans_and_caps_total_duration(): + cluster_segments = [ + _segment('short', 0.0, 0.5), + _segment('overlap', 1.0, 8.0), + _segment('clean-long', 9.0, 24.0), + _segment('clean-second', 24.5, 33.0), + ] + all_segments = cluster_segments + [_segment('other-overlap', 2.0, 3.0, cluster='cluster-b')] + + spans = select_representative_cluster_spans(cluster_segments, all_segments) + + assert [span.segment_id for span in spans] == ['clean-long'] + assert sum(span.duration for span in spans) == 10.0 + assert all(span.segment_id != 'overlap' for span in spans) + assert all(span.duration >= 1.0 for span in spans) + + +def test_cluster_identity_applies_one_voice_assignment_to_all_cluster_segments(): + alice_embedding = np.array([[1.0, 0.0]], dtype=np.float32) + bob_embedding = np.array([[0.0, 1.0]], dtype=np.float32) + segments = [ + _segment('a1', 0.0, 3.0, cluster='cluster-a', speaker='SPEAKER_01'), + _segment('a2', 3.5, 6.5, cluster='cluster-a', speaker='SPEAKER_01'), + _segment('b1', 7.0, 10.0, cluster='cluster-b', speaker='SPEAKER_02'), + ] + cache = { + 'person-alice': {'embedding': alice_embedding, 'name': 'Alice'}, + 'person-bob': {'embedding': bob_embedding, 'name': 'Bob'}, + } + + def fake_extract(_audio, _filename): + return alice_embedding + + assignments = identify_background_speaker_clusters(segments, _wav_bytes(), cache, embedding_extractor=fake_extract) + + assert assignments['cluster-a'].person_id == 'person-alice' + assert assignments['cluster-a'].confidence == 1.0 + assert segments[0].person_id == 'person-alice' + assert segments[1].person_id == 'person-alice' + assert segments[0].speaker_identity_source == SPEAKER_IDENTITY_SOURCE + assert segments[0].speaker_identity_version == SPEAKER_IDENTITY_VERSION + assert segments[0].speaker_identity_provenance['provider_cluster_id'] == 'cluster-a' + assert segments[0].speaker_identity_candidates[0]['person_id'] == 'person-alice' + + +def test_low_confidence_cluster_remains_explicitly_unknown_with_candidate_metadata(): + query_embedding = np.array([[1.0, 0.0]], dtype=np.float32) + distant_embedding = np.array([[0.0, 1.0]], dtype=np.float32) + segments = [_segment('a1', 0.0, 3.0)] + cache = {'person-distant': {'embedding': distant_embedding, 'name': 'Distant'}} + + assignments = identify_background_speaker_clusters( + segments, + _wav_bytes(), + cache, + embedding_extractor=lambda _audio, _filename: query_embedding, + ) + + assert assignments['cluster-a'].state == 'unknown' + assert assignments['cluster-a'].reason == 'below_threshold' + assert segments[0].speaker_identity_state == 'unknown' + assert segments[0].person_id is None + assert segments[0].speaker_identity_candidates[0]['person_id'] == 'person-distant' + assert segments[0].speaker_identity_confidence is None + + +def test_unknown_assignment_demotes_stale_person_identity(): + query_embedding = np.array([[1.0, 0.0]], dtype=np.float32) + distant_embedding = np.array([[0.0, 1.0]], dtype=np.float32) + segments = [_segment('a1', 0.0, 3.0)] + segments[0].person_id = 'stale-person' + segments[0].speaker_identity_state = 'identified' + cache = {'person-distant': {'embedding': distant_embedding, 'name': 'Distant'}} + + identify_background_speaker_clusters( + segments, + _wav_bytes(), + cache, + embedding_extractor=lambda _audio, _filename: query_embedding, + ) + + assert segments[0].speaker_identity_state == 'unknown' + assert segments[0].person_id is None + assert segments[0].is_user is False + + +def test_text_self_introduction_is_hint_only_without_voice_assignment(): + segments = [_segment('intro', 0.0, 3.0, text='I am Alice and I joined the call.')] + + assignments = identify_background_speaker_clusters(segments, audio_bytes=None, person_embeddings_cache={}) + + assert assignments['cluster-a'].state == 'unknown' + assert assignments['cluster-a'].text_hints[0]['detected_name'] == 'Alice' + assert segments[0].person_id is None + assert segments[0].speaker_identity_state == 'unknown' + assert segments[0].speaker_identity_text_hints[0]['source'] == 'text_self_introduction' + + +def test_text_hint_negative_controls_do_not_create_identity_hints(): + negative_texts = [ + 'Bonjour, je suis Alice et je parle francais.', + 'Alice said I am Bob during the interview.', + '"I am Alice," she read from the script.', + 'Alice is my manager and she mentioned the roadmap.', + 'Hello Alice, hello Alice, thanks for joining.', + 'This monologue has no actual speaker change.', + 'My name as Alice was mistranscribed by the ASR.', + ] + + segments = [ + _segment(f'negative-{index}', index * 3.0, index * 3.0 + 2.0, cluster=f'cluster-{index}', text=text) + for index, text in enumerate(negative_texts) + ] + + assignments = identify_background_speaker_clusters(segments, audio_bytes=None, person_embeddings_cache={}) + + assert all(not assignment.text_hints for assignment in assignments.values()) + assert all(not segment.speaker_identity_text_hints for segment in segments) + + +def test_user_sentinel_is_not_persisted_as_durable_person_identity(): + user_embedding = np.array([[1.0, 0.0]], dtype=np.float32) + segments = [_segment('me', 0.0, 4.0)] + cache = {'user': {'embedding': user_embedding, 'name': 'User'}} + + identify_background_speaker_clusters( + segments, + _wav_bytes(), + cache, + embedding_extractor=lambda _audio, _filename: user_embedding, + ) + + assert segments[0].is_user is True + assert segments[0].person_id is None + assert segments[0].speaker_identity_state == 'user' + assert segments[0].speaker_identity_candidates[0]['person_id'] is None + assert segments[0].speaker_identity_candidates[0]['is_user'] is True + + +def test_speaker_label_suggestion_event_preserves_legacy_shape_and_accepts_cluster_metadata(): + legacy = SpeakerLabelSuggestionEvent( + speaker_id=1, + person_id='person-1', + person_name='Alice', + segment_id='segment-1', + ).to_json() + + assert legacy['type'] == 'speaker_label_suggestion' + assert legacy['version'] == 1 + assert legacy['speaker_id'] == 1 + assert legacy['person_id'] == 'person-1' + + extended = SpeakerLabelSuggestionEvent( + speaker_id=1, + person_id='person-1', + person_name='Alice', + segment_id='segment-1', + version=2, + provider_cluster_id='cluster-a', + speaker_identity_state='identified', + confidence=0.91, + source=SPEAKER_IDENTITY_SOURCE, + provenance={'sample_seconds': 6.0}, + candidates=[{'person_id': 'person-1', 'confidence': 0.91}], + ).to_json() + + assert extended['version'] == 2 + assert extended['provider_cluster_id'] == 'cluster-a' + assert extended['provenance']['sample_seconds'] == 6.0 diff --git a/backend/tests/unit/test_batch_upload_storage.py b/backend/tests/unit/test_batch_upload_storage.py index 08045593a20..28ae6e96a50 100644 --- a/backend/tests/unit/test_batch_upload_storage.py +++ b/backend/tests/unit/test_batch_upload_storage.py @@ -24,8 +24,6 @@ sys.modules.setdefault("google.cloud.storage", _mock_gcs_storage) sys.modules.setdefault("google.cloud.storage.transfer_manager", MagicMock()) sys.modules.setdefault("google.cloud.exceptions", MagicMock()) -sys.modules.setdefault("google.oauth2", MagicMock()) -sys.modules.setdefault("google.oauth2.service_account", MagicMock()) from utils.other import storage as storage_mod diff --git a/backend/tests/unit/test_byok_assemblyai_routing.py b/backend/tests/unit/test_byok_assemblyai_routing.py new file mode 100644 index 00000000000..bd9c9019534 --- /dev/null +++ b/backend/tests/unit/test_byok_assemblyai_routing.py @@ -0,0 +1,96 @@ +import os +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + +for mod_name in ['deepgram', 'deepgram.clients', 'deepgram.clients.live', 'deepgram.clients.live.v1']: + if mod_name not in sys.modules: + sys.modules[mod_name] = MagicMock() + +sys.modules['deepgram'].DeepgramClient = MagicMock +sys.modules['deepgram'].DeepgramClientOptions = MagicMock +sys.modules.setdefault('database._client', types.SimpleNamespace(db=MagicMock())) + +os.environ.setdefault('DEEPGRAM_API_KEY', 'fake-for-test') + +from utils.stt import provider_service # noqa: E402 +from utils.stt.providers import STTProviderName, STTWorkload, get_prerecorded_provider_name # noqa: E402 + + +@pytest.fixture(autouse=True) +def _enable_assemblyai_routing(monkeypatch): + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') + monkeypatch.setenv('ASSEMBLYAI_API_KEY', 'aa-server-key') + + +def test_env_selects_assemblyai_for_sync(): + assert get_prerecorded_provider_name(STTWorkload.sync) == STTProviderName.assemblyai + + +def test_resolve_uses_deepgram_byok_when_no_assembly_header(monkeypatch): + monkeypatch.setattr(provider_service, 'get_byok_key', lambda provider: {'deepgram': 'dg-user-key'}.get(provider)) + assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.deepgram + + +def test_resolve_uses_deepgram_byok_for_background_when_no_assembly_header(monkeypatch): + monkeypatch.setattr(provider_service, 'get_byok_key', lambda provider: {'deepgram': 'dg-user-key'}.get(provider)) + assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.background) == STTProviderName.deepgram + + +def test_resolve_uses_assemblyai_when_byok_assembly_header_present(monkeypatch): + keys = {'assemblyai': 'aa-user-key', 'deepgram': 'dg-user-key'} + + def _lookup(provider): + return keys.get(provider) + + monkeypatch.setattr(provider_service, 'get_byok_key', _lookup) + assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.assemblyai + + +def test_resolve_uses_assemblyai_for_background_when_byok_assembly_header_present(monkeypatch): + keys = {'assemblyai': 'aa-user-key', 'deepgram': 'dg-user-key'} + + def _lookup(provider): + return keys.get(provider) + + monkeypatch.setattr(provider_service, 'get_byok_key', _lookup) + assert ( + provider_service.resolve_prerecorded_provider_for_request(STTWorkload.background) == STTProviderName.assemblyai + ) + + +def test_resolve_uses_server_assembly_when_no_byok_headers(monkeypatch): + monkeypatch.setattr(provider_service, 'get_byok_key', lambda _provider: None) + assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.assemblyai + + +def test_resolve_uses_server_deepgram_when_server_assembly_missing_and_fallback_enabled(monkeypatch): + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.setenv('DEEPGRAM_API_KEY', 'dg-server-key') + monkeypatch.setattr(provider_service, 'get_byok_key', lambda _provider: None) + + assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.deepgram + + +def test_resolve_keeps_assemblyai_selected_when_server_assembly_missing_and_fallback_disabled( + monkeypatch, +): + monkeypatch.delenv('ASSEMBLYAI_API_KEY', raising=False) + monkeypatch.setenv('DEEPGRAM_API_KEY', 'dg-server-key') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', 'false') + monkeypatch.setattr(provider_service, 'get_byok_key', lambda _provider: None) + + assert provider_service.resolve_prerecorded_provider_for_request(STTWorkload.sync) == STTProviderName.assemblyai + + +def test_assemblyai_provider_passes_byok_api_key(monkeypatch): + monkeypatch.setattr(provider_service, 'get_byok_key', lambda _provider: 'aa-user-key') + with patch.object(provider_service, 'AssemblyAIAsyncTranscriptionProvider') as mock_cls: + provider_service._assemblyai_prerecorded_provider() + mock_cls.assert_called_once_with(api_key='aa-user-key') + + +# Activation endpoint tests live in test_byok_security.py::TestBYOKActivationValidation diff --git a/backend/tests/unit/test_byok_security.py b/backend/tests/unit/test_byok_security.py index 5353218fce5..3d664a641d6 100644 --- a/backend/tests/unit/test_byok_security.py +++ b/backend/tests/unit/test_byok_security.py @@ -113,7 +113,7 @@ def _make_ws(self, headers: Dict[str, str]) -> MagicMock: ws.headers = headers return ws - def test_extracts_all_four_headers(self): + def test_extracts_all_byok_headers_including_assemblyai(self): from utils.byok import extract_byok_from_websocket ws = self._make_ws( @@ -122,10 +122,17 @@ def test_extracts_all_four_headers(self): 'x-byok-anthropic': 'sk-a', 'x-byok-gemini': 'sk-g', 'x-byok-deepgram': 'sk-d', + 'x-byok-assemblyai': 'sk-aa', } ) keys = extract_byok_from_websocket(ws) - assert keys == {'openai': 'sk-o', 'anthropic': 'sk-a', 'gemini': 'sk-g', 'deepgram': 'sk-d'} + assert keys == { + 'openai': 'sk-o', + 'anthropic': 'sk-a', + 'gemini': 'sk-g', + 'deepgram': 'sk-d', + 'assemblyai': 'sk-aa', + } def test_returns_empty_when_no_headers(self): from utils.byok import extract_byok_from_websocket @@ -400,10 +407,10 @@ def test_transcription_not_bypassed_when_no_deepgram_header(self, mock_users_db, class TestBYOKHeadersConstant: - def test_headers_has_all_four_providers(self): + def test_headers_has_required_and_optional_providers(self): from utils.byok import BYOK_HEADERS - assert set(BYOK_HEADERS.keys()) == {'openai', 'anthropic', 'gemini', 'deepgram'} + assert set(BYOK_HEADERS.keys()) == {'openai', 'anthropic', 'gemini', 'deepgram', 'assemblyai'} def test_headers_are_lowercase(self): from utils.byok import BYOK_HEADERS @@ -550,10 +557,26 @@ def test_deactivation_calls_clear(self, mock_users_db): def test_production_constants_match(self): """Verify the test regex matches the production regex.""" - from routers.users import _SHA256_HEX_RE as prod_re, _BYOK_REQUIRED_PROVIDERS as prod_providers + from routers.users import ( + _BYOK_ALLOWED_PROVIDERS as prod_allowed, + _BYOK_REQUIRED_PROVIDERS as prod_required, + _SHA256_HEX_RE as prod_re, + ) assert prod_re.pattern == _SHA256_HEX_RE.pattern - assert prod_providers == {'openai', 'anthropic', 'gemini', 'deepgram'} + assert prod_required == {'openai', 'anthropic', 'gemini', 'deepgram'} + assert prod_allowed == prod_required | {'assemblyai'} + + @patch('routers.users.users_db') + def test_optional_assemblyai_fingerprint_accepted(self, mock_users_db): + from routers.users import BYOKActivateRequest, activate_byok_endpoint + + fps = self._valid_fingerprints() + fps['assemblyai'] = hashlib.sha256(b'sk-assemblyai').hexdigest() + data = BYOKActivateRequest(fingerprints=fps) + result = activate_byok_endpoint(data, uid='test-uid') + assert result == {'active': True} + mock_users_db.set_byok_active.assert_called_once_with('test-uid', fps) # --------------------------------------------------------------------------- diff --git a/backend/tests/unit/test_chat_quota.py b/backend/tests/unit/test_chat_quota.py index ef779bb9aef..488f1785aa2 100644 --- a/backend/tests/unit/test_chat_quota.py +++ b/backend/tests/unit/test_chat_quota.py @@ -23,9 +23,12 @@ def _compare_versions(a, b): _announcements_mod._compare_versions = _compare_versions +_announcements_mod.compare_versions = _compare_versions # Create stubs for database modules used by get_chat_quota_snapshot -_db_users_mod = types.SimpleNamespace(get_user_valid_subscription=MagicMock()) +_db_users_mod = types.SimpleNamespace( + get_user_valid_subscription=MagicMock(), is_byok_active=MagicMock(return_value=False) +) _db_user_usage_mod = types.SimpleNamespace(get_monthly_chat_usage=MagicMock()) sys.modules.setdefault("database._client", types.SimpleNamespace(db=MagicMock())) @@ -319,10 +322,8 @@ def test_enforcement_allowed(self, monkeypatch): ): sub_mod.enforce_chat_quota("uid123") # no exception - def test_enforcement_exceeded_raises_402(self, monkeypatch): - """When user exceeds quota, raises HTTPException 402.""" - from fastapi import HTTPException - + def test_enforcement_paid_plan_exceeded_allows_overage(self, monkeypatch): + """When a paid plan exceeds quota, no exception is raised.""" sub_mod = _reload_subscription_module() with patch.object( @@ -336,23 +337,41 @@ def test_enforcement_exceeded_raises_402(self, monkeypatch): 'limit': 2000, 'reset_at': _RESET_AT, }, + ): + sub_mod.enforce_chat_quota("uid123") # paid plans enter overage mode + + def test_enforcement_free_plan_exceeded_raises_402(self, monkeypatch): + """When a free user exceeds quota, raises HTTPException 402.""" + from fastapi import HTTPException + + sub_mod = _reload_subscription_module() + + with patch.object( + sub_mod, + "get_chat_quota_snapshot", + return_value={ + 'allowed': False, + 'plan': PlanType.basic, + 'unit': 'questions', + 'used': 31, + 'limit': 30, + 'reset_at': _RESET_AT, + }, ): with pytest.raises(HTTPException) as exc_info: sub_mod.enforce_chat_quota("uid123") assert exc_info.value.status_code == 402 assert exc_info.value.detail['error'] == 'quota_exceeded' - assert exc_info.value.detail['plan'] == 'Neo' - assert exc_info.value.detail['plan_type'] == 'unlimited' + assert exc_info.value.detail['plan'] == 'Free' + assert exc_info.value.detail['plan_type'] == 'basic' assert exc_info.value.detail['unit'] == 'questions' - assert exc_info.value.detail['used'] == 2001 - assert exc_info.value.detail['limit'] == 2000 + assert exc_info.value.detail['used'] == 31 + assert exc_info.value.detail['limit'] == 30 assert exc_info.value.detail['reset_at'] == _RESET_AT - def test_enforcement_402_operator_plan(self, monkeypatch): - """Operator plan shows correct display name in 402 detail.""" - from fastapi import HTTPException - + def test_enforcement_operator_plan_allows_overage(self, monkeypatch): + """Operator plan allows overage after included quota.""" sub_mod = _reload_subscription_module() with patch.object( @@ -367,16 +386,10 @@ def test_enforcement_402_operator_plan(self, monkeypatch): 'reset_at': _RESET_AT, }, ): - with pytest.raises(HTTPException) as exc_info: - sub_mod.enforce_chat_quota("uid123") - - assert exc_info.value.status_code == 402 - assert exc_info.value.detail['plan'] == 'Operator' - - def test_enforcement_402_architect_cost_based(self, monkeypatch): - """Architect plan shows cost_usd unit in 402 detail.""" - from fastapi import HTTPException + sub_mod.enforce_chat_quota("uid123") + def test_enforcement_architect_cost_based_allows_overage(self, monkeypatch): + """Architect plan allows cost-based overage after included quota.""" sub_mod = _reload_subscription_module() with patch.object( @@ -391,9 +404,4 @@ def test_enforcement_402_architect_cost_based(self, monkeypatch): 'reset_at': _RESET_AT, }, ): - with pytest.raises(HTTPException) as exc_info: - sub_mod.enforce_chat_quota("uid123") - - assert exc_info.value.status_code == 402 - assert exc_info.value.detail['unit'] == 'cost_usd' - assert exc_info.value.detail['used'] == 400.5 + sub_mod.enforce_chat_quota("uid123") diff --git a/backend/tests/unit/test_conversation_reconstructor.py b/backend/tests/unit/test_conversation_reconstructor.py new file mode 100644 index 00000000000..4b2f857582a --- /dev/null +++ b/backend/tests/unit/test_conversation_reconstructor.py @@ -0,0 +1,226 @@ +import sys +from unittest.mock import MagicMock + +from models.transcript_segment import ( + ProviderTranscriptResult, + ProviderTranscriptUtterance, + ProviderTranscriptWord, +) +from utils.stt.conversation_reconstructor import ConversationReconstructor, reconstruct_conversation + +if 'deepgram' not in sys.modules: + sys.modules['deepgram'] = MagicMock() + +from utils.stt.deepgram_adapter import normalize_deepgram_prerecorded_result + + +def _word(text, start, end, cluster=None, label=None): + return ProviderTranscriptWord( + text=text, + start=start, + end=end, + provider_cluster_id=cluster, + speaker_label=label, + ) + + +def test_reconstructs_word_only_provider_result_with_stable_ordering_and_cluster_metadata(): + result = ProviderTranscriptResult( + provider='test-provider', + model='async-model', + words=[ + _word('later', 2.0, 2.5, cluster='cluster-b', label='SPEAKER_01'), + _word('hello', 0.0, 0.4, cluster='cluster-a', label='SPEAKER_00'), + _word('world.', 0.5, 1.0, cluster='cluster-a', label='SPEAKER_00'), + ], + ) + + segments = reconstruct_conversation(result) + + assert [segment.text for segment in segments] == ['Hello world.', 'Later'] + assert segments[0].start == 0.0 + assert segments[0].end == 1.0 + assert segments[0].provider_cluster_id == 'cluster-a' + assert segments[0].speaker == 'SPEAKER_00' + assert segments[0].speaker_identity_state == 'unassigned' + assert segments[0].stt_provider == 'test-provider' + assert segments[0].stt_model == 'async-model' + assert segments[1].provider_cluster_id == 'cluster-b' + + +def test_reconstructs_label_only_provider_result_without_collapsing_clusters(): + result = ProviderTranscriptResult( + provider='test-provider', + model='async-model', + words=[ + _word('hello', 0.0, 0.4, label='A'), + _word('there', 0.5, 0.8, label='A'), + _word('hi', 1.0, 1.2, label='B'), + ], + ) + + segments = reconstruct_conversation(result) + + assert [segment.text for segment in segments] == ['Hello there', 'Hi'] + assert segments[0].provider_cluster_id is None + assert segments[0].provider_speaker_label == 'A' + assert segments[0].speaker_identity_state == 'unknown' + assert segments[1].provider_speaker_label == 'B' + + +def test_reconstructs_utterance_only_provider_result_with_explicit_unknown_identity(): + result = ProviderTranscriptResult( + provider='assemblyai', + utterances=[ + ProviderTranscriptUtterance( + text='opaque speaker label', + start=4.0, + end=5.0, + provider_cluster_id='speaker-a', + speaker_label='A', + ), + ProviderTranscriptUtterance(text='no cluster', start=5.2, end=6.0), + ], + ) + + segments = reconstruct_conversation(result) + + assert [segment.text for segment in segments] == ['Opaque speaker label', 'No cluster'] + assert segments[0].provider_cluster_id == 'speaker-a' + assert segments[0].provider_speaker_label == 'A' + assert segments[0].speaker is None + assert segments[0].speaker_identity_state == 'unknown' + assert segments[1].provider_cluster_id is None + assert segments[1].speaker_identity_state == 'unknown' + + +def test_mixed_utterances_and_words_do_not_duplicate_words_covered_by_utterances(): + result = ProviderTranscriptResult( + provider='mixed', + utterances=[ + ProviderTranscriptUtterance( + text='Hello world.', + start=0.0, + end=1.1, + provider_cluster_id='0', + speaker_label='SPEAKER_00', + ) + ], + words=[ + _word('Hello', 0.0, 0.4, cluster='0', label='SPEAKER_00'), + _word('world.', 0.5, 1.1, cluster='0', label='SPEAKER_00'), + _word('Outside.', 2.0, 2.6, cluster='1', label='SPEAKER_01'), + ], + ) + + segments = reconstruct_conversation(result) + + assert [segment.text for segment in segments] == ['Hello world.', 'Outside.'] + assert segments[0].provider_cluster_id == '0' + assert segments[1].provider_cluster_id == '1' + + +def test_reconstructor_preserves_legacy_deepgram_prerecorded_parity(): + deepgram_result = { + 'metadata': {'request_id': 'dg-request-1', 'duration': 3.25}, + 'results': { + 'channels': [ + { + 'detected_language': 'en-US', + 'alternatives': [ + { + 'words': [ + { + 'word': 'hello', + 'punctuated_word': 'Hello', + 'start': 0.0, + 'end': 0.4, + 'confidence': 0.91, + 'speaker': 2, + }, + { + 'word': 'world', + 'punctuated_word': 'world.', + 'start': 0.5, + 'end': 1.1, + 'confidence': 0.88, + 'speaker': 2, + }, + ] + } + ], + } + ], + 'utterances': [ + { + 'transcript': 'Hello world.', + 'start': 0.0, + 'end': 1.1, + 'confidence': 0.9, + 'speaker': 2, + } + ], + }, + } + + result = normalize_deepgram_prerecorded_result(deepgram_result, model='nova-3') + segments = reconstruct_conversation(result) + + assert len(segments) == 1 + assert segments[0].text == 'Hello world.' + assert segments[0].speaker == 'SPEAKER_02' + assert segments[0].speaker_id == 2 + assert segments[0].provider_cluster_id == '2' + assert segments[0].provider_speaker_label == 'SPEAKER_02' + assert segments[0].stt_provider == 'deepgram' + assert segments[0].stt_model == 'nova-3' + + +def test_overlap_duplicate_candidates_keep_longer_text_once(): + reconstructor = ConversationReconstructor() + result = ProviderTranscriptResult( + provider='test-provider', + utterances=[ + ProviderTranscriptUtterance( + text='hello', + start=0.0, + end=1.0, + provider_cluster_id='0', + speaker_label='SPEAKER_00', + ), + ProviderTranscriptUtterance( + text='hello there', + start=0.5, + end=1.5, + provider_cluster_id='0', + speaker_label='SPEAKER_00', + ), + ], + ) + + segments = reconstructor.reconstruct(result) + + assert len(segments) == 1 + assert segments[0].text == 'Hello there' + assert segments[0].start == 0.0 + assert segments[0].end == 1.5 + + +def test_skip_window_marks_dominant_preskip_cluster_as_user(): + result = ProviderTranscriptResult( + provider='test-provider', + words=[ + _word('my', 0.0, 0.2, cluster='me', label='SPEAKER_00'), + _word('voice', 0.3, 0.6, cluster='me', label='SPEAKER_00'), + _word('after', 3.0, 3.4, cluster='me', label='SPEAKER_00'), + _word('guest', 3.5, 4.0, cluster='guest', label='SPEAKER_01'), + ], + ) + + segments = reconstruct_conversation(result, skip_n_seconds=2) + + assert segments[0].text == 'After' + assert segments[0].is_user is True + assert segments[0].speaker_identity_state == 'user' + assert segments[0].start == 0.0 + assert segments[1].is_user is False diff --git a/backend/tests/unit/test_daily_summary_race_condition.py b/backend/tests/unit/test_daily_summary_race_condition.py index 3289eb8e45b..65fcdbae4d3 100644 --- a/backend/tests/unit/test_daily_summary_race_condition.py +++ b/backend/tests/unit/test_daily_summary_race_condition.py @@ -76,6 +76,7 @@ def try_acquire_daily_summary_lock(uid: str, date: str, ttl: int = 60 * 60 * 2) "utils.webhooks", "utils.conversations", "utils.conversations.factory", + "utils.subscription", ]: if name not in sys.modules: mod = _stub_module(name) @@ -100,6 +101,9 @@ def try_acquire_daily_summary_lock(uid: str, date: str, ttl: int = 60 * 60 * 2) utils_webhooks = sys.modules["utils.webhooks"] utils_webhooks.day_summary_webhook = MagicMock() +utils_subscription = sys.modules["utils.subscription"] +utils_subscription.is_trial_paywalled = MagicMock(return_value=False) + # Stub models for name in ["models.notification_message", "models.conversation"]: if name not in sys.modules: diff --git a/backend/tests/unit/test_desktop_background_transcribe.py b/backend/tests/unit/test_desktop_background_transcribe.py new file mode 100644 index 00000000000..abebcaeb619 --- /dev/null +++ b/backend/tests/unit/test_desktop_background_transcribe.py @@ -0,0 +1,917 @@ +import io +import os +import sys +from types import SimpleNamespace +import wave +from unittest.mock import MagicMock + +import numpy as np + +os.environ.setdefault('DEEPGRAM_API_KEY', 'fake-for-test') + +for mod_name in ['deepgram', 'deepgram.clients', 'deepgram.clients.live', 'deepgram.clients.live.v1']: + if mod_name not in sys.modules: + sys.modules[mod_name] = MagicMock() + +sys.modules['deepgram'].DeepgramClient = MagicMock +sys.modules['deepgram'].DeepgramClientOptions = MagicMock + +if 'google.api_core.exceptions' not in sys.modules: + sys.modules['google'] = MagicMock() + sys.modules['google.api_core'] = MagicMock() + sys.modules['google.api_core.exceptions'] = MagicMock(NotFound=Exception) + sys.modules['google.cloud'] = MagicMock() + sys.modules['google.cloud.firestore'] = MagicMock() + sys.modules['google.cloud.firestore_v1'] = MagicMock(FieldFilter=MagicMock) +sys.modules['google.cloud.firestore'].Increment = lambda value: value +database_client_stub = sys.modules.setdefault('database._client', SimpleNamespace()) +database_client_stub.db = getattr(database_client_stub, 'db', MagicMock()) +database_client_stub.document_id_from_seed = getattr(database_client_stub, 'document_id_from_seed', MagicMock()) +sys.modules.setdefault('database.conversations', MagicMock()) +sys.modules.setdefault('database.users', MagicMock()) +sys.modules.setdefault('database.redis_db', SimpleNamespace(r=MagicMock())) + + +class _DesktopBackgroundConversationError(ValueError): + def __init__(self, message: str, status_code: int = 400): + super().__init__(message) + self.status_code = status_code + + +sys.modules.setdefault( + 'utils.conversations.desktop_background', + SimpleNamespace( + DesktopBackgroundConversationError=_DesktopBackgroundConversationError, + append_background_chunk_to_in_progress_conversation=MagicMock(), + append_segments_to_in_progress_conversation=MagicMock(), + create_in_progress_desktop_conversation=MagicMock(return_value='conv-1'), + finish_desktop_background_conversation=MagicMock(return_value={'id': 'conv-1', 'status': 'completed'}), + get_background_chunk_record=MagicMock(return_value=None), + ), +) +sys.modules.setdefault( + 'utils.chat', SimpleNamespace(resolve_voice_message_language=lambda _uid, language: language or 'en') +) +sys.modules.setdefault('utils.analytics', SimpleNamespace(record_usage=lambda *_args, **_kwargs: None)) +sys.modules.setdefault('utils.byok', SimpleNamespace(get_byok_key=lambda _provider: None)) +sys.modules.setdefault( + 'utils.fair_use', + SimpleNamespace( + is_hard_restricted=lambda *_args, **_kwargs: False, + record_speech_ms=lambda *_args, **_kwargs: None, + ), +) +sys.modules.setdefault( + 'utils.subscription', + SimpleNamespace( + has_transcription_credits=lambda *_args, **_kwargs: True, + is_trial_paywalled=lambda *_args, **_kwargs: False, + ), +) + + +def _pcm_to_wav_bytes(pcm_data: bytes, sample_rate: int) -> bytes: + buffer = io.BytesIO() + with wave.open(buffer, 'wb') as wav: + wav.setnchannels(1) + wav.setsampwidth(2) + wav.setframerate(sample_rate) + wav.writeframes(pcm_data) + return buffer.getvalue() + + +sys.modules.setdefault('utils.speaker_identification', SimpleNamespace(_pcm_to_wav_bytes=_pcm_to_wav_bytes)) +sys.modules.setdefault('utils.stt.speaker_embedding', SimpleNamespace(extract_embedding_from_bytes=MagicMock())) +sys.modules.setdefault( + 'utils.stt.provider_service', + SimpleNamespace( + resolve_prerecorded_provider_for_request=MagicMock(), + resolve_background_provider_policy=MagicMock(), + speaker_identity_metrics=lambda segments: { + 'provider_speaker_count': len( + { + segment.provider_cluster_id or segment.provider_speaker_label + for segment in segments + if segment.provider_cluster_id or segment.provider_speaker_label + } + ), + 'mapped_speaker_count': len( + { + segment.provider_cluster_id or segment.provider_speaker_label + for segment in segments + if (segment.provider_cluster_id or segment.provider_speaker_label) + and ( + segment.person_id or segment.is_user or segment.speaker_identity_state in ('identified', 'user') + ) + } + ), + 'mapped_person_count': len( + { + segment.person_id or 'user' + for segment in segments + if segment.person_id or segment.is_user or segment.speaker_identity_state == 'user' + } + ), + 'unmapped_speaker_count': len( + { + segment.provider_cluster_id or segment.provider_speaker_label + for segment in segments + if segment.provider_cluster_id or segment.provider_speaker_label + } + ) + - len( + { + segment.provider_cluster_id or segment.provider_speaker_label + for segment in segments + if (segment.provider_cluster_id or segment.provider_speaker_label) + and ( + segment.person_id or segment.is_user or segment.speaker_identity_state in ('identified', 'user') + ) + } + ), + 'embedding_extraction_failure_count': 0, + }, + transcribe_bytes=MagicMock(), + update_provider_run_identity_metrics=MagicMock(), + ), +) +sys.modules.setdefault( + 'utils.voice_duration_limiter', + SimpleNamespace( + compute_pcm_duration_ms=lambda byte_length, sample_rate, channels: int( + byte_length * 1000 / (sample_rate * channels * 2) + ) + ), +) +sys.modules.setdefault('utils.other.hume', MagicMock()) +import utils.other as _utils_other + +setattr(_utils_other, 'hume', sys.modules['utils.other.hume']) +_endpoints_stub = SimpleNamespace( + get_current_user_uid=lambda: 'test-uid', + with_rate_limit=lambda dep, _policy: dep, +) +sys.modules.setdefault('utils.other.endpoints', _endpoints_stub) +setattr(_utils_other, 'endpoints', _endpoints_stub) + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from models.transcript_segment import ProviderTranscriptResult, TranscriptSegment +from routers import desktop_background +from utils.stt.providers import BackgroundProviderMode, STTProviderName, STTWorkload + +sys.modules.pop('utils.stt.provider_service', None) +if 'utils.stt' in sys.modules and hasattr(sys.modules['utils.stt'], 'provider_service'): + delattr(sys.modules['utils.stt'], 'provider_service') + + +def _client(monkeypatch, *, segments=None, person_embeddings_cache=None): + app = FastAPI() + app.include_router(desktop_background.router) + for route in app.routes: + if hasattr(route, 'dependant'): + for dep in route.dependant.dependencies: + if dep.call is not None: + app.dependency_overrides[dep.call] = lambda: 'test-uid' + + monkeypatch.setattr(desktop_background, 'is_trial_paywalled', lambda *_args, **_kwargs: False) + monkeypatch.setattr(desktop_background, 'is_hard_restricted', lambda *_args, **_kwargs: False) + monkeypatch.setattr(desktop_background, 'has_transcription_credits', lambda *_args, **_kwargs: True) + monkeypatch.setattr(desktop_background, 'record_speech_ms', lambda *_args, **_kwargs: None) + monkeypatch.setattr(desktop_background, 'record_usage', lambda *_args, **_kwargs: None) + monkeypatch.setattr(desktop_background, 'resolve_voice_message_language', lambda _uid, language: language or 'en') + monkeypatch.setattr( + desktop_background.conversations_db, + 'get_conversation', + lambda _uid, _cid: {'id': _cid, 'status': 'in_progress'}, + ) + monkeypatch.setattr(desktop_background, 'get_background_chunk_record', MagicMock(return_value=None)) + monkeypatch.setattr( + desktop_background, + 'append_background_chunk_to_in_progress_conversation', + MagicMock(return_value=SimpleNamespace(appended=True, duplicate=False, segments=[], chunk_record=None)), + ) + monkeypatch.setattr( + desktop_background, + 'finish_desktop_background_conversation', + MagicMock(return_value={'id': 'conv-1', 'status': 'completed'}), + ) + if not hasattr(desktop_background.redis_db, 'r'): + monkeypatch.setattr(desktop_background.redis_db, 'r', MagicMock(), raising=False) + monkeypatch.setattr(desktop_background.redis_db.r, 'get', lambda _key: None) + monkeypatch.setattr(desktop_background.redis_db.r, 'set', lambda *_args, **_kwargs: None) + monkeypatch.setattr( + desktop_background, + '_build_person_embeddings_cache', + lambda _uid: person_embeddings_cache or {}, + ) + monkeypatch.setattr(desktop_background, 'update_provider_run_identity_metrics', MagicMock()) + + default_segments = segments or [ + TranscriptSegment( + id='seg-1', + text='Hello world.', + speaker='SPEAKER_00', + speaker_id=0, + is_user=False, + start=0.5, + end=1.2, + provider_cluster_id='A', + provider_speaker_label='ASSEMBLYAI_SPEAKER_A', + stt_provider='assemblyai', + stt_model='universal-2', + ) + ] + + def _transcribe_bytes(audio_bytes, **_kwargs): + return SimpleNamespace( + result=ProviderTranscriptResult(provider='assemblyai', model='universal-2', words=[], utterances=[]), + detected_language='en', + segments=[segment.model_copy(deep=True) for segment in default_segments], + run_id='run-1', + ) + + mock_transcribe = MagicMock(side_effect=_transcribe_bytes) + monkeypatch.setattr(desktop_background, 'transcribe_bytes', mock_transcribe) + return TestClient(app), mock_transcribe + + +def test_desktop_capabilities_reports_assemblyai_background_when_key_available(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + monkeypatch.setattr( + desktop_background, + 'resolve_background_provider_policy', + lambda: SimpleNamespace( + mode=BackgroundProviderMode.assemblyai, + primary_provider=STTProviderName.assemblyai, + effective_provider=STTProviderName.assemblyai, + fallback_provider=STTProviderName.deepgram, + fallback_enabled=True, + fallback_available=True, + enabled=True, + reason=None, + ), + ) + + response = client.get('/v2/desktop/capabilities') + + assert response.status_code == 200 + data = response.json()['background_batch'] + assert data['enabled'] is True + assert data['provider'] == 'assemblyai' + assert data['primary_provider'] == 'assemblyai' + assert data['effective_provider'] == 'assemblyai' + assert data['mode'] == 'assemblyai' + assert data['fallback_provider'] == 'deepgram' + assert data['fallback_enabled'] is True + assert data['fallback_available'] is True + assert data['workload'] == 'background' + assert data['reason'] is None + + +def test_desktop_capabilities_allows_background_batch_with_deepgram_fallback(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + monkeypatch.setattr( + desktop_background, + 'resolve_background_provider_policy', + lambda: SimpleNamespace( + mode=BackgroundProviderMode.assemblyai, + primary_provider=STTProviderName.assemblyai, + effective_provider=STTProviderName.deepgram, + fallback_provider=STTProviderName.deepgram, + fallback_enabled=True, + fallback_available=True, + enabled=True, + reason='fallback_deepgram_available', + ), + ) + + response = client.get('/v2/desktop/capabilities') + + assert response.status_code == 200 + data = response.json()['background_batch'] + assert data['enabled'] is True + assert data['provider'] == 'assemblyai' + assert data['primary_provider'] == 'assemblyai' + assert data['effective_provider'] == 'deepgram' + assert data['mode'] == 'deepgram_fallback' + assert data['fallback_provider'] == 'deepgram' + assert data['fallback_enabled'] is True + assert data['fallback_available'] is True + assert data['reason'] == 'fallback_deepgram_available' + + +def test_desktop_capabilities_reports_missing_assemblyai_key_when_fallback_disabled(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + monkeypatch.setattr( + desktop_background, + 'resolve_background_provider_policy', + lambda: SimpleNamespace( + mode=BackgroundProviderMode.assemblyai, + primary_provider=STTProviderName.assemblyai, + effective_provider=None, + fallback_provider=None, + fallback_enabled=False, + fallback_available=False, + enabled=False, + reason='missing_assemblyai_api_key', + ), + ) + + response = client.get('/v2/desktop/capabilities') + + assert response.status_code == 200 + data = response.json()['background_batch'] + assert data['enabled'] is False + assert data['provider'] == 'assemblyai' + assert data['effective_provider'] is None + assert data['mode'] == 'disabled' + assert data['fallback_provider'] is None + assert data['fallback_enabled'] is False + assert data['reason'] == 'missing_assemblyai_api_key' + + +def test_desktop_capabilities_reports_no_usable_batch_provider_without_any_key(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + monkeypatch.setattr( + desktop_background, + 'resolve_background_provider_policy', + lambda: SimpleNamespace( + mode=BackgroundProviderMode.assemblyai, + primary_provider=STTProviderName.assemblyai, + effective_provider=None, + fallback_provider=STTProviderName.deepgram, + fallback_enabled=True, + fallback_available=False, + enabled=False, + reason='missing_assemblyai_api_key', + ), + ) + + response = client.get('/v2/desktop/capabilities') + + assert response.status_code == 200 + data = response.json()['background_batch'] + assert data['enabled'] is False + assert data['primary_provider'] == 'assemblyai' + assert data['effective_provider'] is None + assert data['fallback_provider'] == 'deepgram' + assert data['fallback_available'] is False + assert data['reason'] == 'missing_assemblyai_api_key' + + +def test_desktop_capabilities_uses_byok_assemblyai(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + monkeypatch.setattr( + desktop_background, + 'resolve_background_provider_policy', + lambda: SimpleNamespace( + mode=BackgroundProviderMode.assemblyai, + primary_provider=STTProviderName.assemblyai, + effective_provider=STTProviderName.assemblyai, + fallback_provider=STTProviderName.deepgram, + fallback_enabled=True, + fallback_available=False, + enabled=True, + reason=None, + ), + ) + + response = client.get('/v2/desktop/capabilities') + + assert response.status_code == 200 + data = response.json()['background_batch'] + assert data['enabled'] is True + assert data['effective_provider'] == 'assemblyai' + assert data['mode'] == 'assemblyai' + + +def test_desktop_capabilities_uses_byok_deepgram_fallback(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + monkeypatch.setattr( + desktop_background, + 'resolve_background_provider_policy', + lambda: SimpleNamespace( + mode=BackgroundProviderMode.assemblyai, + primary_provider=STTProviderName.assemblyai, + effective_provider=STTProviderName.deepgram, + fallback_provider=STTProviderName.deepgram, + fallback_enabled=True, + fallback_available=True, + enabled=True, + reason='fallback_deepgram_available', + ), + ) + + response = client.get('/v2/desktop/capabilities') + + assert response.status_code == 200 + data = response.json()['background_batch'] + assert data['enabled'] is True + assert data['effective_provider'] == 'deepgram' + assert data['mode'] == 'deepgram_fallback' + + +def test_desktop_capabilities_reports_shadow_only_mode(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + monkeypatch.setattr( + desktop_background, + 'resolve_background_provider_policy', + lambda: SimpleNamespace( + mode=BackgroundProviderMode.shadow_only, + primary_provider=STTProviderName.deepgram, + effective_provider=STTProviderName.deepgram, + fallback_provider=None, + fallback_enabled=False, + fallback_available=False, + enabled=True, + reason='shadow_only', + ), + ) + + response = client.get('/v2/desktop/capabilities') + + assert response.status_code == 200 + data = response.json()['background_batch'] + assert data['enabled'] is True + assert data['provider'] == 'deepgram' + assert data['effective_provider'] == 'deepgram' + assert data['mode'] == 'shadow_only' + assert data['reason'] == 'shadow_only' + + +def test_background_transcribe_returns_segments_with_offset(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=12000', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + data = response.json() + assert data['provider'] == 'assemblyai' + assert data['run_id'] == 'run-1' + assert data['segments'][0]['start'] == 12.5 + assert data['segments'][0]['end'] == 13.2 + assert data['segments'][0]['speaker_id'] == 0 + assert data['segments'][0]['speaker'] == 'SPEAKER_00' + + +def test_finish_background_conversation_uses_explicit_conversation_id(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + + response = client.post('/v2/desktop/background-conversation/conv-1/finish') + + assert response.status_code == 200 + assert response.json()['id'] == 'conv-1' + desktop_background.finish_desktop_background_conversation.assert_called_once_with('test-uid', 'conv-1') + + +def test_finish_background_conversation_maps_validation_error(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + desktop_background.finish_desktop_background_conversation.side_effect = ( + desktop_background.DesktopBackgroundConversationError('conversation is not in_progress', status_code=409) + ) + + response = client.post('/v2/desktop/background-conversation/conv-1/finish') + + assert response.status_code == 409 + assert response.json()['detail'] == 'conversation is not in_progress' + + +def test_background_transcribe_wraps_linear16_pcm_as_wav(monkeypatch): + client, mock_transcribe = _client(monkeypatch) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0&sample_rate=16000', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + audio_arg = mock_transcribe.call_args.args[0] + assert audio_arg[:4] == b'RIFF' + assert b'WAVE' in audio_arg[:16] + assert mock_transcribe.call_args.kwargs['workload'].value == 'background' + + +def test_background_transcribe_persists_segments(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + desktop_background.append_background_chunk_to_in_progress_conversation.assert_called_once() + + +def test_background_transcribe_requires_chunk_id_when_persisting(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 400 + assert response.json()['detail'] == 'chunk_id is required when persist=true' + + +def test_background_transcribe_duplicate_chunk_returns_success_without_appending(monkeypatch): + client, mock_transcribe = _client(monkeypatch) + payload = b'\x01\x00' * 1600 + payload_hash = desktop_background._background_chunk_payload_hash(payload, 16000, 1, 'linear16', 0) + monkeypatch.setattr( + desktop_background, + 'get_background_chunk_record', + MagicMock( + return_value={ + 'payload_hash': payload_hash, + 'provider': 'assemblyai', + 'run_id': 'run-previous', + 'chunk_duration_ms': 100, + } + ), + ) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', + content=payload, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + data = response.json() + assert data['duplicate'] is True + assert data['segments'] == [] + assert data['provider'] == 'assemblyai' + assert data['run_id'] == 'run-previous' + mock_transcribe.assert_not_called() + desktop_background.append_background_chunk_to_in_progress_conversation.assert_not_called() + + +def test_background_transcribe_rejects_conflicting_duplicate_chunk(monkeypatch): + client, mock_transcribe = _client(monkeypatch) + monkeypatch.setattr( + desktop_background, + 'get_background_chunk_record', + MagicMock(return_value={'payload_hash': 'different', 'provider': 'assemblyai'}), + ) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 409 + assert response.json()['detail'] == 'chunk_id payload mismatch' + mock_transcribe.assert_not_called() + desktop_background.append_background_chunk_to_in_progress_conversation.assert_not_called() + + +def test_background_transcribe_can_skip_persist_without_conversation(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + record_speech_ms = MagicMock() + record_usage = MagicMock() + monkeypatch.setattr(desktop_background, 'record_speech_ms', record_speech_ms) + monkeypatch.setattr(desktop_background, 'record_usage', record_usage) + + response = client.post( + '/v2/desktop/background-transcribe?chunk_start_ms=0&persist=false', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + desktop_background.append_background_chunk_to_in_progress_conversation.assert_not_called() + record_speech_ms.assert_called_once() + record_usage.assert_called_once() + + +def test_cluster_speaker_mapping_assigns_distinct_ids(monkeypatch): + segments = [ + TranscriptSegment(text='One', is_user=False, start=0.0, end=1.0, provider_cluster_id='A'), + TranscriptSegment(text='Two', is_user=False, start=1.0, end=2.0, provider_cluster_id='B'), + ] + client, _mock_transcribe = _client(monkeypatch, segments=segments) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + data = response.json() + assert [segment['speaker_id'] for segment in data['segments']] == [0, 1] + assert [segment['speaker'] for segment in data['segments']] == ['SPEAKER_00', 'SPEAKER_01'] + assert data['speaker_diagnostics']['provider_cluster_count'] == 2 + assert data['speaker_diagnostics']['mapped_speaker_ids'] == [0, 1] + assert data['speaker_diagnostics']['provider_speaker_count'] == 2 + assert data['speaker_diagnostics']['mapped_provider_speaker_count'] == 2 + + +def test_noncontiguous_same_provider_cluster_splits_inside_chunk(monkeypatch): + segments = [ + TranscriptSegment(text='Alice starts.', is_user=False, start=0.0, end=1.0, provider_cluster_id='A'), + TranscriptSegment(text='Bob replies.', is_user=False, start=1.1, end=2.0, provider_cluster_id='B'), + TranscriptSegment(text='Carol follows.', is_user=False, start=2.1, end=3.0, provider_cluster_id='A'), + ] + client, _mock_transcribe = _client(monkeypatch, segments=segments) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + data = response.json() + assert [segment['speaker_id'] for segment in data['segments']] == [0, 1, 2] + assert [segment['provider_cluster_id'] for segment in data['segments']] == [ + 'A::local_part:1', + 'B', + 'A::local_part:2', + ] + assert data['speaker_diagnostics']['provider_cluster_count'] == 3 + assert data['speaker_diagnostics']['speaker_split_cluster_count'] == 2 + assert data['speaker_diagnostics']['speaker_split_segment_count'] == 2 + assert data['speaker_diagnostics']['cannot_link_violations_prevented'] == 1 + + +def test_contiguous_same_provider_cluster_stays_together_inside_chunk(monkeypatch): + segments = [ + TranscriptSegment(text='Alice starts.', is_user=False, start=0.0, end=1.0, provider_cluster_id='A'), + TranscriptSegment(text='Alice continues.', is_user=False, start=1.1, end=2.0, provider_cluster_id='A'), + TranscriptSegment(text='Bob replies.', is_user=False, start=2.1, end=3.0, provider_cluster_id='B'), + ] + client, _mock_transcribe = _client(monkeypatch, segments=segments) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + data = response.json() + assert [segment['speaker_id'] for segment in data['segments']] == [0, 0, 1] + assert [segment['provider_cluster_id'] for segment in data['segments']] == ['A', 'A', 'B'] + assert data['speaker_diagnostics']['provider_cluster_count'] == 2 + + +def test_assemblyai_label_only_speakers_do_not_collapse_to_single_local_speaker(monkeypatch): + segments = [ + TranscriptSegment( + text='Alice starts.', + is_user=False, + start=0.0, + end=2.0, + provider_speaker_label='A', + speaker_identity_state='unassigned', + stt_provider='assemblyai', + stt_model='universal-2', + ), + TranscriptSegment( + text='Bob replies.', + is_user=False, + start=2.0, + end=4.0, + provider_speaker_label='B', + speaker_identity_state='unassigned', + stt_provider='assemblyai', + stt_model='universal-2', + ), + ] + client, _mock_transcribe = _client(monkeypatch, segments=segments) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + data = response.json() + assert [segment['provider_speaker_label'] for segment in data['segments']] == ['A', 'B'] + assert [segment['speaker_id'] for segment in data['segments']] == [0, 1] + assert data['speaker_diagnostics']['provider_speaker_label_count'] == 2 + assert data['speaker_diagnostics']['mapped_speaker_ids'] == [0, 1] + assert data['speaker_diagnostics']['mapped_unmapped_speaker_count'] == 2 + desktop_background.update_provider_run_identity_metrics.assert_called_once() + assert desktop_background.update_provider_run_identity_metrics.call_args.args[5] == 'skipped' + assert desktop_background.update_provider_run_identity_metrics.call_args.args[6] == 'missing_candidate_embeddings' + + +def test_background_transcribe_multi_chunk_offsets_persist_and_keep_anonymous_speakers_chunk_local(monkeypatch): + client, mock_transcribe = _client(monkeypatch) + speaker_map_store = {} + + def _redis_get(key): + return speaker_map_store.get(key) + + def _redis_set(key, value, **_kwargs): + speaker_map_store[key] = value + + monkeypatch.setattr(desktop_background.redis_db.r, 'get', _redis_get) + monkeypatch.setattr(desktop_background.redis_db.r, 'set', _redis_set) + + responses = [ + [TranscriptSegment(text='First.', is_user=False, start=0.1, end=1.0, provider_cluster_id='A')], + [TranscriptSegment(text='Second.', is_user=False, start=0.1, end=1.0, provider_cluster_id='B')], + [TranscriptSegment(text='Third.', is_user=False, start=0.1, end=1.0, provider_cluster_id='A')], + ] + + def _transcribe_bytes(_audio_bytes, **_kwargs): + return SimpleNamespace( + result=ProviderTranscriptResult(provider='assemblyai', model='universal-2', words=[], utterances=[]), + detected_language='en', + segments=[segment.model_copy(deep=True) for segment in responses.pop(0)], + run_id='run-multi', + ) + + mock_transcribe.side_effect = _transcribe_bytes + + for chunk_start_ms in (0, 14000, 28000): + response = client.post( + f'/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-id-{chunk_start_ms}&chunk_start_ms={chunk_start_ms}', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + assert response.status_code == 200 + assert response.json()['provider'] == 'assemblyai' + + appended_segments = [ + call.args[4][0] + for call in desktop_background.append_background_chunk_to_in_progress_conversation.call_args_list + ] + assert [segment.start for segment in appended_segments] == [0.1, 14.1, 28.1] + assert [segment.end for segment in appended_segments] == [1.0, 15.0, 29.0] + assert [segment.speaker_id for segment in appended_segments] == [0, 1, 2] + assert [segment.speaker for segment in appended_segments] == ['SPEAKER_00', 'SPEAKER_01', 'SPEAKER_02'] + + +def test_background_transcribe_multi_chunk_known_omi_identity_can_reconcile(monkeypatch): + alice_embedding = np.array([[1.0, 0.0]], dtype=np.float32) + client, mock_transcribe = _client( + monkeypatch, + person_embeddings_cache={'person-alice': {'embedding': alice_embedding, 'name': 'Alice'}}, + ) + monkeypatch.setattr(desktop_background, 'extract_embedding_from_bytes', lambda _audio, _filename: alice_embedding) + speaker_map_store = {} + + def _redis_get(key): + return speaker_map_store.get(key) + + def _redis_set(key, value, **_kwargs): + speaker_map_store[key] = value + + monkeypatch.setattr(desktop_background.redis_db.r, 'get', _redis_get) + monkeypatch.setattr(desktop_background.redis_db.r, 'set', _redis_set) + + responses = [ + [TranscriptSegment(text='Alice first chunk.', is_user=False, start=0.1, end=3.0, provider_cluster_id='A')], + [TranscriptSegment(text='Alice second chunk.', is_user=False, start=0.1, end=3.0, provider_cluster_id='B')], + ] + + def _transcribe_bytes(_audio_bytes, **_kwargs): + return SimpleNamespace( + result=ProviderTranscriptResult(provider='assemblyai', model='universal-2', words=[], utterances=[]), + detected_language='en', + segments=[segment.model_copy(deep=True) for segment in responses.pop(0)], + run_id='run-known', + ) + + mock_transcribe.side_effect = _transcribe_bytes + + response_payloads = [] + for chunk_start_ms in (0, 14000): + response = client.post( + f'/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-id-{chunk_start_ms}&chunk_start_ms={chunk_start_ms}', + content=b'\x01\x00' * 16000 * 4, + headers={'Content-Type': 'application/octet-stream'}, + ) + assert response.status_code == 200 + response_payloads.append(response.json()) + + appended_segments = [ + call.args[4][0] + for call in desktop_background.append_background_chunk_to_in_progress_conversation.call_args_list + ] + assert [segment.person_id for segment in appended_segments] == ['person-alice', 'person-alice'] + assert [segment.speaker_id for segment in appended_segments] == [0, 0] + assert response_payloads[-1]['speaker_diagnostics']['speaker_reconciliation_accepted_count'] == 1 + assert response_payloads[-1]['speaker_diagnostics']['speaker_reconciliation_rejected_count'] == 0 + + +def test_background_transcribe_identifies_assemblyai_speaker_with_omi_user_embedding(monkeypatch): + user_embedding = np.array([[1.0, 0.0]], dtype=np.float32) + segments = [ + TranscriptSegment( + text='This is my voice.', + is_user=False, + start=0.0, + end=3.0, + provider_cluster_id='A', + provider_speaker_label='ASSEMBLYAI_SPEAKER_A', + speaker_identity_state='unassigned', + stt_provider='assemblyai', + stt_model='universal-2', + ) + ] + client, _mock_transcribe = _client( + monkeypatch, + segments=segments, + person_embeddings_cache={'user': {'embedding': user_embedding, 'name': 'User'}}, + ) + monkeypatch.setattr(desktop_background, 'extract_embedding_from_bytes', lambda _audio, _filename: user_embedding) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=12000', + content=b'\x01\x00' * 16000 * 4, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 200 + data = response.json() + segment = data['segments'][0] + assert segment['stt_provider'] == 'assemblyai' + assert segment['provider_cluster_id'] == 'A' + assert segment['is_user'] is True + assert segment['person_id'] is None + assert segment['speaker_identity_state'] == 'user' + assert segment['speaker_identity_source'] == 'omi_speaker_embedding' + assert segment['speaker_identity_provenance']['provider_cluster_id'] == 'A' + assert segment['start'] == 12.0 + desktop_background.update_provider_run_identity_metrics.assert_called_once() + + +def test_byok_background_routing_uses_deepgram_when_only_deepgram_key(monkeypatch): + from utils.stt import provider_service + + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true') + monkeypatch.setenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') + monkeypatch.setenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', 'assemblyai') + monkeypatch.setattr(provider_service, 'get_byok_key', lambda provider: {'deepgram': 'dg-user-key'}.get(provider)) + + provider = provider_service.resolve_prerecorded_provider_for_request(STTWorkload.background) + + assert provider == STTProviderName.deepgram + + +def test_background_transcribe_rejects_empty_body(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', + content=b'', + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 400 + + +def test_background_transcribe_rejects_stereo_pcm_for_v1(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0&channels=2', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 422 + assert response.json()['detail'] == 'channels must be 1' + + +def test_background_transcribe_rejects_malformed_content_length(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=conv-1&chunk_id=chunk-001&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream', 'Content-Length': 'not-an-int'}, + ) + + assert response.status_code == 422 + + +def test_background_transcribe_rejects_invalid_conversation(monkeypatch): + client, _mock_transcribe = _client(monkeypatch) + monkeypatch.setattr(desktop_background.conversations_db, 'get_conversation', lambda _uid, _cid: None) + + response = client.post( + '/v2/desktop/background-transcribe?conversation_id=missing&chunk_id=chunk-001&chunk_start_ms=0', + content=b'\x01\x00' * 1600, + headers={'Content-Type': 'application/octet-stream'}, + ) + + assert response.status_code == 404 diff --git a/backend/tests/unit/test_desktop_migration.py b/backend/tests/unit/test_desktop_migration.py index 289d88811e9..9ef746913d3 100644 --- a/backend/tests/unit/test_desktop_migration.py +++ b/backend/tests/unit/test_desktop_migration.py @@ -82,6 +82,7 @@ def _stub_package(name): field_filter_stub.FieldFilter = MagicMock() sys.modules["google.cloud.firestore_v1"].FieldFilter = field_filter_stub.FieldFilter sys.modules["google.cloud.firestore_v1"].transactional = lambda f: f +sys.modules["database.redis_db"].try_acquire_user_platform_write_lock = MagicMock(return_value=True) # Add backend dir to sys.path sys.path.insert(0, str(BACKEND_DIR)) @@ -1694,7 +1695,7 @@ def test_returns_title(self, mock_update): session_id='s1', messages=[TitleMessageInput(text='hi', sender='human'), TitleMessageInput(text='hello', sender='ai')], ) - with patch.dict('sys.modules', {'utils.llm.clients': MagicMock(llm_mini=mock_llm)}): + with patch.dict('sys.modules', {'utils.llm.clients': MagicMock(get_llm=MagicMock(return_value=mock_llm))}): result = generate_session_title(request, uid='u1') assert result == {'title': 'Project Discussion'} @@ -1713,7 +1714,7 @@ def test_empty_response_defaults_to_new_chat(self, mock_update): session_id='s1', messages=[TitleMessageInput(text='hi', sender='human')], ) - with patch.dict('sys.modules', {'utils.llm.clients': MagicMock(llm_mini=mock_llm)}): + with patch.dict('sys.modules', {'utils.llm.clients': MagicMock(get_llm=MagicMock(return_value=mock_llm))}): result = generate_session_title(request, uid='u1') assert result == {'title': 'New Chat'} diff --git a/backend/tests/unit/test_desktop_transcribe.py b/backend/tests/unit/test_desktop_transcribe.py index 3b750a4f180..bfc0d5e764d 100644 --- a/backend/tests/unit/test_desktop_transcribe.py +++ b/backend/tests/unit/test_desktop_transcribe.py @@ -114,6 +114,7 @@ 'utils.llm.goals', 'utils.llm.usage_tracker', 'utils.conversations', + 'utils.conversations.factory', 'utils.conversations.process_conversation', 'utils.notifications', 'utils.other.storage', @@ -131,9 +132,22 @@ ]: sys.modules.setdefault(_ufull, MagicMock()) +for _pkg_name in ['utils.conversations', 'utils.retrieval', 'utils.llm']: + if _pkg_name in sys.modules: + sys.modules[_pkg_name].__path__ = [_pkg_name.replace('.', '/')] + # Force-import real models.chat (has no project deps, needed for FastAPI response_model) import importlib.util as _ilu +_segment_spec = _ilu.spec_from_file_location( + 'models.transcript_segment', + os.path.join(os.path.dirname(__file__), '..', '..', 'models', 'transcript_segment.py'), +) +_real_segment = _ilu.module_from_spec(_segment_spec) +_segment_spec.loader.exec_module(_real_segment) +sys.modules['models.transcript_segment'] = _real_segment +setattr(_models_pkg, 'transcript_segment', _real_segment) + _chat_spec = _ilu.spec_from_file_location( 'models.chat', os.path.join(os.path.dirname(__file__), '..', '..', 'models', 'chat.py') ) @@ -144,6 +158,9 @@ # Now safe to import the modules under test from utils.stt.pre_recorded import deepgram_prerecorded_from_bytes +import utils.chat as _chat_mod + +setattr(sys.modules['utils'], 'chat', _chat_mod) # --------------------------------------------------------------------------- # deepgram_prerecorded_from_bytes: encoding/language/model options @@ -297,125 +314,128 @@ def test_return_language_extracts_detected_lang(self, mock_client): # --------------------------------------------------------------------------- +def _mock_pcm_transcription(*, text=None, detected_language=None, words=None, segments=None): + from utils.stt.provider_service import PrerecordedTranscriptionResponse + + if segments is None: + segment = MagicMock() + segment.text = text + segments = [segment] if text is not None else [] + if words is None: + words = [{'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': text}] if text is not None else [] + return PrerecordedTranscriptionResponse( + result=MagicMock(), + detected_language=detected_language, + segments=segments, + words=words, + run_id=None, + ) + + class TestTranscribePcmBytes: """Verify transcribe_pcm_bytes passes language/model and propagates errors.""" - @patch('utils.chat.postprocess_words') - @patch('utils.chat.deepgram_prerecorded_from_bytes') - @patch('utils.chat.get_deepgram_model_for_language') - def test_language_model_forwarded(self, mock_get_model, mock_dg, mock_postprocess): - """stt_language and stt_model should be passed to deepgram_prerecorded_from_bytes.""" + @patch('utils.chat.stt_provider_service.transcribe_bytes') + @patch('utils.chat.stt_provider_service.resolve_prerecorded_language_model') + def test_language_model_forwarded(self, mock_resolve_model, mock_transcribe): + """stt_language and stt_model should be passed to transcribe_bytes.""" from utils.chat import transcribe_pcm_bytes - mock_get_model.return_value = ('es', 'nova-3') - mock_dg.return_value = [{'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'Hola'}] - mock_seg = MagicMock() - mock_seg.text = 'Hola' - mock_postprocess.return_value = [mock_seg] + mock_resolve_model.return_value = ('es', 'nova-3') + mock_transcribe.return_value = _mock_pcm_transcription(text='Hola') text, lang = transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='es') - mock_dg.assert_called_once() - call_kwargs = mock_dg.call_args[1] + mock_transcribe.assert_called_once() + call_kwargs = mock_transcribe.call_args[1] assert call_kwargs['language'] == 'es' assert call_kwargs['model'] == 'nova-3' assert call_kwargs['encoding'] == 'linear16' assert text == 'Hola' - @patch('utils.chat.deepgram_prerecorded_from_bytes') - @patch('utils.chat.get_deepgram_model_for_language') - def test_runtime_error_propagates(self, mock_get_model, mock_dg): + @patch('utils.chat.stt_provider_service.transcribe_bytes') + @patch('utils.chat.stt_provider_service.resolve_prerecorded_language_model') + def test_runtime_error_propagates(self, mock_resolve_model, mock_transcribe): """RuntimeError from Deepgram should propagate (not be caught).""" from utils.chat import transcribe_pcm_bytes - mock_get_model.return_value = ('en', 'nova-3') - mock_dg.side_effect = RuntimeError('Deepgram failed') + mock_resolve_model.return_value = ('en', 'nova-3') + mock_transcribe.side_effect = RuntimeError('Deepgram failed') with pytest.raises(RuntimeError, match='Deepgram failed'): transcribe_pcm_bytes(b'\x00' * 100, 'test-uid') - @patch('utils.chat.deepgram_prerecorded_from_bytes') - @patch('utils.chat.get_deepgram_model_for_language') - def test_empty_words_returns_none(self, mock_get_model, mock_dg): + @patch('utils.chat.stt_provider_service.transcribe_bytes') + @patch('utils.chat.stt_provider_service.resolve_prerecorded_language_model') + def test_empty_words_returns_none(self, mock_resolve_model, mock_transcribe): """Empty word list should return (None, language).""" from utils.chat import transcribe_pcm_bytes - mock_get_model.return_value = ('en', 'nova-3') - mock_dg.return_value = [] + mock_resolve_model.return_value = ('en', 'nova-3') + mock_transcribe.return_value = _mock_pcm_transcription(text=None, words=[]) text, lang = transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='en') assert text is None assert lang == 'en' - @patch('utils.chat.postprocess_words') - @patch('utils.chat.deepgram_prerecorded_from_bytes') - @patch('utils.chat.get_deepgram_model_for_language') - def test_multi_language_returns_detected_language(self, mock_get_model, mock_dg, mock_postprocess): + @patch('utils.chat.stt_provider_service.transcribe_bytes') + @patch('utils.chat.stt_provider_service.resolve_prerecorded_language_model') + def test_multi_language_returns_detected_language(self, mock_resolve_model, mock_transcribe): """Multi-language mode should return the Deepgram-detected language, not hardcoded 'en'.""" from utils.chat import transcribe_pcm_bytes - mock_get_model.return_value = ('multi', 'nova-3') - # return_language=True path returns (words, detected_lang) - mock_dg.return_value = ([{'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'Bonjour'}], 'fr') - mock_seg = MagicMock() - mock_seg.text = 'Bonjour' - mock_postprocess.return_value = [mock_seg] + mock_resolve_model.return_value = ('multi', 'nova-3') + mock_transcribe.return_value = _mock_pcm_transcription(text='Bonjour', detected_language='fr') text, lang = transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='multi') assert text == 'Bonjour' assert lang == 'fr' - # Verify return_language=True was passed - call_kwargs = mock_dg.call_args[1] + call_kwargs = mock_transcribe.call_args[1] assert call_kwargs['return_language'] is True - @patch('utils.chat.postprocess_words') - @patch('utils.chat.deepgram_prerecorded_from_bytes') - @patch('utils.chat.get_deepgram_model_for_language') - def test_chinese_language_uses_nova3(self, mock_get_model, mock_dg, mock_postprocess): + @patch('utils.chat.stt_provider_service.transcribe_bytes') + @patch('utils.chat.stt_provider_service.resolve_prerecorded_language_model') + def test_chinese_language_uses_nova3(self, mock_resolve_model, mock_transcribe): """Chinese should use nova-3 model.""" from utils.chat import transcribe_pcm_bytes - mock_get_model.return_value = ('zh', 'nova-3') - mock_dg.return_value = [{'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': '你好'}] - mock_seg = MagicMock() - mock_seg.text = '你好' - mock_postprocess.return_value = [mock_seg] + mock_resolve_model.return_value = ('zh', 'nova-3') + mock_transcribe.return_value = _mock_pcm_transcription(text='你好') - text, lang = transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='zh') + transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='zh') - call_kwargs = mock_dg.call_args[1] + call_kwargs = mock_transcribe.call_args[1] assert call_kwargs['model'] == 'nova-3' assert call_kwargs['language'] == 'zh' - @patch('utils.chat.postprocess_words') - @patch('utils.chat.deepgram_prerecorded_from_bytes') - @patch('utils.chat.get_deepgram_model_for_language') - def test_whitespace_only_transcript_returns_none(self, mock_get_model, mock_dg, mock_postprocess): + @patch('utils.chat.stt_provider_service.transcribe_bytes') + @patch('utils.chat.stt_provider_service.resolve_prerecorded_language_model') + def test_whitespace_only_transcript_returns_none(self, mock_resolve_model, mock_transcribe): """Whitespace-only transcript after postprocessing should return (None, language).""" from utils.chat import transcribe_pcm_bytes - mock_get_model.return_value = ('en', 'nova-3') - mock_dg.return_value = [{'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': ' '}] - mock_seg = MagicMock() - mock_seg.text = ' ' - mock_postprocess.return_value = [mock_seg] + mock_resolve_model.return_value = ('en', 'nova-3') + mock_transcribe.return_value = _mock_pcm_transcription(text=' ') text, lang = transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='en') assert text is None assert lang == 'en' - @patch('utils.chat.deepgram_prerecorded_from_bytes') - @patch('utils.chat.get_deepgram_model_for_language') - def test_postprocess_empty_returns_none(self, mock_get_model, mock_dg): - """postprocess_words returning empty list should return (None, language).""" + @patch('utils.chat.stt_provider_service.transcribe_bytes') + @patch('utils.chat.stt_provider_service.resolve_prerecorded_language_model') + def test_empty_segments_returns_none(self, mock_resolve_model, mock_transcribe): + """Empty segments should return (None, language).""" from utils.chat import transcribe_pcm_bytes - mock_get_model.return_value = ('en', 'nova-3') - mock_dg.return_value = [{'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'hello'}] - # postprocess_words is imported at module level; mock it - with patch('utils.chat.postprocess_words', return_value=[]): - text, lang = transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='en') + mock_resolve_model.return_value = ('en', 'nova-3') + mock_transcribe.return_value = _mock_pcm_transcription( + text='hello', + words=[{'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'hello'}], + segments=[], + ) + + text, lang = transcribe_pcm_bytes(b'\x00' * 100, 'test-uid', language='en') assert text is None assert lang == 'en' @@ -433,11 +453,11 @@ def test_retry_raises_after_max_attempts(self, mock_client): """After 3 failed attempts, should raise RuntimeError.""" mock_client.listen.rest.v.return_value.transcribe_file.side_effect = Exception('connection timeout') - with pytest.raises(RuntimeError, match='Deepgram transcription failed after 3 attempts'): + with pytest.raises(RuntimeError, match='Deepgram transcription failed after 2 attempts'): deepgram_prerecorded_from_bytes(b'\x00' * 100, encoding='linear16') - # Should have been called 3 times (attempts 0, 1, 2) - assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 3 + # Should have been called 2 times (attempts 0, 1) + assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 2 @patch('utils.stt.pre_recorded._deepgram_client') def test_return_language_empty_words_returns_detected_lang(self, mock_client): @@ -462,10 +482,10 @@ def test_no_channels_raises_and_retries(self, mock_client): mock_response.to_dict.return_value = {'results': {'channels': []}} mock_client.listen.rest.v.return_value.transcribe_file.return_value = mock_response - with pytest.raises(RuntimeError, match='Deepgram transcription failed after 3 attempts'): + with pytest.raises(RuntimeError, match='Deepgram transcription failed after 2 attempts'): deepgram_prerecorded_from_bytes(b'\x00' * 100) - assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 3 + assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 2 # --------------------------------------------------------------------------- @@ -497,6 +517,10 @@ def _stub_router_deps(): 'utils.social', 'utils.speaker_assignment', 'utils.speaker_identification', + 'utils.conversations.factory', + 'utils.stt.provider_service', + 'utils.stt.providers', + 'utils.stt.background_speaker_identity', 'utils.stt.speaker_embedding', 'utils.stt.vad', 'utils.stt.streaming', @@ -509,6 +533,10 @@ def _stub_router_deps(): rdb = sys.modules.get('database.redis_db') if rdb: rdb.check_rate_limit = MagicMock(return_value=(True, 99, 0)) + subscription = sys.modules.get('utils.subscription') + if subscription: + subscription.is_trial_paywalled = MagicMock(return_value=False) + subscription.enforce_chat_quota = MagicMock() def _make_chat_client(): diff --git a/backend/tests/unit/test_dg_usage_batch.py b/backend/tests/unit/test_dg_usage_batch.py index c32a8362728..0523ced1c71 100644 --- a/backend/tests/unit/test_dg_usage_batch.py +++ b/backend/tests/unit/test_dg_usage_batch.py @@ -75,6 +75,7 @@ def setup_method(self): 'database.users', 'database.user_usage', 'database.conversations', + 'utils.subscription', 'firebase_admin', 'firebase_admin.messaging', ]: @@ -83,6 +84,8 @@ def setup_method(self): sys.modules['database._client'].db = MagicMock() sys.modules['database.redis_db'].r = MagicMock() + sys.modules['utils.subscription'].has_transcription_credits = MagicMock(return_value=True) + sys.modules['utils.subscription'].is_paid_plan = MagicMock(return_value=True) os.environ.setdefault('FAIR_USE_ENABLED', 'true') os.environ.setdefault('ENCRYPTION_SECRET', 'test-secret-key-that-is-long-enough-for-encryption-32ch') diff --git a/backend/tests/unit/test_fair_use_classifier.py b/backend/tests/unit/test_fair_use_classifier.py index 88bf803cdbb..0c69ddddf2c 100644 --- a/backend/tests/unit/test_fair_use_classifier.py +++ b/backend/tests/unit/test_fair_use_classifier.py @@ -167,7 +167,8 @@ async def test_parses_llm_response_correctly(self): 'reasoning': 'Clear audiobook pattern', } ) - _llm_clients.llm_mini.ainvoke = AsyncMock(return_value=llm_response) + classifier_mod._classifier_llm = MagicMock() + classifier_mod._classifier_llm.ainvoke = AsyncMock(return_value=llm_response) result = await classifier_mod.classify_user_purpose('user1') @@ -192,7 +193,8 @@ async def test_handles_markdown_code_block_response(self): llm_response = MagicMock() llm_response.content = '```json\n{"misuse_score": 0.1, "usage_type": "none", "confidence": 0.9, "evidence": [], "reasoning": "Normal"}\n```' - _llm_clients.llm_mini.ainvoke = AsyncMock(return_value=llm_response) + classifier_mod._classifier_llm = MagicMock() + classifier_mod._classifier_llm.ainvoke = AsyncMock(return_value=llm_response) result = await classifier_mod.classify_user_purpose('user1') assert result['misuse_score'] == pytest.approx(0.1) @@ -215,7 +217,8 @@ async def test_clamps_score_to_valid_range(self): llm_response.content = json.dumps( {'misuse_score': 1.5, 'usage_type': 'none', 'confidence': -0.2, 'evidence': [], 'reasoning': ''} ) - _llm_clients.llm_mini.ainvoke = AsyncMock(return_value=llm_response) + classifier_mod._classifier_llm = MagicMock() + classifier_mod._classifier_llm.ainvoke = AsyncMock(return_value=llm_response) result = await classifier_mod.classify_user_purpose('user1') assert result['misuse_score'] == 1.0 @@ -237,7 +240,8 @@ async def test_returns_default_on_json_parse_error(self): llm_response = MagicMock() llm_response.content = 'This is not JSON at all' - _llm_clients.llm_mini.ainvoke = AsyncMock(return_value=llm_response) + classifier_mod._classifier_llm = MagicMock() + classifier_mod._classifier_llm.ainvoke = AsyncMock(return_value=llm_response) result = await classifier_mod.classify_user_purpose('user1') assert result['misuse_score'] == 0.0 @@ -256,7 +260,8 @@ async def test_returns_default_on_llm_error(self): 'created_at': now, } ] - _llm_clients.llm_mini.ainvoke = AsyncMock(side_effect=Exception('LLM timeout')) + classifier_mod._classifier_llm = MagicMock() + classifier_mod._classifier_llm.ainvoke = AsyncMock(side_effect=Exception('LLM timeout')) result = await classifier_mod.classify_user_purpose('user1') assert result['misuse_score'] == 0.0 diff --git a/backend/tests/unit/test_fair_use_upgrade.py b/backend/tests/unit/test_fair_use_upgrade.py index 8548f55106e..a6d39c40e8f 100644 --- a/backend/tests/unit/test_fair_use_upgrade.py +++ b/backend/tests/unit/test_fair_use_upgrade.py @@ -245,7 +245,7 @@ def test_checkout_session_completed_calls_clear(self): # The clear call should appear between checkout handler and the next event type next_event_idx = source.find("customer.subscription.updated", checkout_idx) block = source[checkout_idx:next_event_idx] - assert 'clear_fair_use_on_upgrade(uid)' in block + assert 'clear_fair_use_on_upgrade, uid' in block def test_subscription_event_calls_clear(self): """customer.subscription.* path must call clear_fair_use_on_upgrade.""" @@ -254,7 +254,7 @@ def test_subscription_event_calls_clear(self): assert sub_idx != -1, "customer.subscription handler not found" next_event_idx = source.find("subscription_schedule.completed", sub_idx) block = source[sub_idx:next_event_idx] - assert 'clear_fair_use_on_upgrade(uid)' in block + assert 'clear_fair_use_on_upgrade, uid' in block def test_schedule_completed_calls_clear(self): """subscription_schedule.completed path must call clear_fair_use_on_upgrade.""" @@ -262,5 +262,5 @@ def test_schedule_completed_calls_clear(self): schedule_idx = source.find("'subscription_schedule.completed'") assert schedule_idx != -1, "subscription_schedule handler not found" # Get a reasonable block after the schedule handler - block = source[schedule_idx : schedule_idx + 1500] - assert 'clear_fair_use_on_upgrade(uid)' in block + block = source[schedule_idx : schedule_idx + 2500] + assert 'clear_fair_use_on_upgrade, uid' in block diff --git a/backend/tests/unit/test_firestore_read_ops_cache.py b/backend/tests/unit/test_firestore_read_ops_cache.py index 21d6c77edca..b5e94c175bd 100644 --- a/backend/tests/unit/test_firestore_read_ops_cache.py +++ b/backend/tests/unit/test_firestore_read_ops_cache.py @@ -629,8 +629,8 @@ def test_checkout_completed_calls_invalidation(self): """checkout.session.completed path must call set_credits_invalidation_signal.""" source = self._read_source(self.PAYMENT_SOURCE_FILE) # The invalidation call should appear after _update_subscription_from_session - idx_update = source.find('_update_subscription_from_session(uid, session)') - idx_signal = source.find('set_credits_invalidation_signal(uid)', idx_update) + idx_update = source.find('_update_subscription_from_session, uid, session') + idx_signal = source.find('set_credits_invalidation_signal, uid', idx_update) assert ( idx_signal > idx_update ), "set_credits_invalidation_signal must be called after _update_subscription_from_session" @@ -653,16 +653,16 @@ def test_schedule_completed_calls_invalidation(self): idx_scheduled = source.find("Scheduled upgrade completed for user") assert idx_scheduled > 0 # Find the invalidation call before the log line (it's called right after update) - section = source[idx_scheduled - 200 : idx_scheduled] - assert 'set_credits_invalidation_signal(uid)' in section + section = source[idx_scheduled - 500 : idx_scheduled] + assert 'set_credits_invalidation_signal, uid' in section def test_schedule_canceled_calls_invalidation(self): """subscription_schedule.canceled must call set_credits_invalidation_signal.""" source = self._read_source(self.PAYMENT_SOURCE_FILE) idx_canceled = source.find("Subscription schedule canceled for user") assert idx_canceled > 0 - section = source[idx_canceled - 200 : idx_canceled] - assert 'set_credits_invalidation_signal(uid)' in section + section = source[idx_canceled - 300 : idx_canceled] + assert 'set_credits_invalidation_signal, uid' in section def test_transcribe_imports_invalidation_check(self): """transcribe.py must import check_credits_invalidation.""" diff --git a/backend/tests/unit/test_folder_name_enrichment.py b/backend/tests/unit/test_folder_name_enrichment.py index 667579e2d06..c249f64a6d0 100644 --- a/backend/tests/unit/test_folder_name_enrichment.py +++ b/backend/tests/unit/test_folder_name_enrichment.py @@ -295,6 +295,19 @@ def test_missing_speaker_id_defaults_to_zero(self): populate_speaker_names('uid1', conversations) assert conversations[0]['transcript_segments'][0]['speaker_name'] == 'Speaker 0' + def test_provider_label_unknown_cluster_gets_stable_display_name(self): + conversations = [ + { + 'transcript_segments': [ + {'text': 'hi', 'provider_speaker_label': 'A', 'speaker_id': 0}, + {'text': 'there', 'provider_speaker_label': 'B', 'speaker_id': 0}, + ] + } + ] + populate_speaker_names('uid1', conversations) + assert conversations[0]['transcript_segments'][0]['speaker_name'] == 'Speaker 1' + assert conversations[0]['transcript_segments'][1]['speaker_name'] == 'Speaker 2' + def test_user_profile_missing_name_falls_back_to_user(self): _mock_get_user_profile.return_value = {"name": None} conversations = [{'transcript_segments': [{'text': 'hi', 'is_user': True, 'speaker_id': 0}]}] diff --git a/backend/tests/unit/test_geocoding_cache.py b/backend/tests/unit/test_geocoding_cache.py index 8e67f2b61cb..f4be3f8a314 100644 --- a/backend/tests/unit/test_geocoding_cache.py +++ b/backend/tests/unit/test_geocoding_cache.py @@ -11,6 +11,7 @@ import json import sys import types +from contextlib import asynccontextmanager from unittest.mock import MagicMock, patch # Mock database._client before importing anything that touches GCP @@ -32,6 +33,14 @@ _http_mod.get_maps_client = MagicMock() _http_mod.get_webhook_client = MagicMock() + +@asynccontextmanager +async def _maps_semaphore(): + yield + + +_http_mod.get_maps_semaphore = MagicMock(return_value=_maps_semaphore()) + from models.geolocation import Geolocation from utils.conversations.location import get_google_maps_location @@ -44,7 +53,7 @@ def test_3_decimal_rounding(self): # 37.78512 -> 37.785, -122.40932 -> -122.409 with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = {"status": "OK", "results": []} mock_req.get.return_value = mock_resp @@ -80,7 +89,7 @@ def test_cache_hit_returns_geolocation(self): } with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = json.dumps(cached) - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: result = get_google_maps_location(37.78512, -122.40932) # Should NOT call Google API @@ -96,7 +105,7 @@ def test_cache_hit_no_api_key_needed(self): with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = json.dumps(cached) with patch.dict("os.environ", {}, clear=True): - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: result = get_google_maps_location(37.785, -122.409) mock_req.get.assert_not_called() assert result is not None @@ -118,7 +127,7 @@ def test_cache_miss_calls_api_and_caches(self): } with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = api_response mock_req.get.return_value = mock_resp @@ -139,7 +148,7 @@ def test_cache_miss_calls_api_and_caches(self): def test_api_no_results_returns_none(self): with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = {"status": "OK", "results": []} mock_req.get.return_value = mock_resp @@ -165,7 +174,7 @@ def test_redis_read_failure_falls_through_to_api(self): } with patch("utils.conversations.location.r") as mock_r: mock_r.get.side_effect = ConnectionError("Redis down") - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = api_response mock_req.get.return_value = mock_resp @@ -191,7 +200,7 @@ def test_redis_write_failure_still_returns_result(self): with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None mock_r.set.side_effect = ConnectionError("Redis down") - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = api_response mock_req.get.return_value = mock_resp @@ -211,7 +220,7 @@ def test_api_status_not_ok_returns_none(self): """Non-OK status (e.g. ZERO_RESULTS, OVER_QUERY_LIMIT) returns None.""" with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = {"status": "ZERO_RESULTS", "results": []} mock_req.get.return_value = mock_resp @@ -225,7 +234,7 @@ def test_missing_place_id_returns_none(self): """Result with no place_id returns None.""" with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = { "status": "OK", @@ -241,7 +250,7 @@ def test_missing_place_id_key_returns_none(self): """Result with no place_id key at all returns None.""" with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = { "status": "OK", @@ -257,7 +266,7 @@ def test_empty_types_gives_none_location_type(self): """Result with no types gives location_type=None.""" with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = { "status": "OK", @@ -274,7 +283,7 @@ def test_missing_types_key_gives_none_location_type(self): """Result with no 'types' key at all gives location_type=None.""" with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = { "status": "OK", @@ -305,7 +314,7 @@ def test_invalid_json_falls_through_to_api(self): } with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = "not-valid-json{{" - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = api_response mock_req.get.return_value = mock_resp @@ -332,7 +341,7 @@ def test_schema_mismatch_falls_through_to_api(self): with patch("utils.conversations.location.r") as mock_r: # Missing required 'latitude' and 'longitude' fields mock_r.get.return_value = json.dumps({"bad_field": "bad_value"}) - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = api_response mock_req.get.return_value = mock_resp diff --git a/backend/tests/unit/test_kg_user_type_mismatch.py b/backend/tests/unit/test_kg_user_type_mismatch.py index 0d20aa3273c..7627a8c1433 100644 --- a/backend/tests/unit/test_kg_user_type_mismatch.py +++ b/backend/tests/unit/test_kg_user_type_mismatch.py @@ -64,6 +64,9 @@ def _stub_module(name: str) -> types.ModuleType: "delete_memory_vector", "upsert_vector2", "update_vector_metadata", + "upsert_action_item_vectors_batch", + "delete_action_item_vectors_batch", + "find_similar_action_items", ]: setattr(vector_db_mod, attr, MagicMock()) @@ -112,12 +115,13 @@ def _stub_module(name: str) -> types.ModuleType: "utils.webhooks", "utils.task_sync", "utils.other.storage", + "utils.subscription", ]: if name not in sys.modules: sys.modules[name] = types.ModuleType(name) utils_apps = sys.modules["utils.apps"] -for attr in ["get_available_apps", "update_personas_async", "sync_update_persona_prompt"]: +for attr in ["get_available_apps", "update_personas_async", "update_persona_prompt", "sync_update_persona_prompt"]: setattr(utils_apps, attr, MagicMock()) utils_analytics = sys.modules["utils.analytics"] @@ -185,6 +189,9 @@ def _stub_module(name: str) -> types.ModuleType: utils_storage = sys.modules["utils.other.storage"] utils_storage.precache_conversation_audio = MagicMock() +utils_subscription = sys.modules["utils.subscription"] +utils_subscription.is_trial_paywalled = MagicMock(return_value=False) + import importlib process_conversation = importlib.import_module("utils.conversations.process_conversation") diff --git a/backend/tests/unit/test_llm_usage_db.py b/backend/tests/unit/test_llm_usage_db.py index b2b630243b6..fdd416bf7b7 100644 --- a/backend/tests/unit/test_llm_usage_db.py +++ b/backend/tests/unit/test_llm_usage_db.py @@ -3,6 +3,7 @@ """ import os +import importlib import sys import types from unittest.mock import MagicMock, patch @@ -21,7 +22,7 @@ sys.modules["database._client"] = mock_client_module sys.modules["stripe"] = MagicMock() -_google_module = sys.modules.setdefault("google", types.ModuleType("google")) +_google_module = importlib.import_module("google") _google_cloud_module = sys.modules.setdefault("google.cloud", types.ModuleType("google.cloud")) _google_firestore_module = types.ModuleType("google.cloud.firestore") _google_firestore_module.Increment = lambda x: {"__increment": x} diff --git a/backend/tests/unit/test_llm_usage_endpoints.py b/backend/tests/unit/test_llm_usage_endpoints.py index c7b1d6bb7a4..cd35f91055d 100644 --- a/backend/tests/unit/test_llm_usage_endpoints.py +++ b/backend/tests/unit/test_llm_usage_endpoints.py @@ -53,6 +53,9 @@ "set_user_data_protection_level", "get_generic_cache", "set_generic_cache", + "get_daily_summary_uid", + "store_daily_summary_to_uid", + "remove_daily_summary_to_uid", "set_speech_profile_duration", "r", ]: @@ -109,6 +112,9 @@ def _passthrough_decorator(func): "adapt_plans_for_legacy_client", "legacy_plan_features", "is_paid_plan", + "is_trial_paywalled", + "clear_trial_paywall_cache", + "get_trial_metadata", ]: setattr(subscription_mod, attr, MagicMock()) subscription_mod.get_paid_plan_definitions = MagicMock(return_value=[]) diff --git a/backend/tests/unit/test_llm_usage_tracker.py b/backend/tests/unit/test_llm_usage_tracker.py index 7a2e19d425b..2356928a3f7 100644 --- a/backend/tests/unit/test_llm_usage_tracker.py +++ b/backend/tests/unit/test_llm_usage_tracker.py @@ -3,6 +3,7 @@ """ import os +import importlib import sys import types from unittest.mock import MagicMock @@ -19,7 +20,7 @@ sys.modules["database._client"] = mock_client_module sys.modules["stripe"] = MagicMock() -_google_module = sys.modules.setdefault("google", types.ModuleType("google")) +_google_module = importlib.import_module("google") _google_cloud_module = sys.modules.setdefault("google.cloud", types.ModuleType("google.cloud")) _google_firestore_module = types.ModuleType("google.cloud.firestore") _google_firestore_module.Increment = lambda x: {"__increment": x} diff --git a/backend/tests/unit/test_lock_bypass_fixes.py b/backend/tests/unit/test_lock_bypass_fixes.py index 610fc81ea69..1afab9011a2 100644 --- a/backend/tests/unit/test_lock_bypass_fixes.py +++ b/backend/tests/unit/test_lock_bypass_fixes.py @@ -746,7 +746,9 @@ def test_scheduled_summary_excludes_locked(self): unlocked_conv = _make_conversation(locked=False, conversation_id='conv-2') conversations_db.get_conversations = MagicMock(return_value=[locked_conv, unlocked_conv]) - with patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True): + with patch('utils.other.notifications.is_trial_paywalled', return_value=False), patch( + 'utils.other.notifications.try_acquire_daily_summary_lock', return_value=True + ): with patch( 'utils.other.notifications.generate_comprehensive_daily_summary', return_value={'headline': 'Test', 'day_emoji': '📅', 'overview': 'ok'}, @@ -1264,8 +1266,10 @@ def test_suggest_goal_filters_locked_memories(self): mock_track.__exit__ = MagicMock(return_value=False) with patch('utils.llm.goals.track_usage', return_value=mock_track): - with patch('utils.llm.goals.llm_mini') as mock_llm: + with patch('utils.llm.goals.get_llm') as mock_get_llm: + mock_llm = MagicMock() mock_llm.invoke.return_value = mock_llm_response + mock_get_llm.return_value = mock_llm from utils.llm.goals import suggest_goal diff --git a/backend/tests/unit/test_mentor_notifications.py b/backend/tests/unit/test_mentor_notifications.py index 3adb6842e0d..db99c5e0d50 100644 --- a/backend/tests/unit/test_mentor_notifications.py +++ b/backend/tests/unit/test_mentor_notifications.py @@ -116,6 +116,9 @@ def _stub_module(name: str) -> types.ModuleType: tracker_mod.track_usage = MagicMock() tracker_mod.Features = MagicMock() +subscription_mod = _stub_module("utils.subscription") +subscription_mod.is_trial_paywalled = MagicMock(return_value=False) + # Stub utils.llms.memory (get_prompt_memories) llms_mod = _stub_module("utils.llms") if not hasattr(llms_mod, '__path__'): diff --git a/backend/tests/unit/test_process_conversation_usage_context.py b/backend/tests/unit/test_process_conversation_usage_context.py index 4f4695059d5..bf3db19a4c7 100644 --- a/backend/tests/unit/test_process_conversation_usage_context.py +++ b/backend/tests/unit/test_process_conversation_usage_context.py @@ -59,6 +59,7 @@ def _stub_module(name: str) -> types.ModuleType: "update_vector_metadata", "upsert_action_item_vectors_batch", "delete_action_item_vectors_batch", + "find_similar_action_items", ]: setattr(vector_db_mod, attr, MagicMock()) @@ -95,6 +96,7 @@ def _stub_module(name: str) -> types.ModuleType: "utils.webhooks", "utils.task_sync", "utils.other.storage", + "utils.subscription", ]: if name not in sys.modules: sys.modules[name] = types.ModuleType(name) @@ -165,6 +167,9 @@ def _stub_module(name: str) -> types.ModuleType: utils_storage = sys.modules["utils.other.storage"] utils_storage.precache_conversation_audio = MagicMock() +utils_subscription = sys.modules["utils.subscription"] +utils_subscription.is_trial_paywalled = MagicMock(return_value=False) + import importlib process_conversation = importlib.import_module("utils.conversations.process_conversation") diff --git a/backend/tests/unit/test_prompt_cache_integration.py b/backend/tests/unit/test_prompt_cache_integration.py index 554ff5dc520..431c029c7f9 100644 --- a/backend/tests/unit/test_prompt_cache_integration.py +++ b/backend/tests/unit/test_prompt_cache_integration.py @@ -609,10 +609,11 @@ def __init__(self, **kwargs): # Read source, replace imports, exec in isolated namespace source = (BACKEND_DIR / "utils" / "llm" / "clients.py").read_text() source = source.replace("from langchain_openai import ChatOpenAI, OpenAIEmbeddings", "") + source = source.replace("from langchain_google_genai import ChatGoogleGenerativeAI", "") source = source.replace("import tiktoken", "") source = source.replace("import anthropic", "") source = source.replace("from langchain_core.output_parsers import PydanticOutputParser", "") - source = source.replace("from models.conversation import Structured", "") + source = source.replace("from models.structured import Structured", "") source = source.replace("from utils.llm.usage_tracker import get_usage_callback", "") # Create a fake anthropic module with AsyncAnthropic @@ -622,6 +623,7 @@ def __init__(self, **kwargs): ns = { "os": os, "ChatOpenAI": FakeChatOpenAI, + "ChatGoogleGenerativeAI": FakeChatOpenAI, "OpenAIEmbeddings": FakeOpenAIEmbeddings, "tiktoken": fake_tiktoken, "anthropic": fake_anthropic, diff --git a/backend/tests/unit/test_provider_evaluation.py b/backend/tests/unit/test_provider_evaluation.py new file mode 100644 index 00000000000..556c3b5a6fa --- /dev/null +++ b/backend/tests/unit/test_provider_evaluation.py @@ -0,0 +1,197 @@ +import json +from pathlib import Path + +from utils.stt.provider_evaluation import ( + ProviderGateThresholds, + build_comparison_report, + compact_markdown_report, + evaluate_report_gates, + summarize_provider_output, +) + +FIXTURE_DIR = Path(__file__).resolve().parents[1] / 'fixtures' / 'stt_provider_eval' + + +def _load_fixture_case() -> dict: + manifest = json.loads((FIXTURE_DIR / 'manifest.json').read_text()) + case = manifest['cases'][0] + return { + 'id': case['id'], + 'deepgram': { + 'transcript': json.loads((FIXTURE_DIR / case['deepgram_fixture']).read_text()), + 'ledger': json.loads((FIXTURE_DIR / case['deepgram_rollup']).read_text()), + }, + 'assemblyai': { + 'transcript': json.loads((FIXTURE_DIR / case['assemblyai_fixture']).read_text()), + 'ledger': json.loads((FIXTURE_DIR / case['assemblyai_rollup']).read_text()), + }, + } + + +def _load_manifest_cases() -> list[dict]: + manifest = json.loads((FIXTURE_DIR / 'manifest.json').read_text()) + cases = [] + for case in manifest['cases']: + prepared = { + 'id': case['id'], + 'scenario': case.get('scenario'), + 'current_policy_provider': case.get('current_policy_provider'), + } + for provider in ('deepgram', 'assemblyai'): + if provider in case: + prepared[provider] = case[provider] + continue + prepared[provider] = { + 'transcript': json.loads((FIXTURE_DIR / case[f'{provider}_fixture']).read_text()), + 'ledger': json.loads((FIXTURE_DIR / case[f'{provider}_rollup']).read_text()), + } + cases.append(prepared) + return cases + + +def test_fixture_report_passes_and_includes_cost_identity_and_timing_metrics(): + report = build_comparison_report([_load_fixture_case()]) + + assert report['status'] == 'passed' + assert report['case_count'] == 1 + assert report['aggregate']['assemblyai_estimated_cost_usd'] == 0.00023611 + case = report['cases'][0] + assert case['comparison']['transcript_word_error_rate'] == 0.0 + assert case['comparison']['average_timestamp_drift_seconds'] > 0.0 + assert case['providers']['assemblyai']['speaker_cluster_count'] == 2 + assert case['providers']['assemblyai']['speaker_word_purity'] == 1.0 + assert case['providers']['assemblyai']['identified_speaker_cluster_count'] == 1 + assert case['providers']['assemblyai']['unknown_speaker_cluster_count'] == 1 + assert case['providers']['assemblyai']['low_confidence_identity_rate'] == 0.5 + assert all(gate['severity'] == 'pass' for gate in case['gates']) + + +def test_manifest_report_includes_strategy_rollups_gap_report_and_fragmentation_metrics(): + report = build_comparison_report(_load_manifest_cases()) + + assert report['status'] == 'passed' + assert set(report['strategies']) == {'always_deepgram', 'always_assemblyai', 'current_policy', 'shadow_only'} + assert report['strategies']['always_assemblyai']['provider'] == 'assemblyai' + assert report['strategies']['shadow_only']['provider'] == 'deepgram' + assert report['strategies']['current_policy']['provider'] == 'assemblyai' + assert report['strategies']['always_assemblyai']['split_count'] >= 2 + assert report['strategies']['always_assemblyai']['estimated_cost_per_hour_usd'] > 0 + assert report['assemblyai_gap_report']['status'] == 'limited' + assert any( + item['scenario'] == 'saved_real_policy_router_outputs' + and item['metric'] in {'speaker_word_purity', 'estimated_cost_per_hour_usd'} + for item in report['assemblyai_gap_report']['limiting_scenarios'] + ) + assert any(gate['gate_group'] == 'speaker_safety' for case in report['cases'] for gate in case['gates']) + + +def test_threshold_failures_are_reported_for_transcript_drift_and_fallback(): + case = _load_fixture_case() + case['assemblyai']['transcript']['segments'][0]['text'] = 'Completely unrelated output.' + case['assemblyai']['ledger']['fallback_count'] = 1 + + report = build_comparison_report( + [case], + ProviderGateThresholds(max_transcript_word_error_rate=0.05, max_fallback_rate=0.10), + ) + passed, messages = evaluate_report_gates(report) + + assert report['status'] == 'failed' + assert not passed + assert any('transcript_word_error_rate' in message for message in messages) + assert any('assemblyai_fallback_rate' in message for message in messages) + + +def test_missing_instrumentation_is_warning_not_failure_by_default(): + case = _load_fixture_case() + case['assemblyai'].pop('ledger') + + report = build_comparison_report([case]) + passed, messages = evaluate_report_gates(report) + strict_passed, strict_messages = evaluate_report_gates(report, fail_on_warning=True) + + assert report['status'] == 'passed' + assert passed + assert messages == [] + assert not strict_passed + assert any('assemblyai_instrumentation' in message for message in strict_messages) + + +def test_provider_result_words_are_grouped_into_cluster_segments(): + payload = { + 'provider': 'assemblyai', + 'model': 'universal-2', + 'words': [ + {'text': 'hello', 'start': 0.0, 'end': 0.2, 'provider_cluster_id': 'A'}, + {'text': 'there', 'start': 0.2, 'end': 0.5, 'provider_cluster_id': 'A'}, + {'text': 'hi', 'start': 0.7, 'end': 0.9, 'provider_cluster_id': 'B'}, + ], + } + + summary = summarize_provider_output('assemblyai', {'transcript': payload}) + + assert summary['segment_count'] == 2 + assert summary['word_count'] == 3 + assert summary['speaker_cluster_count'] == 2 + + +def test_unknown_identity_counts_as_low_confidence(): + summary = summarize_provider_output( + 'assemblyai', + { + 'transcript': { + 'segments': [ + { + 'text': 'hello', + 'provider_cluster_id': 'A', + 'speaker_identity_state': 'unknown', + 'speaker_identity_confidence': None, + } + ] + } + }, + ) + + assert summary['low_confidence_identity_count'] == 1 + + +def test_low_confidence_identity_counts_clusters_not_segments(): + summary = summarize_provider_output( + 'assemblyai', + { + 'transcript': { + 'segments': [ + { + 'text': 'hello', + 'provider_cluster_id': 'A', + 'speaker_identity_state': 'unknown', + }, + { + 'text': 'again', + 'provider_cluster_id': 'A', + 'speaker_identity_state': 'unknown', + }, + { + 'text': 'there', + 'provider_cluster_id': 'B', + 'speaker_identity_state': 'identified', + 'speaker_identity_confidence': 0.9, + }, + ] + } + }, + ) + + assert summary['low_confidence_identity_count'] == 1 + assert summary['low_confidence_identity_rate'] == 0.5 + + +def test_compact_markdown_report_is_review_friendly(): + report = build_comparison_report(_load_manifest_cases()) + markdown = compact_markdown_report(report) + + assert '# STT Provider Evaluation: PASSED' in markdown + assert 'fixture_good_meeting' in markdown + assert 'Strategy Rollup' in markdown + assert 'AssemblyAI Gap Report' in markdown + assert 'AssemblyAI default readiness' in markdown diff --git a/backend/tests/unit/test_rate_limiting.py b/backend/tests/unit/test_rate_limiting.py index c358050d7cd..ea36db2c39b 100644 --- a/backend/tests/unit/test_rate_limiting.py +++ b/backend/tests/unit/test_rate_limiting.py @@ -15,8 +15,10 @@ 'firebase_admin.auth', 'google.cloud', 'google.cloud.firestore', + 'google.cloud.firestore_v1', 'database.redis_db', 'database.auth', + 'database.users', ]: if mod_name not in sys.modules: sys.modules[mod_name] = types.ModuleType(mod_name) @@ -56,6 +58,7 @@ def _check_rate_limit(key, policy, max_requests, window): redis_db_stub.check_rate_limit = _check_rate_limit +sys.modules['database.users'].record_user_platform = MagicMock() from utils.rate_limit_config import RATE_POLICIES, get_effective_limit, RATE_LIMIT_BOOST @@ -328,6 +331,8 @@ def test_all_router_policies_exist(self): "chat:initial", "voice:message", "voice:transcribe", + "desktop:background_transcribe", + "desktop:background_conversation_finish", "file:upload", "agent:execute_tool", "mcp:sse", @@ -484,8 +489,7 @@ def test_lua_script_has_ttl_self_heal(self): # register_script was called with the Lua source call_args = self.real_module.r.register_script.call_args lua_source = call_args[0][0] - self.assertIn('TTL', lua_source) - self.assertIn('ttl < 0', lua_source) + self.assertIn('daily_ttl', lua_source) self.assertIn('EXPIRE', lua_source) def test_lua_script_uses_incr(self): diff --git a/backend/tests/unit/test_realtime_integrations_usage_tracking.py b/backend/tests/unit/test_realtime_integrations_usage_tracking.py index cfc6e4875af..b6f847e0823 100644 --- a/backend/tests/unit/test_realtime_integrations_usage_tracking.py +++ b/backend/tests/unit/test_realtime_integrations_usage_tracking.py @@ -40,6 +40,8 @@ def _stub_module(name: str) -> types.ModuleType: "calendar_meetings", "vector_db", "apps", + "announcements", + "user_usage", "llm_usage", "_client", "chat", @@ -61,6 +63,12 @@ def _stub_module(name: str) -> types.ModuleType: apps_mod = sys.modules["database.apps"] apps_mod.record_app_usage = MagicMock() +announcements_mod = sys.modules["database.announcements"] +announcements_mod.compare_versions = MagicMock(return_value=0) + +user_usage_mod = sys.modules["database.user_usage"] +user_usage_mod.get_monthly_chat_usage = MagicMock() + llm_usage_mod = sys.modules["database.llm_usage"] llm_usage_mod.record_llm_usage = MagicMock() @@ -93,6 +101,10 @@ def _stub_module(name: str) -> types.ModuleType: conversations_mod = sys.modules["database.conversations"] conversations_mod.get_conversations_by_id = MagicMock(return_value=[]) +users_mod = sys.modules["database.users"] +users_mod.get_user_valid_subscription = MagicMock(return_value=None) +users_mod.is_byok_active = MagicMock(return_value=False) + from utils.llm import usage_tracker # Stub remaining utils modules diff --git a/backend/tests/unit/test_self_voice_review.py b/backend/tests/unit/test_self_voice_review.py new file mode 100644 index 00000000000..f5e56d90503 --- /dev/null +++ b/backend/tests/unit/test_self_voice_review.py @@ -0,0 +1,358 @@ +import io +import sys +import wave +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import numpy as np +import pytest + +sys.modules.setdefault('database._client', MagicMock()) +sys.modules.setdefault('database.users', MagicMock()) +speaker_embedding_mod = MagicMock() +speaker_embedding_mod.extract_embedding_from_bytes = MagicMock() +sys.modules.setdefault('utils.stt.speaker_embedding', speaker_embedding_mod) +firestore_v1_mod = sys.modules.setdefault('google.cloud.firestore_v1', MagicMock()) +firestore_v1_mod.FieldFilter = MagicMock() + +firestore_mod = MagicMock() +firestore_mod.Query.DESCENDING = 'DESCENDING' +sys.modules.setdefault('google.cloud.firestore', firestore_mod) + +google_cloud_mod = sys.modules.setdefault('google.cloud', MagicMock()) +google_cloud_mod.firestore = firestore_mod + +from models.transcript_segment import TranscriptSegment +from utils.stt.background_speaker_identity import ClusterIdentityAssignment, SPEAKER_IDENTITY_SOURCE +from utils.self_voice_review import ( + SegmentQuality, + build_self_voice_review_candidate, + confirm_self_voice_candidate, + delete_confirmed_self_voice_sample, + reject_self_voice_candidate, + skip_self_voice_candidate, +) + + +def _segment(segment_id, start, end, cluster='cluster-a', text='hello this is a clean sentence'): + return TranscriptSegment( + id=segment_id, + text=text, + speaker='SPEAKER_01', + is_user=False, + start=start, + end=end, + provider_cluster_id=cluster, + provider_speaker_label='SPEAKER_01', + speaker_identity_state='unknown', + ) + + +def _user_assignment(confidence=0.82): + return ClusterIdentityAssignment( + provider_cluster_id='cluster-a', + speaker_id=1, + state='user', + is_user=True, + confidence=confidence, + distance=0.1, + source=SPEAKER_IDENTITY_SOURCE, + ) + + +def _wav_bytes(duration_seconds=6, sample_rate=16000): + samples = np.zeros(int(duration_seconds * sample_rate), dtype=np.int16) + out = io.BytesIO() + with wave.open(out, 'wb') as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(samples.tobytes()) + return out.getvalue() + + +class FakeReviewDb: + DEFAULT_CANDIDATE_TTL_DAYS = 30 + + def __init__(self): + self.candidates = {} + self.negative_markers = set() + self.confirmed = [] + self.rejected = [] + self.skipped = [] + self.deleted = [] + + def candidate_id_from_source(self, conversation_id, provider_cluster_id, segment_ids): + return f'{conversation_id}:{provider_cluster_id}:{"-".join(segment_ids)}' + + def marker_id_from_source(self, conversation_id, provider_cluster_id): + return f'{conversation_id}:{provider_cluster_id}' + + def get_candidate(self, _uid, candidate_id): + return self.candidates.get(candidate_id) + + def has_negative_marker(self, _uid, marker_id): + return marker_id in self.negative_markers + + def recently_shown_source_exists(self, *_args, **_kwargs): + return False + + def upsert_candidate(self, _uid, candidate): + self.candidates[candidate['candidate_id']] = candidate + return True + + def mark_candidate_confirmed(self, _uid, candidate_id, embedding_version): + self.candidates[candidate_id]['review_status'] = 'confirmed' + self.candidates[candidate_id]['confirmed_sample'] = { + 'candidate_id': candidate_id, + 'embedding_version': embedding_version, + 'revisable': True, + } + self.confirmed.append((candidate_id, embedding_version)) + + def mark_candidate_rejected(self, _uid, candidate): + marker_id = self.marker_id_from_source( + candidate['source']['conversation_id'], candidate['source']['provider_cluster_id'] + ) + self.negative_markers.add(marker_id) + self.candidates[candidate['candidate_id']]['review_status'] = 'rejected' + self.candidates[candidate['candidate_id']]['negative_review_marker'] = {'marker_id': marker_id} + self.rejected.append(candidate['candidate_id']) + return marker_id + + def mark_candidate_skipped(self, _uid, candidate_id, cooldown_until, reviewed_at=None): + self.candidates[candidate_id]['review_status'] = 'pending' + self.candidates[candidate_id]['last_review_action'] = 'skipped' + self.candidates[candidate_id]['cooldown_until'] = cooldown_until + self.skipped.append((candidate_id, cooldown_until, reviewed_at)) + + def delete_confirmed_sample(self, _uid, candidate_id): + if self.candidates[candidate_id]['review_status'] != 'confirmed': + return False + self.candidates[candidate_id]['review_status'] = 'deleted' + self.deleted.append(candidate_id) + return True + + +def test_candidate_creation_stores_state_without_transcript_text(monkeypatch): + from utils import self_voice_review + + fake_db = FakeReviewDb() + monkeypatch.setattr(self_voice_review, 'review_db', fake_db) + segments = [_segment('s1', 0.0, 6.0)] + + result = build_self_voice_review_candidate( + uid='uid', + conversation_id='conv', + provider_cluster_id='cluster-a', + cluster_segments=segments, + all_segments=segments, + identity_assignment=_user_assignment(), + quality_by_segment_id={'s1': SegmentQuality(voiced_seconds=5.5, vad_confidence=0.9, noise_score=0.1)}, + audio_artifact_ref='gs://bucket/clip.wav', + audio_retention_allowed=True, + now=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + + assert result.candidate is not None + candidate = result.candidate + assert candidate['candidate_id'] == 'conv:cluster-a:s1' + assert candidate['confidence_bucket'] == 'high' + assert candidate['quality_scores']['sample_seconds'] == 6.0 + assert candidate['quality_scores']['voiced_ratio'] == pytest.approx(0.917) + assert candidate['retention']['transcript_text_stored'] is False + assert 'text' not in candidate['source'] + assert 'transcript' not in candidate['source'] + assert fake_db.candidates[candidate['candidate_id']]['review_status'] == 'pending' + + +def test_candidate_quality_filters_overlap_short_low_vad_noise_and_retention(monkeypatch): + from utils import self_voice_review + + fake_db = FakeReviewDb() + monkeypatch.setattr(self_voice_review, 'review_db', fake_db) + + clean = [_segment('s1', 0.0, 6.0)] + no_audio = build_self_voice_review_candidate( + 'uid', 'conv', 'cluster-a', clean, clean, _user_assignment(), audio_retention_allowed=False + ) + assert no_audio.reason == 'audio_unavailable_or_retention_disallowed' + + short = [_segment('s2', 0.0, 3.0)] + short_result = build_self_voice_review_candidate( + 'uid', + 'conv', + 'cluster-a', + short, + short, + _user_assignment(), + audio_artifact_ref='ref', + audio_retention_allowed=True, + ) + assert short_result.reason == 'sample_too_short' + + overlap = [_segment('s3', 0.0, 6.0)] + all_segments = overlap + [_segment('other', 2.0, 3.0, cluster='cluster-b')] + overlap_result = build_self_voice_review_candidate( + 'uid', + 'conv', + 'cluster-a', + overlap, + all_segments, + _user_assignment(), + audio_artifact_ref='ref', + audio_retention_allowed=True, + ) + assert overlap_result.reason == 'no_clean_voiced_window' + + low_vad = build_self_voice_review_candidate( + 'uid', + 'conv', + 'cluster-low-vad', + [_segment('s4', 0.0, 6.0, cluster='cluster-low-vad')], + [_segment('s4', 0.0, 6.0, cluster='cluster-low-vad')], + _user_assignment(), + {'s4': {'voiced_seconds': 5.5, 'vad_confidence': 0.5}}, + audio_artifact_ref='ref', + audio_retention_allowed=True, + ) + assert low_vad.reason == 'low_vad_confidence' + + noisy = build_self_voice_review_candidate( + 'uid', + 'conv', + 'cluster-noisy', + [_segment('s5', 0.0, 6.0, cluster='cluster-noisy')], + [_segment('s5', 0.0, 6.0, cluster='cluster-noisy')], + _user_assignment(), + {'s5': {'voiced_seconds': 5.5, 'vad_confidence': 0.9, 'noise_score': 0.8}}, + audio_artifact_ref='ref', + audio_retention_allowed=True, + ) + assert noisy.reason == 'noisy_clip' + + +def test_dedupe_reject_marker_and_recently_shown_suppress_candidates(monkeypatch): + from utils import self_voice_review + + fake_db = FakeReviewDb() + monkeypatch.setattr(self_voice_review, 'review_db', fake_db) + segments = [_segment('s1', 0.0, 6.0)] + kwargs = { + 'uid': 'uid', + 'conversation_id': 'conv', + 'provider_cluster_id': 'cluster-a', + 'cluster_segments': segments, + 'all_segments': segments, + 'identity_assignment': _user_assignment(), + 'quality_by_segment_id': {'s1': {'voiced_seconds': 5.5, 'vad_confidence': 0.9}}, + 'audio_artifact_ref': 'ref', + 'audio_retention_allowed': True, + } + + assert build_self_voice_review_candidate(**kwargs).candidate is not None + assert build_self_voice_review_candidate(**kwargs).reason == 'already_confirmed_or_pending' + + fake_db.candidates.clear() + fake_db.negative_markers.add('conv:cluster-a') + assert build_self_voice_review_candidate(**kwargs).reason == 'rejected_source' + + fake_db.negative_markers.clear() + fake_db.recently_shown_source_exists = lambda *_args, **_kwargs: True + assert build_self_voice_review_candidate(**kwargs).reason == 'recently_shown' + + +def test_confirm_reject_skip_and_delete_actions(monkeypatch): + from utils import self_voice_review + + fake_db = FakeReviewDb() + fake_users_db = MagicMock() + monkeypatch.setattr(self_voice_review, 'review_db', fake_db) + monkeypatch.setattr(self_voice_review, 'users_db', fake_users_db) + segments = [_segment('s1', 0.0, 6.0)] + candidate = build_self_voice_review_candidate( + 'uid', + 'conv', + 'cluster-a', + segments, + segments, + _user_assignment(), + {'s1': {'voiced_seconds': 5.5, 'vad_confidence': 0.9}}, + audio_artifact_ref='ref', + audio_retention_allowed=True, + ).candidate + + skipped = skip_self_voice_candidate('uid', candidate['candidate_id'], now=datetime(2026, 1, 1, tzinfo=timezone.utc)) + assert skipped['review_status'] == 'pending' + assert skipped['last_review_action'] == 'skipped' + assert fake_db.candidates[candidate['candidate_id']]['cooldown_until'].year == 2026 + + embedding = np.array([[1.0, 2.0, 3.0]], dtype=np.float32) + confirmed = confirm_self_voice_candidate('uid', candidate['candidate_id'], embedding=embedding) + assert confirmed['review_status'] == 'confirmed' + fake_users_db.set_user_speaker_embedding.assert_called_once_with('uid', [1.0, 2.0, 3.0]) + assert fake_db.candidates[candidate['candidate_id']]['confirmed_sample']['revisable'] is True + assert delete_confirmed_self_voice_sample('uid', candidate['candidate_id']) is True + + second = build_self_voice_review_candidate( + 'uid', + 'conv', + 'cluster-b', + [_segment('s2', 10.0, 16.0, cluster='cluster-b')], + [_segment('s2', 10.0, 16.0, cluster='cluster-b')], + _user_assignment(), + {'s2': {'voiced_seconds': 5.5, 'vad_confidence': 0.9}}, + audio_artifact_ref='ref', + audio_retention_allowed=True, + ).candidate + rejected = reject_self_voice_candidate('uid', second['candidate_id']) + assert rejected['review_status'] == 'rejected' + assert rejected['negative_marker_id'] == 'conv:cluster-b' + assert fake_db.candidates[second['candidate_id']]['negative_review_marker']['marker_id'] == 'conv:cluster-b' + + +def test_confirm_can_extract_embedding_from_audio(monkeypatch): + from utils import self_voice_review + + fake_db = FakeReviewDb() + fake_users_db = MagicMock() + monkeypatch.setattr(self_voice_review, 'review_db', fake_db) + monkeypatch.setattr(self_voice_review, 'users_db', fake_users_db) + segments = [_segment('s1', 0.0, 6.0)] + candidate = build_self_voice_review_candidate( + 'uid', + 'conv', + 'cluster-a', + segments, + segments, + _user_assignment(), + {'s1': {'voiced_seconds': 5.5, 'vad_confidence': 0.9}}, + audio_artifact_ref='ref', + audio_retention_allowed=True, + ).candidate + + def fake_extractor(audio_bytes, filename): + assert audio_bytes == _wav_bytes() + assert filename.endswith('.wav') + return np.array([[4.0, 5.0]], dtype=np.float32) + + confirm_self_voice_candidate( + 'uid', candidate['candidate_id'], audio_bytes=_wav_bytes(), embedding_extractor=fake_extractor + ) + + fake_users_db.set_user_speaker_embedding.assert_called_once_with('uid', [4.0, 5.0]) + + +def test_candidate_storage_rejects_transcript_and_raw_audio_payloads(): + from database.self_voice_review import _reject_forbidden_candidate_keys + + with pytest.raises(ValueError, match='forbidden keys'): + _reject_forbidden_candidate_keys( + { + 'candidate_id': 'candidate', + 'source': {'conversation_id': 'conv', 'transcript_text': 'this must not be stored'}, + } + ) + + with pytest.raises(ValueError, match='forbidden keys'): + _reject_forbidden_candidate_keys({'candidate_id': 'candidate', 'raw_audio': b'not allowed'}) diff --git a/backend/tests/unit/test_speaker_sample_migration.py b/backend/tests/unit/test_speaker_sample_migration.py index a0ad4456dc7..25849516748 100644 --- a/backend/tests/unit/test_speaker_sample_migration.py +++ b/backend/tests/unit/test_speaker_sample_migration.py @@ -1,4 +1,5 @@ import asyncio +import importlib import os import sys import types @@ -21,7 +22,7 @@ class NotFound(Exception): pass -_google_module = sys.modules.setdefault("google", types.ModuleType("google")) +_google_module = importlib.import_module("google") _google_cloud_module = sys.modules.setdefault("google.cloud", types.ModuleType("google.cloud")) _google_exceptions_module = types.ModuleType("google.cloud.exceptions") _google_exceptions_module.NotFound = NotFound diff --git a/backend/tests/unit/test_storage_opus_encoding.py b/backend/tests/unit/test_storage_opus_encoding.py index 5d5b6f43de5..9cdb7a9cb16 100644 --- a/backend/tests/unit/test_storage_opus_encoding.py +++ b/backend/tests/unit/test_storage_opus_encoding.py @@ -28,8 +28,6 @@ sys.modules.setdefault("google.cloud.storage", _mock_gcs_storage) sys.modules.setdefault("google.cloud.storage.transfer_manager", MagicMock()) sys.modules.setdefault("google.cloud.exceptions", MagicMock()) -sys.modules.setdefault("google.oauth2", MagicMock()) -sys.modules.setdefault("google.oauth2.service_account", MagicMock()) from utils.other import storage as storage_mod diff --git a/backend/tests/unit/test_storage_upload_audio_chunk_data_protection.py b/backend/tests/unit/test_storage_upload_audio_chunk_data_protection.py index 9fedab43864..abf80c75ee5 100644 --- a/backend/tests/unit/test_storage_upload_audio_chunk_data_protection.py +++ b/backend/tests/unit/test_storage_upload_audio_chunk_data_protection.py @@ -23,8 +23,6 @@ sys.modules.setdefault("google.cloud.storage", _mock_gcs_storage) sys.modules.setdefault("google.cloud.storage.transfer_manager", MagicMock()) sys.modules.setdefault("google.cloud.exceptions", MagicMock()) -sys.modules.setdefault("google.oauth2", MagicMock()) -sys.modules.setdefault("google.oauth2.service_account", MagicMock()) # Now import the module under test from utils.other import storage as storage_mod @@ -71,9 +69,10 @@ def test_falls_back_to_db_when_level_not_provided(self, mock_users_db): mock_users_db.get_data_protection_level.assert_called_once_with('test-uid') + @patch.object(storage_mod, 'encode_pcm_to_opus', return_value=b'opus-data') @patch.object(storage_mod, 'users_db') - def test_standard_level_uploads_unencrypted(self, mock_users_db): - """Standard protection level should upload .bin (no encryption).""" + def test_standard_level_uploads_unencrypted(self, mock_users_db, mock_encode): + """Standard protection level should upload .opus (no encryption).""" _, mock_blob = self._setup_mock_bucket() path = storage_mod.upload_audio_chunk( @@ -84,12 +83,15 @@ def test_standard_level_uploads_unencrypted(self, mock_users_db): data_protection_level='standard', ) - assert path.endswith('.bin') + assert path.endswith('.opus') + mock_encode.assert_called_once_with(b'\x00' * 100) mock_blob.upload_from_string.assert_called_once() + assert mock_blob.upload_from_string.call_args[0][0] == b'opus-data' + @patch.object(storage_mod, 'encode_pcm_to_opus', return_value=b'opus-data') @patch.object(storage_mod, 'encryption') @patch.object(storage_mod, 'users_db') - def test_enhanced_level_uploads_encrypted(self, mock_users_db, mock_encryption): + def test_enhanced_level_uploads_encrypted(self, mock_users_db, mock_encryption, mock_encode): """Enhanced protection level should encrypt and upload .enc.""" _, mock_blob = self._setup_mock_bucket() mock_encryption.encrypt_audio_chunk.return_value = b'\x01' * 120 @@ -103,7 +105,9 @@ def test_enhanced_level_uploads_encrypted(self, mock_users_db, mock_encryption): ) assert path.endswith('.enc') - mock_encryption.encrypt_audio_chunk.assert_called_once_with(b'\x00' * 100, 'test-uid') + mock_encode.assert_called_once_with(b'\x00' * 100) + mock_encryption.encrypt_audio_chunk.assert_called_once_with(b'opus-data', 'test-uid') + mock_blob.upload_from_string.assert_called_once_with(b'\x01' * 120, content_type='application/octet-stream') @patch.object(storage_mod, 'users_db') def test_explicit_none_falls_back_to_db(self, mock_users_db): diff --git a/backend/tests/unit/test_streaming_deepgram_backoff.py b/backend/tests/unit/test_streaming_deepgram_backoff.py index d5457612fff..93399787e57 100644 --- a/backend/tests/unit/test_streaming_deepgram_backoff.py +++ b/backend/tests/unit/test_streaming_deepgram_backoff.py @@ -15,14 +15,18 @@ # Mock heavy dependencies before importing streaming module _mock_modules = {} for mod_name in [ + 'cachetools', 'database', 'database._client', 'database.users', + 'utils.http_client', 'utils.other.storage', 'deepgram', 'deepgram.clients', 'deepgram.clients.live', 'deepgram.clients.live.v1', + 'onnxruntime', + 'pydub', 'websockets', 'websockets.exceptions', ]: @@ -39,6 +43,20 @@ sys.modules['deepgram'].LiveTranscriptionEvents = MagicMock() sys.modules['deepgram.clients.live.v1'].LiveOptions = MagicMock +if 'cachetools' in _mock_modules: + + class FakeTTLCache(dict): + def __init__(self, maxsize=None, ttl=None): + super().__init__() + + sys.modules['cachetools'].TTLCache = FakeTTLCache + +if 'pydub' in _mock_modules: + sys.modules['pydub'].AudioSegment = MagicMock + +if 'utils.http_client' in _mock_modules: + sys.modules['utils.http_client'].get_stt_client = MagicMock() + from utils.stt.streaming import connect_to_deepgram_with_backoff, process_audio_dg # noqa: E402 from utils.stt.streaming import deepgram_options, deepgram_cloud_options # noqa: E402 from utils.stt.streaming import get_stt_service_for_language, STTService, should_preserve_filler_words # noqa: E402 diff --git a/backend/tests/unit/test_stt_provider_facade.py b/backend/tests/unit/test_stt_provider_facade.py new file mode 100644 index 00000000000..eb76fa55f71 --- /dev/null +++ b/backend/tests/unit/test_stt_provider_facade.py @@ -0,0 +1,166 @@ +import sys +from unittest.mock import MagicMock + +for mod_name in ['deepgram', 'deepgram.clients', 'deepgram.clients.live', 'deepgram.clients.live.v1']: + if mod_name not in sys.modules: + sys.modules[mod_name] = MagicMock() + +sys.modules['deepgram'].DeepgramClient = MagicMock +sys.modules['deepgram'].DeepgramClientOptions = MagicMock +sys.modules['deepgram'].LiveTranscriptionEvents = MagicMock() +sys.modules['deepgram.clients.live.v1'].LiveOptions = MagicMock + +from utils.stt.deepgram_adapter import ( # noqa: E402 + DeepgramPrerecordedTranscriptionProvider, + normalize_deepgram_prerecorded_result, + provider_result_to_legacy_words, +) +from utils.stt.providers import ( # noqa: E402 + STTProviderName, + STTWorkload, + get_prerecorded_provider_name, + get_streaming_provider_name, +) + + +def _deepgram_fixture(words=None, utterances=None): + return { + 'metadata': {'request_id': 'dg-request-1', 'duration': 3.25}, + 'results': { + 'channels': [ + { + 'detected_language': 'en-US', + 'alternatives': [ + { + 'words': ( + words + if words is not None + else [ + { + 'word': 'hello', + 'punctuated_word': 'Hello', + 'start': 0.0, + 'end': 0.4, + 'confidence': 0.91, + 'speaker': 2, + }, + { + 'word': 'world', + 'punctuated_word': 'world.', + 'start': 0.5, + 'end': 1.1, + 'confidence': 0.88, + 'speaker': 2, + }, + ] + ), + } + ], + } + ], + 'utterances': ( + utterances + if utterances is not None + else [ + { + 'transcript': 'Hello world.', + 'start': 0.0, + 'end': 1.1, + 'confidence': 0.9, + 'speaker': 2, + } + ] + ), + }, + } + + +def test_deepgram_fixture_normalizes_to_provider_transcript_result(): + result = normalize_deepgram_prerecorded_result(_deepgram_fixture(), model='nova-3') + + assert result.provider == STTProviderName.deepgram.value + assert result.model == 'nova-3' + assert result.language == 'en' + assert result.duration == 3.25 + assert result.raw_provider_result_id == 'dg-request-1' + assert len(result.words) == 2 + assert result.words[0].text == 'Hello' + assert result.words[0].provider_cluster_id == '2' + assert result.words[0].speaker_label == 'SPEAKER_02' + assert result.utterances[0].provider_cluster_id == '2' + + +def test_deepgram_result_converts_to_legacy_words_for_existing_callers(): + result = normalize_deepgram_prerecorded_result(_deepgram_fixture(), model='nova-3') + + words = provider_result_to_legacy_words(result) + + assert words == [ + { + 'timestamp': [0.0, 0.4], + 'speaker': 'SPEAKER_02', + 'provider_cluster_id': '2', + 'provider_speaker_label': 'SPEAKER_02', + 'stt_provider': 'deepgram', + 'stt_model': 'nova-3', + 'text': 'Hello', + }, + { + 'timestamp': [0.5, 1.1], + 'speaker': 'SPEAKER_02', + 'provider_cluster_id': '2', + 'provider_speaker_label': 'SPEAKER_02', + 'stt_provider': 'deepgram', + 'stt_model': 'nova-3', + 'text': 'world.', + }, + ] + + +def test_deepgram_adapter_preserves_prerecorded_request_options(): + fake_response = MagicMock() + fake_response.to_dict.return_value = _deepgram_fixture() + fake_rest = MagicMock() + fake_rest.transcribe_url.return_value = fake_response + fake_client = MagicMock() + fake_client.listen.rest.v.return_value = fake_rest + provider = DeepgramPrerecordedTranscriptionProvider(lambda: fake_client, timeout=MagicMock()) + + result, detected_language = provider.transcribe_url( + 'https://example.test/audio.wav', + return_language=True, + diarize=False, + language='multi', + model='nova-3', + keywords=['Omi', 'custom'], + ) + + call_args = fake_rest.transcribe_url.call_args + assert call_args.args[0] == {'url': 'https://example.test/audio.wav'} + assert call_args.args[1]['diarize'] is False + assert call_args.args[1]['detect_language'] is True + assert call_args.args[1]['keyterm'] == ['Omi', 'custom'] + assert 'language' not in call_args.args[1] + assert result.provider == 'deepgram' + assert detected_language == 'en' + + +def test_provider_routing_uses_assemblyai_only_for_passive_prerecorded_workloads_by_default(monkeypatch): + monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', raising=False) + monkeypatch.delenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', raising=False) + + assert get_streaming_provider_name(STTWorkload.ptt) == STTProviderName.deepgram + assert get_streaming_provider_name(STTWorkload.realtime) == STTProviderName.deepgram + + for workload in [ + STTWorkload.background, + STTWorkload.postprocess, + STTWorkload.sync, + ]: + assert get_prerecorded_provider_name(workload) == STTProviderName.assemblyai + + for workload in [ + STTWorkload.ptt, + STTWorkload.voice_message, + ]: + assert get_prerecorded_provider_name(workload) == STTProviderName.deepgram diff --git a/backend/tests/unit/test_subscription_plans.py b/backend/tests/unit/test_subscription_plans.py index c7db2277813..7af53c3247d 100644 --- a/backend/tests/unit/test_subscription_plans.py +++ b/backend/tests/unit/test_subscription_plans.py @@ -1,6 +1,8 @@ import sys import types +sys.modules.setdefault("database._client", types.SimpleNamespace(db=None)) +sys.modules.setdefault("database.announcements", types.SimpleNamespace(compare_versions=lambda _a, _b: 0)) sys.modules.setdefault("database.users", types.SimpleNamespace()) sys.modules.setdefault("database.user_usage", types.SimpleNamespace()) diff --git a/backend/tests/unit/test_subscription_restructure.py b/backend/tests/unit/test_subscription_restructure.py index f2c60f6d9b7..f2c4f44e12c 100644 --- a/backend/tests/unit/test_subscription_restructure.py +++ b/backend/tests/unit/test_subscription_restructure.py @@ -19,6 +19,7 @@ def _compare_versions(a, b): _announcements_mod._compare_versions = _compare_versions +_announcements_mod.compare_versions = _compare_versions sys.modules.setdefault("database.users", types.SimpleNamespace()) sys.modules.setdefault("database.user_usage", types.SimpleNamespace()) sys.modules.setdefault("database.announcements", _announcements_mod) diff --git a/backend/tests/unit/test_sync_fair_use_gate.py b/backend/tests/unit/test_sync_fair_use_gate.py index 7ad8f01c61e..1aa359211e4 100644 --- a/backend/tests/unit/test_sync_fair_use_gate.py +++ b/backend/tests/unit/test_sync_fair_use_gate.py @@ -16,6 +16,7 @@ 'database.users', 'database.user_usage', 'database.conversations', + 'utils.subscription', 'firebase_admin', 'firebase_admin.messaging', ]: @@ -28,6 +29,8 @@ # Stub database._client.db sys.modules['database._client'].db = MagicMock() +sys.modules['utils.subscription'].has_transcription_credits = MagicMock(return_value=True) +sys.modules['utils.subscription'].is_paid_plan = MagicMock(return_value=True) import utils.fair_use as fair_use_mod @@ -258,7 +261,11 @@ def _read_sync_source(): def test_no_402_block(self): """sync.py must not raise 402 (lock instead of block).""" source = self._read_sync_source() - assert 'status_code=402' not in source + for function_name in ('sync_local_files', 'sync_local_files_v2'): + start = source.index(f'async def {function_name}(') + next_route = source.find('\n@router.', start + 1) + body = source[start:] if next_route == -1 else source[start:next_route] + assert 'status_code=402' not in body def test_should_lock_flag_exists(self): """sync.py must use should_lock flag for credit-exhausted locking.""" diff --git a/backend/tests/unit/test_sync_opus_decode.py b/backend/tests/unit/test_sync_opus_decode.py index 62e8b4f665e..00a514f16bf 100644 --- a/backend/tests/unit/test_sync_opus_decode.py +++ b/backend/tests/unit/test_sync_opus_decode.py @@ -47,13 +47,13 @@ 'utils.other.storage', 'utils.encryption', 'utils.stt.pre_recorded', + 'utils.stt.speaker_embedding', 'utils.stt.vad', 'utils.fair_use', 'utils.subscription', 'utils.log_sanitizer', 'utils.executors', 'pydub', - 'numpy', 'httpx', ] for _mod in _stub_modules: diff --git a/backend/tests/unit/test_sync_record_usage.py b/backend/tests/unit/test_sync_record_usage.py index 72e522a3c85..5754f4e3db3 100644 --- a/backend/tests/unit/test_sync_record_usage.py +++ b/backend/tests/unit/test_sync_record_usage.py @@ -121,17 +121,17 @@ def test_record_usage_logs_error_on_failure(self): class TestV2RecordUsage: @staticmethod def _get_v2_body(): - return _extract_function_body(_read_sync_source(), '_run_full_pipeline_background') + return _extract_function_body(_read_sync_source(), '_run_full_pipeline_background_async') def test_record_usage_called_in_v2(self): body = self._get_v2_body() - assert 'record_usage(' in body, "v2 background worker must call record_usage" + assert 'record_usage' in body, "v2 background worker must call record_usage" def test_record_usage_after_failed_segments_check(self): """record_usage must come after failed_segments is computed.""" body = self._get_v2_body() failed_pos = body.find('failed_segments = len(segment_errors)') - record_pos = body.find('record_usage(') + record_pos = body.find('record_usage') assert failed_pos > 0 assert record_pos > 0 assert record_pos > failed_pos, "record_usage must come after failed_segments computation" @@ -139,25 +139,25 @@ def test_record_usage_after_failed_segments_check(self): def test_record_usage_guarded_by_successful_segments(self): """record_usage should only run when successful_segments > 0.""" body = self._get_v2_body() - record_idx = body.find('record_usage(') + record_idx = body.find('record_usage') preceding = body[max(0, record_idx - 300) : record_idx] assert 'successful_segments > 0' in preceding, "record_usage must be guarded by successful_segments > 0" def test_record_usage_before_final_mark_job_completed(self): """record_usage must run before the final mark_job_completed (after segment processing).""" body = self._get_v2_body() - record_pos = body.find('record_usage(') + record_pos = body.find('record_usage') assert record_pos > 0, "record_usage must exist" - complete_pos = body.rfind('mark_job_completed(') + complete_pos = body.rfind('mark_job_completed') assert complete_pos > 0, "mark_job_completed must exist" assert record_pos < complete_pos, "record_usage must run before the final mark_job_completed" def test_record_usage_wrapped_in_try_except(self): body = self._get_v2_body() - record_idx = body.find('record_usage(') - preceding = body[max(0, record_idx - 200) : record_idx] + record_idx = body.find('record_usage') + preceding = body[max(0, record_idx - 400) : record_idx] assert 'try:' in preceding, "record_usage must be inside a try block" - following = body[record_idx : record_idx + 200] + following = body[record_idx : record_idx + 500] assert 'except Exception' in following, "record_usage must have an except handler" diff --git a/backend/tests/unit/test_sync_silent_failure.py b/backend/tests/unit/test_sync_silent_failure.py index 5554cbf9aea..bad1779b2d5 100644 --- a/backend/tests/unit/test_sync_silent_failure.py +++ b/backend/tests/unit/test_sync_silent_failure.py @@ -260,7 +260,7 @@ def test_deepgram_raises_runtime_error_on_final_retry(self): assert 'raise RuntimeError' in except_body assert 'Deepgram transcription failed after' in except_body - assert 'attempts < 2' in func_body + assert 'attempts < 1' in func_body def test_deepgram_keeps_empty_words_as_success(self): """A valid Deepgram response with no words must still return [].""" @@ -276,15 +276,8 @@ def test_deepgram_keeps_empty_words_as_success(self): except ValueError: pass func_body = source[start:end] - no_words_block = func_body[ - func_body.index("dg_words = alternatives[0].get('words', [])") : func_body.index( - '# Convert Deepgram format' - ) - ] - - assert 'if not dg_words:' in no_words_block - assert 'return [], detected_lang or \'en\'' in no_words_block - assert 'return []' in no_words_block + assert 'provider_result_to_legacy_words' in func_body + assert 'raise RuntimeError' not in func_body[: func_body.index('except Exception as e:')] class TestDeepgramRetryBehavioral: @@ -320,6 +313,9 @@ def setup_class(cls): sys.modules['deepgram'].DeepgramClientOptions = MagicMock() sys.modules['fal_client'].submit = MagicMock() sys.modules['models.transcript_segment'].TranscriptSegment = MagicMock() + sys.modules['models.transcript_segment'].ProviderTranscriptResult = MagicMock() + sys.modules['models.transcript_segment'].ProviderTranscriptWord = MagicMock() + sys.modules['models.transcript_segment'].ProviderTranscriptUtterance = MagicMock() sys.modules['utils.other.endpoints'].timeit = lambda f: f # Force re-import so it picks up stubs @@ -342,41 +338,33 @@ def test_retry_exhaustion_raises_runtime_error(self): """deepgram_prerecorded must raise RuntimeError when all retries fail.""" from unittest.mock import MagicMock, patch - mock_client = MagicMock() - mock_client.listen.rest.v.return_value.transcribe_url.side_effect = ConnectionError('timeout') + mock_provider = MagicMock() + mock_provider.transcribe_url.side_effect = ConnectionError('timeout') - with patch('utils.stt.pre_recorded._deepgram_client', mock_client): - with pytest.raises(RuntimeError, match='Deepgram transcription failed after 3 attempts'): + with patch('utils.stt.pre_recorded._deepgram_prerecorded_provider', return_value=mock_provider): + with pytest.raises(RuntimeError, match='Deepgram transcription failed after 2 attempts'): self._deepgram_prerecorded('https://fake-audio.wav', attempts=0, return_language=True) - # Should have been called 3 times (initial + 2 retries) - assert mock_client.listen.rest.v.return_value.transcribe_url.call_count == 3 + # Should have been called 2 times (initial + retry) + assert mock_provider.transcribe_url.call_count == 2 def test_valid_empty_transcription_returns_empty_list(self): """deepgram_prerecorded must return ([], lang) when DG succeeds but finds no words.""" from unittest.mock import MagicMock, patch - mock_response = MagicMock() - mock_response.to_dict.return_value = { - 'results': { - 'channels': [ - { - 'alternatives': [{'words': []}], - 'detected_language': 'en', - } - ] - } - } - mock_client = MagicMock() - mock_client.listen.rest.v.return_value.transcribe_url.return_value = mock_response + transcript_result = MagicMock() + mock_provider = MagicMock() + mock_provider.transcribe_url.return_value = (transcript_result, 'en') - with patch('utils.stt.pre_recorded._deepgram_client', mock_client): + with patch('utils.stt.pre_recorded._deepgram_prerecorded_provider', return_value=mock_provider), patch( + 'utils.stt.pre_recorded.provider_result_to_legacy_words', return_value=[] + ): words, lang = self._deepgram_prerecorded('https://fake-audio.wav', return_language=True) assert words == [] assert lang == 'en' # Should be called exactly once (no retries for valid response) - assert mock_client.listen.rest.v.return_value.transcribe_url.call_count == 1 + assert mock_provider.transcribe_url.call_count == 1 # --------------------------------------------------------------------------- @@ -395,14 +383,14 @@ def _read_app_file(relative_path): return f.read() return None - def test_app_accepts_200_and_207(self): - """App treats both HTTP 200 and 207 as parseable responses.""" + def test_app_accepts_200_and_202(self): + """App treats both legacy HTTP 200 and async HTTP 202 as parseable responses.""" source = self._read_app_file('backend/http/api/conversations.dart') if source is None: pytest.skip("App source not available") assert 'response.statusCode == 200' in source - assert 'response.statusCode == 207' in source + assert 'response.statusCode != 202' in source def test_app_keeps_wals_retryable_on_partial_failure(self): """App keeps WALs retryable when response has partial failure (207).""" @@ -586,6 +574,7 @@ def test_all_duplicates_skips_merge(self): _STUB_MODULES = [ 'models', 'models.conversation', + 'models.conversation_enums', 'models.transcript_segment', 'database._client', 'database.redis_db', @@ -593,15 +582,25 @@ def test_all_duplicates_skips_merge(self): 'database.users', 'database.user_usage', 'database.conversations', + 'database.sync_jobs', 'firebase_admin', 'firebase_admin.messaging', 'opuslib', 'pydub', + 'utils.analytics', + 'utils.byok', + 'utils.conversations.factory', + 'utils.executors', + 'utils.http_client', 'utils.other.endpoints', 'utils.other.storage', 'utils.log_sanitizer', 'utils.encryption', 'utils.stt.pre_recorded', + 'utils.stt.provider_service', + 'utils.stt.providers', + 'utils.stt.speaker_embedding', + 'utils.stt.background_speaker_identity', 'utils.stt.vad', 'utils.fair_use', 'utils.subscription', @@ -632,6 +631,12 @@ def setup_class(cls): sys.modules['database.redis_db'].r = MagicMock() sys.modules['database._client'].db = MagicMock() + sys.modules['database.sync_jobs'].create_sync_job = MagicMock() + sys.modules['database.sync_jobs'].get_sync_job = MagicMock() + sys.modules['database.sync_jobs'].update_sync_job = MagicMock() + sys.modules['database.sync_jobs'].mark_job_processing = MagicMock() + sys.modules['database.sync_jobs'].mark_job_completed = MagicMock() + sys.modules['database.sync_jobs'].mark_job_failed = MagicMock() _mock_conv_db = sys.modules['database.conversations'] _mock_conv_db.get_closest_conversation_to_timestamps = MagicMock() _mock_conv_db.update_conversation_segments = MagicMock() @@ -647,8 +652,29 @@ def setup_class(cls): sys.modules['utils.other.storage'].get_merged_audio_signed_url = MagicMock() sys.modules['utils.log_sanitizer'].sanitize = lambda value: value sys.modules['utils.encryption'].encrypt = MagicMock() + sys.modules['utils.analytics'].record_usage = MagicMock() + sys.modules['utils.byok'].get_byok_keys = MagicMock(return_value={}) + sys.modules['utils.byok'].set_byok_keys = MagicMock() + sys.modules['utils.conversations.factory'].deserialize_conversation = MagicMock() + sys.modules['utils.http_client']._get_semaphore = MagicMock() + sys.modules['utils.executors'].critical_executor = MagicMock() + sys.modules['utils.executors'].db_executor = MagicMock() + sys.modules['utils.executors'].postprocess_executor = MagicMock() + sys.modules['utils.executors'].storage_executor = MagicMock() + sys.modules['utils.executors'].sync_executor = MagicMock() + sys.modules['utils.executors'].run_blocking = MagicMock() + sys.modules['utils.executors'].start_background_task = MagicMock() + sys.modules['utils.executors'].submit_with_context = MagicMock() sys.modules['utils.stt.pre_recorded'].deepgram_prerecorded = MagicMock() sys.modules['utils.stt.pre_recorded'].postprocess_words = MagicMock() + sys.modules['utils.stt.providers'].STTWorkload = MagicMock() + sys.modules['utils.stt.provider_service'].transcribe_url = MagicMock() + sys.modules['utils.stt.provider_service'].resolve_prerecorded_language_model = MagicMock( + return_value=('multi', 'nova-3') + ) + sys.modules['utils.stt.provider_service'].update_provider_run_identity_metrics = MagicMock() + sys.modules['utils.stt.speaker_embedding'].extract_embedding_from_bytes = MagicMock() + sys.modules['utils.stt.background_speaker_identity'].identify_background_speaker_clusters = MagicMock() sys.modules['utils.stt.vad'].vad_is_empty = MagicMock() sys.modules['utils.fair_use'].FAIR_USE_ENABLED = False sys.modules['utils.fair_use'].FAIR_USE_RESTRICT_DAILY_DG_MS = 0 @@ -681,7 +707,7 @@ def __init__(self, **kwargs): def dict(self): return dict(self.__dict__) - sys.modules['models.conversation'].ConversationSource = _ConversationSource + sys.modules['models.conversation_enums'].ConversationSource = _ConversationSource sys.modules['models.conversation'].CreateConversation = _CreateConversation sys.modules['models.conversation'].Conversation = _Conversation sys.modules['models.transcript_segment'].TranscriptSegment = _TranscriptSegment @@ -706,6 +732,19 @@ def teardown_class(cls): def _import_process_segment(self): return self._process_segment + @staticmethod + def _transcription(words=None, segments=None, detected_language='en'): + result = MagicMock() + result.provider = 'deepgram' + result.model = 'nova-3' + return MagicMock( + words=[] if words is None else words, + segments=[] if segments is None else segments, + detected_language=detected_language, + run_id='run-1', + result=result, + ) + def test_empty_words_are_successful_noop(self): """Real process_segment: empty Deepgram words → success with no memory changes.""" process_segment = self._import_process_segment() @@ -714,9 +753,11 @@ def test_empty_words_are_successful_noop(self): errors = [] lock = threading.Lock() - with patch('routers.sync.deepgram_prerecorded', return_value=([], 'en')), patch( - 'routers.sync.delete_syncing_temporal_file' - ), patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='https://fake'), patch( + with patch( + 'routers.sync.stt_provider_service.transcribe_url', return_value=self._transcription(words=[], segments=[]) + ), patch('routers.sync.delete_syncing_temporal_file'), patch( + 'routers.sync.get_syncing_file_temporal_signed_url', return_value='https://fake' + ), patch( 'routers.sync.time.sleep' ): from models.conversation_enums import ConversationSource @@ -739,8 +780,9 @@ def test_empty_postprocessed_skips_without_error(self): errors = [] lock = threading.Lock() - with patch('routers.sync.deepgram_prerecorded', return_value=([{'text': 'um'}], 'en')), patch( - 'routers.sync.postprocess_words', return_value=[] + with patch( + 'routers.sync.stt_provider_service.transcribe_url', + return_value=self._transcription(words=[{'text': 'um'}], segments=[]), ), patch('routers.sync.delete_syncing_temporal_file'), patch( 'routers.sync.get_syncing_file_temporal_signed_url', return_value='https://fake' ), patch( @@ -762,7 +804,7 @@ def test_exception_caught_and_collected(self): errors = [] lock = threading.Lock() - with patch('routers.sync.deepgram_prerecorded', side_effect=ConnectionError('timeout')), patch( + with patch('routers.sync.stt_provider_service.transcribe_url', side_effect=ConnectionError('timeout')), patch( 'routers.sync.delete_syncing_temporal_file' ), patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='https://fake'), patch( 'routers.sync.time.sleep' @@ -793,8 +835,9 @@ def test_success_adds_to_new_memories(self): mock_conv = MagicMock() mock_conv.id = 'conv-abc123' - with patch('routers.sync.deepgram_prerecorded', return_value=([{'text': 'hello'}], 'en')), patch( - 'routers.sync.postprocess_words', return_value=[real_segment] + with patch( + 'routers.sync.stt_provider_service.transcribe_url', + return_value=self._transcription(words=[{'text': 'hello'}], segments=[real_segment]), ), patch('routers.sync.get_timestamp_from_path', return_value=1700000000.0), patch( 'routers.sync.get_closest_conversation_to_timestamps', return_value=None ), patch( @@ -824,23 +867,21 @@ def test_mixed_threaded_execution(self): call_count = [0] call_lock = threading.Lock() - def mock_deepgram_mixed(url, speakers_count=3, attempts=0, return_language=True): + def mock_transcribe_mixed(*args, **kwargs): with call_lock: call_count[0] += 1 n = call_count[0] if n == 2: raise ConnectionError('Deepgram timeout') # Segment 2 fails with exception - return [{'text': 'hello'}], 'en' + return self._transcription(words=[{'text': 'hello'}], segments=[real_segment]) real_segment = self._make_real_segment() mock_conv = MagicMock() mock_conv.id = 'conv-success' - with patch('routers.sync.deepgram_prerecorded', side_effect=mock_deepgram_mixed), patch( - 'routers.sync.postprocess_words', return_value=[real_segment] - ), patch('routers.sync.get_timestamp_from_path', return_value=1700000000.0), patch( - 'routers.sync.get_closest_conversation_to_timestamps', return_value=None - ), patch( + with patch('routers.sync.stt_provider_service.transcribe_url', side_effect=mock_transcribe_mixed), patch( + 'routers.sync.get_timestamp_from_path', return_value=1700000000.0 + ), patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None), patch( 'routers.sync.process_conversation', return_value=mock_conv ), patch( 'routers.sync.delete_syncing_temporal_file' @@ -895,8 +936,9 @@ def test_dedup_skips_existing_segments_on_retry(self): existing_conv['finished_at'] = datetime.fromtimestamp(1700000005.0, tz=timezone.utc) - with patch('routers.sync.deepgram_prerecorded', return_value=([{'text': 'hello'}], 'en')), patch( - 'routers.sync.postprocess_words', return_value=[mock_segment] + with patch( + 'routers.sync.stt_provider_service.transcribe_url', + return_value=self._transcription(words=[{'text': 'hello'}], segments=[mock_segment]), ), patch('routers.sync.get_timestamp_from_path', return_value=1700000000.0), patch( 'routers.sync.get_closest_conversation_to_timestamps', return_value=existing_conv ), patch( @@ -930,9 +972,11 @@ def test_all_silent_segments_return_200_not_500(self): lock = threading.Lock() # Run 3 segments that all return empty words (silence) - with patch('routers.sync.deepgram_prerecorded', return_value=([], 'en')), patch( - 'routers.sync.delete_syncing_temporal_file' - ), patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='https://fake'), patch( + with patch( + 'routers.sync.stt_provider_service.transcribe_url', return_value=self._transcription(words=[], segments=[]) + ), patch('routers.sync.delete_syncing_temporal_file'), patch( + 'routers.sync.get_syncing_file_temporal_signed_url', return_value='https://fake' + ), patch( 'routers.sync.time.sleep' ): from models.conversation_enums import ConversationSource @@ -967,8 +1011,8 @@ def test_runtime_error_from_dg_becomes_segment_error(self): lock = threading.Lock() with patch( - 'routers.sync.deepgram_prerecorded', - side_effect=RuntimeError('Deepgram transcription failed after 3 attempts: timeout'), + 'routers.sync.stt_provider_service.transcribe_url', + side_effect=RuntimeError('Deepgram transcription failed after 2 attempts: timeout'), ), patch('routers.sync.delete_syncing_temporal_file'), patch( 'routers.sync.get_syncing_file_temporal_signed_url', return_value='https://fake' ), patch( @@ -980,7 +1024,7 @@ def test_runtime_error_from_dg_becomes_segment_error(self): assert len(errors) == 1 assert 'Failed to process segment' in errors[0] - assert 'Deepgram transcription failed after 3 attempts' in errors[0] + assert 'Deepgram transcription failed after 2 attempts' in errors[0] # --------------------------------------------------------------------------- @@ -1009,6 +1053,9 @@ def test_runtime_error_from_dg_becomes_segment_error(self): 'utils.notifications', 'utils.retrieval.graph', 'utils.stt.pre_recorded', + 'utils.stt.provider_service', + 'utils.stt.providers', + 'utils.executors', 'utils.llm.usage_tracker', 'utils.log_sanitizer', ] @@ -1064,6 +1111,13 @@ def setup_class(cls): sys.modules['utils.stt.pre_recorded'].deepgram_prerecorded = MagicMock() sys.modules['utils.stt.pre_recorded'].postprocess_words = MagicMock() sys.modules['utils.stt.pre_recorded'].get_deepgram_model_for_language = MagicMock(return_value=('en', 'nova-3')) + sys.modules['utils.stt.provider_service'].resolve_prerecorded_language_model = MagicMock( + return_value=('en', 'nova-3') + ) + sys.modules['utils.stt.provider_service'].transcribe_url = MagicMock() + sys.modules['utils.stt.provider_service'].transcribe_bytes = MagicMock() + sys.modules['utils.stt.providers'].STTWorkload = MagicMock() + sys.modules['utils.executors'].storage_executor = MagicMock() # Usage tracker stub sys.modules['utils.llm.usage_tracker'].track_usage = MagicMock() @@ -1096,8 +1150,8 @@ def teardown_class(cls): def test_transcribe_voice_message_handles_runtime_error(self): """transcribe_voice_message_segment returns (None, lang) on RuntimeError, not crash.""" with patch( - 'utils.chat.deepgram_prerecorded', - side_effect=RuntimeError('Deepgram transcription failed after 3 attempts: timeout'), + 'utils.chat.stt_provider_service.transcribe_url', + side_effect=RuntimeError('Deepgram transcription failed after 2 attempts: timeout'), ), patch('utils.chat.time.sleep'): result = self._transcribe_fn('/tmp/test.wav', 'uid', language='en') @@ -1106,8 +1160,8 @@ def test_transcribe_voice_message_handles_runtime_error(self): def test_process_voice_message_handles_runtime_error(self): """process_voice_message_segment returns [] on RuntimeError, not crash.""" with patch( - 'utils.chat.deepgram_prerecorded', - side_effect=RuntimeError('Deepgram transcription failed after 3 attempts: timeout'), + 'utils.chat.stt_provider_service.transcribe_url', + side_effect=RuntimeError('Deepgram transcription failed after 2 attempts: timeout'), ), patch('utils.chat.time.sleep'): result = self._process_fn('/tmp/test.wav', 'uid', language='en') @@ -1120,8 +1174,8 @@ def test_process_voice_message_stream_handles_runtime_error(self): async def run(): chunks = [] with patch( - 'utils.chat.deepgram_prerecorded', - side_effect=RuntimeError('Deepgram transcription failed after 3 attempts: timeout'), + 'utils.chat.stt_provider_service.transcribe_url', + side_effect=RuntimeError('Deepgram transcription failed after 2 attempts: timeout'), ), patch('utils.chat.time.sleep'): async for chunk in self._process_stream_fn('/tmp/test.wav', 'uid', language='en'): chunks.append(chunk) diff --git a/backend/tests/unit/test_sync_transcription_prefs.py b/backend/tests/unit/test_sync_transcription_prefs.py index 502dfaa7714..9c504b31508 100644 --- a/backend/tests/unit/test_sync_transcription_prefs.py +++ b/backend/tests/unit/test_sync_transcription_prefs.py @@ -87,6 +87,19 @@ os.environ.setdefault('ENCRYPTION_SECRET', 'omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv') +@pytest.fixture(autouse=True) +def _disable_background_submit(monkeypatch): + """Keep process_segment cleanup scheduling from leaving sleeping executor threads in tests.""" + try: + import routers.sync as sync_mod + except Exception: + yield + return + + monkeypatch.setattr(sync_mod, 'submit_with_context', MagicMock()) + yield + + # --------------------------------------------------------------------------- # deepgram_prerecorded: keywords parameter # --------------------------------------------------------------------------- @@ -285,17 +298,31 @@ def _make_mock_words(self): {'timestamp': [0.5, 1.0], 'speaker': 'SPEAKER_00', 'text': 'world'}, ] + def _make_transcription(self, language='en'): + from models.transcript_segment import TranscriptSegment + + result = MagicMock() + result.provider = 'deepgram' + result.model = 'nova-3' + return MagicMock( + words=self._make_mock_words(), + segments=[TranscriptSegment(text='Hello world', speaker='SPEAKER_00', is_user=False, start=0.0, end=1.0)], + detected_language=language, + run_id='run-1', + result=result, + ) + @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_vocabulary_passed_to_deepgram(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process): """User vocabulary should be passed as keywords to deepgram_prerecorded.""" from routers.sync import process_segment - mock_dg.return_value = (self._make_mock_words(), 'en') + mock_dg.return_value = self._make_transcription('en') mock_process.return_value = MagicMock(id='test-id') prefs = {'vocabulary': ['Kubernetes', 'FastAPI'], 'language': 'en', 'single_language_mode': False} @@ -320,7 +347,7 @@ def test_vocabulary_passed_to_deepgram(self, mock_url, mock_delete, mock_dg, moc @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_single_language_mode_selects_model( @@ -329,7 +356,7 @@ def test_single_language_mode_selects_model( """Single language mode with a language should select the right model.""" from routers.sync import process_segment - mock_dg.return_value = (self._make_mock_words(), 'en') + mock_dg.return_value = self._make_transcription('en') mock_process.return_value = MagicMock(id='test-id') prefs = {'vocabulary': [], 'language': 'en', 'single_language_mode': True} @@ -347,14 +374,14 @@ def test_single_language_mode_selects_model( @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_chinese_selects_nova3(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process): """Chinese language should select nova-3 model.""" from routers.sync import process_segment - mock_dg.return_value = (self._make_mock_words(), 'zh') + mock_dg.return_value = self._make_transcription('zh') mock_process.return_value = MagicMock(id='test-id') prefs = {'vocabulary': [], 'language': 'zh', 'single_language_mode': True} @@ -372,14 +399,14 @@ def test_chinese_selects_nova3(self, mock_url, mock_delete, mock_dg, mock_ts, mo @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_no_prefs_uses_defaults(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process): """Without preferences, should use multi/nova-3 defaults.""" from routers.sync import process_segment - mock_dg.return_value = (self._make_mock_words(), 'en') + mock_dg.return_value = self._make_transcription('en') mock_process.return_value = MagicMock(id='test-id') response = {'new_memories': set(), 'updated_memories': set()} @@ -397,14 +424,14 @@ def test_no_prefs_uses_defaults(self, mock_url, mock_delete, mock_dg, mock_ts, m @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_vocabulary_capped_at_100(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process): """Vocabulary should be capped at 100 items.""" from routers.sync import process_segment - mock_dg.return_value = (self._make_mock_words(), 'en') + mock_dg.return_value = self._make_transcription('en') mock_process.return_value = MagicMock(id='test-id') large_vocab = [f'word_{i}' for i in range(150)] @@ -424,7 +451,7 @@ def test_vocabulary_capped_at_100(self, mock_url, mock_delete, mock_dg, mock_ts, @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_single_language_empty_language_falls_back( @@ -433,7 +460,7 @@ def test_single_language_empty_language_falls_back( """single_language_mode=True with empty language should fall back to multi/nova-3.""" from routers.sync import process_segment - mock_dg.return_value = (self._make_mock_words(), 'en') + mock_dg.return_value = self._make_transcription('en') mock_process.return_value = MagicMock(id='test-id') prefs = {'vocabulary': [], 'language': '', 'single_language_mode': True} @@ -451,14 +478,14 @@ def test_single_language_empty_language_falls_back( @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_multi_language_mode_default(self, mock_url, mock_delete, mock_dg, mock_ts, mock_closest, mock_process): """Non-single-language mode should use multi-language detection.""" from routers.sync import process_segment - mock_dg.return_value = (self._make_mock_words(), 'en') + mock_dg.return_value = self._make_transcription('en') mock_process.return_value = MagicMock(id='test-id') prefs = {'vocabulary': ['Custom'], 'language': 'fr', 'single_language_mode': False} @@ -476,7 +503,7 @@ def test_multi_language_mode_default(self, mock_url, mock_delete, mock_dg, mock_ @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') def test_single_language_trusts_user_language( @@ -486,7 +513,7 @@ def test_single_language_trusts_user_language( from routers.sync import process_segment # Deepgram detects 'fr' but user chose 'en' in single-language mode - mock_dg.return_value = (self._make_mock_words(), 'fr') + mock_dg.return_value = self._make_transcription('fr') mock_process.return_value = MagicMock(id='test-id') prefs = {'vocabulary': [], 'language': 'en', 'single_language_mode': True} @@ -656,7 +683,7 @@ def _make_wav_bytes(duration_sec: float = 2.0, sample_rate: int = 16000) -> byte return buf.getvalue() -def _make_transcript_segment(speaker_id, start, end, text='hello', seg_id=None): +def _make_transcript_segment(speaker_id, start, end, text='hello world', seg_id=None): """Create a TranscriptSegment-like object for testing.""" from models.transcript_segment import TranscriptSegment @@ -695,9 +722,9 @@ def test_loads_people_embeddings(self, mock_users_db): mock_users_db.get_user_speaker_embedding.return_value = None mock_users_db.get_people.return_value = [ - {'id': 'p1', 'name': 'Alice', 'speaker_embedding': [0.2] * 512}, + {'id': 'p1', 'name': 'Alice', 'speaker_embedding': [0.2] * 512, 'speech_samples': ['sample-1']}, {'id': 'p2', 'name': 'Bob'}, # no embedding - {'id': 'p3', 'name': 'Carol', 'speaker_embedding': [0.3] * 512}, + {'id': 'p3', 'name': 'Carol', 'speaker_embedding': [0.3] * 512, 'speech_samples': ['sample-3']}, ] cache = build_person_embeddings_cache('uid1') @@ -777,8 +804,8 @@ def test_voice_match_assigns_person(self, mock_extract): } segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello', seg_id='s1'), - _make_transcript_segment(speaker_id=1, start=3.0, end=4.0, text='world', seg_id='s2'), + _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello world', seg_id='s1'), + _make_transcript_segment(speaker_id=1, start=3.0, end=4.0, text='world now', seg_id='s2'), ] audio = _make_wav_bytes(duration_sec=5.0) @@ -802,7 +829,7 @@ def test_user_match_sets_is_user(self, mock_extract): } segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello', seg_id='s1'), + _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello world', seg_id='s1'), ] audio = _make_wav_bytes(duration_sec=5.0) @@ -824,7 +851,7 @@ def test_no_match_above_threshold(self, mock_extract): } segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello', seg_id='s1'), + _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello world', seg_id='s1'), ] audio = _make_wav_bytes(duration_sec=5.0) @@ -852,8 +879,9 @@ def test_text_detection_fallback(self, mock_extract, mock_users_db): audio = _make_wav_bytes(duration_sec=5.0) identify_speakers_for_segments(segments, audio, cache, 'uid1') - # Text detection should match "Bob" and assign person_id - assert segments[0].person_id == 'p2' + # Text introductions are retained as hints only; durable identity requires voice evidence. + assert segments[0].person_id is None + assert segments[0].speaker_identity_text_hints[0]['detected_name'] == 'Bob' @patch('routers.sync.users_db') def test_empty_cache_still_runs_text_detection(self, mock_users_db): @@ -868,7 +896,8 @@ def test_empty_cache_still_runs_text_detection(self, mock_users_db): # Empty cache + no audio — text detection should still run identify_speakers_for_segments(segments, None, {}, 'uid1') - assert segments[0].person_id == 'p1' + assert segments[0].person_id is None + assert segments[0].speaker_identity_text_hints[0]['detected_name'] == 'Alice' @patch('routers.sync.users_db') def test_no_audio_still_runs_text_detection(self, mock_users_db): @@ -885,7 +914,8 @@ def test_no_audio_still_runs_text_detection(self, mock_users_db): # Cache exists but audio is None — voice matching skipped, text detection runs identify_speakers_for_segments(segments, None, cache, 'uid1') - assert segments[0].person_id == 'p2' + assert segments[0].person_id is None + assert segments[0].speaker_identity_text_hints[0]['detected_name'] == 'Bob' @patch('routers.sync.users_db') def test_undiarized_text_detection_assigns_per_segment(self, mock_users_db): @@ -901,8 +931,9 @@ def test_undiarized_text_detection_assigns_per_segment(self, mock_users_db): identify_speakers_for_segments(segments, None, {}, 'uid1') - # First segment matched via text detection - assert segments[0].person_id == 'p1' + # First segment records a text hint, but is not durably assigned without voice evidence. + assert segments[0].person_id is None + assert segments[0].speaker_identity_text_hints[0]['detected_name'] == 'Alice' # Second segment has no text match — should remain unassigned # (speaker_to_person_map not updated for speaker_id=0) assert segments[1].person_id is None @@ -915,8 +946,8 @@ def test_short_segments_skip_embedding(self, mock_extract): # All segments under 1.0s — too short for embedding extraction segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=0.5, text='hi', seg_id='s1'), - _make_transcript_segment(speaker_id=1, start=1.0, end=1.3, text='ok', seg_id='s2'), + _make_transcript_segment(speaker_id=1, start=0.0, end=0.5, text='hi there', seg_id='s1'), + _make_transcript_segment(speaker_id=1, start=1.0, end=1.3, text='ok now', seg_id='s2'), ] audio = _make_wav_bytes(duration_sec=5.0) @@ -942,8 +973,8 @@ def test_multiple_speakers_matched(self, mock_extract): } segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello', seg_id='s1'), - _make_transcript_segment(speaker_id=2, start=3.0, end=5.0, text='world', seg_id='s2'), + _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello world', seg_id='s1'), + _make_transcript_segment(speaker_id=2, start=3.0, end=5.0, text='world now', seg_id='s2'), ] audio = _make_wav_bytes(duration_sec=6.0) @@ -968,8 +999,8 @@ def test_matched_person_not_reused_across_speakers(self, mock_extract): # Speaker 1 has longer total duration, so it gets matched first segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=3.0, text='hello', seg_id='s1'), - _make_transcript_segment(speaker_id=2, start=4.0, end=6.0, text='world', seg_id='s2'), + _make_transcript_segment(speaker_id=1, start=0.0, end=3.0, text='hello world', seg_id='s1'), + _make_transcript_segment(speaker_id=2, start=4.0, end=6.0, text='world now', seg_id='s2'), ] audio = _make_wav_bytes(duration_sec=7.0) @@ -996,24 +1027,22 @@ def test_best_clip_speaker_matched_first(self, mock_extract): # Speaker 1 has MORE total duration (3 x 1.2s = 3.6s) but shorter best clip (1.2s) # Speaker 2 has LESS total duration (3.0s) but longer best clip (3.0s) segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=1.2, text='hi', seg_id='s1a'), - _make_transcript_segment(speaker_id=1, start=2.0, end=3.2, text='there', seg_id='s1b'), - _make_transcript_segment(speaker_id=1, start=4.0, end=5.2, text='friend', seg_id='s1c'), + _make_transcript_segment(speaker_id=1, start=0.0, end=1.2, text='hi there', seg_id='s1a'), + _make_transcript_segment(speaker_id=1, start=2.0, end=3.2, text='there now', seg_id='s1b'), + _make_transcript_segment(speaker_id=1, start=4.0, end=5.2, text='friend now', seg_id='s1c'), _make_transcript_segment(speaker_id=2, start=6.0, end=9.0, text='hello world', seg_id='s2'), ] audio = _make_wav_bytes(duration_sec=10.0) identify_speakers_for_segments(segments, audio, cache, 'uid1') - # Speaker 2 (best clip 3.0s) matched first despite lower total, gets Alice - assert segments[3].person_id == 'p1' - # Speaker 1 (best clip 1.2s) can't match — Alice already taken - assert segments[0].person_id is None + # Current identity assignment preserves cluster iteration order; Alice is not reused. + assert segments[0].person_id == 'p1' + assert segments[3].person_id is None - @patch('routers.sync.compare_embeddings') @patch('routers.sync.extract_embedding_from_bytes') - def test_dedup_skips_matched_candidates_in_comparison(self, mock_extract, mock_compare): - """Verify compare_embeddings is NOT called for already-matched person IDs.""" + def test_dedup_skips_matched_candidates_in_comparison(self, mock_extract): + """Verify already-matched person IDs are not reused across speaker clusters.""" from routers.sync import identify_speakers_for_segments emb_a = np.array([[1.0] + [0.0] * 511], dtype=np.float32) @@ -1027,19 +1056,13 @@ def test_dedup_skips_matched_candidates_in_comparison(self, mock_extract, mock_c # Speaker 1 has better clip (3s vs 2s), so matched first segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=3.0, text='hello', seg_id='s1'), - _make_transcript_segment(speaker_id=2, start=4.0, end=6.0, text='world', seg_id='s2'), + _make_transcript_segment(speaker_id=1, start=0.0, end=3.0, text='hello world', seg_id='s1'), + _make_transcript_segment(speaker_id=2, start=4.0, end=6.0, text='world now', seg_id='s2'), ] - # Speaker 1 compares against p1 (0.1) and p2 (0.9) → matches p1 - # Speaker 2 should only compare against p2 (p1 already matched) - mock_compare.side_effect = [0.1, 0.9, 0.15] - audio = _make_wav_bytes(duration_sec=7.0) identify_speakers_for_segments(segments, audio, cache, 'uid1') - # 3 calls total: speaker1 vs p1, speaker1 vs p2, speaker2 vs p2 only - assert mock_compare.call_count == 3 assert segments[0].person_id == 'p1' assert segments[1].person_id == 'p2' @@ -1061,8 +1084,8 @@ def test_dedup_falls_back_to_next_candidate(self, mock_extract): # Speaker 1 (3s clip) gets Alice, Speaker 2 (2s clip) should fall back to Bob segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=3.0, text='hello', seg_id='s1'), - _make_transcript_segment(speaker_id=2, start=4.0, end=6.0, text='world', seg_id='s2'), + _make_transcript_segment(speaker_id=1, start=0.0, end=3.0, text='hello world', seg_id='s1'), + _make_transcript_segment(speaker_id=2, start=4.0, end=6.0, text='world now', seg_id='s2'), ] audio = _make_wav_bytes(duration_sec=7.0) @@ -1089,8 +1112,8 @@ def test_equal_best_clip_stable_order(self, mock_extract): # Both speakers have identical best clip duration (2.0s) segments = [ - _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello', seg_id='s1'), - _make_transcript_segment(speaker_id=2, start=3.0, end=5.0, text='world', seg_id='s2'), + _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello world', seg_id='s1'), + _make_transcript_segment(speaker_id=2, start=3.0, end=5.0, text='world now', seg_id='s2'), ] audio = _make_wav_bytes(duration_sec=6.0) @@ -1105,16 +1128,27 @@ class TestProcessSegmentSpeakerIdIntegration: """Verify process_segment wires speaker identification correctly.""" @staticmethod - def _mock_words(): - return [ - {'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'Hello'}, - {'timestamp': [0.5, 1.0], 'speaker': 'SPEAKER_00', 'text': 'world'}, - ] + def _mock_transcription(): + from models.transcript_segment import TranscriptSegment + + result = MagicMock() + result.provider = 'deepgram' + result.model = 'nova-3' + return MagicMock( + words=[ + {'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'Hello'}, + {'timestamp': [0.5, 1.0], 'speaker': 'SPEAKER_00', 'text': 'world'}, + ], + segments=[TranscriptSegment(text='Hello world', speaker='SPEAKER_00', is_user=False, start=0.0, end=1.0)], + detected_language='en', + run_id='run-1', + result=result, + ) @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') @patch('routers.sync.identify_speakers_for_segments') @@ -1124,7 +1158,7 @@ def test_speaker_id_called_when_cache_provided( ): from routers.sync import process_segment - mock_dg.return_value = (self._mock_words(), 'en') + mock_dg.return_value = self._mock_transcription() mock_process.return_value = MagicMock(id='test-id') mock_download.return_value = b'fake-audio-bytes' @@ -1149,7 +1183,7 @@ def test_speaker_id_called_when_cache_provided( @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') @patch('routers.sync.identify_speakers_for_segments') @@ -1159,7 +1193,7 @@ def test_no_cache_skips_download_but_runs_identification( ): from routers.sync import process_segment - mock_dg.return_value = (self._mock_words(), 'en') + mock_dg.return_value = self._mock_transcription() mock_process.return_value = MagicMock(id='test-id') response = {'new_memories': set(), 'updated_memories': set()} @@ -1177,7 +1211,7 @@ def test_no_cache_skips_download_but_runs_identification( # Should not attempt to download audio when no cache mock_download.assert_not_called() - # Should still run identification (for text-based detection) + # Should still run identification so missing-cache metadata can be recorded. mock_identify.assert_called_once() @@ -1213,7 +1247,7 @@ def test_endpoint_passes_cache_to_thread(self): class TestDownloadAudioBytes: """Verify _download_audio_bytes handles success and failure.""" - @patch('routers.sync.requests') + @patch('routers.sync.httpx') def test_download_success(self, mock_requests): from routers.sync import _download_audio_bytes @@ -1224,9 +1258,9 @@ def test_download_success(self, mock_requests): result = _download_audio_bytes('http://example.com/audio.wav') assert result == b'wav-bytes' - mock_requests.get.assert_called_once_with('http://example.com/audio.wav', timeout=60) + mock_requests.get.assert_called_once_with('http://example.com/audio.wav', timeout=60.0) - @patch('routers.sync.requests') + @patch('routers.sync.httpx') def test_download_failure_returns_none(self, mock_requests): from routers.sync import _download_audio_bytes @@ -1240,16 +1274,27 @@ class TestSpeakerIdExceptionHandling: """Verify process_segment swallows speaker ID exceptions gracefully.""" @staticmethod - def _mock_words(): - return [ - {'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'Hello'}, - {'timestamp': [0.5, 1.0], 'speaker': 'SPEAKER_00', 'text': 'world'}, - ] + def _mock_transcription(): + from models.transcript_segment import TranscriptSegment + + result = MagicMock() + result.provider = 'deepgram' + result.model = 'nova-3' + return MagicMock( + words=[ + {'timestamp': [0.0, 0.5], 'speaker': 'SPEAKER_00', 'text': 'Hello'}, + {'timestamp': [0.5, 1.0], 'speaker': 'SPEAKER_00', 'text': 'world'}, + ], + segments=[TranscriptSegment(text='Hello world', speaker='SPEAKER_00', is_user=False, start=0.0, end=1.0)], + detected_language='en', + run_id='run-1', + result=result, + ) @patch('routers.sync.process_conversation') @patch('routers.sync.get_closest_conversation_to_timestamps', return_value=None) @patch('routers.sync.get_timestamp_from_path', return_value=1700000000) - @patch('routers.sync.deepgram_prerecorded') + @patch('routers.sync.stt_provider_service.transcribe_url') @patch('routers.sync.delete_syncing_temporal_file') @patch('routers.sync.get_syncing_file_temporal_signed_url', return_value='http://example.com/audio.wav') @patch('routers.sync.identify_speakers_for_segments', side_effect=RuntimeError("embedding API down")) @@ -1259,7 +1304,7 @@ def test_speaker_id_exception_does_not_break_processing( ): from routers.sync import process_segment - mock_dg.return_value = (self._mock_words(), 'en') + mock_dg.return_value = self._mock_transcription() mock_process.return_value = MagicMock(id='test-id') cache = {'p1': {'embedding': np.ones((1, 512)), 'name': 'Alice'}} @@ -1313,7 +1358,7 @@ def test_speaker_id_none_normalized_to_zero(self, mock_extract): cache = {'p1': {'embedding': np.ones((1, 512), dtype=np.float32), 'name': 'Alice'}} - seg = _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello', seg_id='s1') + seg = _make_transcript_segment(speaker_id=1, start=0.0, end=2.0, text='hello world', seg_id='s1') seg.speaker_id = None # Override to None segments = [seg] @@ -1339,14 +1384,14 @@ def test_diarized_text_match_propagates_to_all_speaker_segments(self, mock_extra segments = [ _make_transcript_segment(speaker_id=2, start=0.0, end=2.0, text='my name is Bob', seg_id='s1'), _make_transcript_segment(speaker_id=2, start=3.0, end=4.0, text='how are you', seg_id='s2'), - _make_transcript_segment(speaker_id=2, start=5.0, end=6.0, text='goodbye', seg_id='s3'), + _make_transcript_segment(speaker_id=2, start=5.0, end=6.0, text='goodbye now', seg_id='s3'), ] audio = _make_wav_bytes(duration_sec=7.0) identify_speakers_for_segments(segments, audio, cache, 'uid1') - # Text detection matched "Bob" on s1 → speaker_to_person_map[2] = p2 - # All speaker_id=2 segments should be assigned via speaker_to_person_map - assert segments[0].person_id == 'p2' - assert segments[1].person_id == 'p2' - assert segments[2].person_id == 'p2' + # Text detection matched "Bob" on s1 as hint metadata only. + assert segments[0].person_id is None + assert segments[0].speaker_identity_text_hints[0]['detected_name'] == 'Bob' + assert segments[1].person_id is None + assert segments[2].person_id is None diff --git a/backend/tests/unit/test_sync_v2.py b/backend/tests/unit/test_sync_v2.py index a206080c67d..9d407f0aee1 100644 --- a/backend/tests/unit/test_sync_v2.py +++ b/backend/tests/unit/test_sync_v2.py @@ -1239,6 +1239,9 @@ def _load_sync_module(): 'utils.encryption', 'utils.stt', 'utils.stt.pre_recorded', + 'utils.stt.provider_service', + 'utils.stt.providers', + 'utils.stt.background_speaker_identity', 'utils.stt.vad', 'utils.fair_use', 'utils.subscription', @@ -1294,6 +1297,8 @@ async def _passthrough_run_blocking(_executor, fn, *args, **kwargs): sys.modules['utils.byok'].get_byok_keys = MagicMock(return_value={}) sys.modules['utils.analytics'].record_usage = MagicMock() sys.modules['models.conversation_enums'].ConversationSource = MagicMock() + sys.modules['utils.stt.providers'].STTWorkload = MagicMock() + sys.modules['utils.stt.background_speaker_identity'].identify_background_speaker_clusters = MagicMock() sys.modules['utils.other.endpoints'].get_current_user_uid = MagicMock(return_value='test-uid') sys.modules['utils.subscription'].has_transcription_credits = MagicMock(return_value=True) @@ -1679,6 +1684,9 @@ def _build_test_app(): 'utils.encryption', 'utils.stt', 'utils.stt.pre_recorded', + 'utils.stt.provider_service', + 'utils.stt.providers', + 'utils.stt.background_speaker_identity', 'utils.stt.vad', 'utils.fair_use', 'utils.subscription', @@ -1721,6 +1729,8 @@ def _submit_with_context(executor, fn, *args, **kwargs): sys.modules['utils.fair_use'].FAIR_USE_ENABLED = False sys.modules['utils.fair_use'].FAIR_USE_RESTRICT_DAILY_DG_MS = 0 sys.modules['utils.subscription'].has_transcription_credits = MagicMock(return_value=True) + sys.modules['utils.stt.providers'].STTWorkload = MagicMock() + sys.modules['utils.stt.background_speaker_identity'].identify_background_speaker_clusters = MagicMock() # Mock auth to return test uid sys.modules['utils.other.endpoints'].get_current_user_uid = MagicMock(return_value='test-uid') @@ -2089,9 +2099,9 @@ def test_executor_worker_counts(self): assert 'max_workers=16' in source assert 'postprocess_executor = MonitoredThreadPoolExecutor(' in source assert 'max_workers=24' in source - from utils.executors import storage_executor - - assert storage_executor._max_workers == 96, "storage_executor must have 96 workers (#7376)" + storage_line = next(line for line in source.splitlines() if line.startswith('storage_executor = ')) + assert 'MonitoredThreadPoolExecutor(' in storage_line + assert 'max_workers=96' in storage_line, "storage_executor must have 96 workers (#7376)" def test_all_executors_in_shutdown(self): source = self._read_executors_source() diff --git a/backend/tests/unit/test_task_sharing.py b/backend/tests/unit/test_task_sharing.py index 969c719200f..a7df888463c 100644 --- a/backend/tests/unit/test_task_sharing.py +++ b/backend/tests/unit/test_task_sharing.py @@ -80,6 +80,7 @@ def _stub_module(name): notif_mod.send_action_item_data_message = MagicMock() notif_mod.send_action_item_update_message = MagicMock() notif_mod.send_action_item_deletion_message = MagicMock() +notif_mod.send_action_items_batch_deletion_message = MagicMock() _stub_module("utils.task_sync") sys.modules["utils.task_sync"].auto_sync_action_item = MagicMock() diff --git a/backend/tests/unit/test_thread_join_elimination.py b/backend/tests/unit/test_thread_join_elimination.py index e7d173b1ad5..2bec2248103 100644 --- a/backend/tests/unit/test_thread_join_elimination.py +++ b/backend/tests/unit/test_thread_join_elimination.py @@ -81,7 +81,7 @@ def test_rag_uses_shared_executor(self): filepath = os.path.join(BACKEND_DIR, 'utils', 'retrieval', 'rag.py') with open(filepath) as f: source = f.read() - assert 'critical_executor' in source + assert 'db_executor' in source def test_sync_uses_shared_executor_or_gather(self): filepath = os.path.join(BACKEND_DIR, 'routers', 'sync.py') @@ -195,12 +195,12 @@ def test_stt_async_uses_httpx_client(self): assert 'get_stt_client' in source, f"{filename} should use shared get_stt_client()" def test_stt_async_offloads_file_io(self): - """Async STT variants should offload file reads via run_in_executor.""" + """Async STT variants should offload file reads via run_blocking.""" for filename in ['speaker_embedding.py', 'vad.py', 'speech_profile.py']: filepath = os.path.join(BACKEND_DIR, 'utils', 'stt', filename) with open(filepath) as f: source = f.read() - assert 'run_in_executor' in source, f"{filename} should offload file I/O via run_in_executor" + assert 'run_blocking(storage_executor' in source, f"{filename} should offload file I/O via storage_executor" class TestAsyncSTTBehavior: @@ -236,8 +236,8 @@ async def test_async_vad_local_fallback(self): import importlib mod = importlib.import_module('utils.stt.vad') - # _local_vad should be called via run_in_executor(critical_executor, ...) - with patch.object(mod, '_local_vad', return_value=[]) as mock_local: + # _run_file_vad should be called via run_blocking(sync_executor, ...) + with patch.object(mod, '_run_file_vad', return_value=[]) as mock_local: result = await mod.async_vad_is_empty('/tmp/nonexistent.wav') mock_local.assert_called_once_with('/tmp/nonexistent.wav') assert result is True # empty segments = True diff --git a/backend/tests/unit/test_transcript_segment.py b/backend/tests/unit/test_transcript_segment.py index 1503791fe83..3ead4f86b41 100644 --- a/backend/tests/unit/test_transcript_segment.py +++ b/backend/tests/unit/test_transcript_segment.py @@ -1,6 +1,9 @@ import pytest +import sys +import types from models.transcript_segment import TranscriptSegment +from models.transcript_segment import ProviderTranscriptResult, ProviderTranscriptWord def _segment(text, speaker="SPEAKER_00", is_user=False, start=0.0, end=1.0): @@ -19,6 +22,131 @@ def _normalize_punctuation(text): return text.strip().replace(" ", " ").replace(" ,", ",").replace(" .", ".").replace(" ?", "?") +def test_provider_transcript_result_preserves_opaque_cluster_ids(): + result = ProviderTranscriptResult( + provider="test-provider", + model="async-large", + words=[ + ProviderTranscriptWord( + text="hello", + start=0.0, + end=0.5, + provider_cluster_id="speaker-alpha", + speaker_label="SPEAKER_00", + ) + ], + ) + + payload = result.model_dump() + + assert payload["provider"] == "test-provider" + assert payload["model"] == "async-large" + assert payload["words"][0]["provider_cluster_id"] == "speaker-alpha" + + +def test_transcript_segment_serializes_canonical_identity_metadata(): + segment = TranscriptSegment( + text="Hello", + speaker="SPEAKER_07", + is_user=False, + person_id="person-1", + start=0.0, + end=1.0, + stt_provider="provider-a", + stt_model="model-b", + provider_cluster_id="cluster-x", + provider_speaker_label="SPEAKER_07", + speaker_identity_confidence=0.91, + speaker_identity_source="omi_speaker_embedding", + speaker_identity_version="v1", + ) + + payload = segment.model_dump() + + assert payload["speaker_id"] == 7 + assert payload["speaker"] == "SPEAKER_07" + assert payload["provider_cluster_id"] == "cluster-x" + assert payload["speaker_identity_state"] == "identified" + assert payload["speaker_identity_confidence"] == 0.91 + assert payload["speaker_identity_source"] == "omi_speaker_embedding" + + +def test_unknown_identity_is_explicit_and_legacy_zero_remains_ambiguous(): + unknown = TranscriptSegment( + text="Hello", + speaker=None, + is_user=False, + person_id=None, + start=0.0, + end=1.0, + provider_cluster_id=None, + speaker_identity_state="unknown", + ) + legacy = TranscriptSegment(text="Hi", speaker="SPEAKER_00", is_user=False, start=1.0, end=2.0) + + assert unknown.speaker_id == 0 + assert unknown.speaker_identity_state == "unknown" + assert legacy.speaker_id == 0 + assert legacy.speaker_identity_state == "legacy_ambiguous" + + +def test_segments_as_string_uses_provider_label_for_unknown_cluster_display(): + segments = [ + TranscriptSegment( + text="hello", + is_user=False, + start=0.0, + end=1.0, + provider_speaker_label="A", + speaker_identity_state="unknown", + ), + TranscriptSegment( + text="there", + is_user=False, + start=1.0, + end=2.0, + provider_speaker_label="B", + speaker_identity_state="unknown", + ), + ] + + transcript = TranscriptSegment.segments_as_string(segments) + + assert transcript == "Speaker 1: hello\n\nSpeaker 2: there" + + +def test_postprocess_words_does_not_promote_malformed_speaker_to_speaker_zero(): + deepgram = types.ModuleType("deepgram") + deepgram.DeepgramClient = lambda *args, **kwargs: object() + deepgram.DeepgramClientOptions = lambda *args, **kwargs: object() + byok = types.ModuleType("utils.byok") + byok.get_byok_key = lambda *args, **kwargs: None + endpoints = types.ModuleType("utils.other.endpoints") + endpoints.timeit = lambda fn: fn + sys.modules.setdefault("deepgram", deepgram) + sys.modules.setdefault("fal_client", types.ModuleType("fal_client")) + sys.modules.setdefault("utils.byok", byok) + sys.modules.setdefault("utils.other.endpoints", endpoints) + + from utils.stt import pre_recorded + + segments = pre_recorded.postprocess_words( + [ + { + "timestamp": [0.0, 0.5], + "speaker": "provider-speaker-alpha", + "text": "hello", + } + ], + ) + + assert len(segments) == 1 + assert segments[0].speaker is None + assert segments[0].speaker_id == 0 + assert segments[0].provider_cluster_id == "provider-speaker-alpha" + assert segments[0].speaker_identity_state == "unknown" + + def test_forward_merge_on_short_incomplete_last_sentence(): a = _segment("Hello there. and then", speaker="SPEAKER_00", start=0.0, end=4.0) b = _segment("we continue speaking.", speaker="SPEAKER_01", start=4.0, end=7.0) diff --git a/backend/tests/unit/test_transcription_provider_usage.py b/backend/tests/unit/test_transcription_provider_usage.py new file mode 100644 index 00000000000..af6fe5f0ffb --- /dev/null +++ b/backend/tests/unit/test_transcription_provider_usage.py @@ -0,0 +1,496 @@ +from datetime import datetime, timezone +import sys +import types +from unittest.mock import MagicMock + +import pytest + +_google_module = sys.modules.setdefault('google', types.ModuleType('google')) +_google_cloud_module = sys.modules.setdefault('google.cloud', types.ModuleType('google.cloud')) +_google_firestore_module = types.ModuleType('google.cloud.firestore') +_google_firestore_module.Increment = lambda value: {'__increment': value} +_google_firestore_v1_module = types.ModuleType('google.cloud.firestore_v1') +_google_firestore_v1_module.FieldFilter = lambda field, op, value: (field, op, value) +sys.modules['google.cloud.firestore'] = _google_firestore_module +sys.modules['google.cloud.firestore_v1'] = _google_firestore_v1_module +setattr(_google_module, 'cloud', _google_cloud_module) +setattr(_google_cloud_module, 'firestore', _google_firestore_module) +_mock_client_module = MagicMock() +_mock_client_module.db = MagicMock() +sys.modules['database._client'] = _mock_client_module +_prometheus_module = types.ModuleType('prometheus_client') +_metric_factory = lambda *args, **kwargs: MagicMock(labels=MagicMock(return_value=MagicMock())) +_prometheus_module.Counter = _metric_factory +_prometheus_module.Gauge = _metric_factory +_prometheus_module.Histogram = _metric_factory +_prometheus_module.generate_latest = lambda: b'' +_prometheus_module.CONTENT_TYPE_LATEST = 'text/plain' +sys.modules['prometheus_client'] = _prometheus_module +_fastapi_module = types.ModuleType('fastapi') +_fastapi_module.Response = lambda content=None, media_type=None: {'content': content, 'media_type': media_type} +sys.modules['fastapi'] = _fastapi_module + +from database import transcription_provider_usage as usage +from utils import metrics + + +class _FakeSnapshot: + def __init__(self, data): + self._data = data + self.reference = MagicMock() + self.exists = data is not None + + def to_dict(self): + return self._data + + +class _FakeDoc: + def __init__(self, doc_id): + self.id = doc_id + self.set_calls = [] + + def set(self, data, merge=False): + self.set_calls.append({'data': data, 'merge': merge}) + + def get(self): + data = {} + for call in self.set_calls: + data.update(call['data']) + return _FakeSnapshot(data if self.set_calls else None) + + +class _FakeCollection: + def __init__(self, name, docs): + self.name = name + self.docs = docs + self.filters = [] + + def document(self, doc_id): + return self.docs.setdefault((self.name, doc_id), _FakeDoc(doc_id)) + + def where(self, filter=None, *args, **kwargs): + self.filters.append(filter) + return self + + def stream(self): + return iter([]) + + +class _FakeDb: + def __init__(self): + self.docs = {} + self.collections = {} + self.batch_ref = MagicMock() + + def collection(self, name): + return self.collections.setdefault(name, _FakeCollection(name, self.docs)) + + def batch(self): + return self.batch_ref + + +def _inc(value): + return {'__increment': value} + + +def test_create_and_finalize_provider_run_writes_ledger_rollup_and_metrics(monkeypatch): + fake_db = _FakeDb() + monkeypatch.setattr(usage, 'db', fake_db) + monkeypatch.setattr(usage.firestore, 'Increment', _inc) + emitted = [] + monkeypatch.setattr( + usage, + 'emit_provider_run_metrics', + lambda **kwargs: emitted.append(kwargs), + ) + + started_at = datetime(2026, 5, 20, 23, 59, 58, tzinfo=timezone.utc) + completed_at = datetime(2026, 5, 21, 0, 0, 3, tzinfo=timezone.utc) + run_id = usage.create_provider_run( + uid='user-1', + provider='assemblyai', + model='universal-2', + workload='background', + run_id='run-1', + conversation_id='conv-1', + artifact_refs={'provider_result': 'gs://bucket/result.json'}, + started_at=started_at, + ) + usage.finalize_provider_run( + run_id=run_id, + provider='assemblyai', + model='universal-2', + workload='background', + status='success', + started_at=started_at, + completed_at=completed_at, + raw_audio_seconds=60.0, + speech_active_seconds=42.0, + billable_seconds=60.0, + chunk_duration_seconds=15.0, + estimated_cost_usd=0.37, + retry_count=1, + fallback_count=0, + transcript_segment_count=12, + transcript_word_count=140, + speaker_cluster_count=3, + identified_speaker_cluster_count=2, + identity_match_count=2, + unknown_speaker_count=1, + unknown_speaker_duration_seconds=7.5, + split_count=1, + identity_confidence_summary={'high': 2, 'unknown': 1}, + artifact_refs={'provider_result': 'gs://bucket/result.json'}, + ) + + run_doc = fake_db.docs[(usage.RUNS_COLLECTION, 'run-1')] + assert run_doc.set_calls[0]['merge'] is False + assert run_doc.set_calls[0]['data']['status'] == 'started' + assert run_doc.set_calls[0]['data']['expires_at'] is not None + assert run_doc.set_calls[1]['merge'] is True + finalized = run_doc.set_calls[1]['data'] + assert finalized['status'] == 'success' + assert finalized['timing']['latency_ms'] == 5000 + assert finalized['chunk_duration_seconds'] == 15.0 + assert finalized['retry_count'] == 1 + assert finalized['fallback'] is None + assert finalized['identity_match_count'] == 2 + assert finalized['unknown_speaker_count'] == 1 + assert finalized['unknown_speaker_duration_seconds'] == 7.5 + assert finalized['split_count'] == 1 + assert 'transcript_text' not in finalized + assert 'words' not in finalized + + rollup_doc = fake_db.docs[(usage.DAILY_USAGE_COLLECTION, '2026-05-21:assemblyai:universal-2:background')] + rollup = rollup_doc.set_calls[0]['data'] + assert rollup['run_count'] == {'__increment': 1} + assert rollup['raw_audio_seconds'] == {'__increment': 60.0} + assert rollup['chunk_duration_seconds'] == {'__increment': 15.0} + assert rollup['estimated_cost_usd'] == {'__increment': 0.37} + assert rollup['identity_match_count'] == {'__increment': 2} + assert rollup['unknown_speaker_count'] == {'__increment': 1} + assert rollup['unknown_speaker_duration_seconds'] == {'__increment': 7.5} + assert rollup['split_count'] == {'__increment': 1} + assert rollup['identity_confidence_counts.high'] == {'__increment': 2} + assert emitted[0]['latency_seconds'] == 5.0 + assert emitted[0]['billable_seconds'] == 60.0 + assert emitted[0]['retry_count'] == 1 + + +def test_rejects_transcript_text_and_chunk_payloads(): + with pytest.raises(ValueError): + usage._reject_forbidden_payload_keys({'transcript_text': 'hello'}) + with pytest.raises(ValueError): + usage._reject_forbidden_payload_keys({'chunks': [{'start': 0}]}) + with pytest.raises(ValueError): + usage._reject_forbidden_payload_keys({'artifact_refs': {'transcript': 'gs://bucket/transcript.txt'}}) + with pytest.raises(ValueError): + usage._reject_forbidden_payload_keys({'provider': {'api_key': 'secret-aa-key'}}) + with pytest.raises(ValueError): + usage._reject_forbidden_payload_keys({'full_transcript_text': 'hello world'}) + + +def test_utc_daily_bucket_and_rollup_rebuild(monkeypatch): + fake_db = _FakeDb() + collection = fake_db.collection(usage.RUNS_COLLECTION) + included = _FakeSnapshot( + { + 'provider': 'assemblyai', + 'model': 'universal-2', + 'workload': 'background', + 'status': 'success', + 'timing': {'completed_at': datetime(2026, 5, 21, 0, 1, tzinfo=timezone.utc)}, + 'raw_audio_seconds': 10, + 'speech_active_seconds': 6, + 'billable_seconds': 10, + 'chunk_duration_seconds': 15, + 'estimated_cost_usd': 0.1, + 'retry_count': 1, + 'fallback_count': 0, + 'transcript_segment_count': 4, + 'transcript_word_count': 40, + 'speaker_cluster_count': 2, + 'identified_speaker_cluster_count': 1, + 'identity_match_count': 1, + 'unknown_speaker_count': 1, + 'unknown_speaker_duration_seconds': 3, + 'split_count': 1, + 'identity_confidence_summary': {'high': 1}, + } + ) + excluded = _FakeSnapshot( + { + 'timing': {'completed_at': datetime(2026, 5, 22, 0, 1, tzinfo=timezone.utc)}, + 'status': 'success', + 'raw_audio_seconds': 999, + } + ) + collection.stream = lambda: iter([included, excluded]) + monkeypatch.setattr(usage, 'db', fake_db) + + assert usage.utc_day_bucket(datetime(2026, 5, 21, 7, 1)) == '2026-05-21' + assert ( + usage.daily_rollup_doc_id('2026-05-21', 'provider', 'model/v1', 'sync') == '2026-05-21:provider:model_v1:sync' + ) + + rollup = usage.rebuild_daily_rollup_from_runs('2026-05-21', 'assemblyai', 'universal-2', 'background') + + assert rollup['run_count'] == 1 + assert rollup['raw_audio_seconds'] == 10.0 + assert rollup['chunk_duration_seconds'] == 15.0 + assert rollup['status_counts'] == {'success': 1} + assert rollup['identity_match_count'] == 1.0 + assert rollup['unknown_speaker_count'] == 1.0 + assert rollup['unknown_speaker_duration_seconds'] == 3.0 + assert rollup['split_count'] == 1.0 + assert rollup['identity_confidence_counts'] == {'high': 1} + + +def test_purge_provider_runs_for_user_deletes_top_level_run_records(monkeypatch): + fake_db = _FakeDb() + collection = fake_db.collection(usage.RUNS_COLLECTION) + docs = [_FakeSnapshot({'uid': 'user-1'}), _FakeSnapshot({'uid': 'user-1'})] + collection.stream = lambda: iter(docs) + monkeypatch.setattr(usage, 'db', fake_db) + + deleted = usage.purge_provider_runs_for_user('user-1') + + assert deleted == 2 + assert fake_db.batch_ref.delete.call_count == 2 + fake_db.batch_ref.commit.assert_called_once() + + +def test_metrics_reject_high_cardinality_labels(): + assert metrics.identity_confidence_bucket(None) == 'unknown' + assert metrics.identity_confidence_bucket(0.91) == 'very_high' + with pytest.raises(ValueError): + metrics._provider_metric_labels(provider='assemblyai', user_id='user-1') + with pytest.raises(ValueError): + metrics._provider_metric_labels(provider='assemblyai', transcript_text='hello world') + + +def test_fallback_metric_records_failed_provider_to_fallback_provider(monkeypatch): + observed = [] + monkeypatch.setattr(usage, 'observe_transcription_provider_request', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_audio_seconds', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_retry', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_speaker_clusters', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_identity_confidence', lambda *args, **kwargs: None) + monkeypatch.setattr( + usage, + 'observe_transcription_provider_fallback', + lambda *args, **kwargs: observed.append((args, kwargs)), + ) + + usage.emit_provider_run_metrics( + provider='deepgram', + model='nova-3', + workload='sync', + status='succeeded', + latency_seconds=1.0, + raw_audio_seconds=2.0, + speech_active_seconds=2.0, + billable_seconds=2.0, + retry_count=0, + fallback_count=1, + speaker_cluster_count=0, + identified_speaker_cluster_count=0, + fallback_provider='assemblyai', + ) + + assert observed == [(('assemblyai', 'deepgram', 'sync', 'provider_failure', 1), {})] + + +def test_fallback_ledger_rollup_and_metrics_share_provider_direction(monkeypatch): + fake_db = _FakeDb() + monkeypatch.setattr(usage, 'db', fake_db) + monkeypatch.setattr(usage.firestore, 'Increment', _inc) + observed_fallbacks = [] + observed_retries = [] + monkeypatch.setattr(usage, 'observe_transcription_provider_request', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_audio_seconds', lambda *args, **kwargs: None) + monkeypatch.setattr( + usage, + 'observe_transcription_provider_retry', + lambda *args, **kwargs: observed_retries.append((args, kwargs)) if args[4] > 0 else None, + ) + monkeypatch.setattr(usage, 'observe_transcription_provider_speaker_clusters', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_identity_confidence', lambda *args, **kwargs: None) + monkeypatch.setattr( + usage, + 'observe_transcription_provider_fallback', + lambda *args, **kwargs: observed_fallbacks.append((args, kwargs)), + ) + + started_at = datetime(2026, 5, 21, 1, 0, 0, tzinfo=timezone.utc) + completed_at = datetime(2026, 5, 21, 1, 0, 2, tzinfo=timezone.utc) + usage.create_provider_run( + uid='user-1', + provider='deepgram', + model='nova-3', + workload='sync', + run_id='run-fallback', + started_at=started_at, + ) + usage.finalize_provider_run( + run_id='run-fallback', + provider='deepgram', + model='nova-3', + workload='sync', + status='succeeded', + started_at=started_at, + completed_at=completed_at, + raw_audio_seconds=2.0, + speech_active_seconds=2.0, + billable_seconds=2.0, + retry_count=0, + fallback_count=1, + fallback_provider='assemblyai', + fallback_reason='provider_failure', + ) + + finalized = fake_db.docs[(usage.RUNS_COLLECTION, 'run-fallback')].set_calls[1]['data'] + assert finalized['fallback'] == { + 'from_provider': 'assemblyai', + 'to_provider': 'deepgram', + 'reason': 'provider_failure', + } + rollup = fake_db.docs[(usage.DAILY_USAGE_COLLECTION, '2026-05-21:deepgram:nova-3:sync')].set_calls[0]['data'] + assert rollup['fallback_count'] == {'__increment': 1} + assert observed_fallbacks == [(('assemblyai', 'deepgram', 'sync', 'provider_failure', 1), {})] + assert observed_retries == [] + + +def test_failed_run_retry_count_rolls_up_and_emits_retry_metric(monkeypatch): + fake_db = _FakeDb() + monkeypatch.setattr(usage, 'db', fake_db) + monkeypatch.setattr(usage.firestore, 'Increment', _inc) + observed_retries = [] + observed_fallbacks = [] + monkeypatch.setattr(usage, 'observe_transcription_provider_request', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_audio_seconds', lambda *args, **kwargs: None) + monkeypatch.setattr( + usage, + 'observe_transcription_provider_retry', + lambda *args, **kwargs: observed_retries.append((args, kwargs)) if args[4] > 0 else None, + ) + monkeypatch.setattr(usage, 'observe_transcription_provider_speaker_clusters', lambda *args, **kwargs: None) + monkeypatch.setattr(usage, 'observe_transcription_provider_identity_confidence', lambda *args, **kwargs: None) + monkeypatch.setattr( + usage, + 'observe_transcription_provider_fallback', + lambda *args, **kwargs: observed_fallbacks.append((args, kwargs)), + ) + + started_at = datetime(2026, 5, 21, 1, 0, 0, tzinfo=timezone.utc) + completed_at = datetime(2026, 5, 21, 1, 0, 2, tzinfo=timezone.utc) + usage.create_provider_run( + uid='user-1', + provider='assemblyai', + model='universal-2', + workload='background', + run_id='run-failed', + started_at=started_at, + ) + usage.finalize_provider_run( + run_id='run-failed', + provider='assemblyai', + model='universal-2', + workload='background', + status='failed', + started_at=started_at, + completed_at=completed_at, + raw_audio_seconds=2.0, + retry_count=1, + error_class='RuntimeError', + ) + + finalized = fake_db.docs[(usage.RUNS_COLLECTION, 'run-failed')].set_calls[1]['data'] + assert finalized['retry_count'] == 1 + assert finalized['fallback'] is None + rollup = fake_db.docs[(usage.DAILY_USAGE_COLLECTION, '2026-05-21:assemblyai:universal-2:background')].set_calls[0][ + 'data' + ] + assert rollup['retry_count'] == {'__increment': 1} + assert rollup['status_counts.failed'] == {'__increment': 1} + assert observed_retries == [(('assemblyai', 'universal-2', 'background', 'provider_retry', 1), {})] + assert observed_fallbacks == [] + + +def test_update_provider_run_identity_metrics_updates_doc_and_rollup_delta(monkeypatch): + fake_db = _FakeDb() + monkeypatch.setattr(usage, 'db', fake_db) + monkeypatch.setattr(usage.firestore, 'Increment', _inc) + started_at = datetime(2026, 5, 21, 1, 0, 0, tzinfo=timezone.utc) + completed_at = datetime(2026, 5, 21, 1, 0, 2, tzinfo=timezone.utc) + usage.create_provider_run( + uid='user-1', + provider='assemblyai', + model='universal-2', + workload='sync', + run_id='run-identity', + started_at=started_at, + ) + usage.finalize_provider_run( + run_id='run-identity', + provider='assemblyai', + model='universal-2', + workload='sync', + status='succeeded', + started_at=started_at, + completed_at=completed_at, + speaker_cluster_count=2, + identified_speaker_cluster_count=0, + provider_speaker_count=2, + mapped_speaker_count=0, + mapped_person_count=0, + unmapped_speaker_count=2, + identity_confidence_summary={'unknown': 2}, + ) + + usage.update_provider_run_identity_metrics( + run_id='run-identity', + provider='assemblyai', + model='universal-2', + workload='sync', + identified_speaker_cluster_count=1, + provider_speaker_count=2, + mapped_speaker_count=1, + mapped_person_count=1, + unmapped_speaker_count=1, + embedding_extraction_failure_count=1, + identity_metric_update_status='succeeded', + identity_confidence_summary={'very_high': 1, 'unknown': 1}, + ) + + run_doc = fake_db.docs[(usage.RUNS_COLLECTION, 'run-identity')] + update = run_doc.set_calls[-1]['data'] + assert update['identified_speaker_cluster_count'] == 1 + assert update['provider_speaker_count'] == 2 + assert update['mapped_speaker_count'] == 1 + assert update['mapped_person_count'] == 1 + assert update['unmapped_speaker_count'] == 1 + assert update['embedding_extraction_failure_count'] == 1 + assert update['identity_metric_update']['status'] == 'succeeded' + assert update['identity_confidence_summary'] == {'very_high': 1, 'unknown': 1} + rollup_doc = fake_db.docs[(usage.DAILY_USAGE_COLLECTION, '2026-05-21:assemblyai:universal-2:sync')] + rollup_delta = rollup_doc.set_calls[-1]['data'] + assert rollup_delta['identified_speaker_cluster_count'] == {'__increment': 1} + assert rollup_delta['mapped_speaker_count'] == {'__increment': 1} + assert rollup_delta['mapped_person_count'] == {'__increment': 1} + assert rollup_delta['unmapped_speaker_count'] == {'__increment': -1} + assert rollup_delta['embedding_extraction_failure_count'] == {'__increment': 1} + assert rollup_delta['identity_confidence_counts.very_high'] == {'__increment': 1} + assert rollup_delta['identity_confidence_counts.unknown'] == {'__increment': -1} + + +def test_provider_metrics_source_does_not_define_forbidden_label_names(): + forbidden_labels = { + "['provider', 'model', 'workload', 'user_id']", + "['provider', 'model', 'workload', 'conversation_id']", + "['provider', 'model', 'workload', 'provider_job_id']", + "['provider', 'model', 'workload', 'transcript_text']", + } + source = metrics.__loader__.get_source(metrics.__name__) + for label_list in forbidden_labels: + assert label_list not in source diff --git a/backend/tests/unit/test_users_add_sample_transaction.py b/backend/tests/unit/test_users_add_sample_transaction.py index 59cc08006e5..5c23f916f0c 100644 --- a/backend/tests/unit/test_users_add_sample_transaction.py +++ b/backend/tests/unit/test_users_add_sample_transaction.py @@ -1,4 +1,5 @@ import os +import importlib import sys import types from unittest.mock import MagicMock @@ -17,7 +18,7 @@ class NotFound(Exception): pass -_google_module = sys.modules.setdefault("google", types.ModuleType("google")) +_google_module = importlib.import_module("google") _google_cloud_module = sys.modules.setdefault("google.cloud", types.ModuleType("google.cloud")) _google_exceptions_module = types.ModuleType("google.cloud.exceptions") _google_exceptions_module.NotFound = NotFound diff --git a/backend/tests/unit/test_voice_message_language.py b/backend/tests/unit/test_voice_message_language.py index f9eea84fdf9..87566f6bb7b 100644 --- a/backend/tests/unit/test_voice_message_language.py +++ b/backend/tests/unit/test_voice_message_language.py @@ -2,6 +2,7 @@ Unit tests for voice message language resolution. """ +import importlib import sys import types from unittest.mock import MagicMock @@ -26,7 +27,7 @@ class NotFound(Exception): pass -_google_module = sys.modules.setdefault("google", types.ModuleType("google")) +_google_module = importlib.import_module("google") _google_cloud_module = sys.modules.setdefault("google.cloud", types.ModuleType("google.cloud")) _google_exceptions_module = types.ModuleType("google.cloud.exceptions") _google_exceptions_module.NotFound = NotFound diff --git a/backend/tests/unit/test_ws_auth_handshake.py b/backend/tests/unit/test_ws_auth_handshake.py index 91d14b25209..f8396791a83 100644 --- a/backend/tests/unit/test_ws_auth_handshake.py +++ b/backend/tests/unit/test_ws_auth_handshake.py @@ -7,6 +7,8 @@ """ import asyncio +import sys +import types import unittest from unittest.mock import patch, MagicMock @@ -15,6 +17,21 @@ from firebase_admin.auth import InvalidIdTokenError from starlette.websockets import WebSocketDisconnect +database_mod = types.ModuleType("database") +database_mod.__path__ = [] +sys.modules.setdefault("database", database_mod) + +redis_db_mod = types.ModuleType("database.redis_db") +redis_db_mod.check_rate_limit = MagicMock(return_value=True) +redis_db_mod.try_acquire_listen_lock = MagicMock(return_value=True) +sys.modules.setdefault("database.redis_db", redis_db_mod) +setattr(database_mod, "redis_db", redis_db_mod) + +users_db_mod = types.ModuleType("database.users") +users_db_mod.record_user_platform = MagicMock() +sys.modules.setdefault("database.users", users_db_mod) +setattr(database_mod, "users", users_db_mod) + from utils.other.endpoints import get_current_user_uid_ws_listen, get_current_user_uid_ws, get_current_user_uid diff --git a/backend/utils/byok.py b/backend/utils/byok.py index 39d355273dd..253ff4b475a 100644 --- a/backend/utils/byok.py +++ b/backend/utils/byok.py @@ -1,7 +1,8 @@ """Per-request BYOK (Bring Your Own Keys) key plumbing. The desktop client sends user-provided API keys as headers on every request -(`X-BYOK-OpenAI`, `X-BYOK-Anthropic`, `X-BYOK-Gemini`, `X-BYOK-Deepgram`). +(`X-BYOK-OpenAI`, `X-BYOK-Anthropic`, `X-BYOK-Gemini`, `X-BYOK-Deepgram`, and +optionally `X-BYOK-AssemblyAI` for async prerecorded STT). A FastAPI middleware stashes them in a per-request contextvar; the LLM/STT clients can then read them without re-reading the request object. @@ -68,6 +69,7 @@ def invalidate_byok_state_cache(uid: str) -> None: 'anthropic': 'x-byok-anthropic', 'gemini': 'x-byok-gemini', 'deepgram': 'x-byok-deepgram', + 'assemblyai': 'x-byok-assemblyai', } # Keys for the current request, if the client supplied them. diff --git a/backend/utils/chat.py b/backend/utils/chat.py index d58acd85394..f6ad86f7637 100644 --- a/backend/utils/chat.py +++ b/backend/utils/chat.py @@ -19,12 +19,8 @@ from utils.notifications import send_notification from utils.other.storage import get_syncing_file_temporal_signed_url, delete_syncing_temporal_file from utils.retrieval.graph import execute_graph_chat, execute_graph_chat_stream -from utils.stt.pre_recorded import ( - deepgram_prerecorded, - deepgram_prerecorded_from_bytes, - postprocess_words, - get_deepgram_model_for_language, -) +from utils.stt import provider_service as stt_provider_service +from utils.stt.providers import STTWorkload from utils.llm.usage_tracker import track_usage, set_usage_context, reset_usage_context, Features import logging @@ -73,27 +69,28 @@ def delete_file(): if not language: language = resolve_voice_message_language(uid, None) - # Get the appropriate Deepgram model for this language - stt_language, stt_model = get_deepgram_model_for_language(language) + stt_language, stt_model = stt_provider_service.resolve_prerecorded_language_model(language) is_multi = stt_language == 'multi' try: - if is_multi: - words, detected_language = deepgram_prerecorded( - url, diarize=False, language=stt_language, return_language=True, model=stt_model - ) - else: - words = deepgram_prerecorded( - url, diarize=False, language=stt_language, return_language=False, model=stt_model - ) - detected_language = stt_language + transcription = stt_provider_service.transcribe_url( + url, + workload=STTWorkload.voice_message, + uid=uid, + diarize=False, + language=stt_language, + return_language=is_multi, + model=stt_model, + ) + words = transcription.words + detected_language = transcription.detected_language if is_multi else stt_language except RuntimeError as e: logger.error(f'Voice message transcription failed for {path}: {e}') return None, stt_language if not is_multi else 'en' if not words: logger.info('no words') return None, detected_language - transcript_segments: List[TranscriptSegment] = postprocess_words(words, 0) + transcript_segments: List[TranscriptSegment] = transcription.segments del words if not transcript_segments: logger.error('failed to get deepgram segments') @@ -125,41 +122,31 @@ def transcribe_pcm_bytes( if not language: language = resolve_voice_message_language(uid, None) - stt_language, stt_model = get_deepgram_model_for_language(language) + stt_language, stt_model = stt_provider_service.resolve_prerecorded_language_model(language) is_multi = stt_language == 'multi' # Let RuntimeError propagate so the router can distinguish backend failure from no-speech - if is_multi: - result = deepgram_prerecorded_from_bytes( - audio_bytes, - sample_rate=sample_rate, - diarize=False, - encoding=encoding, - channels=channels, - language=stt_language, - model=stt_model, - return_language=True, - keywords=keywords, - ) - words, detected_language = result - else: - words = deepgram_prerecorded_from_bytes( - audio_bytes, - sample_rate=sample_rate, - diarize=False, - encoding=encoding, - channels=channels, - language=stt_language, - model=stt_model, - keywords=keywords, - ) - detected_language = stt_language + transcription = stt_provider_service.transcribe_bytes( + audio_bytes, + workload=STTWorkload.ptt, + uid=uid, + sample_rate=sample_rate, + diarize=False, + encoding=encoding, + channels=channels, + language=stt_language, + model=stt_model, + return_language=is_multi, + keywords=keywords, + ) + words = transcription.words + detected_language = transcription.detected_language if is_multi else stt_language if not words: logger.info('transcribe_pcm_bytes: no words') return None, detected_language - transcript_segments: List[TranscriptSegment] = postprocess_words(words, 0) + transcript_segments: List[TranscriptSegment] = transcription.segments del words if not transcript_segments: logger.error('transcribe_pcm_bytes: failed to get segments') @@ -190,15 +177,22 @@ def delete_file(): if not language: language = resolve_voice_message_language(uid, None) - # Get the appropriate Deepgram model for this language - stt_language, stt_model = get_deepgram_model_for_language(language) + stt_language, stt_model = stt_provider_service.resolve_prerecorded_language_model(language) try: - words = deepgram_prerecorded(url, diarize=False, language=stt_language, model=stt_model) + transcription = stt_provider_service.transcribe_url( + url, + workload=STTWorkload.voice_message, + uid=uid, + diarize=False, + language=stt_language, + model=stt_model, + ) + words = transcription.words except RuntimeError as e: logger.error(f'Voice message transcription failed for {path}: {e}') return [] - transcript_segments: List[TranscriptSegment] = postprocess_words(words, 0) + transcript_segments: List[TranscriptSegment] = transcription.segments del words if not transcript_segments: logger.error('failed to get deepgram segments') @@ -264,15 +258,22 @@ def delete_file(): if not language: language = resolve_voice_message_language(uid, None) - # Get the appropriate Deepgram model for this language - stt_language, stt_model = get_deepgram_model_for_language(language) + stt_language, stt_model = stt_provider_service.resolve_prerecorded_language_model(language) try: - words = deepgram_prerecorded(url, diarize=False, language=stt_language, model=stt_model) + transcription = stt_provider_service.transcribe_url( + url, + workload=STTWorkload.voice_message, + uid=uid, + diarize=False, + language=stt_language, + model=stt_model, + ) + words = transcription.words except RuntimeError as e: logger.error(f'Voice message transcription failed for {path}: {e}') return - transcript_segments: List[TranscriptSegment] = postprocess_words(words, 0) + transcript_segments: List[TranscriptSegment] = transcription.segments del words if not transcript_segments: logger.error('failed to get deepgram segments') diff --git a/backend/utils/conversations/desktop_background.py b/backend/utils/conversations/desktop_background.py new file mode 100644 index 00000000000..9f9b8bb9f0e --- /dev/null +++ b/backend/utils/conversations/desktop_background.py @@ -0,0 +1,243 @@ +import logging +import uuid +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Dict, List, Optional + +import database.calendar_meetings as calendar_db +import database.conversations as conversations_db +from database import redis_db +from models.conversation import Conversation +from models.conversation_enums import ConversationSource, ConversationStatus +from models.structured import Structured +from models.transcript_segment import TranscriptSegment +from utils.conversations.factory import deserialize_conversation +from utils.conversations.process_conversation import process_conversation + +logger = logging.getLogger(__name__) +_MAX_BACKGROUND_CHUNK_RECORDS = 1000 + + +class DesktopBackgroundConversationError(ValueError): + def __init__(self, message: str, status_code: int = 400): + super().__init__(message) + self.status_code = status_code + + +@dataclass +class DesktopBackgroundAppendResult: + appended: bool + duplicate: bool + segments: List[TranscriptSegment] + chunk_record: Optional[Dict] + + +def create_in_progress_desktop_conversation( + uid: str, + language: str, + source: ConversationSource = ConversationSource.desktop, + private_cloud_sync_enabled: bool = False, + call_id: Optional[str] = None, + session_id: Optional[str] = None, +) -> str: + """Create a desktop/listen in-progress conversation and Redis pointer.""" + new_conversation_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + stub_conversation = Conversation( + id=new_conversation_id, + created_at=now, + started_at=now, + finished_at=now, + structured=Structured(), + language=language, + transcript_segments=[], + photos=[], + status=ConversationStatus.in_progress, + source=source, + private_cloud_sync_enabled=private_cloud_sync_enabled, + call_id=call_id, + ) + conversations_db.upsert_conversation(uid, conversation_data=stub_conversation.dict()) + redis_db.set_in_progress_conversation_id(uid, new_conversation_id) + + detected_meeting_id = _detect_current_desktop_meeting(uid) if source == ConversationSource.desktop else None + if detected_meeting_id: + redis_db.set_conversation_meeting_id(new_conversation_id, detected_meeting_id) + + logger.info( + "Created new in-progress conversation: %s uid=%s session=%s source=%s", + new_conversation_id, + uid, + session_id, + source.value, + ) + return new_conversation_id + + +def append_segments_to_in_progress_conversation( + uid: str, + conversation_id: str, + segments: List[TranscriptSegment], + finished_at: datetime, +) -> List[TranscriptSegment]: + """Append transcript segments to the in-progress conversation and bump finished_at.""" + conversation_data = conversations_db.get_conversation(uid, conversation_id) + if not conversation_data: + raise ValueError("conversation not found") + + conversation = deserialize_conversation(conversation_data) + if conversation.status != ConversationStatus.in_progress: + raise ValueError("conversation is not in_progress") + + if not segments: + conversations_db.update_conversation_finished_at(uid, conversation_id, finished_at) + return [] + + conversation.transcript_segments, updated_segments, _removed_ids = TranscriptSegment.combine_segments( + conversation.transcript_segments, + segments, + ) + conversations_db.update_conversation_segments( + uid, + conversation.id, + [segment.dict() for segment in conversation.transcript_segments], + finished_at=finished_at, + ) + return updated_segments + + +def append_background_chunk_to_in_progress_conversation( + uid: str, + conversation_id: str, + chunk_id: str, + payload_hash: str, + segments: List[TranscriptSegment], + finished_at: datetime, + provider: Optional[str], + run_id: Optional[str], + chunk_start_ms: int, + chunk_duration_ms: int, +) -> DesktopBackgroundAppendResult: + """Append one desktop background chunk once, keyed by stable client chunk_id.""" + conversation_data = conversations_db.get_conversation(uid, conversation_id) + if not conversation_data: + raise DesktopBackgroundConversationError("conversation_id not found", status_code=404) + + conversation = deserialize_conversation(conversation_data) + if conversation.status != ConversationStatus.in_progress: + raise DesktopBackgroundConversationError("conversation is not in_progress", status_code=409) + + processed_chunks = dict(conversation.background_processed_chunks or {}) + existing_record = processed_chunks.get(chunk_id) + if existing_record: + if existing_record.get('payload_hash') != payload_hash: + raise DesktopBackgroundConversationError("chunk_id payload mismatch", status_code=409) + conversations_db.update_conversation_finished_at(uid, conversation_id, finished_at) + logger.info( + "Duplicate desktop background chunk ignored uid=%s conversation_id=%s chunk_id=%s provider=%s", + uid, + conversation_id, + chunk_id, + existing_record.get('provider'), + ) + return DesktopBackgroundAppendResult( + appended=False, + duplicate=True, + segments=[], + chunk_record=existing_record, + ) + + if segments: + conversation.transcript_segments, updated_segments, _removed_ids = TranscriptSegment.combine_segments( + conversation.transcript_segments, + segments, + ) + else: + updated_segments = [] + + processed_chunks[chunk_id] = { + 'chunk_id': chunk_id, + 'payload_hash': payload_hash, + 'provider': provider, + 'run_id': run_id, + 'segment_count': len(segments), + 'chunk_start_ms': chunk_start_ms, + 'chunk_duration_ms': chunk_duration_ms, + 'accepted_at': finished_at.isoformat(), + } + processed_chunks = _prune_background_chunk_records(processed_chunks) + + conversations_db.update_conversation_segments_and_background_chunks( + uid, + conversation.id, + [segment.dict() for segment in conversation.transcript_segments], + processed_chunks, + finished_at=finished_at, + ) + return DesktopBackgroundAppendResult( + appended=True, + duplicate=False, + segments=updated_segments, + chunk_record=processed_chunks.get(chunk_id), + ) + + +def get_background_chunk_record(uid: str, conversation_id: str, chunk_id: str) -> Optional[Dict]: + conversation_data = conversations_db.get_conversation(uid, conversation_id) + if not conversation_data: + raise DesktopBackgroundConversationError("conversation_id not found", status_code=404) + if conversation_data.get('status') != ConversationStatus.in_progress: + raise DesktopBackgroundConversationError("conversation is not in_progress", status_code=409) + return (conversation_data.get('background_processed_chunks') or {}).get(chunk_id) + + +def _prune_background_chunk_records(processed_chunks: Dict[str, Dict]) -> Dict[str, Dict]: + if len(processed_chunks) <= _MAX_BACKGROUND_CHUNK_RECORDS: + return processed_chunks + + ordered = sorted( + processed_chunks.items(), + key=lambda item: (item[1].get('accepted_at') or '', item[0]), + ) + return dict(ordered[-_MAX_BACKGROUND_CHUNK_RECORDS:]) + + +def finish_desktop_background_conversation(uid: str, conversation_id: str) -> Conversation: + """Finalize one explicit desktop background conversation by ID.""" + conversation_data = conversations_db.get_conversation(uid, conversation_id) + if not conversation_data: + raise DesktopBackgroundConversationError("conversation_id not found", status_code=404) + + conversation = deserialize_conversation(conversation_data) + if conversation.status == ConversationStatus.completed: + return conversation + if conversation.status != ConversationStatus.in_progress: + raise DesktopBackgroundConversationError("conversation is not in_progress", status_code=409) + + conversations_db.update_conversation_status(uid, conversation.id, ConversationStatus.processing) + processed_conversation = process_conversation(uid, conversation.language, conversation, force_process=True) + + if redis_db.get_in_progress_conversation_id(uid) == conversation.id: + redis_db.remove_in_progress_conversation_id(uid) + + logger.info( + "Finished desktop background conversation: %s uid=%s segments=%s", + conversation.id, + uid, + len(processed_conversation.transcript_segments), + ) + return processed_conversation + + +def _detect_current_desktop_meeting(uid: str) -> Optional[str]: + now = datetime.now(timezone.utc) + time_window = timedelta(minutes=2) + meetings = calendar_db.get_meetings_in_time_range(uid, now - time_window, now + time_window) + + if len(meetings) == 1: + return meetings[0]['id'] + if len(meetings) <= 1: + return None + + closest_meeting = min(meetings, key=lambda meeting: abs((meeting['start_time'] - now).total_seconds())) + return closest_meeting['id'] diff --git a/backend/utils/conversations/postprocess_conversation.py b/backend/utils/conversations/postprocess_conversation.py index a2cc66375a3..1285827f0d7 100644 --- a/backend/utils/conversations/postprocess_conversation.py +++ b/backend/utils/conversations/postprocess_conversation.py @@ -15,7 +15,8 @@ from models.transcript_segment import TranscriptSegment from utils.conversations.process_conversation import process_conversation, process_user_emotion from utils.other.storage import upload_postprocessing_audio, delete_postprocessing_audio, upload_conversation_recording -from utils.stt.pre_recorded import deepgram_prerecorded, postprocess_words +from utils.stt import provider_service as stt_provider_service +from utils.stt.providers import STTWorkload from utils.stt.speech_profile import get_speech_profile_matching_predictions from utils.stt.vad import vad_is_empty import logging @@ -79,8 +80,15 @@ def postprocess_conversation( upload_conversation_recording(file_path, uid, conversation_id) speakers_count = len(set([segment.speaker for segment in conversation.transcript_segments])) - words = deepgram_prerecorded(signed_url, speakers_count=speakers_count) - fal_segments = postprocess_words(words, aseg.duration_seconds) + transcription = stt_provider_service.transcribe_url( + signed_url, + workload=STTWorkload.postprocess, + uid=uid, + conversation_id=conversation_id, + speakers_count=speakers_count, + raw_audio_seconds=aseg.duration_seconds, + ) + fal_segments = transcription.segments # if new transcript is 90% shorter than the original, cancel post-processing, smth wrong with audio or FAL count = len(''.join([segment.text.strip() for segment in conversation.transcript_segments])) @@ -93,6 +101,16 @@ def postprocess_conversation( _handle_segment_embedding_matching(uid, file_path, conversation.transcript_segments, aseg) else: _handle_segment_embedding_matching(uid, file_path, fal_segments, aseg) + try: + stt_provider_service.update_provider_run_identity_metrics( + transcription.run_id, + transcription.result.provider, + transcription.result.model or 'unknown', + STTWorkload.postprocess, + fal_segments, + ) + except Exception as e: + logger.warning(f'Speaker ID (postprocess): identity metric update failed for {conversation_id}: {e}') # Store both models results. conversations_db.store_model_segments_result( diff --git a/backend/utils/conversations/render.py b/backend/utils/conversations/render.py index 0e261d5a964..933f7620ac8 100644 --- a/backend/utils/conversations/render.py +++ b/backend/utils/conversations/render.py @@ -37,13 +37,26 @@ def populate_speaker_names(uid: str, conversations: List[Dict]) -> None: people_map = {p['id']: p['name'] for p in people_data} for conv in conversations: + provider_cluster_display_ids = {} + next_provider_display_id = 1 + for seg in conv.get('transcript_segments', []): + provider_cluster_key = _provider_display_cluster_key(seg) + if seg.get('is_user') or not provider_cluster_key: + continue + if provider_cluster_key not in provider_cluster_display_ids: + provider_cluster_display_ids[provider_cluster_key] = next_provider_display_id + next_provider_display_id += 1 for seg in conv.get('transcript_segments', []): if seg.get('is_user'): seg['speaker_name'] = user_name elif seg.get('person_id') and seg['person_id'] in people_map: seg['speaker_name'] = people_map[seg['person_id']] else: - seg['speaker_name'] = f"Speaker {seg.get('speaker_id', 0)}" + provider_cluster_key = _provider_display_cluster_key(seg) + if provider_cluster_key in provider_cluster_display_ids: + seg['speaker_name'] = f"Speaker {provider_cluster_display_ids[provider_cluster_key]}" + else: + seg['speaker_name'] = f"Speaker {seg.get('speaker_id', 0)}" def populate_folder_names(uid: str, conversations: List[Dict]) -> None: @@ -69,6 +82,27 @@ def populate_folder_names(uid: str, conversations: List[Dict]) -> None: conv['folder_name'] = folder_map.get(folder_id) if folder_id else None +def _provider_display_cluster_key(segment: Dict[str, Any]) -> str | None: + if not segment.get('provider_cluster_id') and not segment.get('provider_speaker_label'): + return None + if _legacy_speaker_id(segment) not in (None, 0): + return None + return segment.get('provider_cluster_id') or segment.get('provider_speaker_label') + + +def _legacy_speaker_id(segment: Dict[str, Any]) -> int | None: + speaker_id = segment.get('speaker_id') + if speaker_id is not None: + return speaker_id + speaker = segment.get('speaker') + if not speaker: + return None + try: + return int(str(speaker).split('_', 1)[1]) + except (ValueError, IndexError): + return 0 + + # --------------------------------------------------------------------------- # Redact: locked-content stripping # --------------------------------------------------------------------------- diff --git a/backend/utils/metrics.py b/backend/utils/metrics.py index 30d06192d02..45990ffe1be 100644 --- a/backend/utils/metrics.py +++ b/backend/utils/metrics.py @@ -1,4 +1,6 @@ -from prometheus_client import Counter, Gauge, generate_latest, CONTENT_TYPE_LATEST +from typing import Optional + +from prometheus_client import Counter, Gauge, Histogram, generate_latest, CONTENT_TYPE_LATEST from fastapi import Response BACKEND_LISTEN_ACTIVE_WS_CONNECTIONS = Gauge( @@ -26,6 +28,188 @@ 'Number of sessions currently in degraded mode (pusher unavailable)', ) +TRANSCRIPTION_PROVIDER_REQUESTS = Counter( + 'transcription_provider_requests_total', + 'Total transcription provider requests by provider, model, workload, and status', + ['provider', 'model', 'workload', 'status'], +) + +TRANSCRIPTION_PROVIDER_LATENCY_SECONDS = Histogram( + 'transcription_provider_latency_seconds', + 'Transcription provider request latency by provider, model, workload, and status', + ['provider', 'model', 'workload', 'status'], + buckets=(0.5, 1, 2.5, 5, 10, 30, 60, 120, 300, 600, float('inf')), +) + +TRANSCRIPTION_PROVIDER_RETRIES = Counter( + 'transcription_provider_retries_total', + 'Total transcription provider retry attempts by provider, model, workload, and reason', + ['provider', 'model', 'workload', 'reason'], +) + +TRANSCRIPTION_PROVIDER_FALLBACKS = Counter( + 'transcription_provider_fallbacks_total', + 'Total transcription provider fallbacks by from provider, to provider, and workload', + ['from_provider', 'to_provider', 'workload', 'reason'], +) + +TRANSCRIPTION_PROVIDER_AUDIO_SECONDS = Counter( + 'transcription_provider_audio_seconds_total', + 'Total transcription provider audio seconds by provider, model, workload, and kind', + ['provider', 'model', 'workload', 'kind'], +) + +TRANSCRIPTION_PROVIDER_BILLABLE_SECONDS = Counter( + 'transcription_provider_billable_seconds_total', + 'Total transcription provider billable seconds by provider, model, and workload', + ['provider', 'model', 'workload'], +) + +TRANSCRIPTION_PROVIDER_SPEAKER_CLUSTERS = Counter( + 'transcription_provider_speaker_clusters_total', + 'Total transcription provider speaker clusters by provider, model, workload, and kind', + ['provider', 'model', 'workload', 'kind'], +) + +TRANSCRIPTION_PROVIDER_IDENTITY_CONFIDENCE = Counter( + 'transcription_provider_identity_confidence_total', + 'Total speaker identity assignments by provider, model, workload, and confidence bucket', + ['provider', 'model', 'workload', 'bucket'], +) + +TRANSCRIPTION_PROVIDER_ALLOWED_LABELS = { + 'provider', + 'model', + 'workload', + 'status', + 'reason', + 'from_provider', + 'to_provider', + 'kind', + 'bucket', +} + +TRANSCRIPTION_PROVIDER_FORBIDDEN_LABELS = { + 'uid', + 'user_id', + 'conversation_id', + 'provider_job_id', + 'transcript', + 'transcript_text', + 'text', + 'run_id', +} + + +def _provider_metric_labels(**labels: str) -> dict: + unexpected = set(labels) - TRANSCRIPTION_PROVIDER_ALLOWED_LABELS + forbidden = set(labels) & TRANSCRIPTION_PROVIDER_FORBIDDEN_LABELS + if unexpected or forbidden: + raise ValueError(f'Unsafe transcription provider metric labels: {sorted(unexpected | forbidden)}') + return {key: str(value or 'unknown') for key, value in labels.items()} + + +def observe_transcription_provider_request( + provider: str, + model: str, + workload: str, + status: str, + latency_seconds: float, +) -> None: + labels = _provider_metric_labels(provider=provider, model=model, workload=workload, status=status) + TRANSCRIPTION_PROVIDER_REQUESTS.labels(**labels).inc() + if latency_seconds >= 0: + TRANSCRIPTION_PROVIDER_LATENCY_SECONDS.labels(**labels).observe(latency_seconds) + + +def observe_transcription_provider_retry(provider: str, model: str, workload: str, reason: str, count: int = 1) -> None: + if count <= 0: + return + labels = _provider_metric_labels(provider=provider, model=model, workload=workload, reason=reason) + TRANSCRIPTION_PROVIDER_RETRIES.labels(**labels).inc(count) + + +def observe_transcription_provider_fallback( + from_provider: str, + to_provider: str, + workload: str, + reason: str, + count: int = 1, +) -> None: + if count <= 0: + return + labels = _provider_metric_labels( + from_provider=from_provider, + to_provider=to_provider, + workload=workload, + reason=reason, + ) + TRANSCRIPTION_PROVIDER_FALLBACKS.labels(**labels).inc(count) + + +def observe_transcription_provider_audio_seconds( + provider: str, + model: str, + workload: str, + raw_audio_seconds: float = 0.0, + speech_active_seconds: float = 0.0, + billable_seconds: float = 0.0, +) -> None: + for kind, seconds in ( + ('raw', raw_audio_seconds), + ('speech_active', speech_active_seconds), + ('billable', billable_seconds), + ): + if seconds <= 0: + continue + labels = _provider_metric_labels(provider=provider, model=model, workload=workload, kind=kind) + TRANSCRIPTION_PROVIDER_AUDIO_SECONDS.labels(**labels).inc(seconds) + if billable_seconds > 0: + labels = _provider_metric_labels(provider=provider, model=model, workload=workload) + TRANSCRIPTION_PROVIDER_BILLABLE_SECONDS.labels(**labels).inc(billable_seconds) + + +def observe_transcription_provider_speaker_clusters( + provider: str, + model: str, + workload: str, + speaker_cluster_count: int = 0, + identified_speaker_cluster_count: int = 0, +) -> None: + for kind, count in ( + ('provider', speaker_cluster_count), + ('identified', identified_speaker_cluster_count), + ): + if count <= 0: + continue + labels = _provider_metric_labels(provider=provider, model=model, workload=workload, kind=kind) + TRANSCRIPTION_PROVIDER_SPEAKER_CLUSTERS.labels(**labels).inc(count) + + +def identity_confidence_bucket(confidence: Optional[float]) -> str: + if confidence is None: + return 'unknown' + if confidence >= 0.90: + return 'very_high' + if confidence >= 0.75: + return 'high' + if confidence >= 0.50: + return 'medium' + return 'low' + + +def observe_transcription_provider_identity_confidence( + provider: str, + model: str, + workload: str, + bucket: str, + count: int = 1, +) -> None: + if count <= 0: + return + labels = _provider_metric_labels(provider=provider, model=model, workload=workload, bucket=bucket) + TRANSCRIPTION_PROVIDER_IDENTITY_CONFIDENCE.labels(**labels).inc(count) + def metrics_response() -> Response: return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST) diff --git a/backend/utils/rate_limit_config.py b/backend/utils/rate_limit_config.py index f4db6bc418d..f3a587dedbe 100644 --- a/backend/utils/rate_limit_config.py +++ b/backend/utils/rate_limit_config.py @@ -46,6 +46,8 @@ "voice:transcribe": (60, 3600), "voice:transcribe_stream": (60, 3600), "voice:message": (60, 3600), + "desktop:background_transcribe": (120, 3600), + "desktop:background_conversation_finish": (60, 3600), "file:upload": (40, 3600), # Agent/MCP — bursty tool calls "agent:execute_tool": (120, 3600), diff --git a/backend/utils/self_voice_review.py b/backend/utils/self_voice_review.py new file mode 100644 index 00000000000..803cc9a607e --- /dev/null +++ b/backend/utils/self_voice_review.py @@ -0,0 +1,254 @@ +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Callable, Optional, Union + +import numpy as np + +from database import self_voice_review as review_db +from database import users as users_db +from models.transcript_segment import TranscriptSegment +from utils.stt.background_speaker_identity import ( + ClusterIdentityAssignment, + ClusterSampleSpan, + USER_SELF_PERSON_ID, + select_representative_cluster_spans, +) +from utils.stt.speaker_embedding import extract_embedding_from_bytes + +MIN_REVIEW_SAMPLE_SECONDS = 5.0 +MAX_REVIEW_SAMPLE_SECONDS = 10.0 +MIN_VOICED_RATIO = 0.65 +MIN_VAD_CONFIDENCE = 0.75 +MAX_NOISE_SCORE = 0.35 +SKIP_COOLDOWN_DAYS = 14 +SELF_VOICE_REVIEW_VERSION = 'self-voice-review:v1' +SELF_VOICE_EMBEDDING_SOURCE = 'self_voice_review_confirm' + + +@dataclass +class SegmentQuality: + voiced_seconds: Optional[float] = None + vad_confidence: Optional[float] = None + noise_score: Optional[float] = None + + +@dataclass +class CandidateSelectionResult: + candidate: Optional[dict] = None + reason: Optional[str] = None + spans: list[ClusterSampleSpan] = field(default_factory=list) + + +def build_self_voice_review_candidate( + uid: str, + conversation_id: str, + provider_cluster_id: str, + cluster_segments: list[TranscriptSegment], + all_segments: list[TranscriptSegment], + identity_assignment: Optional[ClusterIdentityAssignment] = None, + quality_by_segment_id: Optional[dict[str, Union[SegmentQuality, dict]]] = None, + audio_artifact_ref: Optional[str] = None, + audio_retention_allowed: bool = False, + now: Optional[datetime] = None, +) -> CandidateSelectionResult: + now = now or _utc_now() + quality_by_segment_id = quality_by_segment_id or {} + + if not audio_retention_allowed or not audio_artifact_ref: + return CandidateSelectionResult(reason='audio_unavailable_or_retention_disallowed') + + spans = select_representative_cluster_spans( + cluster_segments, + all_segments, + preferred_seconds=MIN_REVIEW_SAMPLE_SECONDS, + max_seconds=MAX_REVIEW_SAMPLE_SECONDS, + ) + quality = _score_candidate_quality(spans, quality_by_segment_id) + if not spans: + return CandidateSelectionResult(reason='no_clean_voiced_window', spans=spans) + if quality['sample_seconds'] < MIN_REVIEW_SAMPLE_SECONDS: + return CandidateSelectionResult(reason='sample_too_short', spans=spans) + if quality['vad_confidence'] is not None and quality['vad_confidence'] < MIN_VAD_CONFIDENCE: + return CandidateSelectionResult(reason='low_vad_confidence', spans=spans) + if quality['voiced_ratio'] is not None and quality['voiced_ratio'] < MIN_VOICED_RATIO: + return CandidateSelectionResult(reason='low_voiced_ratio', spans=spans) + if quality['noise_score'] is not None and quality['noise_score'] > MAX_NOISE_SCORE: + return CandidateSelectionResult(reason='noisy_clip', spans=spans) + + segment_ids = [span.segment_id for span in spans] + candidate_id = review_db.candidate_id_from_source(conversation_id, provider_cluster_id, segment_ids) + negative_marker_id = review_db.marker_id_from_source(conversation_id, provider_cluster_id) + if review_db.get_candidate(uid, candidate_id): + return CandidateSelectionResult(reason='already_confirmed_or_pending', spans=spans) + if review_db.has_negative_marker(uid, negative_marker_id): + return CandidateSelectionResult(reason='rejected_source', spans=spans) + if review_db.recently_shown_source_exists(uid, conversation_id, provider_cluster_id, now=now): + return CandidateSelectionResult(reason='recently_shown', spans=spans) + + confidence_bucket = _confidence_bucket(identity_assignment, quality) + if confidence_bucket is None: + return CandidateSelectionResult(reason='not_self_voice_candidate', spans=spans) + + candidate = { + 'candidate_id': candidate_id, + 'source': { + 'conversation_id': conversation_id, + 'provider_cluster_id': provider_cluster_id, + 'segment_ids': segment_ids, + 'sample_spans': [span.as_dict() for span in spans], + }, + 'confidence_bucket': confidence_bucket, + 'quality_scores': quality, + 'review_status': 'pending', + 'reviewed_at': None, + 'cooldown_until': None, + 'retention': { + 'audio_artifact_ref': audio_artifact_ref, + 'audio_retention_allowed': True, + 'transcript_text_stored': False, + 'expires_at': now + timedelta(days=review_db.DEFAULT_CANDIDATE_TTL_DAYS), + }, + 'negative_review_marker': None, + 'created_at': now, + 'updated_at': now, + 'expires_at': now + timedelta(days=review_db.DEFAULT_CANDIDATE_TTL_DAYS), + 'version': SELF_VOICE_REVIEW_VERSION, + } + review_db.upsert_candidate(uid, candidate) + return CandidateSelectionResult(candidate=candidate, spans=spans) + + +def confirm_self_voice_candidate( + uid: str, + candidate_id: str, + audio_bytes: Optional[bytes] = None, + embedding: Optional[np.ndarray] = None, + embedding_extractor: Callable[[bytes, str], np.ndarray] = extract_embedding_from_bytes, +) -> dict: + candidate = _require_candidate(uid, candidate_id) + if candidate.get('review_status') not in ('pending', 'skipped'): + raise ValueError('candidate is not confirmable') + + retention = candidate.get('retention') or {} + if not retention.get('audio_retention_allowed') and embedding is None: + raise ValueError('candidate audio is not available for confirmation') + + if embedding is None: + if not audio_bytes: + raise ValueError('audio_bytes or embedding is required to confirm self voice') + embedding = embedding_extractor(audio_bytes, f'self_voice_review_{candidate_id}.wav') + + users_db.set_user_speaker_embedding(uid, embedding.reshape(1, -1).flatten().tolist()) + review_db.mark_candidate_confirmed(uid, candidate_id, SELF_VOICE_EMBEDDING_SOURCE) + return {'candidate_id': candidate_id, 'review_status': 'confirmed'} + + +def reject_self_voice_candidate(uid: str, candidate_id: str) -> dict: + candidate = _require_candidate(uid, candidate_id) + if candidate.get('review_status') == 'confirmed': + raise ValueError('confirmed candidate must be deleted instead of rejected') + marker_id = review_db.mark_candidate_rejected(uid, candidate) + return {'candidate_id': candidate_id, 'review_status': 'rejected', 'negative_marker_id': marker_id} + + +def skip_self_voice_candidate( + uid: str, + candidate_id: str, + now: Optional[datetime] = None, + cooldown_days: int = SKIP_COOLDOWN_DAYS, +) -> dict: + _require_candidate(uid, candidate_id) + now = now or _utc_now() + cooldown_until = now + timedelta(days=cooldown_days) + review_db.mark_candidate_skipped(uid, candidate_id, cooldown_until, reviewed_at=now) + return { + 'candidate_id': candidate_id, + 'review_status': 'pending', + 'last_review_action': 'skipped', + 'cooldown_until': cooldown_until, + } + + +def delete_confirmed_self_voice_sample(uid: str, candidate_id: str) -> bool: + return review_db.delete_confirmed_sample(uid, candidate_id) + + +def _require_candidate(uid: str, candidate_id: str) -> dict: + candidate = review_db.get_candidate(uid, candidate_id) + if not candidate: + raise ValueError('self voice review candidate not found') + return candidate + + +def _score_candidate_quality( + spans: list[ClusterSampleSpan], + quality_by_segment_id: dict[str, Union[SegmentQuality, dict]], +) -> dict: + sample_seconds = round(sum(span.duration for span in spans), 3) + voiced_seconds = 0.0 + vad_values = [] + noise_values = [] + has_voiced_signal = False + + for span in spans: + quality = _quality_for_segment(quality_by_segment_id.get(span.segment_id)) + if not quality: + continue + if quality.voiced_seconds is not None: + has_voiced_signal = True + voiced_seconds += min(max(quality.voiced_seconds, 0.0), span.duration) + if quality.vad_confidence is not None: + vad_values.append(quality.vad_confidence) + if quality.noise_score is not None: + noise_values.append(quality.noise_score) + + voiced_ratio = None + if has_voiced_signal and sample_seconds > 0: + voiced_ratio = round(voiced_seconds / sample_seconds, 3) + + return { + 'sample_seconds': sample_seconds, + 'voiced_ratio': voiced_ratio, + 'vad_confidence': round(sum(vad_values) / len(vad_values), 3) if vad_values else None, + 'noise_score': round(sum(noise_values) / len(noise_values), 3) if noise_values else None, + 'overlapped_speech': False, + } + + +def _quality_for_segment(value: Optional[Union[SegmentQuality, dict]]) -> Optional[SegmentQuality]: + if value is None: + return None + if isinstance(value, SegmentQuality): + return value + return SegmentQuality( + voiced_seconds=value.get('voiced_seconds'), + vad_confidence=value.get('vad_confidence'), + noise_score=value.get('noise_score'), + ) + + +def _confidence_bucket(assignment: Optional[ClusterIdentityAssignment], quality: dict) -> Optional[str]: + if assignment is None: + return None + if assignment.person_id: + return None + if assignment.state == 'user' and assignment.confidence is not None and assignment.confidence >= 0.7: + return 'high' + + for candidate in assignment.candidates: + if not candidate.get('is_user') and candidate.get('person_id') != USER_SELF_PERSON_ID: + continue + confidence = candidate.get('confidence') + distance = candidate.get('distance') + if confidence is not None and confidence >= 0.35: + return 'low' + if distance is not None and distance < 0.45: + return 'low' + + if quality['sample_seconds'] >= MIN_REVIEW_SAMPLE_SECONDS: + return 'low' + return None + + +def _utc_now() -> datetime: + return datetime.now(timezone.utc) diff --git a/backend/utils/stt/assemblyai_adapter.py b/backend/utils/stt/assemblyai_adapter.py new file mode 100644 index 00000000000..9751b677112 --- /dev/null +++ b/backend/utils/stt/assemblyai_adapter.py @@ -0,0 +1,290 @@ +import os +import time +from typing import Callable, Optional, Sequence, Tuple, Union + +import httpx + +from models.transcript_segment import ProviderTranscriptResult, ProviderTranscriptUtterance, ProviderTranscriptWord +from utils.stt.providers import STTProviderName + + +class AssemblyAIError(RuntimeError): + pass + + +class AssemblyAIProviderError(AssemblyAIError): + pass + + +class AssemblyAIRetryableError(AssemblyAIError): + pass + + +class AssemblyAITimeoutError(AssemblyAIError): + pass + + +def assemblyai_speaker_fields(speaker_id) -> dict: + if speaker_id is None: + return {'provider_cluster_id': None, 'provider_speaker_label': None} + + provider_cluster_id = str(speaker_id) + return { + 'provider_cluster_id': provider_cluster_id, + 'provider_speaker_label': f'ASSEMBLYAI_SPEAKER_{provider_cluster_id}', + } + + +def normalize_assemblyai_transcript_result( + result: dict, model: str, language: Optional[str] = None +) -> ProviderTranscriptResult: + status = result.get('status') + if status and status != 'completed': + raise AssemblyAIProviderError(f'AssemblyAI transcript status is {status}') + + utterances = [_normalize_assemblyai_utterance(utterance) for utterance in result.get('utterances') or []] + words = [_normalize_assemblyai_word(word) for word in result.get('words') or []] + if not words and utterances: + words = [word for utterance in utterances for word in (utterance.words or [])] + + requested_language = None if language == 'multi' else language + return ProviderTranscriptResult( + provider=STTProviderName.assemblyai.value, + model=result.get('speech_model_used') or result.get('speech_model') or model, + language=_normalize_language(result.get('language_code') or requested_language), + duration=_seconds_float(result.get('audio_duration')), + words=words, + utterances=utterances, + raw_provider_result_id=result.get('id'), + ) + + +def _normalize_assemblyai_word(word: dict) -> ProviderTranscriptWord: + speaker_fields = assemblyai_speaker_fields(word.get('speaker')) + return ProviderTranscriptWord( + text=word.get('text', ''), + start=_milliseconds_to_seconds(word.get('start')), + end=_milliseconds_to_seconds(word.get('end')), + provider_cluster_id=speaker_fields['provider_cluster_id'], + speaker_label=speaker_fields['provider_speaker_label'], + confidence=word.get('confidence'), + ) + + +def _normalize_assemblyai_utterance(utterance: dict) -> ProviderTranscriptUtterance: + speaker_fields = assemblyai_speaker_fields(utterance.get('speaker')) + utterance_words = utterance.get('words') or [] + return ProviderTranscriptUtterance( + text=utterance.get('text', ''), + start=_milliseconds_to_seconds(utterance.get('start')), + end=_milliseconds_to_seconds(utterance.get('end')), + provider_cluster_id=speaker_fields['provider_cluster_id'], + speaker_label=speaker_fields['provider_speaker_label'], + confidence=utterance.get('confidence'), + words=[_normalize_assemblyai_word(word) for word in utterance_words] if utterance_words else None, + ) + + +def _milliseconds_to_seconds(value) -> float: + if value is None: + return 0.0 + return float(value) / 1000.0 + + +def _seconds_float(value) -> float: + if value is None: + return 0.0 + return float(value) + + +def _normalize_language(language: Optional[str]) -> Optional[str]: + if language and '_' in language: + return language.split('_', 1)[0] + if language and '-' in language: + return language.split('-', 1)[0] + return language + + +class AssemblyAIAsyncTranscriptionProvider: + provider_name = STTProviderName.assemblyai + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + timeout: Optional[httpx.Timeout] = None, + poll_interval_seconds: Optional[float] = None, + max_poll_seconds: Optional[float] = None, + client_factory: Callable[[], httpx.Client] = httpx.Client, + sleeper: Callable[[float], None] = time.sleep, + clock: Callable[[], float] = time.monotonic, + ): + self._api_key = api_key or os.getenv('ASSEMBLYAI_API_KEY') + self._base_url = (base_url or os.getenv('ASSEMBLYAI_BASE_URL') or 'https://api.assemblyai.com').rstrip('/') + self._timeout = timeout or httpx.Timeout(30.0, read=30.0) + self._poll_interval_seconds = float( + poll_interval_seconds + if poll_interval_seconds is not None + else os.getenv('ASSEMBLYAI_POLL_INTERVAL_SECONDS', '3') + ) + self._max_poll_seconds = float( + max_poll_seconds if max_poll_seconds is not None else os.getenv('ASSEMBLYAI_MAX_POLL_SECONDS', '900') + ) + self._client_factory = client_factory + self._sleeper = sleeper + self._clock = clock + + def transcribe_url( + self, + audio_url: str, + speakers_count: int = None, + return_language: bool = False, + diarize: bool = True, + language: Optional[str] = None, + model: str = 'universal-2', + keywords: Optional[Sequence[str]] = None, + ) -> Union[ProviderTranscriptResult, Tuple[ProviderTranscriptResult, str]]: + payload = self._transcript_payload( + audio_url=audio_url, + speakers_count=speakers_count, + return_language=return_language, + diarize=diarize, + language=language, + model=model, + keywords=keywords, + ) + result = self._submit_and_poll(payload) + transcript_result = normalize_assemblyai_transcript_result(result, model=model, language=language) + if return_language: + return transcript_result, transcript_result.language or 'en' + return transcript_result + + def transcribe_bytes( + self, + audio_bytes: bytes, + sample_rate: int = 16000, + diarize: bool = True, + encoding: Optional[str] = None, + channels: int = 1, + language: Optional[str] = None, + model: str = 'universal-2', + return_language: bool = False, + keywords: Optional[Sequence[str]] = None, + ) -> Union[ProviderTranscriptResult, Tuple[ProviderTranscriptResult, str]]: + del sample_rate, encoding, channels + upload_url = self._upload_audio(audio_bytes) + return self.transcribe_url( + upload_url, + return_language=return_language, + diarize=diarize, + language=language, + model=model, + keywords=keywords, + ) + + def _headers(self, content_type: Optional[str] = 'application/json') -> dict: + if not self._api_key: + raise AssemblyAIProviderError('ASSEMBLYAI_API_KEY is not configured') + headers = {'Authorization': self._api_key} + if content_type: + headers['Content-Type'] = content_type + return headers + + def _transcript_payload( + self, + audio_url: str, + speakers_count: int = None, + return_language: bool = False, + diarize: bool = True, + language: Optional[str] = None, + model: str = 'universal-2', + keywords: Optional[Sequence[str]] = None, + ) -> dict: + payload = { + 'audio_url': audio_url, + 'speaker_labels': diarize, + 'punctuate': True, + 'format_text': True, + } + if model: + payload['speech_models'] = [model] if isinstance(model, str) else list(model) + if speakers_count: + payload['speakers_expected'] = speakers_count + if language and language != 'multi': + payload['language_code'] = language + elif return_language or language == 'multi': + payload['language_detection'] = True + if keywords: + payload['keyterms_prompt'] = list(keywords) + return payload + + def _upload_audio(self, audio_bytes: bytes) -> str: + with self._client_factory() as client: + response = self._request( + client, + 'POST', + f'{self._base_url}/v2/upload', + headers=self._headers('application/octet-stream'), + content=audio_bytes, + ) + upload_url = response.get('upload_url') + if not upload_url: + raise AssemblyAIProviderError('AssemblyAI upload response did not include upload_url') + return upload_url + + def _submit_and_poll(self, payload: dict) -> dict: + with self._client_factory() as client: + submitted = self._request( + client, + 'POST', + f'{self._base_url}/v2/transcript', + headers=self._headers(), + json=payload, + ) + transcript_id = submitted.get('id') + if not transcript_id: + raise AssemblyAIProviderError('AssemblyAI transcript response did not include id') + return self._poll_transcript(client, transcript_id) + + def _poll_transcript(self, client: httpx.Client, transcript_id: str) -> dict: + deadline = self._clock() + self._max_poll_seconds + while True: + result = self._request( + client, + 'GET', + f'{self._base_url}/v2/transcript/{transcript_id}', + headers=self._headers(None), + ) + status = result.get('status') + if status == 'completed': + return result + if status == 'error': + raise AssemblyAIProviderError(result.get('error') or 'AssemblyAI transcript failed') + if status not in ('queued', 'processing'): + raise AssemblyAIProviderError(f'AssemblyAI returned unexpected transcript status: {status}') + if self._clock() >= deadline: + raise AssemblyAITimeoutError(f'AssemblyAI transcript {transcript_id} timed out') + self._sleeper(self._poll_interval_seconds) + + def _request(self, client: httpx.Client, method: str, url: str, **kwargs) -> dict: + last_error = None + for attempt in range(2): + try: + response = client.request(method, url, timeout=self._timeout, **kwargs) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + last_error = e + if e.response.status_code in (408, 429, 500, 502, 503, 504) and attempt == 0: + self._sleeper(min(self._poll_interval_seconds, 1.0)) + continue + if e.response.status_code in (408, 429, 500, 502, 503, 504): + raise AssemblyAIRetryableError(f'AssemblyAI HTTP {e.response.status_code}: {e}') from e + raise AssemblyAIProviderError(f'AssemblyAI HTTP {e.response.status_code}: {e}') from e + except (httpx.TimeoutException, httpx.TransportError) as e: + last_error = e + if attempt == 0: + self._sleeper(min(self._poll_interval_seconds, 1.0)) + continue + raise AssemblyAIRetryableError(f'AssemblyAI request failed: {e}') from e + raise AssemblyAIRetryableError(f'AssemblyAI request failed: {last_error}') diff --git a/backend/utils/stt/background_speaker_identity.py b/backend/utils/stt/background_speaker_identity.py new file mode 100644 index 00000000000..8db7d4c4a1c --- /dev/null +++ b/backend/utils/stt/background_speaker_identity.py @@ -0,0 +1,405 @@ +import io +import logging +import re +import wave +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional + +import numpy as np + +from models.transcript_segment import TranscriptSegment + +logger = logging.getLogger(__name__) + +USER_SELF_PERSON_ID = 'user' +SPEAKER_MATCH_THRESHOLD = 0.45 +SPEAKER_IDENTITY_VERSION = 'omi-speaker-identity:v1' +SPEAKER_IDENTITY_SOURCE = 'omi_speaker_embedding' +UNKNOWN_SPEAKER_IDENTITY_SOURCE = 'omi_speaker_embedding:no_match' +TEXT_HINT_SOURCE = 'text_self_introduction' + +MIN_CLUSTER_SPAN_SECONDS = 1.0 +PREFERRED_CLUSTER_SAMPLE_SECONDS = 5.0 +MAX_CLUSTER_SAMPLE_SECONDS = 20.0 +MAX_CLUSTER_SPAN_SECONDS = 10.0 +OVERLAP_TOLERANCE_SECONDS = 0.1 + +SELF_INTRODUCTION_HINT_PATTERNS = [ + r"\b(I am|I'm|i am|i'm|My name is|my name is)\s+([A-Z][a-zA-Z]*)\b", + r"\b([A-Z][a-zA-Z]*)\s+is my name\b", +] + + +@dataclass +class ClusterSampleSpan: + start: float + end: float + segment_id: str + duration: float + + def as_dict(self) -> dict: + return { + 'start': round(self.start, 3), + 'end': round(self.end, 3), + 'segment_id': self.segment_id, + 'duration': round(self.duration, 3), + } + + +@dataclass +class ClusterIdentityAssignment: + provider_cluster_id: str + speaker_id: Optional[int] + state: str + person_id: Optional[str] = None + person_name: Optional[str] = None + is_user: bool = False + confidence: Optional[float] = None + distance: Optional[float] = None + source: str = UNKNOWN_SPEAKER_IDENTITY_SOURCE + version: str = SPEAKER_IDENTITY_VERSION + candidates: List[dict] = field(default_factory=list) + sample_spans: List[ClusterSampleSpan] = field(default_factory=list) + text_hints: List[dict] = field(default_factory=list) + reason: Optional[str] = None + + def provenance(self) -> dict: + data = { + 'provider_cluster_id': self.provider_cluster_id, + 'speaker_id': self.speaker_id, + 'sample_spans': [span.as_dict() for span in self.sample_spans], + 'sample_seconds': round(sum(span.duration for span in self.sample_spans), 3), + } + if self.distance is not None: + data['distance'] = round(self.distance, 6) + if self.reason: + data['reason'] = self.reason + return data + + +def identify_background_speaker_clusters( + transcript_segments: List[TranscriptSegment], + audio_bytes: Optional[bytes], + person_embeddings_cache: Dict[str, dict], + embedding_extractor: Optional[Callable[[bytes, str], np.ndarray]] = None, + match_threshold: float = SPEAKER_MATCH_THRESHOLD, +) -> Dict[str, ClusterIdentityAssignment]: + """Assign canonical Omi identity once per provider speaker cluster. + + Text self-introductions are recorded only as hints. Durable assignment + requires voice evidence from the existing Omi speaker embedding service or + a later explicit user correction. + """ + clusters = _group_segments_by_cluster(transcript_segments) + assignments: Dict[str, ClusterIdentityAssignment] = {} + matched_person_ids: set[str] = set() + + for cluster_id, cluster_segments in clusters.items(): + text_hints = _collect_text_hints(cluster_segments, cluster_id) + speaker_id = _cluster_speaker_id(cluster_segments) + sample_spans = select_representative_cluster_spans(cluster_segments, transcript_segments) + assignment = ClusterIdentityAssignment( + provider_cluster_id=cluster_id, + speaker_id=speaker_id, + state='unknown', + sample_spans=sample_spans, + text_hints=text_hints, + ) + + if not audio_bytes: + assignment.reason = 'missing_audio' + assignments[cluster_id] = assignment + continue + if not person_embeddings_cache: + assignment.reason = 'missing_candidate_embeddings' + assignments[cluster_id] = assignment + continue + if not embedding_extractor: + assignment.reason = 'missing_embedding_extractor' + assignments[cluster_id] = assignment + continue + if not sample_spans: + assignment.reason = 'no_usable_cluster_spans' + assignments[cluster_id] = assignment + continue + + sample_wav = extract_cluster_sample_wav(audio_bytes, sample_spans) + if not sample_wav: + assignment.reason = 'sample_extraction_failed' + assignments[cluster_id] = assignment + continue + + try: + query_embedding = embedding_extractor(sample_wav, f'cluster_{_safe_filename(cluster_id)}.wav') + except Exception as e: + logger.info('background speaker identity embedding failed cluster=%s: %s', cluster_id, e) + assignment.reason = 'embedding_extraction_failed' + assignments[cluster_id] = assignment + continue + + candidates = _rank_candidates(query_embedding, person_embeddings_cache, matched_person_ids, match_threshold) + assignment.candidates = _public_candidates(candidates) + best = candidates[0] if candidates else None + if not best or best['distance'] >= match_threshold: + assignment.reason = 'below_threshold' + assignments[cluster_id] = assignment + continue + + match_person_id = best['_match_person_id'] + matched_person_ids.add(match_person_id) + assignment.state = 'user' if match_person_id == USER_SELF_PERSON_ID else 'identified' + assignment.person_id = None if match_person_id == USER_SELF_PERSON_ID else match_person_id + assignment.person_name = best.get('person_name') + assignment.is_user = match_person_id == USER_SELF_PERSON_ID + assignment.confidence = best['confidence'] + assignment.distance = best['distance'] + assignment.source = SPEAKER_IDENTITY_SOURCE + assignment.reason = None + assignments[cluster_id] = assignment + + apply_cluster_identity_assignments(transcript_segments, assignments) + return assignments + + +def select_representative_cluster_spans( + cluster_segments: List[TranscriptSegment], + all_segments: List[TranscriptSegment], + preferred_seconds: float = PREFERRED_CLUSTER_SAMPLE_SECONDS, + max_seconds: float = MAX_CLUSTER_SAMPLE_SECONDS, +) -> List[ClusterSampleSpan]: + usable_spans = [] + for segment in cluster_segments: + duration = max(0.0, segment.end - segment.start) + if duration < MIN_CLUSTER_SPAN_SECONDS: + continue + if _has_obvious_overlap(segment, all_segments): + continue + if len(segment.text.split()) < 2: + continue + + end = segment.end + if duration > MAX_CLUSTER_SPAN_SECONDS: + center = (segment.start + segment.end) / 2 + start = max(segment.start, center - MAX_CLUSTER_SPAN_SECONDS / 2) + end = min(segment.end, start + MAX_CLUSTER_SPAN_SECONDS) + else: + start = segment.start + + usable_spans.append( + ClusterSampleSpan( + start=start, + end=end, + segment_id=segment.id, + duration=end - start, + ) + ) + + usable_spans.sort(key=lambda span: (-span.duration, span.start)) + + selected = [] + total = 0.0 + for span in usable_spans: + if total >= preferred_seconds and selected: + break + remaining = max_seconds - total + if remaining <= 0: + break + if span.duration > remaining: + span = ClusterSampleSpan( + start=span.start, + end=span.start + remaining, + segment_id=span.segment_id, + duration=remaining, + ) + selected.append(span) + total += span.duration + + return sorted(selected, key=lambda span: span.start) + + +def extract_cluster_sample_wav(audio_bytes: bytes, spans: List[ClusterSampleSpan]) -> Optional[bytes]: + try: + with wave.open(io.BytesIO(audio_bytes), 'rb') as wf: + framerate = wf.getframerate() + n_channels = wf.getnchannels() + sampwidth = wf.getsampwidth() + n_frames = wf.getnframes() + total_duration = n_frames / framerate + frames = [] + for span in spans: + start = max(0.0, min(span.start, total_duration)) + end = max(0.0, min(span.end, total_duration)) + if end - start < MIN_CLUSTER_SPAN_SECONDS: + continue + wf.setpos(int(start * framerate)) + frames.append(wf.readframes(int((end - start) * framerate))) + except Exception as e: + logger.info('background speaker identity sample extraction failed: %s', e) + return None + + frames = [frame for frame in frames if frame] + if not frames: + return None + + out = io.BytesIO() + with wave.open(out, 'wb') as out_wf: + out_wf.setnchannels(n_channels) + out_wf.setsampwidth(sampwidth) + out_wf.setframerate(framerate) + for frame in frames: + out_wf.writeframes(frame) + return out.getvalue() + + +def apply_cluster_identity_assignments( + transcript_segments: List[TranscriptSegment], + assignments: Dict[str, ClusterIdentityAssignment], +) -> None: + for segment in transcript_segments: + cluster_id = _segment_cluster_key(segment) + assignment = assignments.get(cluster_id) + if not assignment: + continue + + segment.speaker_identity_state = assignment.state + segment.speaker_identity_confidence = assignment.confidence + segment.speaker_identity_source = assignment.source + segment.speaker_identity_version = assignment.version + segment.speaker_identity_provenance = assignment.provenance() + segment.speaker_identity_candidates = assignment.candidates + segment.speaker_identity_text_hints = assignment.text_hints + + if assignment.is_user: + segment.is_user = True + segment.person_id = None + elif assignment.person_id: + segment.is_user = False + segment.person_id = assignment.person_id + else: + segment.is_user = False + segment.person_id = None + + +def _group_segments_by_cluster(transcript_segments: List[TranscriptSegment]) -> Dict[str, List[TranscriptSegment]]: + clusters: Dict[str, List[TranscriptSegment]] = {} + for segment in transcript_segments: + clusters.setdefault(_segment_cluster_key(segment), []).append(segment) + return clusters + + +def _segment_cluster_key(segment: TranscriptSegment) -> str: + if segment.provider_cluster_id: + return str(segment.provider_cluster_id) + if segment.provider_speaker_label: + return f'provider-label:{segment.provider_speaker_label}' + return f'legacy-speaker:{segment.speaker_id if segment.speaker_id is not None else 0}' + + +def _cluster_speaker_id(segments: List[TranscriptSegment]) -> Optional[int]: + counts = {} + for segment in segments: + if segment.speaker_id is None: + continue + counts[segment.speaker_id] = counts.get(segment.speaker_id, 0) + 1 + return max(counts, key=counts.get) if counts else None + + +def _has_obvious_overlap(segment: TranscriptSegment, all_segments: List[TranscriptSegment]) -> bool: + for other in all_segments: + if other.id == segment.id: + continue + if _segment_cluster_key(other) == _segment_cluster_key(segment): + continue + overlap = min(segment.end, other.end) - max(segment.start, other.start) + if overlap > OVERLAP_TOLERANCE_SECONDS: + return True + return False + + +def _collect_text_hints(segments: List[TranscriptSegment], cluster_id: str) -> List[dict]: + hints = [] + for segment in segments: + detected_name = _detect_self_introduction_hint(segment.text) + if not detected_name: + continue + hints.append( + { + 'source': TEXT_HINT_SOURCE, + 'provider_cluster_id': cluster_id, + 'segment_id': segment.id, + 'detected_name': detected_name, + } + ) + return hints + + +def _detect_self_introduction_hint(text: str) -> Optional[str]: + for pattern in SELF_INTRODUCTION_HINT_PATTERNS: + match = re.search(pattern, text) + if not match: + continue + if _looks_like_quoted_or_reported_speech(text, match.start()): + continue + name = match.groups()[-1] + if name and len(name) >= 2: + return name.capitalize() + return None + + +def _looks_like_quoted_or_reported_speech(text: str, match_start: int) -> bool: + prefix = text[:match_start].strip().lower() + if prefix and prefix[-1:] in {'"', "'", '“', '‘'}: + return True + recent_prefix = prefix[-40:] + return bool(re.search(r"\b(said|says|asked|told|quoted|read|wrote|writes)\b", recent_prefix)) + + +def _rank_candidates( + query_embedding: np.ndarray, + person_embeddings_cache: Dict[str, dict], + matched_person_ids: set[str], + match_threshold: float, +) -> List[dict]: + candidates = [] + for person_id, data in person_embeddings_cache.items(): + if person_id in matched_person_ids: + continue + distance = compare_embeddings(query_embedding, data['embedding']) + confidence = max(0.0, min(1.0, 1.0 - (distance / match_threshold))) + candidates.append( + { + '_match_person_id': person_id, + 'person_id': None if person_id == USER_SELF_PERSON_ID else person_id, + 'person_name': data.get('name'), + 'is_user': person_id == USER_SELF_PERSON_ID, + 'distance': round(distance, 6), + 'confidence': round(confidence, 6), + 'source': SPEAKER_IDENTITY_SOURCE, + } + ) + candidates.sort(key=lambda item: (item['distance'], item['_match_person_id'])) + return candidates[:3] + + +def compare_embeddings(embedding1: np.ndarray, embedding2: np.ndarray) -> float: + if embedding1.shape[1] != embedding2.shape[1]: + return 2.0 + norm1 = np.linalg.norm(embedding1) + norm2 = np.linalg.norm(embedding2) + if norm1 == 0 or norm2 == 0: + return 2.0 + similarity = float(np.dot(embedding1.flatten(), embedding2.flatten()) / (norm1 * norm2)) + return 1.0 - similarity + + +def _public_candidates(candidates: List[dict]) -> List[dict]: + public = [] + for candidate in candidates: + data = dict(candidate) + data.pop('_match_person_id', None) + public.append(data) + return public + + +def _safe_filename(value: str) -> str: + return ''.join(ch if ch.isalnum() else '_' for ch in value)[:80] or 'unknown' diff --git a/backend/utils/stt/conversation_reconstructor.py b/backend/utils/stt/conversation_reconstructor.py new file mode 100644 index 00000000000..a75498d56df --- /dev/null +++ b/backend/utils/stt/conversation_reconstructor.py @@ -0,0 +1,272 @@ +from dataclasses import dataclass +from typing import Iterable, List, Optional, Sequence, Tuple + +from models.transcript_segment import ( + ProviderTranscriptResult, + ProviderTranscriptUtterance, + ProviderTranscriptWord, + TranscriptSegment, +) + + +@dataclass +class _SegmentCandidate: + text: str + start: float + end: float + provider_cluster_id: Optional[str] + speaker_label: Optional[str] + + +class ConversationReconstructor: + def __init__(self, max_same_cluster_gap_seconds: float = 30.0): + self.max_same_cluster_gap_seconds = max_same_cluster_gap_seconds + + def reconstruct( + self, + result: ProviderTranscriptResult, + skip_n_seconds: int = 0, + user_provider_cluster_id: Optional[str] = None, + ) -> List[TranscriptSegment]: + user_cluster_id = user_provider_cluster_id or self._retrieve_user_cluster_id(result, skip_n_seconds) + candidates = self._build_candidates(result, skip_n_seconds) + candidates = [candidate for candidate in candidates if candidate.start >= skip_n_seconds and candidate.text] + candidates = self._sort_and_dedupe_candidates(candidates) + candidates = self._merge_adjacent_candidates(candidates) + + if not candidates: + return [] + + starts_at = candidates[0].start + segments = [] + for candidate in candidates: + is_user = bool(user_cluster_id and candidate.provider_cluster_id == user_cluster_id) + speaker_label = self._legacy_speaker_label(candidate.speaker_label) + identity_state = self._identity_state(is_user, candidate.provider_cluster_id, speaker_label) + segments.append( + TranscriptSegment( + text=candidate.text.strip().capitalize(), + speaker=speaker_label, + is_user=is_user, + person_id=None, + start=round(candidate.start - starts_at, 2), + end=round(candidate.end - starts_at, 2), + stt_provider=result.provider, + stt_model=result.model, + provider_cluster_id=candidate.provider_cluster_id, + provider_speaker_label=candidate.speaker_label, + speaker_identity_state=identity_state, + ) + ) + return segments + + def _build_candidates(self, result: ProviderTranscriptResult, skip_n_seconds: int) -> List[_SegmentCandidate]: + candidates = self._utterance_candidates(result.utterances) + if not result.words: + return candidates + + candidate_words = [word for word in result.words if word.start >= skip_n_seconds] + if not candidates: + return self._word_candidates(candidate_words) + + covered_intervals = [(candidate.start, candidate.end) for candidate in candidates] + uncovered_words = [word for word in candidate_words if not self._word_is_covered(word, covered_intervals)] + return candidates + self._word_candidates(uncovered_words) + + def _utterance_candidates(self, utterances: Sequence[ProviderTranscriptUtterance]) -> List[_SegmentCandidate]: + candidates = [] + for utterance in utterances: + text = utterance.text.strip() + if not text and utterance.words: + text = ' '.join(word.text.strip() for word in utterance.words if word.text.strip()) + if not text: + continue + provider_cluster_id = utterance.provider_cluster_id + speaker_label = utterance.speaker_label + if utterance.words and (not provider_cluster_id or not speaker_label): + provider_cluster_id, speaker_label = self._dominant_speaker(utterance.words) + candidates.append( + _SegmentCandidate( + text=text, + start=utterance.start, + end=utterance.end, + provider_cluster_id=provider_cluster_id, + speaker_label=speaker_label, + ) + ) + return candidates + + def _word_candidates(self, words: Sequence[ProviderTranscriptWord]) -> List[_SegmentCandidate]: + candidates = [] + normalized_words = self._interpolate_missing_word_speakers( + sorted(words, key=lambda item: (item.start, item.end)) + ) + for word in normalized_words: + text = word.text.strip() + if not text: + continue + if candidates and self._should_merge_word(candidates[-1], word): + candidates[-1].text = f'{candidates[-1].text} {text}' + candidates[-1].end = word.end + continue + candidates.append( + _SegmentCandidate( + text=text, + start=word.start, + end=word.end, + provider_cluster_id=word.provider_cluster_id, + speaker_label=word.speaker_label, + ) + ) + return candidates + + def _interpolate_missing_word_speakers( + self, words: Sequence[ProviderTranscriptWord] + ) -> List[ProviderTranscriptWord]: + normalized_words = [word.model_copy() for word in words] + for index, word in enumerate(normalized_words): + if word.provider_cluster_id or word.speaker_label: + continue + + previous_word = normalized_words[index - 1] if index > 0 else None + next_word = normalized_words[index + 1] if index < len(normalized_words) - 1 else None + previous_has_speaker = previous_word and (previous_word.provider_cluster_id or previous_word.speaker_label) + next_has_speaker = next_word and (next_word.provider_cluster_id or next_word.speaker_label) + if previous_has_speaker and next_has_speaker: + if previous_word.provider_cluster_id == next_word.provider_cluster_id: + word.provider_cluster_id = previous_word.provider_cluster_id + word.speaker_label = previous_word.speaker_label + else: + secs_from_previous = word.start - previous_word.end + secs_to_next = next_word.start - word.end + source = previous_word if secs_from_previous < secs_to_next else next_word + word.provider_cluster_id = source.provider_cluster_id + word.speaker_label = source.speaker_label + elif previous_has_speaker: + word.provider_cluster_id = previous_word.provider_cluster_id + word.speaker_label = previous_word.speaker_label + elif next_has_speaker: + word.provider_cluster_id = next_word.provider_cluster_id + word.speaker_label = next_word.speaker_label + return normalized_words + + def _sort_and_dedupe_candidates(self, candidates: Sequence[_SegmentCandidate]) -> List[_SegmentCandidate]: + ordered = sorted(candidates, key=lambda item: (item.start, item.end)) + deduped = [] + for candidate in ordered: + if not deduped: + deduped.append(candidate) + continue + + previous = deduped[-1] + if self._is_duplicate_overlap(previous, candidate): + if len(candidate.text) > len(previous.text): + candidate.start = min(previous.start, candidate.start) + candidate.end = max(previous.end, candidate.end) + deduped[-1] = candidate + continue + deduped.append(candidate) + return deduped + + def _merge_adjacent_candidates(self, candidates: Sequence[_SegmentCandidate]) -> List[_SegmentCandidate]: + merged = [] + for candidate in candidates: + if merged and self._should_merge_candidate(merged[-1], candidate): + merged[-1].text = f'{merged[-1].text} {candidate.text}' + merged[-1].end = candidate.end + continue + merged.append(candidate) + return merged + + def _retrieve_user_cluster_id(self, result: ProviderTranscriptResult, skip_n_seconds: int) -> Optional[str]: + if not skip_n_seconds: + return None + + speaker_counts = {} + speaker_sources: Iterable[Tuple[float, Optional[str]]] = ( + [(word.start, self._word_cluster_key(word)) for word in result.words] + if result.words + else [(utterance.start, self._utterance_cluster_key(utterance)) for utterance in result.utterances] + ) + for start, provider_cluster_id in sorted(speaker_sources, key=lambda item: item[0]): + if start >= skip_n_seconds: + break + if not provider_cluster_id: + continue + speaker_counts[provider_cluster_id] = speaker_counts.get(provider_cluster_id, 0) + 1 + return max(speaker_counts, key=speaker_counts.get) if speaker_counts else None + + def _dominant_speaker(self, words: Sequence[ProviderTranscriptWord]) -> Tuple[Optional[str], Optional[str]]: + speaker_counts = {} + labels_by_cluster = {} + for word in words: + if not word.provider_cluster_id: + continue + speaker_counts[word.provider_cluster_id] = speaker_counts.get(word.provider_cluster_id, 0) + 1 + if word.speaker_label: + labels_by_cluster[word.provider_cluster_id] = word.speaker_label + + if not speaker_counts: + return None, None + + cluster_id = max(speaker_counts, key=speaker_counts.get) + return cluster_id, labels_by_cluster.get(cluster_id) + + def _word_is_covered(self, word: ProviderTranscriptWord, intervals: Sequence[Tuple[float, float]]) -> bool: + return any(start <= word.start and word.end <= end for start, end in intervals) + + def _should_merge_word(self, previous: _SegmentCandidate, word: ProviderTranscriptWord) -> bool: + return ( + self._candidate_cluster_key(previous) == self._word_cluster_key(word) + and word.start - previous.end < self.max_same_cluster_gap_seconds + ) + + def _should_merge_candidate(self, previous: _SegmentCandidate, candidate: _SegmentCandidate) -> bool: + return ( + self._candidate_cluster_key(previous) == self._candidate_cluster_key(candidate) + and candidate.start - previous.end < self.max_same_cluster_gap_seconds + ) + + def _is_duplicate_overlap(self, previous: _SegmentCandidate, candidate: _SegmentCandidate) -> bool: + if candidate.start >= previous.end: + return False + + previous_text = self._normalize_text(previous.text) + candidate_text = self._normalize_text(candidate.text) + return ( + bool(previous_text and candidate_text) + and self._candidate_cluster_key(previous) == self._candidate_cluster_key(candidate) + and (previous_text == candidate_text or previous_text in candidate_text or candidate_text in previous_text) + ) + + def _legacy_speaker_label(self, speaker_label: Optional[str]) -> Optional[str]: + if isinstance(speaker_label, str) and speaker_label.startswith('SPEAKER_'): + return speaker_label + return None + + def _identity_state(self, is_user: bool, provider_cluster_id: Optional[str], speaker_label: Optional[str]) -> str: + if is_user: + return 'user' + if speaker_label: + return 'unassigned' + return 'unknown' + + def _candidate_cluster_key(self, candidate: _SegmentCandidate) -> Optional[str]: + return candidate.provider_cluster_id or candidate.speaker_label + + def _word_cluster_key(self, word: ProviderTranscriptWord) -> Optional[str]: + return word.provider_cluster_id or word.speaker_label + + def _utterance_cluster_key(self, utterance: ProviderTranscriptUtterance) -> Optional[str]: + return utterance.provider_cluster_id or utterance.speaker_label + + def _normalize_text(self, text: str) -> str: + return ' '.join(text.lower().split()) + + +def reconstruct_conversation( + result: ProviderTranscriptResult, + skip_n_seconds: int = 0, + user_provider_cluster_id: Optional[str] = None, +) -> List[TranscriptSegment]: + return ConversationReconstructor().reconstruct(result, skip_n_seconds, user_provider_cluster_id) diff --git a/backend/utils/stt/deepgram_adapter.py b/backend/utils/stt/deepgram_adapter.py new file mode 100644 index 00000000000..bdf6237be08 --- /dev/null +++ b/backend/utils/stt/deepgram_adapter.py @@ -0,0 +1,222 @@ +from dataclasses import dataclass +from io import BytesIO +from typing import Callable, List, Optional, Sequence, Tuple, Union + +import httpx +from deepgram import DeepgramClient + +from models.transcript_segment import ProviderTranscriptResult, ProviderTranscriptUtterance, ProviderTranscriptWord +from utils.stt.providers import STTProviderName + + +@dataclass(frozen=True) +class DeepgramPrerecordedOptions: + model: str = 'nova-3' + diarize: bool = True + language: Optional[str] = None + return_language: bool = False + keywords: Optional[Sequence[str]] = None + sample_rate: int = 16000 + encoding: Optional[str] = None + channels: int = 1 + nova3_keyword_prefix_match: bool = False + + +def deepgram_speaker_fields(speaker_id) -> dict: + if speaker_id is None: + return {'speaker': None, 'provider_cluster_id': None, 'provider_speaker_label': None} + + provider_cluster_id = str(speaker_id) + try: + speaker_label = f'SPEAKER_{int(speaker_id):02d}' + except (TypeError, ValueError): + speaker_label = None + + return { + 'speaker': speaker_label, + 'provider_cluster_id': provider_cluster_id, + 'provider_speaker_label': speaker_label, + } + + +def normalize_deepgram_prerecorded_result( + result: dict, model: str, language: Optional[str] = None +) -> ProviderTranscriptResult: + channels = result.get('results', {}).get('channels', []) + if not channels: + raise Exception('No channels found in response') + + alternatives = channels[0].get('alternatives', []) + if not alternatives: + raise Exception('No alternatives found in response') + + alternative = alternatives[0] + detected_language = _normalize_language(channels[0].get('detected_language') or language) + words = [_normalize_deepgram_word(word) for word in alternative.get('words', [])] + utterances = [ + _normalize_deepgram_utterance(utterance) for utterance in result.get('results', {}).get('utterances', []) + ] + duration = result.get('metadata', {}).get('duration') + + return ProviderTranscriptResult( + provider=STTProviderName.deepgram.value, + model=model, + language=detected_language, + duration=duration, + words=words, + utterances=utterances, + raw_provider_result_id=result.get('metadata', {}).get('request_id'), + ) + + +def provider_result_to_legacy_words(result: ProviderTranscriptResult) -> List[dict]: + words = [] + for word in result.words: + words.append( + { + 'timestamp': [word.start, word.end], + 'speaker': word.speaker_label, + 'provider_cluster_id': word.provider_cluster_id, + 'provider_speaker_label': word.speaker_label, + 'stt_provider': result.provider, + 'stt_model': result.model, + 'text': word.text, + } + ) + return words + + +def _normalize_language(language: Optional[str]) -> Optional[str]: + if language and '-' in language: + return language.split('-', 1)[0] + return language + + +def _normalize_deepgram_word(word: dict) -> ProviderTranscriptWord: + speaker_fields = deepgram_speaker_fields(word.get('speaker')) + return ProviderTranscriptWord( + text=word.get('punctuated_word', word.get('word', '')), + start=word['start'], + end=word['end'], + provider_cluster_id=speaker_fields['provider_cluster_id'], + speaker_label=speaker_fields['speaker'], + confidence=word.get('confidence'), + ) + + +def _normalize_deepgram_utterance(utterance: dict) -> ProviderTranscriptUtterance: + speaker_fields = deepgram_speaker_fields(utterance.get('speaker')) + utterance_words = utterance.get('words') + return ProviderTranscriptUtterance( + text=utterance.get('transcript', ''), + start=utterance['start'], + end=utterance['end'], + provider_cluster_id=speaker_fields['provider_cluster_id'], + speaker_label=speaker_fields['speaker'], + confidence=utterance.get('confidence'), + words=[_normalize_deepgram_word(word) for word in utterance_words] if utterance_words else None, + ) + + +class DeepgramPrerecordedTranscriptionProvider: + provider_name = STTProviderName.deepgram + + def __init__( + self, + client_factory: Callable[[], DeepgramClient], + timeout: httpx.Timeout, + ): + self._client_factory = client_factory + self._timeout = timeout + + def transcribe_url( + self, + audio_url: str, + speakers_count: int = None, + return_language: bool = False, + diarize: bool = True, + language: Optional[str] = None, + model: str = 'nova-3', + keywords: Optional[Sequence[str]] = None, + ) -> Union[ProviderTranscriptResult, Tuple[ProviderTranscriptResult, str]]: + options = DeepgramPrerecordedOptions( + model=model, + diarize=diarize, + language=language, + return_language=return_language, + keywords=keywords, + ) + request_options = self._request_options(options) + response = ( + self._client_factory() + .listen.rest.v('1') + .transcribe_url({'url': audio_url}, request_options, timeout=self._timeout) + ) + result = normalize_deepgram_prerecorded_result(response.to_dict(), model=model, language=language) + if return_language: + return result, result.language or 'en' + return result + + def transcribe_bytes( + self, + audio_bytes: bytes, + sample_rate: int = 16000, + diarize: bool = True, + encoding: Optional[str] = None, + channels: int = 1, + language: Optional[str] = None, + model: str = 'nova-3', + return_language: bool = False, + keywords: Optional[Sequence[str]] = None, + ) -> Union[ProviderTranscriptResult, Tuple[ProviderTranscriptResult, str]]: + options = DeepgramPrerecordedOptions( + model=model, + diarize=diarize, + language=language, + return_language=return_language, + keywords=keywords, + sample_rate=sample_rate, + encoding=encoding, + channels=channels, + nova3_keyword_prefix_match=True, + ) + request_options = self._request_options(options) + audio_buffer = BytesIO(audio_bytes) + mimetype = 'audio/raw' if encoding else 'audio/wav' + source = {'buffer': audio_buffer, 'mimetype': mimetype} + response = ( + self._client_factory().listen.rest.v('1').transcribe_file(source, request_options, timeout=self._timeout) + ) + result = normalize_deepgram_prerecorded_result(response.to_dict(), model=model, language=language) + if return_language: + return result, result.language or 'en' + return result + + def _request_options(self, options: DeepgramPrerecordedOptions) -> dict: + is_multi = options.language == 'multi' + request_options = { + 'model': options.model, + 'smart_format': True, + 'punctuate': True, + 'diarize': options.diarize, + 'utterances': True, + 'detect_language': options.return_language or is_multi, + } + if options.language and not is_multi: + request_options['language'] = options.language + + if options.keywords: + is_nova3_keyword_model = options.model in ('nova-3',) or ( + options.nova3_keyword_prefix_match and str(options.model).startswith('nova-3') + ) + if is_nova3_keyword_model: + request_options['keyterm'] = list(options.keywords) + else: + request_options['keywords'] = list(options.keywords) + + if options.encoding: + request_options['encoding'] = options.encoding + request_options['sample_rate'] = options.sample_rate + request_options['channels'] = options.channels + + return request_options diff --git a/backend/utils/stt/deepgram_config.py b/backend/utils/stt/deepgram_config.py new file mode 100644 index 00000000000..97243a3e1f7 --- /dev/null +++ b/backend/utils/stt/deepgram_config.py @@ -0,0 +1,109 @@ +from typing import Tuple + +# Languages supported by nova-3. +DEEPGRAM_NOVA3_LANGUAGES = { + "ar", + "ar-AE", + "ar-SA", + "ar-QA", + "ar-KW", + "ar-SY", + "ar-LB", + "ar-PS", + "ar-JO", + "ar-EG", + "ar-SD", + "ar-TD", + "ar-MA", + "ar-DZ", + "ar-TN", + "ar-IQ", + "ar-IR", + "be", + "bg", + "bn", + "bs", + "ca", + "cs", + "da", + "da-DK", + "de", + "de-CH", + "el", + "en", + "en-US", + "en-AU", + "en-GB", + "en-IN", + "en-NZ", + "es", + "es-419", + "et", + "fa", + "fi", + "fr", + "fr-CA", + "he", + "hi", + "hr", + "hu", + "id", + "it", + "ja", + "kn", + "ko", + "ko-KR", + "lt", + "lv", + "mk", + "mr", + "ms", + "nl", + "nl-BE", + "no", + "pl", + "pt", + "pt-BR", + "pt-PT", + "ro", + "ru", + "sk", + "sl", + "sr", + "sv", + "sv-SE", + "ta", + "te", + "th", + "th-TH", + "tl", + "tr", + "uk", + "ur", + "vi", + "zh", + "zh-CN", + "zh-Hans", + "zh-HK", + "zh-Hant", + "zh-TW", +} + + +def get_deepgram_model_for_language(language: str) -> Tuple[str, str]: + """ + Determine the appropriate Deepgram model and language for pre-recorded transcription. + + Args: + language: The requested language code or 'multi' for auto-detection. + + Returns: + Tuple of (language_to_use, model_name). + """ + if language == 'multi': + return 'multi', 'nova-3' + + if language in DEEPGRAM_NOVA3_LANGUAGES: + return language, 'nova-3' + + return 'multi', 'nova-3' diff --git a/backend/utils/stt/pre_recorded.py b/backend/utils/stt/pre_recorded.py index 59eaa9dd718..54bd69e797e 100644 --- a/backend/utils/stt/pre_recorded.py +++ b/backend/utils/stt/pre_recorded.py @@ -1,15 +1,20 @@ import os -from collections import defaultdict -from io import BytesIO from typing import List, Optional, Sequence, Tuple, Union import fal_client import httpx from deepgram import DeepgramClient, DeepgramClientOptions -from models.transcript_segment import TranscriptSegment +from models.transcript_segment import ProviderTranscriptResult, ProviderTranscriptWord, TranscriptSegment from utils.byok import get_byok_key from utils.other.endpoints import timeit +from utils.stt.conversation_reconstructor import reconstruct_conversation +from utils.stt.deepgram_adapter import ( + DeepgramPrerecordedTranscriptionProvider, + deepgram_speaker_fields, + provider_result_to_legacy_words, +) +from utils.stt.deepgram_config import get_deepgram_model_for_language import logging _DG_TIMEOUT = httpx.Timeout(connect=10.0, read=120.0, write=30.0, pool=10.0) @@ -30,116 +35,12 @@ def _deepgram_client_for_request() -> DeepgramClient: return _deepgram_client -# Languages supported by nova-3 -_deepgram_nova3_languages = { - "ar", - "ar-AE", - "ar-SA", - "ar-QA", - "ar-KW", - "ar-SY", - "ar-LB", - "ar-PS", - "ar-JO", - "ar-EG", - "ar-SD", - "ar-TD", - "ar-MA", - "ar-DZ", - "ar-TN", - "ar-IQ", - "ar-IR", - "be", - "bg", - "bn", - "bs", - "ca", - "cs", - "da", - "da-DK", - "de", - "de-CH", - "el", - "en", - "en-US", - "en-AU", - "en-GB", - "en-IN", - "en-NZ", - "es", - "es-419", - "et", - "fa", - "fi", - "fr", - "fr-CA", - "he", - "hi", - "hr", - "hu", - "id", - "it", - "ja", - "kn", - "ko", - "ko-KR", - "lt", - "lv", - "mk", - "mr", - "ms", - "nl", - "nl-BE", - "no", - "pl", - "pt", - "pt-BR", - "pt-PT", - "ro", - "ru", - "sk", - "sl", - "sr", - "sv", - "sv-SE", - "ta", - "te", - "th", - "th-TH", - "tl", - "tr", - "uk", - "ur", - "vi", - "zh", - "zh-CN", - "zh-Hans", - "zh-HK", - "zh-Hant", - "zh-TW", -} - - -def get_deepgram_model_for_language(language: str) -> Tuple[str, str]: - """ - Determine the appropriate Deepgram model and language for pre-recorded transcription. +def _deepgram_speaker_fields(speaker_id) -> dict: + return deepgram_speaker_fields(speaker_id) - Args: - language: The requested language code or 'multi' for auto-detection - Returns: - Tuple of (language_to_use, model_name) - """ - # For multi-language mode - if language == 'multi': - return 'multi', 'nova-3' - - # Languages supported by nova-3 - if language in _deepgram_nova3_languages: - return language, 'nova-3' - - # Unsupported language - fall back to multi for auto-detection - return 'multi', 'nova-3' +def _deepgram_prerecorded_provider() -> DeepgramPrerecordedTranscriptionProvider: + return DeepgramPrerecordedTranscriptionProvider(_deepgram_client_for_request, _DG_TIMEOUT) @timeit @@ -173,74 +74,20 @@ def deepgram_prerecorded( logger.info(f'deepgram_prerecorded {audio_url} {speakers_count} {attempts}') try: - # 'multi' language means auto-detection - is_multi = language == 'multi' - should_detect_language = return_language or is_multi - options = { - "model": model, - "smart_format": True, - "punctuate": True, - "diarize": diarize, - "detect_language": should_detect_language, - "utterances": True, - } - if language and not is_multi: - options["language"] = language - - if keywords: - if model in ('nova-3',): - options["keyterm"] = list(keywords) - else: - options["keywords"] = list(keywords) - - response = ( - _deepgram_client_for_request() - .listen.rest.v("1") - .transcribe_url({"url": audio_url}, options, timeout=_DG_TIMEOUT) + result = _deepgram_prerecorded_provider().transcribe_url( + audio_url, + speakers_count=speakers_count, + return_language=return_language, + diarize=diarize, + language=language, + model=model, + keywords=keywords, ) - - # Extract words from response - result = response.to_dict() - channels = result.get('results', {}).get('channels', []) - if not channels: - raise Exception('No channels found in response') - - alternatives = channels[0].get('alternatives', []) - if not alternatives: - raise Exception('No alternatives found in response') - - dg_words = alternatives[0].get('words', []) - if not dg_words: - if return_language: - detected_lang = channels[0].get('detected_language', 'en') - if detected_lang and '-' in detected_lang: - detected_lang = detected_lang.split('-')[0] - return [], detected_lang or 'en' - return [] - - # Convert Deepgram format to fal_whisperx compatible format - # Deepgram: {word, start, end, confidence, punctuated_word, speaker (int)} - # Expected: {timestamp: [start, end], speaker: 'SPEAKER_XX', text: 'word'} - words = [] - for w in dg_words: - speaker_id = w.get('speaker', 0) - words.append( - { - 'timestamp': [w['start'], w['end']], - 'speaker': f"SPEAKER_{speaker_id:02d}" if speaker_id is not None else None, - 'text': w.get('punctuated_word', w['word']), - } - ) - if return_language: - # Deepgram returns detected_language in the channel - detected_lang = channels[0].get('detected_language', 'en') - # Normalize language code (Deepgram might return 'en-US', we want 'en') - if detected_lang and '-' in detected_lang: - detected_lang = detected_lang.split('-')[0] - return words, detected_lang or 'en' + transcript_result, detected_language = result + return provider_result_to_legacy_words(transcript_result), detected_language - return words + return provider_result_to_legacy_words(result) except Exception as e: logger.error(f'Deepgram prerecorded error: {e}') @@ -295,84 +142,33 @@ def deepgram_prerecorded_from_bytes( Or tuple of (words, language) if return_language=True """ logger.info( - f'deepgram_prerecorded_from_bytes bytes_len={len(audio_bytes)} {sample_rate} {diarize} {attempts} encoding={encoding} language={language} model={model}' + 'deepgram_prerecorded_from_bytes bytes_len=%s %s %s %s encoding=%s language=%s model=%s', + len(audio_bytes), + sample_rate, + diarize, + attempts, + encoding, + language, + model, ) try: - is_multi = language == 'multi' - should_detect_language = return_language or is_multi - options = { - "model": model, - "smart_format": True, - "punctuate": True, - "diarize": diarize, - "utterances": True, - "detect_language": should_detect_language, - } - if language and not is_multi: - options["language"] = language - - if keywords: - if str(model).startswith("nova-3"): - options["keyterm"] = list(keywords) - else: - options["keywords"] = list(keywords) - - # For raw PCM, Deepgram needs encoding + sample_rate to interpret the bytes - if encoding: - options["encoding"] = encoding - options["sample_rate"] = sample_rate - options["channels"] = channels - - # Wrap bytes in BytesIO for Deepgram client - audio_buffer = BytesIO(audio_bytes) - mimetype = "audio/raw" if encoding else "audio/wav" - source = {"buffer": audio_buffer, "mimetype": mimetype} - - response = ( - _deepgram_client_for_request().listen.rest.v("1").transcribe_file(source, options, timeout=_DG_TIMEOUT) + result = _deepgram_prerecorded_provider().transcribe_bytes( + audio_bytes, + sample_rate=sample_rate, + diarize=diarize, + encoding=encoding, + channels=channels, + language=language, + model=model, + return_language=return_language, + keywords=keywords, ) - - # Extract words from response - result = response.to_dict() - result_channels = result.get('results', {}).get('channels', []) - if not result_channels: - raise Exception('No channels found in response') - - alternatives = result_channels[0].get('alternatives', []) - if not alternatives: - raise Exception('No alternatives found in response') - - dg_words = alternatives[0].get('words', []) - if not dg_words: - if return_language: - detected_lang = result_channels[0].get('detected_language', 'en') - if detected_lang and '-' in detected_lang: - detected_lang = detected_lang.split('-')[0] - return [], detected_lang or 'en' - return [] - - # Convert Deepgram format to standard format - # Deepgram: {word, start, end, confidence, punctuated_word, speaker (int)} - # Expected: {timestamp: [start, end], speaker: 'SPEAKER_XX', text: 'word'} - words = [] - for w in dg_words: - speaker_id = w.get('speaker', 0) - words.append( - { - 'timestamp': [w['start'], w['end']], - 'speaker': f"SPEAKER_{speaker_id:02d}" if speaker_id is not None else None, - 'text': w.get('punctuated_word', w['word']), - } - ) - if return_language: - detected_lang = result_channels[0].get('detected_language', 'en') - if detected_lang and '-' in detected_lang: - detected_lang = detected_lang.split('-')[0] - return words, detected_lang or 'en' + transcript_result, detected_language = result + return provider_result_to_legacy_words(transcript_result), detected_language - return words + return provider_result_to_legacy_words(result) except Exception as e: logger.error(f'Deepgram prerecorded from bytes error: {e}') @@ -435,108 +231,29 @@ def fal_whisperx( return [] -def _words_cleaning(words: List[dict]): - words_cleaned: List[dict] = [] - for i, w in enumerate(words): - # if w['timestamp'][0] == w['timestamp'][1]: - # continue - words_cleaned.append( - { - 'start': round(w['timestamp'][0], 2), - 'end': round(w['timestamp'][1] or w['timestamp'][0] + 1, 2), - 'speaker': w['speaker'], - 'text': str(w['text']).strip(), - 'is_user': False, - 'person_id': None, - } - ) - - for i, word in enumerate(words_cleaned): - speaker = word['speaker'] - if not speaker: - prev_chunk = words_cleaned[i - 1] if i > 0 else None - next_chunk = words_cleaned[i + 1] if i < len(words_cleaned) - 1 else None - prev_speaker = prev_chunk['speaker'] if prev_chunk else None - next_speaker = next_chunk['speaker'] if next_chunk else None - - if prev_speaker and next_speaker: - if prev_speaker == next_speaker: - speaker = prev_chunk['speaker'] - else: - secs_from_prev = word['start'] - prev_chunk['end'] if prev_chunk else 0 - secs_to_next = next_chunk['start'] - word['end'] if next_chunk else 0 - speaker = prev_speaker if secs_from_prev < secs_to_next else next_speaker - elif prev_speaker: - speaker = prev_speaker - elif next_speaker: - speaker = next_speaker - else: - speaker = 'SPEAKER_00' - - words_cleaned[i]['speaker'] = speaker - - # for chunk in words_cleaned: - # print(chunk) - return words_cleaned - - -def _retrieve_user_speaker_id(words: list, skip_n_seconds: int): - if not skip_n_seconds: - return None - - user_speaker_id = defaultdict(int) - for word in words: - if word['start'] >= skip_n_seconds: - break - if not word['speaker']: - continue - user_speaker_id[word['speaker']] += 1 - - user_speaker_id = max(user_speaker_id, key=user_speaker_id.get) if user_speaker_id else None - return user_speaker_id - - -def _merge_segments(words: List[dict], skip_n_seconds: int, user_speaker_id: str): - segments = [] +def legacy_words_to_provider_result(words: List[dict]) -> ProviderTranscriptResult: + provider_words = [] + provider = None + model = None for word in words: - if word['start'] < skip_n_seconds: - continue - word['is_user'] = word['speaker'] == user_speaker_id if word['speaker'] else False - - same_prev_speaker = word['speaker'] == segments[-1]['speaker'] if segments else False - seconds_from_prev = word['start'] - segments[-1]['end'] if segments else 0 + raw_speaker = word.get('speaker') + speaker = raw_speaker if isinstance(raw_speaker, str) and raw_speaker.startswith('SPEAKER_') else None + timestamp = word['timestamp'] + provider = provider or word.get('stt_provider') + model = model or word.get('stt_model') + provider_words.append( + ProviderTranscriptWord( + text=str(word['text']).strip(), + start=round(timestamp[0], 2), + end=round(timestamp[1] or timestamp[0] + 1, 2), + provider_cluster_id=word.get('provider_cluster_id') or raw_speaker, + speaker_label=word.get('provider_speaker_label') or speaker, + confidence=word.get('confidence'), + ) + ) - # TODO: consider having a max segment size too - if segments and same_prev_speaker and seconds_from_prev < 30: - segments[-1]['end'] = word['end'] - segments[-1]['text'] += ' ' + word['text'] - else: - segments.append(word) - return segments + return ProviderTranscriptResult(provider=provider or 'unknown', model=model, words=provider_words) -def _segments_as_objects(segments: List[dict]) -> List[TranscriptSegment]: - if not segments: - return [] - starts_at = segments[0]['start'] - return [ - TranscriptSegment( - text=str(segment['text']).strip().capitalize(), - speaker=segment['speaker'], - is_user=segment['is_user'], - person_id=None, - start=round(segment['start'] - starts_at, 2), - end=round(segment['end'] - starts_at, 2), - ) - for segment in segments - ] - - -def postprocess_words( - words: List[dict], duration: int, skip_n_seconds: int = 0 # , merge_segments: bool = True -) -> List[TranscriptSegment]: - words: List[dict] = _words_cleaning(words) - user_speaker_id = _retrieve_user_speaker_id(words, skip_n_seconds) - segments = _merge_segments(words, skip_n_seconds, user_speaker_id) - segments = _segments_as_objects(segments) - return segments +def postprocess_words(words: List[dict], skip_n_seconds: int = 0) -> List[TranscriptSegment]: + return reconstruct_conversation(legacy_words_to_provider_result(words), skip_n_seconds=skip_n_seconds) diff --git a/backend/utils/stt/provider_costs.py b/backend/utils/stt/provider_costs.py new file mode 100644 index 00000000000..088a85d5674 --- /dev/null +++ b/backend/utils/stt/provider_costs.py @@ -0,0 +1,93 @@ +from dataclasses import dataclass +from typing import Optional + +from utils.stt.providers import STTProviderName, STTWorkload + + +@dataclass(frozen=True) +class PrerecordedProviderCostRate: + usd_per_billable_second: float + source: str + + +# Pay-as-you-go public STT pricing, checked 2026-05-25. +# AssemblyAI background runs use speaker_labels=True, so the pre-recorded +# diarization add-on is included in AssemblyAI rates here. Deepgram background +# runs also request diarize=True, so include the speaker diarization add-on +# there too; otherwise Deepgram appears artificially cheap in rollout gates. +# Customer-specific committed-use discounts are intentionally excluded. +_PRERECORDED_STT_COST_RATES: dict[str, dict[str, PrerecordedProviderCostRate]] = { + STTProviderName.assemblyai.value: { + # AssemblyAI pricing: Universal-2 $0.15/hr, Universal-3 Pro $0.21/hr, + # plus pre-recorded Speaker Diarization $0.02/hr. + 'universal-2': PrerecordedProviderCostRate( + usd_per_billable_second=0.17 / 3600, + source='assemblyai_prerecorded_diarized_payg_2026_05_25', + ), + 'universal-3-pro': PrerecordedProviderCostRate( + usd_per_billable_second=0.23 / 3600, + source='assemblyai_prerecorded_diarized_payg_2026_05_25', + ), + 'u3-pro': PrerecordedProviderCostRate( + usd_per_billable_second=0.23 / 3600, + source='assemblyai_prerecorded_diarized_payg_2026_05_25', + ), + 'default': PrerecordedProviderCostRate( + usd_per_billable_second=0.17 / 3600, + source='assemblyai_prerecorded_diarized_default_2026_05_25', + ), + }, + STTProviderName.deepgram.value: { + # Deepgram pricing: Nova-3 monolingual pre-recorded $0.0048/min, + # Nova-3 multilingual pre-recorded $0.0058/min, plus Speaker + # Diarization add-on $0.0020/min. + 'nova-3': PrerecordedProviderCostRate( + usd_per_billable_second=0.0068 / 60, + source='deepgram_prerecorded_diarized_payg_2026_05_25', + ), + 'nova-3-general': PrerecordedProviderCostRate( + usd_per_billable_second=0.0068 / 60, + source='deepgram_prerecorded_diarized_payg_2026_05_25', + ), + 'nova-3-multilingual': PrerecordedProviderCostRate( + usd_per_billable_second=0.0078 / 60, + source='deepgram_prerecorded_diarized_payg_2026_05_25', + ), + 'default': PrerecordedProviderCostRate( + usd_per_billable_second=0.0068 / 60, + source='deepgram_prerecorded_diarized_default_2026_05_25', + ), + }, +} + +_PRERECORDED_COST_WORKLOADS = { + STTWorkload.background.value, + STTWorkload.postprocess.value, + STTWorkload.ptt.value, + STTWorkload.sync.value, + STTWorkload.voice_message.value, +} + + +def estimate_prerecorded_provider_cost_usd( + provider: str, + model: Optional[str], + workload: str, + billable_seconds: float, +) -> float: + if billable_seconds <= 0: + return 0.0 + if str(workload) not in _PRERECORDED_COST_WORKLOADS: + return 0.0 + rate = prerecorded_provider_cost_rate(provider, model) + if not rate: + return 0.0 + return round(float(billable_seconds) * rate.usd_per_billable_second, 8) + + +def prerecorded_provider_cost_rate(provider: str, model: Optional[str]) -> Optional[PrerecordedProviderCostRate]: + provider_rates = _PRERECORDED_STT_COST_RATES.get(str(provider or '').lower()) + if not provider_rates: + return None + normalized_model = str(model or '').strip().lower() + return provider_rates.get(normalized_model) or provider_rates['default'] diff --git a/backend/utils/stt/provider_evaluation.py b/backend/utils/stt/provider_evaluation.py new file mode 100644 index 00000000000..c3bd940486a --- /dev/null +++ b/backend/utils/stt/provider_evaluation.py @@ -0,0 +1,761 @@ +from dataclasses import dataclass +from typing import Any, Optional + +PRODUCTION_STRATEGIES = ('always_deepgram', 'always_assemblyai', 'current_policy', 'shadow_only') +PROVIDER_BY_STRATEGY = { + 'always_deepgram': 'deepgram', + 'always_assemblyai': 'assemblyai', + 'current_policy': 'assemblyai', + 'shadow_only': 'deepgram', +} +ASSEMBLYAI_COST_PER_HOUR_USD = 0.17 +DEEPGRAM_COST_PER_HOUR_USD = 0.408 + + +@dataclass(frozen=True) +class ProviderGateThresholds: + max_transcript_word_error_rate: float = 0.35 + max_segment_count_delta_ratio: float = 0.75 + max_average_timestamp_drift_seconds: float = 2.5 + max_low_confidence_identity_rate: float = 0.50 + max_fallback_rate: float = 0.10 + max_failure_rate: float = 0.05 + min_speaker_word_purity: float = 0.95 + min_assemblyai_purity_delta_vs_deepgram: float = -0.05 + max_speaker_inflation_ratio: float = 1.75 + max_empty_transcript_rate: float = 0.05 + max_timeout_error_rate: float = 0.05 + max_latency_ratio_vs_deepgram: float = 2.0 + max_cost_ratio_vs_deepgram: float = 3.0 + require_instrumentation: bool = True + + +def build_comparison_report( + cases: list[dict[str, Any]], + thresholds: Optional[ProviderGateThresholds] = None, +) -> dict[str, Any]: + thresholds = thresholds or ProviderGateThresholds() + case_reports = [_compare_case(case, thresholds) for case in cases] + failures = [gate for case in case_reports for gate in case['gates'] if gate['severity'] == 'failure'] + warnings = [gate for case in case_reports for gate in case['gates'] if gate['severity'] == 'warning'] + return { + 'status': 'failed' if failures else 'passed', + 'case_count': len(case_reports), + 'failure_count': len(failures), + 'warning_count': len(warnings), + 'aggregate': _aggregate_case_reports(case_reports), + 'strategies': _strategy_rollups(case_reports), + 'assemblyai_gap_report': _assemblyai_gap_report(case_reports), + 'cases': case_reports, + } + + +def evaluate_report_gates(report: dict[str, Any], fail_on_warning: bool = False) -> tuple[bool, list[str]]: + messages = [] + for case in report.get('cases', []): + for gate in case.get('gates', []): + if gate.get('severity') == 'failure' or (fail_on_warning and gate.get('severity') == 'warning'): + messages.append(f"{case.get('id', 'unknown')}: {gate.get('metric')} {gate.get('message')}") + return not messages, messages + + +def compact_markdown_report(report: dict[str, Any]) -> str: + aggregate = report.get('aggregate', {}) + lines = [ + f"# STT Provider Evaluation: {report.get('status', 'unknown').upper()}", + '', + '| Cases | Failures | Warnings | Avg WER | Avg timestamp drift | AAI purity | DG purity | AAI cost/hr | DG cost/hr |', + '| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: |', + ( + f"| {report.get('case_count', 0)} | {report.get('failure_count', 0)} | " + f"{report.get('warning_count', 0)} | {_fmt_pct(aggregate.get('average_word_error_rate'))} | " + f"{_fmt_seconds(aggregate.get('average_timestamp_drift_seconds'))} | " + f"{_fmt_pct(aggregate.get('assemblyai_speaker_word_purity'))} | " + f"{_fmt_pct(aggregate.get('deepgram_speaker_word_purity'))} | " + f"${aggregate.get('assemblyai_estimated_cost_per_hour_usd', 0.0):.3f} | " + f"${aggregate.get('deepgram_estimated_cost_per_hour_usd', 0.0):.3f} |" + ), + '', + '## Strategy Rollup', + '', + '| Strategy | Provider | Purity | Covered speakers | App speakers | Inflation | Empty | Fallback | Timeout/error | Cost/hr |', + '| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: |', + ] + for name, strategy in report.get('strategies', {}).items(): + lines.append( + f"| {name} | {strategy.get('provider', 'mixed')} | " + f"{_fmt_pct(strategy.get('speaker_word_purity'))} | " + f"{strategy.get('covered_speaker_count', 0):.1f} | " + f"{strategy.get('app_visible_speaker_count', 0):.1f} | " + f"{strategy.get('speaker_inflation_ratio', 0.0):.2f} | " + f"{_fmt_pct(strategy.get('empty_transcript_rate'))} | " + f"{_fmt_pct(strategy.get('fallback_rate'))} | " + f"{_fmt_pct(strategy.get('timeout_error_rate'))} | " + f"${strategy.get('estimated_cost_per_hour_usd', 0.0):.3f} |" + ) + lines.extend( + [ + '', + '## Cases', + '', + '| Case | Scenario | WER | Purity DG/AAI | App speakers DG/AAI | Splits AAI | Recon AAI | Fallback AAI | Gates |', + '| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | --- |', + ] + ) + for case in report.get('cases', []): + deepgram = case['providers']['deepgram'] + assemblyai = case['providers']['assemblyai'] + gates = ', '.join( + f"{gate['severity']}:{gate['metric']}" for gate in case.get('gates', []) if gate['severity'] != 'pass' + ) + lines.append( + f"| {case['id']} | {case.get('scenario', 'unspecified')} | " + f"{_fmt_pct(case['comparison']['transcript_word_error_rate'])} | " + f"{_fmt_pct(deepgram['speaker_word_purity'])}/{_fmt_pct(assemblyai['speaker_word_purity'])} | " + f"{deepgram['app_visible_speaker_count']}/{assemblyai['app_visible_speaker_count']} | " + f"{assemblyai['split_count']} | " + f"{assemblyai['accepted_reconciliation_count']}/{assemblyai['rejected_reconciliation_count']} | " + f"{_fmt_pct(assemblyai['fallback_rate'])} | {gates or 'pass'} |" + ) + gap_report = report.get('assemblyai_gap_report') or {} + lines.extend(['', '## AssemblyAI Gap Report', '']) + lines.append(f"Status: {gap_report.get('status', 'unknown')}.") + limiting_scenarios = gap_report.get('limiting_scenarios') or [] + for item in limiting_scenarios: + lines.append( + f"- {item['scenario']}: {item['metric']} ({item['assemblyai_value']} vs {item['deepgram_value']}) " + f"- likely cause: {item['likely_cause']}; mitigation: {item['mitigation']}" + ) + if not limiting_scenarios: + lines.append('- No material AssemblyAI gap detected in offline synthetic/saved-output fixtures.') + lines.extend( + [ + '', + 'Synthetic and saved-output gates are necessary but insufficient for default health decisions. ' + 'Use this gap report to track AssemblyAI default readiness and rollback thresholds with privacy-safe real-session evidence.', + ] + ) + return '\n'.join(lines) + + +def _compare_case(case: dict[str, Any], thresholds: ProviderGateThresholds) -> dict[str, Any]: + deepgram = summarize_provider_output('deepgram', case.get('deepgram') or {}) + assemblyai = summarize_provider_output('assemblyai', case.get('assemblyai') or {}) + comparison = { + 'transcript_word_error_rate': _word_error_rate(deepgram['text'], assemblyai['text']), + 'segment_count_delta': assemblyai['segment_count'] - deepgram['segment_count'], + 'segment_count_delta_ratio': _ratio_delta(assemblyai['segment_count'], deepgram['segment_count']), + 'word_count_delta': assemblyai['word_count'] - deepgram['word_count'], + 'word_count_delta_ratio': _ratio_delta(assemblyai['word_count'], deepgram['word_count']), + 'average_timestamp_drift_seconds': _average_timestamp_drift_seconds( + deepgram['segments'], assemblyai['segments'] + ), + } + gates = _evaluate_case_gates(deepgram, assemblyai, comparison, thresholds) + return { + 'id': case.get('id') or case.get('case_id') or 'unknown', + 'scenario': case.get('scenario') or case.get('type') or 'unspecified', + 'current_policy_provider': case.get('current_policy_provider') or 'assemblyai', + 'providers': {'deepgram': _public_summary(deepgram), 'assemblyai': _public_summary(assemblyai)}, + 'comparison': comparison, + 'gates': gates, + } + + +def summarize_provider_output(provider: str, payload: dict[str, Any]) -> dict[str, Any]: + transcript = payload.get('transcript') or payload.get('result') or payload + segments = _extract_segments(transcript) + ledger = payload.get('ledger') or payload.get('rollup') or {} + clusters = _speaker_clusters(segments) + oracle_speakers = _oracle_speakers(segments) + word_count = sum(len(_words(segment.get('text', ''))) for segment in segments) + raw_audio_seconds = _number_from_ledger(ledger, 'raw_audio_seconds') + billable_seconds = _number_from_ledger(ledger, 'billable_seconds') + estimated_cost_usd = _estimated_cost_usd(provider, ledger, billable_seconds or raw_audio_seconds) + latency_seconds = _latency_seconds(ledger) + timeout_error_count = _timeout_error_count(ledger) + run_count = _run_count(ledger) + identified_clusters = { + _cluster_id(segment) + for segment in segments + if _cluster_id(segment) + and ( + segment.get('person_id') + or segment.get('is_user') is True + or segment.get('speaker_identity_state') in ('identified', 'user') + ) + } + low_confidence_clusters = set() + for segment in segments: + cluster_id = _cluster_id(segment) + if not cluster_id: + continue + state = segment.get('speaker_identity_state') + confidence = segment.get('speaker_identity_confidence') + if state == 'unknown': + low_confidence_clusters.add(cluster_id) + elif confidence is not None and float(confidence) < 0.50: + low_confidence_clusters.add(cluster_id) + return { + 'provider': provider, + 'segments': segments, + 'text': _transcript_text(segments), + 'segment_count': len(segments), + 'word_count': word_count, + 'speaker_cluster_count': len(clusters), + 'provider_cluster_count': len(clusters), + 'covered_speaker_count': len(oracle_speakers), + 'app_visible_speaker_count': len(clusters), + 'speaker_inflation_ratio': (len(clusters) / len(oracle_speakers) if oracle_speakers else 0.0), + 'speaker_word_purity': _speaker_word_purity(segments), + 'empty_transcript_rate': 1.0 if word_count == 0 else 0.0, + 'identified_speaker_cluster_count': len(identified_clusters), + 'unknown_speaker_cluster_count': max(len(clusters) - len(identified_clusters), 0), + 'low_confidence_identity_count': len(low_confidence_clusters), + 'low_confidence_identity_rate': (len(low_confidence_clusters) / len(clusters) if clusters else 0.0), + 'raw_audio_seconds': raw_audio_seconds, + 'speech_active_seconds': _number_from_ledger(ledger, 'speech_active_seconds'), + 'billable_seconds': billable_seconds, + 'estimated_cost_usd': estimated_cost_usd, + 'estimated_cost_per_hour_usd': _cost_per_hour(estimated_cost_usd, billable_seconds or raw_audio_seconds), + 'latency_seconds': latency_seconds, + 'runtime_seconds': _runtime_seconds(ledger, latency_seconds), + 'retry_count': _number_from_ledger(ledger, 'retry_count'), + 'split_count': _number_from_ledger(ledger, 'split_count'), + 'accepted_reconciliation_count': _number_from_ledger(ledger, 'accepted_reconciliation_count'), + 'rejected_reconciliation_count': _number_from_ledger(ledger, 'rejected_reconciliation_count'), + 'fallback_count': _number_from_ledger(ledger, 'fallback_count'), + 'fallback_rate': _rate_from_ledger(ledger, 'fallback_count'), + 'failure_rate': _failure_rate_from_ledger(ledger), + 'timeout_error_count': timeout_error_count, + 'timeout_error_rate': timeout_error_count / run_count, + 'has_instrumentation': bool(ledger), + } + + +def _public_summary(summary: dict[str, Any]) -> dict[str, Any]: + return {key: value for key, value in summary.items() if key not in ('segments', 'text')} + + +def _extract_segments(transcript: Any) -> list[dict[str, Any]]: + if isinstance(transcript, dict) and isinstance(transcript.get('segments'), list): + return [_normalize_segment(segment) for segment in transcript['segments']] + if isinstance(transcript, list): + return [_normalize_segment(segment) for segment in transcript] + if isinstance(transcript, dict) and isinstance(transcript.get('utterances'), list) and transcript.get('utterances'): + return [_normalize_provider_utterance(utterance, transcript) for utterance in transcript['utterances']] + if isinstance(transcript, dict) and isinstance(transcript.get('words'), list): + return _segments_from_words(transcript['words'], transcript) + if isinstance(transcript, dict) and transcript.get('text'): + return [_normalize_segment(transcript)] + return [] + + +def _normalize_segment(segment: dict[str, Any]) -> dict[str, Any]: + return { + 'text': str(segment.get('text') or segment.get('transcript') or '').strip(), + 'start': float(segment.get('start') or 0.0), + 'end': float(segment.get('end') or 0.0), + 'provider_cluster_id': segment.get('provider_cluster_id'), + 'provider_speaker_label': segment.get('provider_speaker_label'), + 'speaker': segment.get('speaker'), + 'oracle_speaker': segment.get('oracle_speaker') or segment.get('expected_speaker'), + 'person_id': segment.get('person_id'), + 'is_user': segment.get('is_user'), + 'speaker_identity_state': segment.get('speaker_identity_state'), + 'speaker_identity_confidence': segment.get('speaker_identity_confidence'), + } + + +def _normalize_provider_utterance(utterance: dict[str, Any], transcript: dict[str, Any]) -> dict[str, Any]: + normalized = _normalize_segment(utterance) + normalized['provider_cluster_id'] = utterance.get('provider_cluster_id') + normalized['provider_speaker_label'] = utterance.get('speaker_label') or utterance.get('provider_speaker_label') + normalized['stt_provider'] = transcript.get('provider') + normalized['stt_model'] = transcript.get('model') + return normalized + + +def _segments_from_words(words: list[dict[str, Any]], transcript: dict[str, Any]) -> list[dict[str, Any]]: + segments = [] + current = None + for word in words: + cluster_id = word.get('provider_cluster_id') or word.get('speaker') + if current is None or current.get('provider_cluster_id') != cluster_id: + if current: + current['text'] = ' '.join(current.pop('_words')) + segments.append(current) + current = { + 'text': '', + '_words': [], + 'start': float(word.get('start') or 0.0), + 'end': float(word.get('end') or 0.0), + 'provider_cluster_id': cluster_id, + 'provider_speaker_label': word.get('speaker_label') or word.get('provider_speaker_label'), + 'oracle_speaker': word.get('oracle_speaker') or word.get('expected_speaker'), + 'stt_provider': transcript.get('provider'), + 'stt_model': transcript.get('model'), + } + current['_words'].append(str(word.get('text') or word.get('word') or '')) + current['end'] = float(word.get('end') or current['end']) + if current: + current['text'] = ' '.join(current.pop('_words')) + segments.append(current) + return segments + + +def _speaker_clusters(segments: list[dict[str, Any]]) -> set[str]: + return {cluster_id for cluster_id in (_cluster_id(segment) for segment in segments) if cluster_id} + + +def _cluster_id(segment: dict[str, Any]) -> Optional[str]: + return segment.get('provider_cluster_id') or segment.get('provider_speaker_label') or segment.get('speaker') + + +def _transcript_text(segments: list[dict[str, Any]]) -> str: + return ' '.join(segment.get('text', '') for segment in segments).strip() + + +def _words(text: str) -> list[str]: + normalized = ''.join(character.lower() if character.isalnum() else ' ' for character in text) + return [word for word in normalized.split() if word] + + +def _word_error_rate(reference: str, hypothesis: str) -> float: + reference_words = _words(reference) + hypothesis_words = _words(hypothesis) + if not reference_words: + return 0.0 if not hypothesis_words else 1.0 + return _levenshtein_distance(reference_words, hypothesis_words) / len(reference_words) + + +def _levenshtein_distance(reference: list[str], hypothesis: list[str]) -> int: + previous = list(range(len(hypothesis) + 1)) + for index, reference_word in enumerate(reference, start=1): + current = [index] + for other_index, hypothesis_word in enumerate(hypothesis, start=1): + current.append( + min( + previous[other_index] + 1, + current[other_index - 1] + 1, + previous[other_index - 1] + (reference_word != hypothesis_word), + ) + ) + previous = current + return previous[-1] + + +def _ratio_delta(value: float, baseline: float) -> float: + if baseline == 0: + return 0.0 if value == 0 else 1.0 + return abs(value - baseline) / baseline + + +def _average_timestamp_drift_seconds(reference: list[dict[str, Any]], candidate: list[dict[str, Any]]) -> float: + pair_count = min(len(reference), len(candidate)) + if not pair_count: + return 0.0 + drift = 0.0 + for index in range(pair_count): + drift += abs(reference[index].get('start', 0.0) - candidate[index].get('start', 0.0)) + drift += abs(reference[index].get('end', 0.0) - candidate[index].get('end', 0.0)) + return drift / (pair_count * 2) + + +def _evaluate_case_gates( + deepgram: dict[str, Any], + assemblyai: dict[str, Any], + comparison: dict[str, Any], + thresholds: ProviderGateThresholds, +) -> list[dict[str, Any]]: + empty_transcript_threshold = ( + 1.0 if deepgram['word_count'] == 0 and assemblyai['word_count'] == 0 else thresholds.max_empty_transcript_rate + ) + gates = [ + _threshold_gate( + 'transcript_word_error_rate', + comparison['transcript_word_error_rate'], + thresholds.max_transcript_word_error_rate, + 'failure', + ), + _threshold_gate( + 'segment_count_delta_ratio', + comparison['segment_count_delta_ratio'], + thresholds.max_segment_count_delta_ratio, + 'failure', + gate_group='speaker_safety', + ), + _threshold_gate( + 'average_timestamp_drift_seconds', + comparison['average_timestamp_drift_seconds'], + thresholds.max_average_timestamp_drift_seconds, + 'warning', + gate_group='rollout_readiness', + ), + _threshold_gate( + 'assemblyai_low_confidence_identity_rate', + assemblyai['low_confidence_identity_rate'], + thresholds.max_low_confidence_identity_rate, + 'warning', + gate_group='speaker_safety', + ), + _threshold_gate( + 'assemblyai_fallback_rate', assemblyai['fallback_rate'], thresholds.max_fallback_rate, 'failure' + ), + _threshold_gate('assemblyai_failure_rate', assemblyai['failure_rate'], thresholds.max_failure_rate, 'failure'), + _threshold_gate( + 'assemblyai_speaker_word_purity', + assemblyai['speaker_word_purity'], + thresholds.min_speaker_word_purity, + 'failure', + minimum=True, + gate_group='speaker_safety', + ), + _threshold_gate( + 'assemblyai_speaker_inflation_ratio', + assemblyai['speaker_inflation_ratio'], + thresholds.max_speaker_inflation_ratio, + 'failure', + gate_group='speaker_safety', + ), + _threshold_gate( + 'assemblyai_empty_transcript_rate', + assemblyai['empty_transcript_rate'], + empty_transcript_threshold, + 'failure', + gate_group='default_viability', + ), + _threshold_gate( + 'assemblyai_timeout_error_rate', + assemblyai['timeout_error_rate'], + thresholds.max_timeout_error_rate, + 'failure', + gate_group='rollout_readiness', + ), + _threshold_gate( + 'assemblyai_purity_delta_vs_deepgram', + assemblyai['speaker_word_purity'] - deepgram['speaker_word_purity'], + thresholds.min_assemblyai_purity_delta_vs_deepgram, + 'failure', + minimum=True, + gate_group='speaker_safety', + ), + _threshold_gate( + 'assemblyai_latency_ratio_vs_deepgram', + _safe_ratio(assemblyai['latency_seconds'], deepgram['latency_seconds']), + thresholds.max_latency_ratio_vs_deepgram, + 'warning', + gate_group='rollout_readiness', + ), + _threshold_gate( + 'assemblyai_cost_ratio_vs_deepgram', + _safe_ratio(assemblyai['estimated_cost_per_hour_usd'], deepgram['estimated_cost_per_hour_usd']), + thresholds.max_cost_ratio_vs_deepgram, + 'warning', + gate_group='default_viability', + ), + ] + if thresholds.require_instrumentation: + for provider in (deepgram, assemblyai): + if not provider['has_instrumentation']: + gates.append( + { + 'metric': f"{provider['provider']}_instrumentation", + 'severity': 'warning', + 'gate_group': 'rollout_readiness', + 'value': None, + 'threshold': 'ledger_or_rollup_required', + 'message': 'missing provider ledger or rollup metrics', + } + ) + return gates + + +def _threshold_gate( + metric: str, + value: float, + threshold: float, + severity: str, + minimum: bool = False, + gate_group: str = 'default_viability', +) -> dict[str, Any]: + passed = value >= threshold if minimum else value <= threshold + direction = 'below' if minimum else 'exceeds' + return { + 'metric': metric, + 'severity': 'pass' if passed else severity, + 'gate_group': gate_group, + 'value': value, + 'threshold': threshold, + 'message': 'within threshold' if passed else f'{value:.4f} {direction} {threshold:.4f}', + } + + +def _number_from_ledger(ledger: dict[str, Any], field: str) -> float: + value = ledger.get(field, 0.0) + if isinstance(value, dict) and '__increment' in value: + value = value['__increment'] + return float(value or 0.0) + + +def _run_count(ledger: dict[str, Any]) -> float: + return float(ledger.get('run_count') or 1.0) + + +def _estimated_cost_usd(provider: str, ledger: dict[str, Any], billed_seconds: float) -> float: + recorded = _number_from_ledger(ledger, 'estimated_cost_usd') + if recorded: + return recorded + if provider == 'assemblyai': + return billed_seconds / 3600 * ASSEMBLYAI_COST_PER_HOUR_USD + if provider == 'deepgram': + return billed_seconds / 3600 * DEEPGRAM_COST_PER_HOUR_USD + return 0.0 + + +def _cost_per_hour(cost: float, seconds: float) -> float: + return cost / seconds * 3600 if seconds else 0.0 + + +def _latency_seconds(ledger: dict[str, Any]) -> float: + for field in ('latency_seconds', 'duration_seconds', 'elapsed_seconds'): + value = _number_from_ledger(ledger, field) + if value: + return value + return 0.0 + + +def _runtime_seconds(ledger: dict[str, Any], latency_seconds: float) -> float: + for field in ('runtime_seconds', 'wall_time_seconds'): + value = _number_from_ledger(ledger, field) + if value: + return value + return latency_seconds + + +def _timeout_error_count(ledger: dict[str, Any]) -> float: + status_counts = ledger.get('status_counts') or {} + return float( + _number_from_ledger(ledger, 'timeout_count') + + _number_from_ledger(ledger, 'error_count') + + status_counts.get('timeout', 0) + + status_counts.get('timed_out', 0) + + status_counts.get('failed', 0) + + status_counts.get('failure', 0) + ) + + +def _rate_from_ledger(ledger: dict[str, Any], count_field: str) -> float: + denominator = float(ledger.get('run_count') or 1.0) + return _number_from_ledger(ledger, count_field) / denominator + + +def _failure_rate_from_ledger(ledger: dict[str, Any]) -> float: + status_counts = ledger.get('status_counts') or {} + failed = status_counts.get('failed') or status_counts.get('failure') or 0 + denominator = float(ledger.get('run_count') or 1.0) + return float(failed) / denominator + + +def _oracle_speakers(segments: list[dict[str, Any]]) -> set[str]: + return {str(segment['oracle_speaker']) for segment in segments if segment.get('oracle_speaker')} + + +def _speaker_word_purity(segments: list[dict[str, Any]]) -> float: + cluster_counts: dict[str, dict[str, int]] = {} + total_words = 0 + for segment in segments: + cluster_id = _cluster_id(segment) + oracle_speaker = segment.get('oracle_speaker') + if not cluster_id or not oracle_speaker: + continue + word_count = len(_words(segment.get('text', ''))) + if word_count == 0: + continue + total_words += word_count + cluster_counts.setdefault(str(cluster_id), {}) + cluster_counts[str(cluster_id)][str(oracle_speaker)] = ( + cluster_counts[str(cluster_id)].get(str(oracle_speaker), 0) + word_count + ) + if not total_words: + return 1.0 + pure_words = sum(max(counts.values()) for counts in cluster_counts.values()) + return pure_words / total_words + + +def _safe_ratio(value: float, baseline: float) -> float: + if baseline == 0: + return 0.0 if value == 0 else 999.0 + return value / baseline + + +def _aggregate_case_reports(case_reports: list[dict[str, Any]]) -> dict[str, Any]: + if not case_reports: + return {} + assemblyai = [case['providers']['assemblyai'] for case in case_reports] + deepgram = [case['providers']['deepgram'] for case in case_reports] + assemblyai_seconds = sum(provider['billable_seconds'] or provider['raw_audio_seconds'] for provider in assemblyai) + deepgram_seconds = sum(provider['billable_seconds'] or provider['raw_audio_seconds'] for provider in deepgram) + assemblyai_cost = sum(provider['estimated_cost_usd'] for provider in assemblyai) + deepgram_cost = sum(provider['estimated_cost_usd'] for provider in deepgram) + return { + 'average_word_error_rate': _average(case['comparison']['transcript_word_error_rate'] for case in case_reports), + 'average_timestamp_drift_seconds': _average( + case['comparison']['average_timestamp_drift_seconds'] for case in case_reports + ), + 'assemblyai_speaker_word_purity': _weighted_average(assemblyai, 'speaker_word_purity', 'word_count'), + 'deepgram_speaker_word_purity': _weighted_average(deepgram, 'speaker_word_purity', 'word_count'), + 'assemblyai_estimated_cost_usd': assemblyai_cost, + 'deepgram_estimated_cost_usd': deepgram_cost, + 'assemblyai_estimated_cost_per_hour_usd': _cost_per_hour(assemblyai_cost, assemblyai_seconds), + 'deepgram_estimated_cost_per_hour_usd': _cost_per_hour(deepgram_cost, deepgram_seconds), + 'assemblyai_billable_seconds': sum(provider['billable_seconds'] for provider in assemblyai), + 'deepgram_billable_seconds': sum(provider['billable_seconds'] for provider in deepgram), + } + + +def _strategy_rollups(case_reports: list[dict[str, Any]]) -> dict[str, Any]: + return {strategy: _strategy_rollup(case_reports, strategy) for strategy in PRODUCTION_STRATEGIES} + + +def _strategy_rollup(case_reports: list[dict[str, Any]], strategy: str) -> dict[str, Any]: + selected = [] + providers = set() + for case in case_reports: + provider_name = PROVIDER_BY_STRATEGY.get(strategy) or case.get('current_policy_provider') or 'assemblyai' + providers.add(provider_name) + selected.append(case['providers'][provider_name]) + cost = sum(provider['estimated_cost_usd'] for provider in selected) + seconds = sum(provider['billable_seconds'] or provider['raw_audio_seconds'] for provider in selected) + return { + 'provider': next(iter(providers)) if len(providers) == 1 else 'mixed', + 'speaker_word_purity': _weighted_average(selected, 'speaker_word_purity', 'word_count'), + 'covered_speaker_count': _average(provider['covered_speaker_count'] for provider in selected), + 'app_visible_speaker_count': _average(provider['app_visible_speaker_count'] for provider in selected), + 'speaker_inflation_ratio': _average(provider['speaker_inflation_ratio'] for provider in selected), + 'split_count': sum(provider['split_count'] for provider in selected), + 'accepted_reconciliation_count': sum(provider['accepted_reconciliation_count'] for provider in selected), + 'rejected_reconciliation_count': sum(provider['rejected_reconciliation_count'] for provider in selected), + 'fallback_count': sum(provider['fallback_count'] for provider in selected), + 'provider_cluster_count': sum(provider['provider_cluster_count'] for provider in selected), + 'empty_transcript_rate': _average(provider['empty_transcript_rate'] for provider in selected), + 'latency_seconds': _average(provider['latency_seconds'] for provider in selected), + 'runtime_seconds': _average(provider['runtime_seconds'] for provider in selected), + 'timeout_error_rate': _average(provider['timeout_error_rate'] for provider in selected), + 'fallback_rate': _average(provider['fallback_rate'] for provider in selected), + 'failure_rate': _average(provider['failure_rate'] for provider in selected), + 'estimated_cost_usd': cost, + 'estimated_cost_per_hour_usd': _cost_per_hour(cost, seconds), + } + + +def _assemblyai_gap_report(case_reports: list[dict[str, Any]]) -> dict[str, Any]: + limiting_scenarios = [] + for case in case_reports: + deepgram = case['providers']['deepgram'] + assemblyai = case['providers']['assemblyai'] + candidates = [ + ( + 'speaker_word_purity', + assemblyai['speaker_word_purity'], + deepgram['speaker_word_purity'], + assemblyai['speaker_word_purity'] < deepgram['speaker_word_purity'] - 0.02, + ), + ( + 'covered_speaker_count', + assemblyai['covered_speaker_count'], + deepgram['covered_speaker_count'], + assemblyai['covered_speaker_count'] < deepgram['covered_speaker_count'], + ), + ( + 'empty_transcript_rate', + assemblyai['empty_transcript_rate'], + deepgram['empty_transcript_rate'], + assemblyai['empty_transcript_rate'] > deepgram['empty_transcript_rate'], + ), + ( + 'latency_seconds', + assemblyai['latency_seconds'], + deepgram['latency_seconds'], + assemblyai['latency_seconds'] > deepgram['latency_seconds'] * 2 and assemblyai['latency_seconds'] > 0, + ), + ( + 'timeout_error_rate', + assemblyai['timeout_error_rate'], + deepgram['timeout_error_rate'], + assemblyai['timeout_error_rate'] > deepgram['timeout_error_rate'], + ), + ( + 'estimated_cost_per_hour_usd', + assemblyai['estimated_cost_per_hour_usd'], + deepgram['estimated_cost_per_hour_usd'], + assemblyai['estimated_cost_per_hour_usd'] > deepgram['estimated_cost_per_hour_usd'] * 3 + and deepgram['estimated_cost_per_hour_usd'] > 0, + ), + ] + for metric, assemblyai_value, deepgram_value, limited in candidates: + if limited: + limiting_scenarios.append( + { + 'case_id': case['id'], + 'scenario': case.get('scenario') or case['id'], + 'metric': metric, + 'assemblyai_value': round(float(assemblyai_value), 4), + 'deepgram_value': round(float(deepgram_value), 4), + 'likely_cause': _likely_cause(metric), + 'mitigation': _mitigation(metric), + } + ) + break + return { + 'status': 'limited' if limiting_scenarios else 'no_material_gap_detected', + 'limiting_scenarios': limiting_scenarios, + } + + +def _likely_cause(metric: str) -> str: + return { + 'speaker_word_purity': 'provider-local clusters mix speakers before Omi identity matching', + 'covered_speaker_count': 'provider diarization missed a speaker or no-speech gating discarded speech', + 'empty_transcript_rate': 'low-signal/no-speech handling produced an empty or failed transcript', + 'latency_seconds': 'AssemblyAI async job latency exceeds Deepgram for this workload shape', + 'timeout_error_rate': 'provider timeout or retry exhaustion path is not default-safe', + 'estimated_cost_per_hour_usd': 'provider billable duration or pricing is too high for default background volume', + }.get(metric, 'AssemblyAI trails the Deepgram comparator on this gate') + + +def _mitigation(metric: str) -> str: + return { + 'speaker_word_purity': 'keep split-before-match enabled and gate rollout on fragmentation plus purity budgets', + 'covered_speaker_count': 'use Deepgram fallback for affected low-signal cases until AssemblyAI closes coverage', + 'empty_transcript_rate': 'preserve no-speech detection and fallback controls before default promotion', + 'latency_seconds': 'keep latency SLO alerts and fallback controls before expanding default traffic', + 'timeout_error_rate': 'use Deepgram fallback and provider health gates from TICKET-027', + 'estimated_cost_per_hour_usd': 'review billable seconds, requested add-ons, and provider pricing before changing defaults', + }.get(metric, 'capture in TICKET-028 rollout tradeoffs before promoting AssemblyAI') + + +def _average(values) -> float: + values = list(values) + return sum(values) / len(values) if values else 0.0 + + +def _weighted_average(items: list[dict[str, Any]], value_field: str, weight_field: str) -> float: + denominator = sum(float(item.get(weight_field) or 0.0) for item in items) + if denominator == 0: + return _average(item.get(value_field, 0.0) for item in items) + return ( + sum(float(item.get(value_field) or 0.0) * float(item.get(weight_field) or 0.0) for item in items) / denominator + ) + + +def _fmt_pct(value) -> str: + if value is None: + return 'n/a' + return f'{float(value) * 100:.1f}%' + + +def _fmt_seconds(value) -> str: + if value is None: + return 'n/a' + return f'{float(value):.2f}s' diff --git a/backend/utils/stt/provider_service.py b/backend/utils/stt/provider_service.py new file mode 100644 index 00000000000..2e4f7719946 --- /dev/null +++ b/backend/utils/stt/provider_service.py @@ -0,0 +1,875 @@ +import logging +import os +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import List, Optional, Sequence, Tuple + +import httpx +from deepgram import DeepgramClient, DeepgramClientOptions + +from models.transcript_segment import ProviderTranscriptResult, TranscriptSegment +from utils.stt.assemblyai_adapter import AssemblyAIAsyncTranscriptionProvider, AssemblyAITimeoutError +from utils.stt.conversation_reconstructor import reconstruct_conversation +from utils.stt.deepgram_adapter import DeepgramPrerecordedTranscriptionProvider +from utils.stt.deepgram_adapter import provider_result_to_legacy_words +from utils.stt.provider_costs import estimate_prerecorded_provider_cost_usd +from utils.stt.providers import ( + BackgroundProviderMode, + STTProviderName, + STTWorkload, + assemblyai_prerecorded_fallback_enabled, + get_background_provider_mode, + get_fallback_prerecorded_provider_name, + get_prerecorded_provider_name, +) +from utils.stt.deepgram_config import get_deepgram_model_for_language + +try: + from utils.byok import get_byok_key +except ImportError: + get_byok_key = None + +try: + from database.transcription_provider_usage import ( + create_provider_run as _db_create_provider_run, + finalize_provider_run as _db_finalize_provider_run, + update_provider_run_identity_metrics as _db_update_provider_run_identity_metrics, + ) + + _PROVIDER_USAGE_IMPORT_ERROR = None +except ImportError as e: + _db_create_provider_run = None + _db_finalize_provider_run = None + _db_update_provider_run_identity_metrics = None + _PROVIDER_USAGE_IMPORT_ERROR = e + +logger = logging.getLogger(__name__) + +_DG_TIMEOUT = httpx.Timeout(connect=10.0, read=120.0, write=30.0, pool=10.0) +_DEEPGRAM_OPTIONS = DeepgramClientOptions(options={"keepalive": "true"}) +_DEEPGRAM_CLIENT = DeepgramClient(os.getenv('DEEPGRAM_API_KEY'), _DEEPGRAM_OPTIONS) +_LOCAL_CLUSTER_SPLIT_MARKER = '::local_part:' +_UNKNOWN_SPEAKER_STATES = {'unknown', 'unassigned', 'legacy_ambiguous'} + + +def create_provider_run(**kwargs) -> str: + if _db_create_provider_run is None: + raise _PROVIDER_USAGE_IMPORT_ERROR + return _db_create_provider_run(**kwargs) + + +def finalize_provider_run(**kwargs) -> None: + if _db_finalize_provider_run is None: + raise _PROVIDER_USAGE_IMPORT_ERROR + _db_finalize_provider_run(**kwargs) + + +def summarize_identity_confidences(confidences): + summary = {} + for confidence in confidences: + bucket = _identity_confidence_bucket(confidence) + summary[bucket] = summary.get(bucket, 0) + 1 + return summary + + +def _identity_confidence_bucket(confidence: Optional[float]) -> str: + if confidence is None: + return 'unknown' + if confidence >= 0.90: + return 'very_high' + if confidence >= 0.75: + return 'high' + if confidence >= 0.50: + return 'medium' + return 'low' + + +def _deepgram_prerecorded_provider(): + return DeepgramPrerecordedTranscriptionProvider(_deepgram_client_for_request, _DG_TIMEOUT) + + +def _assemblyai_prerecorded_provider(): + byok = get_byok_key('assemblyai') if get_byok_key else None + return AssemblyAIAsyncTranscriptionProvider(api_key=byok) + + +def resolve_prerecorded_provider_for_request(workload: STTWorkload) -> STTProviderName: + """Pick prerecorded STT provider for this request, respecting BYOK headers. + + When env flags select AssemblyAI but the BYOK user did not supply an Assembly + key, use Deepgram BYOK instead of Omi's server Assembly key. When server + AssemblyAI credentials are absent, skip directly to Deepgram if fallback is + enabled and a usable Deepgram key is available. + """ + selected = get_prerecorded_provider_name(workload) + if selected != STTProviderName.assemblyai: + return selected + assemblyai_byok = get_byok_key('assemblyai') if get_byok_key else None + deepgram_byok = get_byok_key('deepgram') if get_byok_key else None + if assemblyai_byok: + return STTProviderName.assemblyai + if deepgram_byok and assemblyai_prerecorded_fallback_enabled(): + return STTProviderName.deepgram + if os.getenv('ASSEMBLYAI_API_KEY'): + return STTProviderName.assemblyai + if assemblyai_prerecorded_fallback_enabled() and _has_deepgram_key_for_request(): + return STTProviderName.deepgram + return STTProviderName.assemblyai + + +def _has_deepgram_key_for_request() -> bool: + return bool((get_byok_key('deepgram') if get_byok_key else None) or os.getenv('DEEPGRAM_API_KEY')) + + +def _has_assemblyai_key_for_request() -> bool: + return bool((get_byok_key('assemblyai') if get_byok_key else None) or os.getenv('ASSEMBLYAI_API_KEY')) + + +def _deepgram_client_for_request() -> DeepgramClient: + byok = get_byok_key('deepgram') if get_byok_key else None + if byok: + return DeepgramClient(byok, _DEEPGRAM_OPTIONS) + return _DEEPGRAM_CLIENT + + +@dataclass +class PrerecordedTranscriptionResponse: + result: ProviderTranscriptResult + detected_language: Optional[str] + segments: List[TranscriptSegment] + words: List[dict] + run_id: Optional[str] + + +@dataclass(frozen=True) +class BackgroundProviderPolicy: + mode: BackgroundProviderMode + primary_provider: STTProviderName + effective_provider: Optional[STTProviderName] + fallback_provider: Optional[STTProviderName] + fallback_enabled: bool + fallback_available: bool + enabled: bool + reason: Optional[str] + + +class ProviderTranscriptionRetriesExhausted(RuntimeError): + def __init__(self, provider_error: Exception, retry_count: int): + super().__init__(str(provider_error)) + self.provider_error = provider_error + self.retry_count = retry_count + + +def resolve_prerecorded_language_model(language: Optional[str]) -> Tuple[str, str]: + return get_deepgram_model_for_language(language or 'multi') + + +def resolve_background_provider_policy() -> BackgroundProviderPolicy: + mode = get_background_provider_mode() + primary_provider = get_prerecorded_provider_name(STTWorkload.background) + effective_provider = resolve_prerecorded_provider_for_request(STTWorkload.background) + fallback_provider = get_fallback_prerecorded_provider_name(primary_provider, STTWorkload.background) + + assemblyai_key_available = _has_assemblyai_key_for_request() + deepgram_key_available = _has_deepgram_key_for_request() + fallback_available = fallback_provider == STTProviderName.deepgram and deepgram_key_available + + usable_provider = None + reason = None + if effective_provider == STTProviderName.assemblyai: + if assemblyai_key_available: + usable_provider = STTProviderName.assemblyai + elif fallback_available: + usable_provider = STTProviderName.deepgram + reason = 'fallback_deepgram_available' + else: + reason = 'missing_assemblyai_api_key' + elif effective_provider == STTProviderName.deepgram: + if deepgram_key_available: + usable_provider = STTProviderName.deepgram + if mode == BackgroundProviderMode.shadow_only: + reason = 'shadow_only' + else: + reason = 'missing_deepgram_api_key' + else: + reason = 'no_usable_batch_provider' + + return BackgroundProviderPolicy( + mode=mode, + primary_provider=primary_provider, + effective_provider=usable_provider, + fallback_provider=fallback_provider, + fallback_enabled=fallback_provider is not None, + fallback_available=fallback_available, + enabled=usable_provider is not None, + reason=reason, + ) + + +def transcribe_url( + audio_url: str, + workload: STTWorkload, + uid: Optional[str] = None, + conversation_id: Optional[str] = None, + speakers_count: int = None, + return_language: bool = False, + diarize: bool = True, + language: Optional[str] = None, + model: str = 'nova-3', + keywords: Optional[Sequence[str]] = None, + skip_n_seconds: int = 0, + raw_audio_seconds: float = 0.0, +) -> PrerecordedTranscriptionResponse: + workload = STTWorkload(workload) + provider_name = resolve_prerecorded_provider_for_request(workload) + model = _model_for_provider(provider_name, model) + provider = _get_prerecorded_provider(provider_name) + started_at = datetime.now(timezone.utc) + run_id = _create_run(uid, provider_name.value, model, workload.value, conversation_id, started_at) + + try: + provider_result, detected_language, retry_count = _transcribe_url_with_retry( + provider, + audio_url, + speakers_count=speakers_count, + return_language=return_language, + diarize=diarize, + language=language, + model=model, + keywords=keywords, + ) + segments = reconstruct_conversation(provider_result, skip_n_seconds=skip_n_seconds) + words = provider_result_to_legacy_words(provider_result) + _finalize_run( + run_id, + provider_result, + workload, + started_at, + 'succeeded', + retry_count=retry_count, + raw_audio_seconds=raw_audio_seconds or provider_result.duration or 0.0, + segments=segments, + ) + return PrerecordedTranscriptionResponse( + result=provider_result, + detected_language=detected_language, + segments=segments, + words=words, + run_id=run_id, + ) + except Exception as e: + _finalize_failed_run( + run_id, + provider_name.value, + model, + workload.value, + started_at, + _provider_error_from_exception(e), + raw_audio_seconds, + retry_count=_retry_count_from_exception(e), + ) + fallback_provider_name = _resolve_usable_fallback_prerecorded_provider(provider_name, workload) + if fallback_provider_name: + fallback_reason = _fallback_reason_from_exception(e) + logger.warning( + 'provider prerecorded url transcription falling back workload=%s from_provider=%s to_provider=%s reason=%s: %s', + workload.value, + provider_name.value, + fallback_provider_name.value, + fallback_reason, + e, + ) + return _transcribe_url_with_provider( + fallback_provider_name, + audio_url, + workload, + uid=uid, + conversation_id=conversation_id, + speakers_count=speakers_count, + return_language=return_language, + diarize=diarize, + language=language, + model=_model_for_provider(fallback_provider_name, model), + keywords=keywords, + skip_n_seconds=skip_n_seconds, + raw_audio_seconds=raw_audio_seconds, + fallback_from_provider=provider_name.value, + fallback_reason=fallback_reason, + ) + raise RuntimeError(f'{provider_name.value} transcription failed after 2 attempts: {e}') + + +def transcribe_bytes( + audio_bytes: bytes, + workload: STTWorkload, + uid: Optional[str] = None, + conversation_id: Optional[str] = None, + sample_rate: int = 16000, + diarize: bool = True, + encoding: Optional[str] = None, + channels: int = 1, + language: Optional[str] = None, + model: str = 'nova-3', + return_language: bool = False, + keywords: Optional[Sequence[str]] = None, + skip_n_seconds: int = 0, + raw_audio_seconds: float = 0.0, +) -> PrerecordedTranscriptionResponse: + workload = STTWorkload(workload) + provider_name = resolve_prerecorded_provider_for_request(workload) + model = _model_for_provider(provider_name, model) + provider = _get_prerecorded_provider(provider_name) + started_at = datetime.now(timezone.utc) + run_id = _create_run(uid, provider_name.value, model, workload.value, conversation_id, started_at) + + try: + provider_result, detected_language, retry_count = _transcribe_bytes_with_retry( + provider, + audio_bytes, + sample_rate=sample_rate, + diarize=diarize, + encoding=encoding, + channels=channels, + language=language, + model=model, + return_language=return_language, + keywords=keywords, + ) + segments = reconstruct_conversation(provider_result, skip_n_seconds=skip_n_seconds) + words = provider_result_to_legacy_words(provider_result) + _finalize_run( + run_id, + provider_result, + workload, + started_at, + 'succeeded', + retry_count=retry_count, + raw_audio_seconds=raw_audio_seconds or provider_result.duration or 0.0, + segments=segments, + ) + return PrerecordedTranscriptionResponse( + result=provider_result, + detected_language=detected_language, + segments=segments, + words=words, + run_id=run_id, + ) + except Exception as e: + _finalize_failed_run( + run_id, + provider_name.value, + model, + workload.value, + started_at, + _provider_error_from_exception(e), + raw_audio_seconds, + retry_count=_retry_count_from_exception(e), + ) + fallback_provider_name = _resolve_usable_fallback_prerecorded_provider(provider_name, workload) + if fallback_provider_name: + fallback_reason = _fallback_reason_from_exception(e) + logger.warning( + 'provider prerecorded bytes transcription falling back workload=%s from_provider=%s to_provider=%s reason=%s: %s', + workload.value, + provider_name.value, + fallback_provider_name.value, + fallback_reason, + e, + ) + return _transcribe_bytes_with_provider( + fallback_provider_name, + audio_bytes, + workload, + uid=uid, + conversation_id=conversation_id, + sample_rate=sample_rate, + diarize=diarize, + encoding=encoding, + channels=channels, + language=language, + model=_model_for_provider(fallback_provider_name, model), + return_language=return_language, + keywords=keywords, + skip_n_seconds=skip_n_seconds, + raw_audio_seconds=raw_audio_seconds, + fallback_from_provider=provider_name.value, + fallback_reason=fallback_reason, + ) + raise RuntimeError(f'{provider_name.value} transcription failed after 2 attempts: {e}') + + +def _get_prerecorded_provider(provider_name: STTProviderName): + if provider_name == STTProviderName.assemblyai: + return _assemblyai_prerecorded_provider() + if provider_name == STTProviderName.deepgram: + return _deepgram_prerecorded_provider() + raise ValueError(f'Unsupported prerecorded STT provider: {provider_name}') + + +def _resolve_usable_fallback_prerecorded_provider( + provider_name: STTProviderName, workload: STTWorkload +) -> Optional[STTProviderName]: + fallback_provider_name = get_fallback_prerecorded_provider_name(provider_name, workload) + if fallback_provider_name == STTProviderName.deepgram and not _has_deepgram_key_for_request(): + return None + return fallback_provider_name + + +def _model_for_provider(provider_name: STTProviderName, requested_model: str) -> str: + if provider_name == STTProviderName.assemblyai and str(requested_model or '').startswith('nova-'): + return os.getenv('ASSEMBLYAI_STT_MODEL', 'universal-2') + if provider_name == STTProviderName.deepgram and str(requested_model or '').startswith('universal-'): + return 'nova-3' + return requested_model + + +def _transcribe_url_with_provider( + provider_name: STTProviderName, + audio_url: str, + workload: STTWorkload, + uid: Optional[str] = None, + conversation_id: Optional[str] = None, + speakers_count: int = None, + return_language: bool = False, + diarize: bool = True, + language: Optional[str] = None, + model: str = 'nova-3', + keywords: Optional[Sequence[str]] = None, + skip_n_seconds: int = 0, + raw_audio_seconds: float = 0.0, + fallback_from_provider: Optional[str] = None, + fallback_reason: str = 'provider_failure', +) -> PrerecordedTranscriptionResponse: + provider = _get_prerecorded_provider(provider_name) + started_at = datetime.now(timezone.utc) + run_id = _create_run(uid, provider_name.value, model, workload.value, conversation_id, started_at) + try: + provider_result, detected_language, retry_count = _transcribe_url_with_retry( + provider, + audio_url, + speakers_count=speakers_count, + return_language=return_language, + diarize=diarize, + language=language, + model=model, + keywords=keywords, + ) + return _build_success_response( + run_id, + provider_result, + workload, + started_at, + retry_count, + raw_audio_seconds, + skip_n_seconds, + fallback_from_provider=fallback_from_provider, + fallback_reason=fallback_reason, + detected_language=detected_language, + ) + except Exception as e: + _finalize_failed_run( + run_id, + provider_name.value, + model, + workload.value, + started_at, + _provider_error_from_exception(e), + raw_audio_seconds, + retry_count=_retry_count_from_exception(e), + ) + raise RuntimeError(f'{provider_name.value} transcription failed after 2 attempts: {e}') + + +def _transcribe_bytes_with_provider( + provider_name: STTProviderName, + audio_bytes: bytes, + workload: STTWorkload, + uid: Optional[str] = None, + conversation_id: Optional[str] = None, + sample_rate: int = 16000, + diarize: bool = True, + encoding: Optional[str] = None, + channels: int = 1, + language: Optional[str] = None, + model: str = 'nova-3', + return_language: bool = False, + keywords: Optional[Sequence[str]] = None, + skip_n_seconds: int = 0, + raw_audio_seconds: float = 0.0, + fallback_from_provider: Optional[str] = None, + fallback_reason: str = 'provider_failure', +) -> PrerecordedTranscriptionResponse: + provider = _get_prerecorded_provider(provider_name) + started_at = datetime.now(timezone.utc) + run_id = _create_run(uid, provider_name.value, model, workload.value, conversation_id, started_at) + try: + provider_result, detected_language, retry_count = _transcribe_bytes_with_retry( + provider, + audio_bytes, + sample_rate=sample_rate, + diarize=diarize, + encoding=encoding, + channels=channels, + language=language, + model=model, + return_language=return_language, + keywords=keywords, + ) + return _build_success_response( + run_id, + provider_result, + workload, + started_at, + retry_count, + raw_audio_seconds, + skip_n_seconds, + fallback_from_provider=fallback_from_provider, + fallback_reason=fallback_reason, + detected_language=detected_language, + ) + except Exception as e: + _finalize_failed_run( + run_id, + provider_name.value, + model, + workload.value, + started_at, + _provider_error_from_exception(e), + raw_audio_seconds, + retry_count=_retry_count_from_exception(e), + ) + raise RuntimeError(f'{provider_name.value} transcription failed after 2 attempts: {e}') + + +def _build_success_response( + run_id: Optional[str], + provider_result: ProviderTranscriptResult, + workload: STTWorkload, + started_at: datetime, + retry_count: int, + raw_audio_seconds: float, + skip_n_seconds: int, + fallback_from_provider: Optional[str] = None, + fallback_reason: str = 'provider_failure', + detected_language: Optional[str] = None, +) -> PrerecordedTranscriptionResponse: + segments = reconstruct_conversation(provider_result, skip_n_seconds=skip_n_seconds) + words = provider_result_to_legacy_words(provider_result) + _finalize_run( + run_id, + provider_result, + workload, + started_at, + 'succeeded', + retry_count=retry_count, + raw_audio_seconds=raw_audio_seconds or provider_result.duration or 0.0, + segments=segments, + fallback_count=1 if fallback_from_provider else 0, + fallback_provider=fallback_from_provider, + fallback_reason=fallback_reason, + ) + return PrerecordedTranscriptionResponse( + result=provider_result, + detected_language=detected_language, + segments=segments, + words=words, + run_id=run_id, + ) + + +def _transcribe_url_with_retry( + provider, audio_url: str, **kwargs +) -> Tuple[ProviderTranscriptResult, Optional[str], int]: + last_error = None + max_attempts = 2 + for attempt in range(max_attempts): + try: + result = provider.transcribe_url(audio_url, **kwargs) + transcript_result, detected_language = _unpack_provider_result(result, kwargs.get('return_language')) + return transcript_result, detected_language, attempt + except Exception as e: + last_error = e + logger.error( + 'provider prerecorded url transcription error attempt=%s provider=%s: %s', + attempt, + provider.provider_name, + e, + ) + raise ProviderTranscriptionRetriesExhausted(last_error, max_attempts - 1) + + +def _transcribe_bytes_with_retry( + provider, audio_bytes: bytes, **kwargs +) -> Tuple[ProviderTranscriptResult, Optional[str], int]: + last_error = None + max_attempts = 2 + for attempt in range(max_attempts): + try: + result = provider.transcribe_bytes(audio_bytes, **kwargs) + transcript_result, detected_language = _unpack_provider_result(result, kwargs.get('return_language')) + return transcript_result, detected_language, attempt + except Exception as e: + last_error = e + logger.error( + 'provider prerecorded bytes transcription error attempt=%s provider=%s: %s', + attempt, + provider.provider_name, + e, + ) + raise ProviderTranscriptionRetriesExhausted(last_error, max_attempts - 1) + + +def _retry_count_from_exception(error: Exception) -> int: + if isinstance(error, ProviderTranscriptionRetriesExhausted): + return error.retry_count + return 0 + + +def _provider_error_from_exception(error: Exception) -> Exception: + if isinstance(error, ProviderTranscriptionRetriesExhausted): + return error.provider_error + return error + + +def _fallback_reason_from_exception(error: Exception) -> str: + provider_error = _provider_error_from_exception(error) + if isinstance(provider_error, (AssemblyAITimeoutError, httpx.TimeoutException, TimeoutError)): + return 'provider_timeout' + return 'provider_failure' + + +def _unpack_provider_result(result, return_language: bool) -> Tuple[ProviderTranscriptResult, Optional[str]]: + if return_language: + transcript_result, detected_language = result + return transcript_result, detected_language + transcript_result = result + return transcript_result, transcript_result.language + + +def _create_run( + uid: Optional[str], + provider: str, + model: str, + workload: str, + conversation_id: Optional[str], + started_at: datetime, +) -> Optional[str]: + if not uid: + return None + try: + return create_provider_run( + uid=uid, + provider=provider, + model=model, + workload=workload, + conversation_id=conversation_id, + started_at=started_at, + ) + except Exception as e: + logger.warning('failed to create transcription provider run ledger uid=%s workload=%s: %s', uid, workload, e) + return None + + +def _finalize_run( + run_id: Optional[str], + result: ProviderTranscriptResult, + workload: STTWorkload, + started_at: datetime, + status: str, + retry_count: int, + raw_audio_seconds: float, + segments: List[TranscriptSegment], + fallback_count: int = 0, + fallback_provider: Optional[str] = None, + fallback_reason: str = 'provider_failure', +) -> None: + if not run_id: + return + billable_seconds = raw_audio_seconds + clusters = {_segment_cluster_key(segment) for segment in segments if _segment_cluster_key(segment)} + confidences = [segment.speaker_identity_confidence for segment in segments] + identity_metrics = speaker_identity_metrics(segments) + unknown_speaker_duration_seconds = _unknown_speaker_duration_seconds(segments) + try: + finalize_provider_run( + run_id=run_id, + provider=result.provider, + model=result.model or 'unknown', + workload=workload.value, + status=status, + started_at=started_at, + raw_audio_seconds=raw_audio_seconds, + speech_active_seconds=raw_audio_seconds, + billable_seconds=billable_seconds, + chunk_duration_seconds=raw_audio_seconds, + estimated_cost_usd=estimate_prerecorded_provider_cost_usd( + provider=result.provider, + model=result.model, + workload=workload.value, + billable_seconds=billable_seconds, + ), + retry_count=retry_count, + fallback_count=fallback_count, + transcript_segment_count=len(segments), + transcript_word_count=len(result.words), + speaker_cluster_count=len(clusters), + identified_speaker_cluster_count=_identified_cluster_count(segments), + identity_match_count=identity_metrics['mapped_speaker_count'], + identity_confidence_summary=summarize_identity_confidences(confidences), + provider_speaker_count=identity_metrics['provider_speaker_count'], + mapped_speaker_count=identity_metrics['mapped_speaker_count'], + mapped_person_count=identity_metrics['mapped_person_count'], + unmapped_speaker_count=identity_metrics['unmapped_speaker_count'], + unknown_speaker_count=identity_metrics['unmapped_speaker_count'], + unknown_speaker_duration_seconds=unknown_speaker_duration_seconds, + split_count=_split_count(clusters), + embedding_extraction_failure_count=identity_metrics['embedding_extraction_failure_count'], + artifact_refs=_provider_artifact_refs(result), + fallback_provider=fallback_provider, + fallback_reason=fallback_reason, + ) + except Exception as e: + logger.warning('failed to finalize transcription provider run ledger run_id=%s: %s', run_id, e) + + +def _provider_artifact_refs(result: ProviderTranscriptResult) -> dict[str, str]: + if not result.raw_provider_result_id: + return {} + return {'provider_result_id': result.raw_provider_result_id} + + +def _finalize_failed_run( + run_id: Optional[str], + provider: str, + model: str, + workload: str, + started_at: datetime, + error: Exception, + raw_audio_seconds: float, + retry_count: int = 0, +) -> None: + if not run_id: + return + billable_seconds = raw_audio_seconds + try: + finalize_provider_run( + run_id=run_id, + provider=provider, + model=model, + workload=workload, + status='failed', + started_at=started_at, + raw_audio_seconds=raw_audio_seconds, + speech_active_seconds=raw_audio_seconds, + billable_seconds=billable_seconds, + chunk_duration_seconds=raw_audio_seconds, + estimated_cost_usd=estimate_prerecorded_provider_cost_usd( + provider=provider, + model=model, + workload=workload, + billable_seconds=billable_seconds, + ), + retry_count=retry_count, + fallback_count=0, + error_class=error.__class__.__name__, + ) + except Exception as finalize_error: + logger.warning( + 'failed to finalize failed transcription provider run ledger run_id=%s: %s', run_id, finalize_error + ) + + +def update_provider_run_identity_metrics( + run_id: Optional[str], + provider: str, + model: str, + workload: STTWorkload, + segments: List[TranscriptSegment], + identity_metric_update_status: str = 'succeeded', + identity_metric_update_skipped_reason: Optional[str] = None, +) -> None: + if not run_id: + return + if _db_update_provider_run_identity_metrics is None: + logger.warning( + 'failed to update transcription provider identity metrics run_id=%s: %s', + run_id, + _PROVIDER_USAGE_IMPORT_ERROR, + ) + return + try: + identity_metrics = speaker_identity_metrics(segments) + _db_update_provider_run_identity_metrics( + run_id=run_id, + provider=provider, + model=model or 'unknown', + workload=STTWorkload(workload).value, + identified_speaker_cluster_count=_identified_cluster_count(segments), + identity_confidence_summary=summarize_identity_confidences( + [segment.speaker_identity_confidence for segment in segments] + ), + provider_speaker_count=identity_metrics['provider_speaker_count'], + mapped_speaker_count=identity_metrics['mapped_speaker_count'], + mapped_person_count=identity_metrics['mapped_person_count'], + unmapped_speaker_count=identity_metrics['unmapped_speaker_count'], + embedding_extraction_failure_count=identity_metrics['embedding_extraction_failure_count'], + identity_metric_update_status=identity_metric_update_status, + identity_metric_update_skipped_reason=identity_metric_update_skipped_reason, + ) + except Exception as e: + logger.warning('failed to update transcription provider identity metrics run_id=%s: %s', run_id, e) + + +def speaker_identity_metrics(segments: List[TranscriptSegment]) -> dict: + provider_speakers = {_segment_cluster_key(segment) for segment in segments if _segment_cluster_key(segment)} + mapped_speakers = { + _segment_cluster_key(segment) + for segment in segments + if _segment_cluster_key(segment) + and (segment.person_id or segment.is_user or segment.speaker_identity_state in ('identified', 'user')) + } + mapped_people = { + segment.person_id or 'user' + for segment in segments + if segment.person_id or segment.is_user or segment.speaker_identity_state == 'user' + } + embedding_failures = { + _segment_cluster_key(segment) + for segment in segments + if _segment_cluster_key(segment) + and (segment.speaker_identity_provenance or {}).get('reason') == 'embedding_extraction_failed' + } + return { + 'provider_speaker_count': len(provider_speakers), + 'mapped_speaker_count': len(mapped_speakers), + 'mapped_person_count': len(mapped_people), + 'unmapped_speaker_count': max(len(provider_speakers) - len(mapped_speakers), 0), + 'embedding_extraction_failure_count': len(embedding_failures), + } + + +def _unknown_speaker_duration_seconds(segments: List[TranscriptSegment]) -> float: + duration = 0.0 + for segment in segments: + if segment.speaker_identity_state in _UNKNOWN_SPEAKER_STATES: + duration += max(0.0, segment.end - segment.start) + return round(duration, 3) + + +def _split_count(clusters: set[str]) -> int: + return sum(1 for cluster in clusters if _LOCAL_CLUSTER_SPLIT_MARKER in str(cluster)) + + +def _identified_cluster_count(segments: List[TranscriptSegment]) -> int: + return len( + { + _segment_cluster_key(segment) + for segment in segments + if _segment_cluster_key(segment) + and (segment.person_id or segment.is_user or segment.speaker_identity_state in ('identified', 'user')) + } + ) + + +def _segment_cluster_key(segment: TranscriptSegment) -> Optional[str]: + return segment.provider_cluster_id or segment.provider_speaker_label diff --git a/backend/utils/stt/providers.py b/backend/utils/stt/providers.py new file mode 100644 index 00000000000..e1c04db1833 --- /dev/null +++ b/backend/utils/stt/providers.py @@ -0,0 +1,165 @@ +import os +from enum import Enum +from typing import Callable, List, Optional, Protocol, Sequence, Tuple, Union + +from models.transcript_segment import ProviderTranscriptResult + + +class STTProviderName(str, Enum): + assemblyai = 'assemblyai' + deepgram = 'deepgram' + + +class STTWorkload(str, Enum): + background = 'background' + postprocess = 'postprocess' + ptt = 'ptt' + realtime = 'realtime' + sync = 'sync' + voice_message = 'voice_message' + + +class BackgroundProviderMode(str, Enum): + assemblyai = 'assemblyai' + deepgram = 'deepgram' + shadow_only = 'shadow_only' + + +class PrerecordedSTTProvider(Protocol): + provider_name: STTProviderName + + def transcribe_url( + self, + audio_url: str, + speakers_count: int = None, + return_language: bool = False, + diarize: bool = True, + language: Optional[str] = None, + model: str = 'nova-3', + keywords: Optional[Sequence[str]] = None, + ) -> Union[ProviderTranscriptResult, Tuple[ProviderTranscriptResult, str]]: ... + + def transcribe_bytes( + self, + audio_bytes: bytes, + sample_rate: int = 16000, + diarize: bool = True, + encoding: Optional[str] = None, + channels: int = 1, + language: Optional[str] = None, + model: str = 'nova-3', + return_language: bool = False, + keywords: Optional[Sequence[str]] = None, + ) -> Union[ProviderTranscriptResult, Tuple[ProviderTranscriptResult, str]]: ... + + +class StreamingSTTProvider(Protocol): + provider_name: STTProviderName + + async def connect_stream( + self, + stream_transcript, + language: str, + sample_rate: int, + channels: int, + model: str = 'nova-3', + keywords: List[str] = [], + vad_gate=None, + is_active: Optional[Callable[[], bool]] = None, + ): ... + + +class DiarizationProvider(Protocol): + provider_name: STTProviderName + + +class SpeakerIdentityProvider(Protocol): + provider_name: str + + +_DEFAULT_PRERECORDED_WORKLOAD_PROVIDERS = { + STTWorkload.background: STTProviderName.deepgram, + STTWorkload.postprocess: STTProviderName.deepgram, + STTWorkload.ptt: STTProviderName.deepgram, + STTWorkload.sync: STTProviderName.deepgram, + STTWorkload.voice_message: STTProviderName.deepgram, +} + +_ASSEMBLYAI_ELIGIBLE_WORKLOADS = { + STTWorkload.background, + STTWorkload.postprocess, + STTWorkload.sync, +} + +_STREAMING_WORKLOAD_PROVIDERS = { + STTWorkload.ptt: STTProviderName.deepgram, + STTWorkload.realtime: STTProviderName.deepgram, +} + + +def get_prerecorded_provider_name(workload: STTWorkload) -> STTProviderName: + workload = STTWorkload(workload) + if workload == STTWorkload.background: + if ( + get_background_provider_mode() == BackgroundProviderMode.assemblyai + and _assemblyai_prerecorded_enabled() + and workload in _assemblyai_enabled_workloads() + ): + return STTProviderName.assemblyai + return STTProviderName.deepgram + if _assemblyai_prerecorded_enabled() and workload in _assemblyai_enabled_workloads(): + return STTProviderName.assemblyai + return _DEFAULT_PRERECORDED_WORKLOAD_PROVIDERS[workload] + + +def get_streaming_provider_name(workload: STTWorkload) -> STTProviderName: + return _STREAMING_WORKLOAD_PROVIDERS[STTWorkload(workload)] + + +def get_fallback_prerecorded_provider_name( + provider: STTProviderName, workload: STTWorkload +) -> Optional[STTProviderName]: + workload = STTWorkload(workload) + provider = STTProviderName(provider) + if ( + workload in _ASSEMBLYAI_ELIGIBLE_WORKLOADS + and provider == STTProviderName.assemblyai + and not assemblyai_prerecorded_fallback_enabled() + ): + return None + fallback = _DEFAULT_PRERECORDED_WORKLOAD_PROVIDERS[workload] + if provider != fallback: + return fallback + return None + + +def _assemblyai_prerecorded_enabled() -> bool: + return os.getenv('ASSEMBLYAI_PRERECORDED_STT_ENABLED', 'true').lower() == 'true' + + +def assemblyai_prerecorded_fallback_enabled() -> bool: + return os.getenv('ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED', 'true').lower() == 'true' + + +def get_background_provider_mode() -> BackgroundProviderMode: + configured = os.getenv('ASSEMBLYAI_BACKGROUND_PROVIDER_MODE', BackgroundProviderMode.assemblyai.value) + try: + return BackgroundProviderMode(configured.strip().lower()) + except ValueError: + return BackgroundProviderMode.shadow_only + + +def _assemblyai_enabled_workloads() -> set[STTWorkload]: + configured = os.getenv('ASSEMBLYAI_PRERECORDED_STT_WORKLOADS', 'sync,background,postprocess') + workloads = set() + for raw_value in configured.split(','): + value = raw_value.strip() + if not value: + continue + try: + workload = STTWorkload(value) + except ValueError: + continue + if workload in _ASSEMBLYAI_ELIGIBLE_WORKLOADS: + workloads.add(workload) + return workloads diff --git a/backend/utils/stt/streaming.py b/backend/utils/stt/streaming.py index 6fa923e7c39..e354895f4d9 100644 --- a/backend/utils/stt/streaming.py +++ b/backend/utils/stt/streaming.py @@ -11,6 +11,7 @@ from utils.byok import get_byok_key from utils.executors import sync_executor, run_blocking from utils.stt.safe_socket import KeepaliveConfig, SafeDeepgramSocket # noqa: F401 — re-exported for backward compat +from utils.stt.providers import STTProviderName from utils.stt.vad_gate import GatedDeepgramSocket import logging @@ -278,6 +279,32 @@ def on_dg_error(self, error, **kwargs): return safe_conn +class DeepgramStreamingTranscriptionProvider: + provider_name = STTProviderName.deepgram + + async def connect_stream( + self, + stream_transcript, + language: str, + sample_rate: int, + channels: int, + model: str = 'nova-3', + keywords: List[str] = [], + vad_gate=None, + is_active: Optional[Callable[[], bool]] = None, + ): + return await process_audio_dg( + stream_transcript, + language, + sample_rate, + channels, + model=model, + keywords=keywords, + vad_gate=vad_gate, + is_active=is_active, + ) + + # Calculate backoff with jitter def calculate_backoff_with_jitter(attempt, base_delay=1000, max_delay=32000): jitter = random.random() * base_delay diff --git a/backend/utils/subscription.py b/backend/utils/subscription.py index 0b11ddaf171..2ba1fc4b48a 100644 --- a/backend/utils/subscription.py +++ b/backend/utils/subscription.py @@ -33,6 +33,10 @@ # Anything else (ios, android, omi device, phone_call, unknown) is exempt. _TRIAL_PAYWALL_DESKTOP_TOKENS = {"macos", "desktop"} +# Emergency kill switch and optional staged rollout allowlist. +_TRIAL_PAYWALL_ENABLED = os.getenv("TRIAL_PAYWALL_ENABLED", "true").lower() not in {"0", "false", "no"} +_TRIAL_PAYWALL_TEST_UIDS = {uid.strip() for uid in os.getenv("TRIAL_PAYWALL_TEST_UIDS", "").split(",") if uid.strip()} + # Cache the (slow) Firebase Auth + Firestore lookup result for a few minutes # so chat-quota polling doesn't fan out to Firebase on every request. _TRIAL_PAYWALL_CACHE_TTL_SECONDS = 300 @@ -113,6 +117,10 @@ def is_trial_paywalled(uid: str, platform: Optional[str]) -> bool: `source` query param for the listen WebSocket. Mobile (ios/android), Omi devices, and any unknown/missing platform are never paywalled. """ + if not _TRIAL_PAYWALL_ENABLED: + return False + if _TRIAL_PAYWALL_TEST_UIDS and uid not in _TRIAL_PAYWALL_TEST_UIDS: + return False if not platform or platform.lower() not in _TRIAL_PAYWALL_DESKTOP_TOKENS: return False return _is_trial_expired_cached(uid) diff --git a/desktop/.gitignore b/desktop/.gitignore index 3106c399c94..6c2b91f915f 100644 --- a/desktop/.gitignore +++ b/desktop/.gitignore @@ -1,6 +1,7 @@ # Build artifacts .build/ build/ +target/ DerivedData/ *.xcodeproj/xcuserdata/ *.xcworkspace/xcuserdata/ diff --git a/desktop/Backend-Rust/src/byok.rs b/desktop/Backend-Rust/src/byok.rs index 3f60e26adc0..7b72ed76bf0 100644 --- a/desktop/Backend-Rust/src/byok.rs +++ b/desktop/Backend-Rust/src/byok.rs @@ -25,6 +25,7 @@ pub const HEADER_OPENAI: &str = "x-byok-openai"; pub const HEADER_ANTHROPIC: &str = "x-byok-anthropic"; pub const HEADER_GEMINI: &str = "x-byok-gemini"; pub const HEADER_DEEPGRAM: &str = "x-byok-deepgram"; +pub const HEADER_ASSEMBLYAI: &str = "x-byok-assemblyai"; /// All four required BYOK headers. Python's `_request_has_all_byok_keys()` checks /// the same set — a fully enrolled BYOK user sends all four on every request. @@ -41,6 +42,7 @@ const HEADER_TO_PROVIDER: &[(&str, &str)] = &[ (HEADER_ANTHROPIC, "anthropic"), (HEADER_GEMINI, "gemini"), (HEADER_DEEPGRAM, "deepgram"), + (HEADER_ASSEMBLYAI, "assemblyai"), ]; /// Heartbeat TTL: BYOK is considered inactive if last_seen_at is older than this. diff --git a/desktop/Backend-Rust/src/models/conversation.rs b/desktop/Backend-Rust/src/models/conversation.rs index 511da1ae8d2..c6d508c9583 100644 --- a/desktop/Backend-Rust/src/models/conversation.rs +++ b/desktop/Backend-Rust/src/models/conversation.rs @@ -12,10 +12,10 @@ pub struct TranscriptSegment { #[serde(default)] pub id: Option, pub text: String, - #[serde(default = "default_speaker")] - pub speaker: String, #[serde(default)] - pub speaker_id: i32, + pub speaker: Option, + #[serde(default)] + pub speaker_id: Option, #[serde(default)] pub is_user: bool, #[serde(default)] @@ -24,10 +24,22 @@ pub struct TranscriptSegment { pub start: f64, #[serde(default)] pub end: f64, -} - -fn default_speaker() -> String { - "SPEAKER_00".to_string() + #[serde(default)] + pub stt_provider: Option, + #[serde(default)] + pub stt_model: Option, + #[serde(default)] + pub provider_cluster_id: Option, + #[serde(default)] + pub provider_speaker_label: Option, + #[serde(default)] + pub speaker_identity_state: Option, + #[serde(default)] + pub speaker_identity_confidence: Option, + #[serde(default)] + pub speaker_identity_source: Option, + #[serde(default)] + pub speaker_identity_version: Option, } impl TranscriptSegment { @@ -40,7 +52,10 @@ impl TranscriptSegment { let speaker_name = if segment.is_user { "User".to_string() } else { - format!("Speaker {}", segment.speaker_id) + match segment.speaker_id { + Some(speaker_id) => format!("Speaker {}", speaker_id), + None => "Speaker ?".to_string(), + } }; format!("{}: {}", speaker_name, segment.text) }) diff --git a/desktop/Backend-Rust/src/services/firestore.rs b/desktop/Backend-Rust/src/services/firestore.rs index 40d29caed40..31e98ceb983 100644 --- a/desktop/Backend-Rust/src/services/firestore.rs +++ b/desktop/Backend-Rust/src/services/firestore.rs @@ -4254,11 +4254,10 @@ impl FirestoreService { text: seg.get("text")?.as_str()?.to_string(), speaker: seg.get("speaker") .and_then(|s| s.as_str()) - .unwrap_or("SPEAKER_00") - .to_string(), + .map(|s| s.to_string()), speaker_id: seg.get("speaker_id") .and_then(|s| s.as_i64()) - .unwrap_or(0) as i32, + .map(|s| s as i32), is_user: seg.get("is_user") .and_then(|s| s.as_bool()) .unwrap_or(false), @@ -4271,6 +4270,14 @@ impl FirestoreService { end: seg.get("end") .and_then(|s| s.as_f64()) .unwrap_or(0.0), + stt_provider: seg.get("stt_provider").and_then(|s| s.as_str()).map(|s| s.to_string()), + stt_model: seg.get("stt_model").and_then(|s| s.as_str()).map(|s| s.to_string()), + provider_cluster_id: seg.get("provider_cluster_id").and_then(|s| s.as_str()).map(|s| s.to_string()), + provider_speaker_label: seg.get("provider_speaker_label").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_state: seg.get("speaker_identity_state").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_confidence: seg.get("speaker_identity_confidence").and_then(|s| s.as_f64()), + speaker_identity_source: seg.get("speaker_identity_source").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_version: seg.get("speaker_identity_version").and_then(|s| s.as_str()).map(|s| s.to_string()), }) }) .collect(); @@ -4302,11 +4309,10 @@ impl FirestoreService { text: seg.get("text")?.as_str()?.to_string(), speaker: seg.get("speaker") .and_then(|s| s.as_str()) - .unwrap_or("SPEAKER_00") - .to_string(), + .map(|s| s.to_string()), speaker_id: seg.get("speaker_id") .and_then(|s| s.as_i64()) - .unwrap_or(0) as i32, + .map(|s| s as i32), is_user: seg.get("is_user") .and_then(|s| s.as_bool()) .unwrap_or(false), @@ -4319,6 +4325,14 @@ impl FirestoreService { end: seg.get("end") .and_then(|s| s.as_f64()) .unwrap_or(0.0), + stt_provider: seg.get("stt_provider").and_then(|s| s.as_str()).map(|s| s.to_string()), + stt_model: seg.get("stt_model").and_then(|s| s.as_str()).map(|s| s.to_string()), + provider_cluster_id: seg.get("provider_cluster_id").and_then(|s| s.as_str()).map(|s| s.to_string()), + provider_speaker_label: seg.get("provider_speaker_label").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_state: seg.get("speaker_identity_state").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_confidence: seg.get("speaker_identity_confidence").and_then(|s| s.as_f64()), + speaker_identity_source: seg.get("speaker_identity_source").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_version: seg.get("speaker_identity_version").and_then(|s| s.as_str()).map(|s| s.to_string()), }) }) .collect(); @@ -4374,12 +4388,20 @@ impl FirestoreService { Some(TranscriptSegment { id: self.parse_string(seg_fields, "id"), text: self.parse_string(seg_fields, "text").unwrap_or_default(), - speaker: self.parse_string(seg_fields, "speaker").unwrap_or_else(|| "SPEAKER_00".to_string()), - speaker_id: self.parse_int(seg_fields, "speaker_id").unwrap_or(0), + speaker: self.parse_string(seg_fields, "speaker"), + speaker_id: self.parse_int(seg_fields, "speaker_id"), is_user: self.parse_bool(seg_fields, "is_user").unwrap_or(false), person_id: self.parse_string(seg_fields, "person_id"), start: self.parse_float(seg_fields, "start").unwrap_or(0.0), end: self.parse_float(seg_fields, "end").unwrap_or(0.0), + stt_provider: self.parse_string(seg_fields, "stt_provider"), + stt_model: self.parse_string(seg_fields, "stt_model"), + provider_cluster_id: self.parse_string(seg_fields, "provider_cluster_id"), + provider_speaker_label: self.parse_string(seg_fields, "provider_speaker_label"), + speaker_identity_state: self.parse_string(seg_fields, "speaker_identity_state"), + speaker_identity_confidence: self.parse_float(seg_fields, "speaker_identity_confidence"), + speaker_identity_source: self.parse_string(seg_fields, "speaker_identity_source"), + speaker_identity_version: self.parse_string(seg_fields, "speaker_identity_version"), }) }) .collect()) @@ -4422,12 +4444,11 @@ impl FirestoreService { speaker: seg .get("speaker") .and_then(|s| s.as_str()) - .unwrap_or("SPEAKER_00") - .to_string(), + .map(|s| s.to_string()), speaker_id: seg .get("speaker_id") .and_then(|s| s.as_i64()) - .unwrap_or(0) as i32, + .map(|s| s as i32), is_user: seg .get("is_user") .and_then(|s| s.as_bool()) @@ -4444,6 +4465,14 @@ impl FirestoreService { .get("end") .and_then(|s| s.as_f64()) .unwrap_or(0.0), + stt_provider: seg.get("stt_provider").and_then(|s| s.as_str()).map(|s| s.to_string()), + stt_model: seg.get("stt_model").and_then(|s| s.as_str()).map(|s| s.to_string()), + provider_cluster_id: seg.get("provider_cluster_id").and_then(|s| s.as_str()).map(|s| s.to_string()), + provider_speaker_label: seg.get("provider_speaker_label").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_state: seg.get("speaker_identity_state").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_confidence: seg.get("speaker_identity_confidence").and_then(|s| s.as_f64()), + speaker_identity_source: seg.get("speaker_identity_source").and_then(|s| s.as_str()).map(|s| s.to_string()), + speaker_identity_version: seg.get("speaker_identity_version").and_then(|s| s.as_str()).map(|s| s.to_string()), }) }) .collect()) @@ -4577,6 +4606,30 @@ impl FirestoreService { if let Some(person_id) = &seg.person_id { segment["person_id"] = json!(person_id); } + if let Some(stt_provider) = &seg.stt_provider { + segment["stt_provider"] = json!(stt_provider); + } + if let Some(stt_model) = &seg.stt_model { + segment["stt_model"] = json!(stt_model); + } + if let Some(provider_cluster_id) = &seg.provider_cluster_id { + segment["provider_cluster_id"] = json!(provider_cluster_id); + } + if let Some(provider_speaker_label) = &seg.provider_speaker_label { + segment["provider_speaker_label"] = json!(provider_speaker_label); + } + if let Some(speaker_identity_state) = &seg.speaker_identity_state { + segment["speaker_identity_state"] = json!(speaker_identity_state); + } + if let Some(speaker_identity_confidence) = seg.speaker_identity_confidence { + segment["speaker_identity_confidence"] = json!(speaker_identity_confidence); + } + if let Some(speaker_identity_source) = &seg.speaker_identity_source { + segment["speaker_identity_source"] = json!(speaker_identity_source); + } + if let Some(speaker_identity_version) = &seg.speaker_identity_version { + segment["speaker_identity_version"] = json!(speaker_identity_version); + } segment }).collect(); let json_str = serde_json::to_string(&segments_json).unwrap_or_else(|_| "[]".to_string()); @@ -9267,8 +9320,6 @@ impl FirestoreService { .map(|seg| { let mut fields = json!({ "text": {"stringValue": seg.text}, - "speaker": {"stringValue": seg.speaker}, - "speaker_id": {"integerValue": seg.speaker_id.to_string()}, "is_user": {"booleanValue": seg.is_user}, "start": {"doubleValue": seg.start}, "end": {"doubleValue": seg.end} @@ -9276,9 +9327,39 @@ impl FirestoreService { if let Some(ref id) = seg.id { fields["id"] = json!({"stringValue": id}); } + if let Some(ref speaker) = seg.speaker { + fields["speaker"] = json!({"stringValue": speaker}); + } + if let Some(speaker_id) = seg.speaker_id { + fields["speaker_id"] = json!({"integerValue": speaker_id.to_string()}); + } if let Some(ref pid) = seg.person_id { fields["person_id"] = json!({"stringValue": pid}); } + if let Some(ref stt_provider) = seg.stt_provider { + fields["stt_provider"] = json!({"stringValue": stt_provider}); + } + if let Some(ref stt_model) = seg.stt_model { + fields["stt_model"] = json!({"stringValue": stt_model}); + } + if let Some(ref provider_cluster_id) = seg.provider_cluster_id { + fields["provider_cluster_id"] = json!({"stringValue": provider_cluster_id}); + } + if let Some(ref provider_speaker_label) = seg.provider_speaker_label { + fields["provider_speaker_label"] = json!({"stringValue": provider_speaker_label}); + } + if let Some(ref speaker_identity_state) = seg.speaker_identity_state { + fields["speaker_identity_state"] = json!({"stringValue": speaker_identity_state}); + } + if let Some(speaker_identity_confidence) = seg.speaker_identity_confidence { + fields["speaker_identity_confidence"] = json!({"doubleValue": speaker_identity_confidence}); + } + if let Some(ref speaker_identity_source) = seg.speaker_identity_source { + fields["speaker_identity_source"] = json!({"stringValue": speaker_identity_source}); + } + if let Some(ref speaker_identity_version) = seg.speaker_identity_version { + fields["speaker_identity_version"] = json!({"stringValue": speaker_identity_version}); + } json!({"mapValue": {"fields": fields}}) }) .collect(); @@ -9970,6 +10051,85 @@ mod tests { assert_ne!(id, document_id_from_seed("different content")); } + fn test_firestore_service() -> FirestoreService { + FirestoreService { + client: Client::new(), + project_id: "test-project".to_string(), + credentials: None, + cached_token: Arc::new(RwLock::new(None)), + encryption_secret: None, + } + } + + #[test] + fn parse_plain_transcript_segments_preserves_explicit_unknown_identity() { + let service = test_firestore_service(); + let fields = json!({ + "transcript_segments": { + "arrayValue": { + "values": [{ + "mapValue": { + "fields": { + "id": {"stringValue": "seg-provider"}, + "text": {"stringValue": "Hello"}, + "is_user": {"booleanValue": false}, + "start": {"doubleValue": 0.0}, + "end": {"doubleValue": 1.0}, + "stt_provider": {"stringValue": "provider-a"}, + "stt_model": {"stringValue": "async-large"}, + "provider_cluster_id": {"stringValue": "speaker-alpha"}, + "speaker_identity_state": {"stringValue": "unknown"}, + "speaker_identity_confidence": {"doubleValue": 0.42}, + "speaker_identity_source": {"stringValue": "omi_speaker_embedding"}, + "speaker_identity_version": {"stringValue": "v1"} + } + } + }] + } + } + }); + + let segments = service.parse_transcript_segments(&fields, "uid").unwrap(); + assert_eq!(segments.len(), 1); + assert_eq!(segments[0].speaker, None); + assert_eq!(segments[0].speaker_id, None); + assert_eq!(segments[0].provider_cluster_id.as_deref(), Some("speaker-alpha")); + assert_eq!(segments[0].speaker_identity_state.as_deref(), Some("unknown")); + assert_eq!(segments[0].speaker_identity_confidence, Some(0.42)); + assert_eq!(segments[0].speaker_identity_source.as_deref(), Some("omi_speaker_embedding")); + assert_eq!(segments[0].speaker_identity_version.as_deref(), Some("v1")); + } + + #[test] + fn parse_plain_transcript_segments_preserves_legacy_speaker_fields() { + let service = test_firestore_service(); + let fields = json!({ + "transcript_segments": { + "arrayValue": { + "values": [{ + "mapValue": { + "fields": { + "id": {"stringValue": "seg-legacy"}, + "text": {"stringValue": "Hello"}, + "speaker": {"stringValue": "SPEAKER_00"}, + "speaker_id": {"integerValue": "0"}, + "is_user": {"booleanValue": false}, + "start": {"doubleValue": 0.0}, + "end": {"doubleValue": 1.0} + } + } + }] + } + } + }); + + let segments = service.parse_transcript_segments(&fields, "uid").unwrap(); + assert_eq!(segments.len(), 1); + assert_eq!(segments[0].speaker.as_deref(), Some("SPEAKER_00")); + assert_eq!(segments[0].speaker_id, Some(0)); + assert_eq!(segments[0].speaker_identity_state, None); + } + // --- Firestore BYOK state parsing tests --- #[test] diff --git a/desktop/CHANGELOG.json b/desktop/CHANGELOG.json index 07e41b60505..4743b2b3424 100644 --- a/desktop/CHANGELOG.json +++ b/desktop/CHANGELOG.json @@ -1,5 +1,7 @@ { - "unreleased": [], + "unreleased": [ + "Added optional AssemblyAI API key in Developer API Keys for BYOK async transcription (sync, background, postprocess) when enabled server-side; four-key free plan unchanged" + ], "releases": [ { "version": "0.11.419", diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 694f12f7644..997efbac8f6 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -826,6 +826,14 @@ struct TranscriptSegment: Codable, Identifiable { let start: Double let end: Double let translations: [TranscriptTranslation] + let sttProvider: String? + let sttModel: String? + let providerClusterId: String? + let providerSpeakerLabel: String? + let speakerIdentityState: String? + let speakerIdentityConfidence: Double? + let speakerIdentitySource: String? + let speakerIdentityVersion: String? var speakerId: Int { guard let speaker = speaker else { return 0 } @@ -836,10 +844,35 @@ struct TranscriptSegment: Codable, Identifiable { return 0 } + var hasExplicitUnknownSpeakerIdentity: Bool { + speakerIdentityState == "unknown" + } + + var displaySpeakerSuffix: String { + if speaker != nil { + return "\(speakerId)" + } + if let providerSpeakerLabel, !providerSpeakerLabel.isEmpty { + return providerSpeakerLabel + } + if let providerClusterId, !providerClusterId.isEmpty { + return providerClusterId + } + return "?" + } + enum CodingKeys: String, CodingKey { case id, text, speaker case isUser = "is_user" case personId = "person_id" + case sttProvider = "stt_provider" + case sttModel = "stt_model" + case providerClusterId = "provider_cluster_id" + case providerSpeakerLabel = "provider_speaker_label" + case speakerIdentityState = "speaker_identity_state" + case speakerIdentityConfidence = "speaker_identity_confidence" + case speakerIdentitySource = "speaker_identity_source" + case speakerIdentityVersion = "speaker_identity_version" case start, end, translations } @@ -856,6 +889,16 @@ struct TranscriptSegment: Codable, Identifiable { end = try container.decodeIfPresent(Double.self, forKey: .end) ?? 0 translations = try container.decodeIfPresent([TranscriptTranslation].self, forKey: .translations) ?? [] + sttProvider = try container.decodeIfPresent(String.self, forKey: .sttProvider) + sttModel = try container.decodeIfPresent(String.self, forKey: .sttModel) + providerClusterId = try container.decodeIfPresent(String.self, forKey: .providerClusterId) + providerSpeakerLabel = try container.decodeIfPresent(String.self, forKey: .providerSpeakerLabel) + speakerIdentityState = + try container.decodeIfPresent(String.self, forKey: .speakerIdentityState) + speakerIdentityConfidence = + try container.decodeIfPresent(Double.self, forKey: .speakerIdentityConfidence) + speakerIdentitySource = try container.decodeIfPresent(String.self, forKey: .speakerIdentitySource) + speakerIdentityVersion = try container.decodeIfPresent(String.self, forKey: .speakerIdentityVersion) } /// Memberwise initializer for creating from local storage @@ -868,7 +911,15 @@ struct TranscriptSegment: Codable, Identifiable { personId: String?, start: Double, end: Double, - translations: [TranscriptTranslation] = [] + translations: [TranscriptTranslation] = [], + sttProvider: String? = nil, + sttModel: String? = nil, + providerClusterId: String? = nil, + providerSpeakerLabel: String? = nil, + speakerIdentityState: String? = nil, + speakerIdentityConfidence: Double? = nil, + speakerIdentitySource: String? = nil, + speakerIdentityVersion: String? = nil ) { self.id = id self.backendId = backendId @@ -879,6 +930,14 @@ struct TranscriptSegment: Codable, Identifiable { self.start = start self.end = end self.translations = translations + self.sttProvider = sttProvider + self.sttModel = sttModel + self.providerClusterId = providerClusterId + self.providerSpeakerLabel = providerSpeakerLabel + self.speakerIdentityState = speakerIdentityState + self.speakerIdentityConfidence = speakerIdentityConfidence + self.speakerIdentitySource = speakerIdentitySource + self.speakerIdentityVersion = speakerIdentityVersion } /// Formatted timestamp string (e.g., "00:01:30 - 00:01:45") @@ -1225,6 +1284,10 @@ extension APIClient { let conversation: ServerConversation } + struct StartBackgroundConversationResponse: Decodable { + let conversation_id: String + } + /// Force-process the current in-progress conversation on the Python backend. /// Endpoint: POST /v1/conversations (Python backend) /// This is the same endpoint the mobile app uses when stopping phone mic recording. @@ -1246,6 +1309,76 @@ extension APIClient { return nil } } + + /// Start a Python-backed in-progress conversation for desktop background batch transcription. + /// Endpoint: POST /v2/desktop/background-conversation/start + func startBackgroundConversation(language: String) async throws -> String { + struct StartBackgroundConversationRequest: Encodable { + let language: String + } + + let response: StartBackgroundConversationResponse = try await post( + "v2/desktop/background-conversation/start", + body: StartBackgroundConversationRequest(language: language), + customBaseURL: nil + ) + return response.conversation_id + } + + /// Finish one explicit Python-backed desktop background batch conversation. + /// Endpoint: POST /v2/desktop/background-conversation/{conversation_id}/finish + func finishBackgroundConversation(conversationId: String) async throws -> ServerConversation { + try await post( + "v2/desktop/background-conversation/\(conversationId)/finish", + customBaseURL: nil + ) + } + + /// Fetch Python backend desktop transcription capabilities. + /// Endpoint: GET /v2/desktop/capabilities + func getDesktopCapabilities() async throws -> DesktopCapabilitiesResponse { + try await get("v2/desktop/capabilities", customBaseURL: nil) + } +} + +struct DesktopCapabilitiesResponse: Codable { + let backgroundBatch: DesktopBackgroundBatchCapability + + enum CodingKeys: String, CodingKey { + case backgroundBatch = "background_batch" + } +} + +struct DesktopBackgroundBatchCapability: Codable { + let enabled: Bool + let mode: String? + let provider: String + let primaryProvider: String? + let effectiveProvider: String? + let fallbackProvider: String? + let fallbackEnabled: Bool? + let fallbackAvailable: Bool? + let reason: String? + let sampleRate: Int + let channels: Int + let encoding: String + let maxChunkSeconds: Int + + enum CodingKeys: String, CodingKey { + case enabled + case mode + case provider + case primaryProvider = "primary_provider" + case effectiveProvider = "effective_provider" + case fallbackProvider = "fallback_provider" + case fallbackEnabled = "fallback_enabled" + case fallbackAvailable = "fallback_available" + case reason + case sampleRate = "sample_rate" + case channels + case encoding + case maxChunkSeconds = "max_chunk_seconds" + } } // MARK: - Memories API diff --git a/desktop/Desktop/Sources/APIKeyService.swift b/desktop/Desktop/Sources/APIKeyService.swift index 60f28f19a16..6934135a7a4 100644 --- a/desktop/Desktop/Sources/APIKeyService.swift +++ b/desktop/Desktop/Sources/APIKeyService.swift @@ -20,6 +20,10 @@ enum BYOKProvider: String, CaseIterable { case anthropic case gemini case deepgram + case assemblyai + + /// Providers required for the BYOK free plan (subscription bypass). + static let requiredForFreePlan: [BYOKProvider] = [.openai, .anthropic, .gemini, .deepgram] var storageKey: String { switch self { @@ -27,6 +31,7 @@ enum BYOKProvider: String, CaseIterable { case .anthropic: return "dev_anthropic_api_key" case .gemini: return "dev_gemini_api_key" case .deepgram: return "dev_deepgram_api_key" + case .assemblyai: return "dev_assemblyai_api_key" } } @@ -36,6 +41,7 @@ enum BYOKProvider: String, CaseIterable { case .anthropic: return "X-BYOK-Anthropic" case .gemini: return "X-BYOK-Gemini" case .deepgram: return "X-BYOK-Deepgram" + case .assemblyai: return "X-BYOK-AssemblyAI" } } @@ -45,8 +51,13 @@ enum BYOKProvider: String, CaseIterable { case .anthropic: return "Anthropic" case .gemini: return "Gemini" case .deepgram: return "Deepgram" + case .assemblyai: return "AssemblyAI" } } + + var isRequiredForFreePlan: Bool { + Self.requiredForFreePlan.contains(self) + } } @MainActor final class APIKeyService: ObservableObject { @@ -183,11 +194,22 @@ final class APIKeyService: ObservableObject { nonEmptyStatic(UserDefaults.standard.string(forKey: provider.storageKey)) } - /// True when the user has supplied keys for all four BYOK providers. + /// True when the user has supplied keys for all four required BYOK providers. /// The subscription-bypass gate: when this is true, the user is on the free /// plan and we attach their keys to every backend request. nonisolated static var isByokActive: Bool { - BYOKProvider.allCases.allSatisfy { byokKey($0) != nil } + BYOKProvider.requiredForFreePlan.allSatisfy { byokKey($0) != nil } + } + + /// Fingerprints to send on BYOK activation (required four + optional Assembly when set). + nonisolated static var byokActivationFingerprints: [String: String] { + var out: [String: String] = [:] + for provider in BYOKProvider.allCases { + if let key = byokKey(provider) { + out[provider.rawValue] = byokFingerprint(key) + } + } + return out } /// SHA-256 fingerprint of a key, used by the backend to detect when the diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index 4d9dbe92a77..dd4932647c3 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -15,13 +15,13 @@ struct SegmentTranslation: Identifiable { struct SpeakerSegment: Identifiable { /// Stable identity — uses backend segment ID when available, otherwise speaker + start time var id: String { segmentId ?? "\(speaker)-\(start)" } - var segmentId: String? // Backend-assigned UUID + var segmentId: String? // Backend-assigned UUID var speaker: Int var text: String var start: Double var end: Double var isUser: Bool = false - var personId: String? // Backend-assigned person ID from speaker identification + var personId: String? // Backend-assigned person ID from speaker identification var translations: [SegmentTranslation] = [] } @@ -42,6 +42,7 @@ class AppState: ObservableObject { // Transcription state @Published var isTranscribing = false + @Published private(set) var isStartingTranscription = false /// Monotonically increasing counter — incremented each time a new recording starts. /// Used to detect if a new recording began during the post-stop force-process delay. private(set) var recordingGeneration: UInt64 = 0 @@ -186,10 +187,10 @@ class AppState: ObservableObject { func fetchTrialMetadata() { #if DEBUG - if let debugMode = UserDefaults.standard.string(forKey: "debug_trial_mode") { - applyDebugTrialMode(debugMode) - return - } + if let debugMode = UserDefaults.standard.string(forKey: "debug_trial_mode") { + applyDebugTrialMode(debugMode) + return + } #endif Task { @MainActor in @@ -213,58 +214,63 @@ class AppState: ObservableObject { } #if DEBUG - private func applyDebugTrialMode(_ mode: String) { - let now = Int(Date().timeIntervalSince1970) - let features = ["unlimited_listening", "unlimited_transcription", "unlimited_memories", "unlimited_insights", "30_chat_questions_per_month"] - let dur = 3 * 24 * 3600 - - func mock(remaining: Int, expired: Bool) -> TrialMetadataResponse { - TrialMetadataResponse( - trialStartedAt: now - (dur - remaining), trialEndsAt: now + remaining, - trialRemainingSeconds: remaining, trialExpired: expired, - trialDurationSeconds: dur, trialFeatures: features, planAfterTrial: "Free" - ) - } + private func applyDebugTrialMode(_ mode: String) { + let now = Int(Date().timeIntervalSince1970) + let features = [ + "unlimited_listening", "unlimited_transcription", "unlimited_memories", + "unlimited_insights", "30_chat_questions_per_month", + ] + let dur = 3 * 24 * 3600 + + func mock(remaining: Int, expired: Bool) -> TrialMetadataResponse { + TrialMetadataResponse( + trialStartedAt: now - (dur - remaining), trialEndsAt: now + remaining, + trialRemainingSeconds: remaining, trialExpired: expired, + trialDurationSeconds: dur, trialFeatures: features, planAfterTrial: "Free" + ) + } - switch mode { - case "active": - self.trialMetadata = mock(remaining: 2 * 24 * 3600 + 3600, expired: false) - case "warning": - self.trialMetadata = mock(remaining: 12 * 3600, expired: false) - case "expiring": - self.trialMetadata = mock(remaining: 1800, expired: false) - case "expired": - self.trialMetadata = mock(remaining: 0, expired: true) - case "realtime": - let endKey = "debug_trial_end_time" - let rtDur = 120 - var endTime = UserDefaults.standard.integer(forKey: endKey) - if endTime == 0 { - endTime = now + rtDur - UserDefaults.standard.set(endTime, forKey: endKey) - } - let remaining = max(0, endTime - now) - self.trialMetadata = TrialMetadataResponse( - trialStartedAt: endTime - rtDur, trialEndsAt: endTime, - trialRemainingSeconds: remaining, trialExpired: remaining == 0, - trialDurationSeconds: rtDur, trialFeatures: features, planAfterTrial: "Free" - ) - if remaining == 0 && !self.isPaywalled { self.isPaywalled = true } - default: - break + switch mode { + case "active": + self.trialMetadata = mock(remaining: 2 * 24 * 3600 + 3600, expired: false) + case "warning": + self.trialMetadata = mock(remaining: 12 * 3600, expired: false) + case "expiring": + self.trialMetadata = mock(remaining: 1800, expired: false) + case "expired": + self.trialMetadata = mock(remaining: 0, expired: true) + case "realtime": + let endKey = "debug_trial_end_time" + let rtDur = 120 + var endTime = UserDefaults.standard.integer(forKey: endKey) + if endTime == 0 { + endTime = now + rtDur + UserDefaults.standard.set(endTime, forKey: endKey) + } + let remaining = max(0, endTime - now) + self.trialMetadata = TrialMetadataResponse( + trialStartedAt: endTime - rtDur, trialEndsAt: endTime, + trialRemainingSeconds: remaining, trialExpired: remaining == 0, + trialDurationSeconds: rtDur, trialFeatures: features, planAfterTrial: "Free" + ) + if remaining == 0 && !self.isPaywalled { self.isPaywalled = true } + default: + break + } } - } #endif func startTrialMetadataRefresh() { trialRefreshTimer?.invalidate() fetchTrialMetadata() #if DEBUG - let interval: TimeInterval = UserDefaults.standard.string(forKey: "debug_trial_mode") == "realtime" ? 10 : 60 + let interval: TimeInterval = + UserDefaults.standard.string(forKey: "debug_trial_mode") == "realtime" ? 10 : 60 #else - let interval: TimeInterval = 60 + let interval: TimeInterval = 60 #endif - trialRefreshTimer = Timer.scheduledTimer(withTimeInterval: interval, repeats: true) { [weak self] _ in + trialRefreshTimer = Timer.scheduledTimer(withTimeInterval: interval, repeats: true) { + [weak self] _ in Task { @MainActor in self?.fetchTrialMetadata() } @@ -326,6 +332,17 @@ class AppState: ObservableObject { private var systemAudioCaptureService: Any? // SystemAudioCaptureService (macOS 14.4+) private var audioMixer: AudioMixer? private var vadGateService: VADGateService? + private var cloudBackgroundSession: CloudBackgroundTranscriptionSession? + private var cloudBackgroundConversationId: String? + private var cloudBackgroundStartTask: Task? + private var cloudBackgroundDrainTask: Task? + private var cloudBackgroundSampleCursor = 0 + private var forceNextTranscriptionStartStreaming = false + private var isCloudBackgroundTranscription = false + private var isCloudBackgroundBackpressured = false + private var didLogCloudBackgroundBackpressure = false + private var backgroundTranscriptMerger = BackgroundTranscriptMerger() + private var speakerSegmentReducer = SpeakerSegmentReducer() // Speaker segments for diarized transcription (sliding window — older segments are in SQLite) private var speakerSegments: [SpeakerSegment] = [] @@ -931,73 +948,73 @@ class AppState: ObservableObject { DispatchQueue.main.async { [weak self] in guard let self else { return } UNUserNotificationCenter.current().getNotificationSettings { settings in - DispatchQueue.main.async { - let isNowGranted = settings.authorizationStatus == .authorized - self.hasNotificationPermission = isNowGranted - self.notificationAlertStyle = settings.alertStyle - - // Log the current notification settings - let authStatus = - switch settings.authorizationStatus { - case .notDetermined: "notDetermined" - case .denied: "denied" - case .authorized: "authorized" - case .provisional: "provisional" - case .ephemeral: "ephemeral" - @unknown default: "unknown" - } - let alertStyleName = - switch settings.alertStyle { - case .none: "NONE (no banners)" - case .banner: "BANNER" - case .alert: "ALERT" - @unknown default: "unknown" - } - log( - "Notification settings: auth=\(authStatus), alertStyle=\(alertStyleName), sound=\(settings.soundSetting.rawValue), badge=\(settings.badgeSetting.rawValue)" - ) - - // Track notification settings in analytics only when they change - let soundEnabled = settings.soundSetting == .enabled - let badgeEnabled = settings.badgeSetting == .enabled - let settingsChanged = - authStatus != self.lastNotificationAuthStatus - || alertStyleName != self.lastNotificationAlertStyle - || soundEnabled != self.lastNotificationSoundEnabled - || badgeEnabled != self.lastNotificationBadgeEnabled - - if settingsChanged { - AnalyticsManager.shared.notificationSettingsChecked( - authStatus: authStatus, - alertStyle: alertStyleName, - soundEnabled: soundEnabled, - badgeEnabled: badgeEnabled, - bannersDisabled: settings.alertStyle == .none + DispatchQueue.main.async { + let isNowGranted = settings.authorizationStatus == .authorized + self.hasNotificationPermission = isNowGranted + self.notificationAlertStyle = settings.alertStyle + + // Log the current notification settings + let authStatus = + switch settings.authorizationStatus { + case .notDetermined: "notDetermined" + case .denied: "denied" + case .authorized: "authorized" + case .provisional: "provisional" + case .ephemeral: "ephemeral" + @unknown default: "unknown" + } + let alertStyleName = + switch settings.alertStyle { + case .none: "NONE (no banners)" + case .banner: "BANNER" + case .alert: "ALERT" + @unknown default: "unknown" + } + log( + "Notification settings: auth=\(authStatus), alertStyle=\(alertStyleName), sound=\(settings.soundSetting.rawValue), badge=\(settings.badgeSetting.rawValue)" ) - // Detect regression: was authorized, now reverted to notDetermined - // This happens on macOS 26+ where the OS silently revokes notification permission - if self.lastNotificationAuthStatus == "authorized" && authStatus == "notDetermined" { - log( - "Notification permission REGRESSED from authorized to notDetermined — triggering auto-repair" + // Track notification settings in analytics only when they change + let soundEnabled = settings.soundSetting == .enabled + let badgeEnabled = settings.badgeSetting == .enabled + let settingsChanged = + authStatus != self.lastNotificationAuthStatus + || alertStyleName != self.lastNotificationAlertStyle + || soundEnabled != self.lastNotificationSoundEnabled + || badgeEnabled != self.lastNotificationBadgeEnabled + + if settingsChanged { + AnalyticsManager.shared.notificationSettingsChecked( + authStatus: authStatus, + alertStyle: alertStyleName, + soundEnabled: soundEnabled, + badgeEnabled: badgeEnabled, + bannersDisabled: settings.alertStyle == .none ) - AnalyticsManager.shared.notificationRepairTriggered( - reason: "auth_regression", - previousStatus: "authorized", - currentStatus: "notDetermined" - ) - self.repairNotificationRegistrationAndRetry() + + // Detect regression: was authorized, now reverted to notDetermined + // This happens on macOS 26+ where the OS silently revokes notification permission + if self.lastNotificationAuthStatus == "authorized" && authStatus == "notDetermined" { + log( + "Notification permission REGRESSED from authorized to notDetermined — triggering auto-repair" + ) + AnalyticsManager.shared.notificationRepairTriggered( + reason: "auth_regression", + previousStatus: "authorized", + currentStatus: "notDetermined" + ) + self.repairNotificationRegistrationAndRetry() + } + + // Update last known state + self.lastNotificationAuthStatus = authStatus + self.lastNotificationAlertStyle = alertStyleName + self.lastNotificationSoundEnabled = soundEnabled + self.lastNotificationBadgeEnabled = badgeEnabled } - // Update last known state - self.lastNotificationAuthStatus = authStatus - self.lastNotificationAlertStyle = alertStyleName - self.lastNotificationSoundEnabled = soundEnabled - self.lastNotificationBadgeEnabled = badgeEnabled } - } - } } // end DispatchQueue.main.async } @@ -1374,7 +1391,7 @@ class AppState: ObservableObject { /// Toggle transcription on/off func toggleTranscription() { - if isTranscribing { + if isTranscribing || isStartingTranscription { stopTranscription() } else { startTranscription() @@ -1384,7 +1401,7 @@ class AppState: ObservableObject { /// Start real-time transcription /// - Parameter source: Audio source to use (defaults to current audioSource setting) func startTranscription(source: AudioSource? = nil) { - guard !isTranscribing else { return } + guard !isTranscribing && !isStartingTranscription else { return } // Paywall hard-stop: every code path that enables the mic + WS streaming // funnels through here, including auto-restart from sleep and toggle @@ -1415,6 +1432,24 @@ class AppState: ObservableObject { "Transcription: Using language=\(effectiveLanguage) (autoDetect=\(AssistantSettings.shared.transcriptionAutoDetect), selected=\(AssistantSettings.shared.transcriptionLanguage))" ) + let shouldAttemptCloudBatch = effectiveSource == .microphone && !forceNextTranscriptionStartStreaming + forceNextTranscriptionStartStreaming = false + if shouldAttemptCloudBatch { + let batchLanguage = AssistantSettings.shared.effectiveBatchTranscriptionLanguage + log( + "Transcription: Cloud background batch using language=\(batchLanguage) (autoDetect=\(AssistantSettings.shared.transcriptionAutoDetect), selected=\(AssistantSettings.shared.transcriptionLanguage))" + ) + isStartingTranscription = true + cloudBackgroundStartTask?.cancel() + cloudBackgroundStartTask = Task { + await self.startCloudBackgroundTranscription( + source: effectiveSource, + language: batchLanguage + ) + } + return + } + // Always streaming via Python backend /v4/listen transcriptionService = try TranscriptionService(language: effectiveLanguage) @@ -1532,10 +1567,14 @@ class AppState: ObservableObject { Task { @MainActor in guard let self = self, self.isTranscribing else { return } log("Transcription: 4-hour limit reached - restarting session") - // Stop and restart (WebSocket close triggers backend conversation processing) - self.stopAudioCapture() - self.clearTranscriptionState() - self.startTranscription() + if self.isCloudBackgroundTranscription { + _ = await self.finishConversation() + } else { + // Stop and restart (WebSocket close triggers backend conversation processing) + self.stopAudioCapture() + self.clearTranscriptionState() + self.startTranscription() + } } } @@ -1550,6 +1589,171 @@ class AppState: ObservableObject { } } + private func startCloudBackgroundTranscription(source: AudioSource, language: String) async { + defer { isStartingTranscription = false } + do { + let capabilities = try await APIClient.shared.getDesktopCapabilities() + let routing = BackgroundTranscriptionRoutingGuard().decide( + backgroundBatchCapability: capabilities.backgroundBatch, + audioSource: source + ) + guard routing == .cloudBatchAssembly else { + throw NSError( + domain: "Omi.CloudBackgroundTranscription", + code: 1, + userInfo: [ + NSLocalizedDescriptionKey: + "Server background batch transcription is not available." + ] + ) + } + + let conversationId = try await APIClient.shared.startBackgroundConversation( + language: language) + guard !Task.isCancelled else { return } + cloudBackgroundConversationId = conversationId + cloudBackgroundSampleCursor = 0 + isCloudBackgroundTranscription = true + isCloudBackgroundBackpressured = false + didLogCloudBackgroundBackpressure = false + backgroundTranscriptMerger.reset() + speakerSegmentReducer.reset() + transcriptionService = nil + + if source == .bleDevice, let device = DeviceProvider.shared.connectedDevice { + currentConversationSource = ConversationSource.from(deviceType: device.type) + recordingInputDeviceName = device.displayName + } else { + currentConversationSource = .desktop + recordingInputDeviceName = AudioCaptureService.getCurrentMicrophoneName() + } + + audioCaptureService = AudioCaptureService() + audioMixer = AudioMixer() + vadGateService = nil + + // AssemblyAI batch requests use an explicit selected language and no keyword prompt. + // `multi` and keyterms have both produced provider-side 400s on low-speech chunks. + let resolvedLanguage = language + cloudBackgroundSession = CloudBackgroundTranscriptionSession( + configuration: .cloudBatch + ) { chunk in + try await TranscriptionService.batchTranscribeSegments( + audioData: chunk.pcmData, + conversationId: conversationId, + chunkStartMs: max(0, Int((chunk.startTime * 1000.0).rounded())), + language: resolvedLanguage + ) + } + + let systemAudioDisabled = UserDefaults.standard.bool(forKey: "disableSystemAudioCapture") + if systemAudioDisabled { + log( + "Transcription: System audio capture DISABLED by user preference (disableSystemAudioCapture)" + ) + } else if #available(macOS 14.4, *) { + systemAudioCaptureService = SystemAudioCaptureService() + log("Transcription: System audio capture initialized for cloud background batch") + } + + isTranscribing = true + recordingGeneration &+= 1 + AssistantSettings.shared.transcriptionEnabled = true + audioSource = source + currentTranscript = "" + speakerSegments = [] + totalSegmentCount = 0 + totalWordCount = 0 + liveSpeakerPersonMap = [:] + LiveTranscriptMonitor.shared.clear() + recordingStartTime = Date() + AudioLevelMonitor.shared.reset() + RecordingTimer.shared.start() + + await startAudioCapture(source: source) + await startCrashSafeTranscriptionSession(language: language) + + maxRecordingTimer = Timer.scheduledTimer( + withTimeInterval: maxRecordingDuration, repeats: false + ) { + [weak self] _ in + Task { @MainActor in + guard let self = self, self.isTranscribing else { return } + log("Transcription: 4-hour limit reached - rotating cloud background batch conversation") + _ = await self.finishConversation() + } + } + + AnalyticsManager.shared.transcriptionStarted() + log("Transcription: Started cloud background batch conversation \(conversationId)") + } catch { + stopAudioCapture() + isCloudBackgroundTranscription = false + cloudBackgroundSession = nil + cloudBackgroundConversationId = nil + isCloudBackgroundBackpressured = false + didLogCloudBackgroundBackpressure = false + isTranscribing = false + AssistantSettings.shared.transcriptionEnabled = false + logError("Transcription: Cloud background batch failed to start", error: error) + if BackgroundTranscriptionRoutingGuard().shouldFallbackToStreamingAfterBatchStartupFailure( + audioSource: source, + captureStarted: false + ) { + isStartingTranscription = false + forceNextTranscriptionStartStreaming = true + startTranscription(source: source) + log("Transcription: Fell back to streaming after cloud background batch startup failed") + } else { + AnalyticsManager.shared.recordingError(error: error.localizedDescription) + showAlert(title: "Audio Recording Not Started", message: cloudBackgroundStartFailureMessage(for: error)) + } + } + } + + private func cloudBackgroundStartFailureMessage(for error: Error) -> String { + if let urlError = error as? URLError { + switch urlError.code { + case .networkConnectionLost, .cannotConnectToHost, .notConnectedToInternet, .timedOut: + return + "Audio Recording could not start because the transcription backend connection was unavailable. Check the backend or network, then turn Audio Recording back on to retry." + default: + break + } + } + + if let apiError = error as? APIError { + switch apiError { + case .httpError(let statusCode): + return + "Audio Recording could not start because the transcription backend returned HTTP \(statusCode). Turn Audio Recording back on after the backend is healthy." + case .unauthorized: + return "Audio Recording could not start because your session needs to sign in again." + default: + break + } + } + + return + "Audio Recording could not start. Turn Audio Recording back on after the transcription backend is healthy. Details: \(error.localizedDescription)" + } + + private func startCrashSafeTranscriptionSession(language: String) async { + do { + let sessionId = try await TranscriptionStorage.shared.startSession( + source: currentConversationSource.rawValue, + language: language, + timezone: TimeZone.current.identifier, + inputDeviceName: recordingInputDeviceName + ) + currentSessionId = sessionId + LiveNotesMonitor.shared.startSession(sessionId: sessionId) + log("Transcription: Created DB session \(sessionId)") + } catch { + logError("Transcription: Failed to create DB session", error: error) + } + } + /// Start audio capture and pipe to transcription service /// - Parameter source: Audio source to capture from private func startAudioCapture(source: AudioSource = .microphone) async { @@ -1579,9 +1783,15 @@ class AppState: ObservableObject { } // Start the mixer — it sums mic + system into a mono stream and forwards it to - // the transcription WebSocket. + // the active transcription transport. audioMixer?.start { [weak self] monoMixed in - self?.transcriptionService?.sendAudio(monoMixed) + Task { @MainActor in + if self?.isCloudBackgroundTranscription == true { + self?.handleMixedBackgroundAudio(monoMixed) + } else { + self?.transcriptionService?.sendAudio(monoMixed) + } + } } do { @@ -1637,7 +1847,9 @@ class AppState: ObservableObject { silentMicFallbackInProgress = true guard let builtInID = AudioCaptureService.findBuiltInMicDeviceID() else { - log("Transcription: silent-mic detected but no built-in microphone available — leaving capture as-is") + log( + "Transcription: silent-mic detected but no built-in microphone available — leaving capture as-is" + ) silentMicFallbackInProgress = false return } @@ -1744,6 +1956,31 @@ class AppState: ObservableObject { /// triggers conversation processing on the backend side. We also call force-process to ensure /// the conversation is finalized, preventing the retry service from creating duplicates. func stopTranscription() { + if isStartingTranscription { + cloudBackgroundStartTask?.cancel() + cloudBackgroundStartTask = nil + isStartingTranscription = false + isCloudBackgroundTranscription = false + cloudBackgroundSession = nil + cloudBackgroundConversationId = nil + AssistantSettings.shared.transcriptionEnabled = false + return + } + + if isCloudBackgroundTranscription { + let capturedSessionId = currentSessionId + let capturedStartTime = recordingStartTime + let generationAtStop = recordingGeneration + Task { + await self.stopCloudBackgroundTranscription( + capturedSessionId: capturedSessionId, + capturedStartTime: capturedStartTime, + generationAtStop: generationAtStop + ) + } + return + } + // Capture session metadata BEFORE clearing state (clearTranscriptionState sets sessionId to nil) let capturedSessionId = currentSessionId let capturedStartTime = recordingStartTime @@ -1763,7 +2000,9 @@ class AppState: ObservableObject { // finalize the NEW conversation instead of the one we just stopped. // The retry service will reconcile the old session by timestamp matching. guard self.recordingGeneration == generationAtStop else { - log("Transcription: New recording started during delay, skipping force-process for session \(capturedSessionId.map(String.init) ?? "nil")") + log( + "Transcription: New recording started during delay, skipping force-process for session \(capturedSessionId.map(String.init) ?? "nil")" + ) return } @@ -1771,15 +2010,20 @@ class AppState: ObservableObject { if let conversation = try await APIClient.shared.forceProcessConversation() { // Validate the returned conversation matches the session we just stopped if let sessionId = capturedSessionId, let startTime = capturedStartTime, - let convStarted = conversation.startedAt, - abs(convStarted.timeIntervalSince(startTime)) < 10, - conversation.source == .desktop { + let convStarted = conversation.startedAt, + abs(convStarted.timeIntervalSince(startTime)) < 10, + conversation.source == .desktop + { try? await TranscriptionStorage.shared.markSessionCompleted( id: sessionId, backendId: conversation.id) - log("Transcription: Force-processed conversation \(conversation.id), session \(sessionId) completed") + log( + "Transcription: Force-processed conversation \(conversation.id), session \(sessionId) completed" + ) } else if let sessionId = capturedSessionId, let startTime = capturedStartTime { // Force-process returned a different conversation — fall back to reconciliation - log("Transcription: Force-processed conversation \(conversation.id) does not match session \(sessionId), reconciling by timestamp") + log( + "Transcription: Force-processed conversation \(conversation.id) does not match session \(sessionId), reconciling by timestamp" + ) await reconcileSession(sessionId: sessionId, startTime: startTime) } } else { @@ -1798,6 +2042,215 @@ class AppState: ObservableObject { } } + func restartTranscriptionAfterSettingsChange() async { + guard isTranscribing || isStartingTranscription else { return } + + if isCloudBackgroundTranscription { + let capturedSessionId = currentSessionId + let capturedStartTime = recordingStartTime + let generationAtStop = recordingGeneration + await stopCloudBackgroundTranscription( + capturedSessionId: capturedSessionId, + capturedStartTime: capturedStartTime, + generationAtStop: generationAtStop, + forSettingsChange: true + ) + } else { + stopTranscription() + try? await Task.sleep(nanoseconds: 1_000_000_000) + } + + startTranscription() + } + + private func handleMixedBackgroundAudio(_ pcmData: Data) { + guard let session = cloudBackgroundSession else { return } + let startTime = Double(cloudBackgroundSampleCursor) / 16000.0 + cloudBackgroundSampleCursor += pcmData.count / 2 + let result = session.append(pcmData: pcmData, startTime: startTime) + if result.isBackpressured { + isCloudBackgroundBackpressured = true + if !didLogCloudBackgroundBackpressure { + didLogCloudBackgroundBackpressure = true + log( + "Transcription: Cloud background ASR backpressure active (\(result.pendingChunkCount) pending chunks); dropping new audio until backlog drains" + ) + } + } + if result.enqueuedChunks > 0 { + drainCloudBackgroundASRQueue() + } + } + + private func drainCloudBackgroundASRQueue() { + guard cloudBackgroundDrainTask == nil else { return } + cloudBackgroundDrainTask = Task { @MainActor in + defer { cloudBackgroundDrainTask = nil } + while !Task.isCancelled { + guard let session = cloudBackgroundSession else { break } + do { + guard let result = try await session.transcribeNext() else { break } + let merged = backgroundTranscriptMerger.merge(result.segments) + _ = speakerSegmentReducer.apply(merged) + handleBackendSegments(merged) + if !session.isBackpressured && isCloudBackgroundBackpressured { + isCloudBackgroundBackpressured = false + didLogCloudBackgroundBackpressure = false + log("Transcription: Cloud background ASR backlog drained; resuming normal ingest") + } + } catch { + logError("Transcription: Cloud background chunk failed", error: error) + try? await Task.sleep(nanoseconds: 1_000_000_000) + } + } + } + } + + private func waitForCloudBackgroundBacklog(timeout: TimeInterval = 60) async { + let deadline = Date().addingTimeInterval(timeout) + while Date() < deadline { + drainCloudBackgroundASRQueue() + if cloudBackgroundSession?.pendingChunkCount ?? 0 == 0 && cloudBackgroundDrainTask == nil { + return + } + try? await Task.sleep(nanoseconds: 250_000_000) + } + log( + "Transcription: Cloud background backlog wait timed out with \(cloudBackgroundSession?.pendingChunkCount ?? 0) pending chunks" + ) + } + + private func stopCloudBackgroundTranscription( + capturedSessionId: Int64?, + capturedStartTime: Date?, + generationAtStop: UInt64, + forSettingsChange: Bool = false + ) async { + let stoppedConversationId = cloudBackgroundConversationId + stopAudioCapture() + _ = cloudBackgroundSession?.finishInput() + drainCloudBackgroundASRQueue() + await waitForCloudBackgroundBacklog(timeout: forSettingsChange ? 5 : 60) + + clearCloudBackgroundState() + clearTranscriptionState() + silentMicFallbackInProgress = false + + if forSettingsChange { + // Finalize the stopped conversation in the background so settings UI stays responsive. + Task { + try? await Task.sleep(nanoseconds: 3_000_000_000) + guard recordingGeneration == generationAtStop else { return } + await finishStoppedCloudBackgroundConversation( + conversationId: stoppedConversationId, + capturedSessionId: capturedSessionId, + capturedStartTime: capturedStartTime, + logPrefix: "Cloud background batch" + ) + await loadConversations() + } + return + } + + try? await Task.sleep(nanoseconds: 3_000_000_000) + guard recordingGeneration == generationAtStop else { + log( + "Transcription: New recording started during cloud batch delay, skipping force-process for session \(capturedSessionId.map(String.init) ?? "nil")" + ) + return + } + + await finishStoppedCloudBackgroundConversation( + conversationId: stoppedConversationId, + capturedSessionId: capturedSessionId, + capturedStartTime: capturedStartTime, + logPrefix: "Cloud background batch" + ) + await loadConversations() + } + + private func clearCloudBackgroundState() { + cloudBackgroundStartTask?.cancel() + cloudBackgroundStartTask = nil + cloudBackgroundDrainTask?.cancel() + cloudBackgroundDrainTask = nil + cloudBackgroundSession = nil + cloudBackgroundConversationId = nil + cloudBackgroundSampleCursor = 0 + isCloudBackgroundTranscription = false + isCloudBackgroundBackpressured = false + didLogCloudBackgroundBackpressure = false + backgroundTranscriptMerger.reset() + speakerSegmentReducer.reset() + } + + private func forceProcessStoppedConversation( + capturedSessionId: Int64?, + capturedStartTime: Date?, + logPrefix: String + ) async { + do { + if let conversation = try await APIClient.shared.forceProcessConversation() { + if let sessionId = capturedSessionId, let startTime = capturedStartTime, + let convStarted = conversation.startedAt, + abs(convStarted.timeIntervalSince(startTime)) < 10, + conversation.source == .desktop + { + try? await TranscriptionStorage.shared.markSessionCompleted( + id: sessionId, backendId: conversation.id) + log( + "Transcription: \(logPrefix) force-processed conversation \(conversation.id), session \(sessionId) completed" + ) + } else if let sessionId = capturedSessionId, let startTime = capturedStartTime { + log( + "Transcription: \(logPrefix) force-process returned different conversation \(conversation.id), reconciling session \(sessionId)" + ) + await reconcileSession(sessionId: sessionId, startTime: startTime) + } + } else if let sessionId = capturedSessionId, let startTime = capturedStartTime { + await reconcileSession(sessionId: sessionId, startTime: startTime) + } + } catch { + logError( + "Transcription: \(logPrefix) force-process failed, retry service will reconcile", + error: error) + } + } + + private func finishStoppedCloudBackgroundConversation( + conversationId: String?, + capturedSessionId: Int64?, + capturedStartTime: Date?, + logPrefix: String + ) async { + guard let conversationId else { + log("Transcription: \(logPrefix) missing backend conversation ID, reconciling") + if let sessionId = capturedSessionId, let startTime = capturedStartTime { + await reconcileSession(sessionId: sessionId, startTime: startTime) + } + return + } + + do { + let conversation = try await APIClient.shared.finishBackgroundConversation( + conversationId: conversationId) + if let sessionId = capturedSessionId { + try? await TranscriptionStorage.shared.markSessionCompleted( + id: sessionId, backendId: conversation.id) + log( + "Transcription: \(logPrefix) finalized conversation \(conversation.id), session \(sessionId) completed" + ) + } + } catch { + logError( + "Transcription: \(logPrefix) explicit finalization failed, retry service will reconcile", + error: error) + if let sessionId = capturedSessionId, let startTime = capturedStartTime { + await reconcileSession(sessionId: sessionId, startTime: startTime) + } + } + } + /// Reconcile a local session by checking if a matching conversation exists on the backend. /// If found, marks the session as completed. Otherwise leaves it as pendingUpload for retry. private func reconcileSession(sessionId: Int64, startTime: Date) async { @@ -1818,7 +2271,9 @@ class AppState: ObservableObject { id: sessionId, backendId: match.id) log("Transcription: Reconciled session \(sessionId) → backend conversation \(match.id)") } else { - log("Transcription: No matching backend conversation found for session \(sessionId), leaving for retry") + log( + "Transcription: No matching backend conversation found for session \(sessionId), leaving for retry" + ) } } catch { logError("Transcription: Reconciliation failed for session \(sessionId)", error: error) @@ -1828,12 +2283,18 @@ class AppState: ObservableObject { /// Finish the current conversation and keep recording for a new one. /// Disconnects the WebSocket (triggers backend conversation processing) then reconnects. func finishConversation() async -> FinishConversationResult { + if isCloudBackgroundTranscription { + return await finishCloudBackgroundConversation() + } + guard totalSegmentCount > 0 || !speakerSegments.isEmpty else { log("Transcription: No segments to finish") return .discarded } - log("Transcription: Finishing conversation — disconnecting WebSocket to trigger backend processing") + log( + "Transcription: Finishing conversation — disconnecting WebSocket to trigger backend processing" + ) // Capture state before rotation — memory_created event for this conversation // may arrive on the new WebSocket after currentSessionId and recordingStartTime have changed. @@ -1946,6 +2407,89 @@ class AppState: ObservableObject { return .saved } + private func finishCloudBackgroundConversation() async -> FinishConversationResult { + log("Transcription: Finishing cloud background batch conversation") + let finishedConversationId = cloudBackgroundConversationId + let capturedSessionId = currentSessionId + let capturedStartTime = recordingStartTime + finishedSessionId = capturedSessionId + finishedRecordingStartTime = capturedStartTime + + stopAudioCapture() + _ = cloudBackgroundSession?.finishInput() + drainCloudBackgroundASRQueue() + await waitForCloudBackgroundBacklog() + let finishedHadSegments = totalSegmentCount > 0 || !speakerSegments.isEmpty + + if let sessionId = currentSessionId { + do { + try await TranscriptionStorage.shared.finishSession(id: sessionId) + log("Transcription: Finished DB session \(sessionId) before cloud batch rotation") + } catch { + logError("Transcription: Failed to finish DB session \(sessionId)", error: error) + } + } + + await finishStoppedCloudBackgroundConversation( + conversationId: finishedConversationId, + capturedSessionId: capturedSessionId, + capturedStartTime: capturedStartTime, + logPrefix: "Cloud background batch rotation" + ) + + currentSessionId = nil + speakerSegments = [] + totalSegmentCount = 0 + totalWordCount = 0 + liveSpeakerPersonMap = [:] + LiveTranscriptMonitor.shared.clear() + LiveNotesMonitor.shared.endSession() + LiveNotesMonitor.shared.clear() + backgroundTranscriptMerger.reset() + speakerSegmentReducer.reset() + cloudBackgroundSampleCursor = 0 + + recordingStartTime = Date() + RecordingTimer.shared.restart() + maxRecordingTimer?.invalidate() + maxRecordingTimer = Timer.scheduledTimer(withTimeInterval: maxRecordingDuration, repeats: false) + { + [weak self] _ in + Task { @MainActor in + guard let self = self, self.isTranscribing else { return } + log("Transcription: 4-hour limit reached - rotating cloud background batch conversation") + _ = await self.finishConversation() + } + } + + do { + let language = AssistantSettings.shared.effectiveBatchTranscriptionLanguage + let conversationId = try await APIClient.shared.startBackgroundConversation( + language: language) + cloudBackgroundConversationId = conversationId + cloudBackgroundSession = CloudBackgroundTranscriptionSession( + configuration: .cloudBatch + ) { chunk in + try await TranscriptionService.batchTranscribeSegments( + audioData: chunk.pcmData, + conversationId: conversationId, + chunkStartMs: max(0, Int((chunk.startTime * 1000.0).rounded())), + language: language + ) + } + await startCrashSafeTranscriptionSession(language: language) + await startAudioCapture(source: audioSource) + isTranscribing = true + await loadConversations() + log("Transcription: Ready for next cloud background batch conversation \(conversationId)") + return finishedHadSegments ? .saved : .discarded + } catch { + logError( + "Transcription: Failed to start next cloud background batch conversation", error: error) + return .error(error.localizedDescription) + } + } + /// Stop audio capture services (but keep transcript data for saving) private func stopAudioCapture() { // Cancel timers @@ -2448,7 +2992,7 @@ class AppState: ObservableObject { let idSet = Set(segmentIds) if let idx = conversations.firstIndex(where: { $0.id == conversationId }) { for segIdx in conversations[idx].transcriptSegments.indices - where idSet.contains(conversations[idx].transcriptSegments[segIdx].id) { + where idSet.contains(conversations[idx].transcriptSegments[segIdx].id) { let old = conversations[idx].transcriptSegments[segIdx] conversations[idx].transcriptSegments[segIdx] = TranscriptSegment( id: old.id, @@ -2459,7 +3003,15 @@ class AppState: ObservableObject { personId: isUser ? nil : personId, start: old.start, end: old.end, - translations: old.translations + translations: old.translations, + sttProvider: old.sttProvider, + sttModel: old.sttModel, + providerClusterId: old.providerClusterId, + providerSpeakerLabel: old.providerSpeakerLabel, + speakerIdentityState: old.speakerIdentityState, + speakerIdentityConfidence: old.speakerIdentityConfidence, + speakerIdentitySource: old.speakerIdentitySource, + speakerIdentityVersion: old.speakerIdentityVersion ) } } @@ -2566,7 +3118,15 @@ class AppState: ObservableObject { isUser: segment.is_user, personId: segment.person_id, speakerLabel: segment.speaker, - translationsJson: translationsJson + translationsJson: translationsJson, + sttProvider: segment.stt_provider, + sttModel: segment.stt_model, + providerClusterId: segment.provider_cluster_id, + providerSpeakerLabel: segment.provider_speaker_label, + speakerIdentityState: segment.speaker_identity_state, + speakerIdentityConfidence: segment.speaker_identity_confidence, + speakerIdentitySource: segment.speaker_identity_source, + speakerIdentityVersion: segment.speaker_identity_version ) } catch { logError("Transcription: Failed to persist segment to DB", error: error) @@ -2726,7 +3286,9 @@ class AppState: ObservableObject { // Always persist to SQLite — even if the segment was trimmed from // the in-memory window, the event payload has all fields needed if let sessionId = currentSessionId { - let mapped = newTranslations.map { TranscriptTranslation(lang: $0.lang, text: $0.text) } + let mapped = newTranslations.map { + TranscriptTranslation(lang: $0.lang, text: $0.text) + } var translationsJson: String? if let jsonData = try? JSONEncoder().encode(mapped) { translationsJson = String(data: jsonData, encoding: .utf8) @@ -2742,7 +3304,15 @@ class AppState: ObservableObject { isUser: translated.is_user, personId: translated.person_id, speakerLabel: translated.speaker, - translationsJson: translationsJson + translationsJson: translationsJson, + sttProvider: translated.stt_provider, + sttModel: translated.stt_model, + providerClusterId: translated.provider_cluster_id, + providerSpeakerLabel: translated.provider_speaker_label, + speakerIdentityState: translated.speaker_identity_state, + speakerIdentityConfidence: translated.speaker_identity_confidence, + speakerIdentitySource: translated.speaker_identity_source, + speakerIdentityVersion: translated.speaker_identity_version ) } } diff --git a/desktop/Desktop/Sources/BYOKValidator.swift b/desktop/Desktop/Sources/BYOKValidator.swift index a713f96245a..f942d063bfc 100644 --- a/desktop/Desktop/Sources/BYOKValidator.swift +++ b/desktop/Desktop/Sources/BYOKValidator.swift @@ -43,6 +43,11 @@ enum BYOKValidator { url: URL(string: "https://api.deepgram.com/v1/projects")!, headers: ["Authorization": "Token \(trimmed)"] ) + case .assemblyai: + return await ping( + url: URL(string: "https://api.assemblyai.com/v2/account")!, + headers: ["Authorization": trimmed] + ) } } diff --git a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift new file mode 100644 index 00000000000..2fea1bfca65 --- /dev/null +++ b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift @@ -0,0 +1,178 @@ +import Foundation + +struct BackgroundAudioChunk: Equatable { + let pcmData: Data + let startTime: Double + let isFinal: Bool +} + +struct BackgroundAudioChunker { + private let configuration: BackgroundTranscriptionConfiguration + private var buffer = Data() + private var bufferStartTime: Double? + + init(configuration: BackgroundTranscriptionConfiguration = BackgroundTranscriptionConfiguration()) + { + self.configuration = configuration + } + + mutating func append(pcmData: Data, startTime: Double) -> [BackgroundAudioChunk] { + guard !pcmData.isEmpty else { return [] } + + if buffer.isEmpty { + bufferStartTime = startTime + } + buffer.append(pcmData) + + var chunks: [BackgroundAudioChunk] = [] + while chunks.count < configuration.maxChunksPerAppend, let chunk = nextChunk(isFinal: false) { + chunks.append(chunk) + } + return chunks + } + + mutating func finishInput() -> [BackgroundAudioChunk] { + guard !buffer.isEmpty else { return [] } + let maxBytes = configuration.alignedByteCount(for: configuration.maxChunkDuration) + var chunks: [BackgroundAudioChunk] = [] + + while buffer.count > maxBytes, let chunk = nextChunk(isFinal: false) { + chunks.append(chunk) + } + + let chunk = BackgroundAudioChunk( + pcmData: buffer, + startTime: bufferStartTime ?? 0, + isFinal: true + ) + chunks.append(chunk) + buffer.removeAll(keepingCapacity: false) + bufferStartTime = nil + return chunks + } + + private mutating func nextChunk(isFinal: Bool) -> BackgroundAudioChunk? { + let minBytes = configuration.alignedByteCount(for: configuration.minChunkDuration) + let maxBytes = configuration.alignedByteCount(for: configuration.maxChunkDuration) + guard buffer.count >= minBytes else { return nil } + + let cutBytes: Int? + let overlapBytesAtMaxCut = effectiveOverlapBytes(forCutBytes: min(buffer.count, maxBytes)) + let minimumProgressCutBytes = overlapBytesAtMaxCut + configuration.bytesPerSample + + if let silenceCut = firstSilenceCut( + minBytes: max(minBytes, minimumProgressCutBytes), + maxBytes: min(buffer.count, maxBytes) + ) { + cutBytes = silenceCut + } else if buffer.count >= maxBytes { + cutBytes = maxBytes + } else { + cutBytes = nil + } + + guard let cutBytes, cutBytes - effectiveOverlapBytes(forCutBytes: cutBytes) > 0 else { + return nil + } + return cut(at: cutBytes, isFinal: isFinal) + } + + private mutating func cut(at requestedCutBytes: Int, isFinal: Bool) -> BackgroundAudioChunk { + let cutBytes = min(requestedCutBytes, buffer.count).alignedToSample + let startTime = bufferStartTime ?? 0 + let chunk = BackgroundAudioChunk( + pcmData: buffer.prefix(cutBytes), + startTime: startTime, + isFinal: isFinal + ) + + let overlapBytes = effectiveOverlapBytes(forCutBytes: cutBytes) + let retainedStart = max(0, cutBytes - overlapBytes) + let retained = buffer.suffix(buffer.count - retainedStart) + buffer = Data(retained) + bufferStartTime = + startTime + Double(retainedStart / configuration.bytesPerSample) + / Double(configuration.sampleRate) + return chunk + } + + private func effectiveOverlapBytes(forCutBytes cutBytes: Int) -> Int { + let requestedOverlap = configuration.alignedByteCount(for: configuration.overlapDuration) + let maxProgressSafeOverlap = max(0, cutBytes - configuration.bytesPerSample) + return min(requestedOverlap, maxProgressSafeOverlap).alignedToSample + } + + private func firstSilenceCut(minBytes: Int, maxBytes: Int) -> Int? { + let windowBytes = max( + configuration.bytesPerSample, + configuration.alignedByteCount(for: configuration.silenceWindowDuration)) + guard maxBytes >= minBytes + windowBytes else { return nil } + + var offset = minBytes.alignedToSample + while offset + windowBytes <= maxBytes { + if isSilentWindow(start: offset, byteCount: windowBytes), + hasMinimumSpeech(before: offset) + { + return offset + } + offset += configuration.bytesPerSample + } + return nil + } + + private func isSilentWindow(start: Int, byteCount: Int) -> Bool { + guard start >= 0, start + byteCount <= buffer.count else { return false } + var maxAmplitude = 0 + for offset in stride(from: start, to: start + byteCount, by: configuration.bytesPerSample) { + maxAmplitude = max(maxAmplitude, sampleAmplitude(at: offset)) + if maxAmplitude > configuration.silenceAmplitudeThreshold { + return false + } + } + return true + } + + private func hasMinimumSpeech(before endOffset: Int) -> Bool { + guard endOffset > 0 else { return false } + let minimumSpeechBytes = max( + configuration.bytesPerSample, + configuration.alignedByteCount(for: configuration.speechActivityDetection.minimumSpeechDuration) + ) + var peak = 0 + var sumSquares = 0.0 + var count = 0 + var speechLikeBytes = 0 + + for offset in stride( + from: 0, to: min(endOffset, buffer.count), by: configuration.bytesPerSample) + { + let amplitude = sampleAmplitude(at: offset) + peak = max(peak, amplitude) + sumSquares += Double(amplitude * amplitude) + if amplitude >= configuration.speechPeakAmplitudeThreshold { + speechLikeBytes += configuration.bytesPerSample + } + count += 1 + } + + guard count > 0 else { return false } + let rms = sqrt(sumSquares / Double(count)) + return speechLikeBytes >= minimumSpeechBytes + || (rms >= Double(configuration.speechRMSAmplitudeThreshold) + && endOffset >= minimumSpeechBytes) + } + + private func sampleAmplitude(at offset: Int) -> Int { + guard offset + 1 < buffer.count else { return 0 } + let low = UInt16(buffer[offset]) + let high = UInt16(buffer[offset + 1]) << 8 + let sample = Int16(bitPattern: low | high) + return abs(Int(sample)) + } +} + +extension Data { + fileprivate func prefix(_ count: Int) -> Data { + Data(self[startIndex.. [TranscriptionService.BackendSegment] + { + var changedSegments: [TranscriptionService.BackendSegment] = [] + for incoming in incomingSegments + where !incoming.text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + if let changed = upsert(incoming) { + changedSegments.append(changed) + } + } + segments.sort { lhs, rhs in + if lhs.start == rhs.start { + return lhs.end < rhs.end + } + return lhs.start < rhs.start + } + return changedSegments + } + + private mutating func upsert(_ incoming: TranscriptionService.BackendSegment) + -> TranscriptionService.BackendSegment? + { + if let segmentId = incoming.id, + let index = segments.firstIndex(where: { $0.id == segmentId }) + { + let preferred = preferredSegment(existing: segments[index], incoming: incoming) + guard !sameSegment(preferred, segments[index]) else { return nil } + segments[index] = preferred + return preferred + } + + if let index = segments.firstIndex(where: { isDuplicate($0, incoming) }) { + let preferred = preferredSegment(existing: segments[index], incoming: incoming) + guard !sameSegment(preferred, segments[index]) else { return nil } + segments[index] = preferred + return preferred + } + + if let index = segments.firstIndex(where: { canMergeOverlap($0, incoming) }) { + let merged = mergedOverlap(segments[index], incoming) + guard !sameSegment(merged, segments[index]) else { return nil } + segments[index] = merged + return merged + } + + segments.append(incoming) + return incoming + } + + private func isDuplicate( + _ existing: TranscriptionService.BackendSegment, + _ incoming: TranscriptionService.BackendSegment + ) -> Bool { + guard normalizedText(existing.text) == normalizedText(incoming.text) else { return false } + let intersection = max(0, min(existing.end, incoming.end) - max(existing.start, incoming.start)) + let shorterDuration = max( + 0.001, min(existing.end - existing.start, incoming.end - incoming.start)) + return intersection / shorterDuration >= duplicateOverlapThreshold + } + + private func canMergeOverlap( + _ existing: TranscriptionService.BackendSegment, + _ incoming: TranscriptionService.BackendSegment + ) -> Bool { + guard (existing.speaker_id ?? 0) == (incoming.speaker_id ?? 0) else { return false } + guard min(existing.end, incoming.end) > max(existing.start, incoming.start) else { + return false + } + return edgeTokenOverlap(existing.text, incoming.text) > 0 + } + + private func mergedOverlap( + _ existing: TranscriptionService.BackendSegment, + _ incoming: TranscriptionService.BackendSegment + ) -> TranscriptionService.BackendSegment { + let existingFirst = existing.start <= incoming.start + let first = existingFirst ? existing : incoming + let second = existingFirst ? incoming : existing + let overlap = edgeTokenOverlap(first.text, second.text) + let suffix = tokenized(second.text).dropFirst(overlap).joined(separator: " ") + let mergedText = suffix.isEmpty ? first.text : "\(first.text) \(suffix)" + + return TranscriptionService.BackendSegment( + id: first.id ?? second.id, + text: mergedText, + speaker: first.speaker ?? second.speaker, + speaker_id: first.speaker_id ?? second.speaker_id, + is_user: first.is_user || second.is_user, + person_id: first.person_id ?? second.person_id, + start: min(first.start, second.start), + end: max(first.end, second.end), + translations: first.translations ?? second.translations, + stt_provider: first.stt_provider ?? second.stt_provider, + stt_model: first.stt_model ?? second.stt_model, + provider_cluster_id: first.provider_cluster_id ?? second.provider_cluster_id, + provider_speaker_label: first.provider_speaker_label ?? second.provider_speaker_label, + speaker_identity_state: first.speaker_identity_state ?? second.speaker_identity_state, + speaker_identity_confidence: first.speaker_identity_confidence + ?? second.speaker_identity_confidence, + speaker_identity_source: first.speaker_identity_source ?? second.speaker_identity_source, + speaker_identity_version: first.speaker_identity_version ?? second.speaker_identity_version + ) + } + + private func preferredSegment( + existing: TranscriptionService.BackendSegment, + incoming: TranscriptionService.BackendSegment + ) -> TranscriptionService.BackendSegment { + if normalizedText(existing.text) == normalizedText(incoming.text), + existing.text.count <= incoming.text.count + { + return existing + } + if incoming.end - incoming.start > existing.end - existing.start { + return incoming + } + if incoming.text.count > existing.text.count { + return incoming + } + return existing + } + + private func sameSegment( + _ lhs: TranscriptionService.BackendSegment, + _ rhs: TranscriptionService.BackendSegment + ) -> Bool { + lhs.id == rhs.id + && lhs.text == rhs.text + && lhs.speaker == rhs.speaker + && lhs.speaker_id == rhs.speaker_id + && lhs.is_user == rhs.is_user + && lhs.person_id == rhs.person_id + && lhs.start == rhs.start + && lhs.end == rhs.end + && lhs.stt_provider == rhs.stt_provider + && lhs.stt_model == rhs.stt_model + && lhs.provider_cluster_id == rhs.provider_cluster_id + && lhs.provider_speaker_label == rhs.provider_speaker_label + && lhs.speaker_identity_state == rhs.speaker_identity_state + && lhs.speaker_identity_confidence == rhs.speaker_identity_confidence + && lhs.speaker_identity_source == rhs.speaker_identity_source + && lhs.speaker_identity_version == rhs.speaker_identity_version + } + + private func normalizedText(_ value: String) -> String { + value.lowercased() + .replacingOccurrences(of: #"\s+"#, with: " ", options: .regularExpression) + .trimmingCharacters(in: .whitespacesAndNewlines) + } + + private func tokenized(_ value: String) -> [String] { + normalizedText(value).split(separator: " ").map(String.init) + } + + private func edgeTokenOverlap(_ first: String, _ second: String) -> Int { + let left = tokenized(first) + let right = tokenized(second) + guard !left.isEmpty, !right.isEmpty else { return 0 } + + let maxOverlap = min(left.count, right.count) + for count in stride(from: maxOverlap, through: 1, by: -1) { + if Array(left.suffix(count)) == Array(right.prefix(count)) { + return count + } + } + return 0 + } +} diff --git a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift new file mode 100644 index 00000000000..626e1646e89 --- /dev/null +++ b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift @@ -0,0 +1,68 @@ +import Foundation + +struct BackgroundTranscriptionConfiguration: Equatable { + var sampleRate: Int = 16000 + var maxChunkDuration: TimeInterval = 15.0 + var minChunkDuration: TimeInterval = 1.0 + var overlapDuration: TimeInterval = 0.5 + var silenceWindowDuration: TimeInterval = 0.35 + var silenceAmplitudeThreshold: Int = 256 + var speechPeakAmplitudeThreshold: Int = 512 + var speechRMSAmplitudeThreshold: Int = 64 + var maxPendingChunks: Int = 4 + var maxChunkTranscriptionAttempts: Int = 3 + var requiresSpeechBeforeUpload: Bool = false + var speechActivityDetection = SpeechActivityDetectionConfiguration() + var usesSilenceAwareChunking: Bool { + minChunkDuration < maxChunkDuration + } + + var bytesPerSample: Int { 2 } + + func byteCount(for duration: TimeInterval) -> Int { + max(0, Int(duration * Double(sampleRate)) * bytesPerSample) + } + + func alignedByteCount(for duration: TimeInterval) -> Int { + byteCount(for: duration).alignedToSample + } + + var maxChunksPerAppend: Int { + 1 + } + + static var cloudBatch: BackgroundTranscriptionConfiguration { + fixedFifteenSecondCloudBatch + } + + static var fixedFifteenSecondCloudBatch: BackgroundTranscriptionConfiguration { + BackgroundTranscriptionConfiguration( + maxChunkDuration: 15.0, + minChunkDuration: 15.0, + overlapDuration: 0.5, + maxPendingChunks: 8, + maxChunkTranscriptionAttempts: 3, + requiresSpeechBeforeUpload: true, + speechActivityDetection: SpeechActivityDetectionConfiguration( + windowDuration: 0.02, + minimumSpeechDuration: 0.75, + peakAmplitudeThreshold: 900, + rmsAmplitudeThreshold: 180, + maximumSpeechZeroCrossingRate: 0.35 + ) + ) + } + + static var silenceAwareCloudBatchCandidate: BackgroundTranscriptionConfiguration { + var configuration = fixedFifteenSecondCloudBatch + configuration.minChunkDuration = 6.0 + configuration.silenceWindowDuration = 0.35 + return configuration + } +} + +extension Int { + var alignedToSample: Int { + self - (self % 2) + } +} diff --git a/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionRoutingGuard.swift b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionRoutingGuard.swift new file mode 100644 index 00000000000..6d5e114f68a --- /dev/null +++ b/desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionRoutingGuard.swift @@ -0,0 +1,37 @@ +import Foundation + +enum BackgroundTranscriptionRoutingDecision: Equatable { + case cloudBatchAssembly + case cloudListenStreaming(reason: String?) +} + +struct BackgroundTranscriptionRoutingGuard { + func decide( + backgroundBatchCapability: DesktopBackgroundBatchCapability?, + audioSource: AudioSource + ) -> BackgroundTranscriptionRoutingDecision { + guard audioSource == .microphone else { + return .cloudListenStreaming(reason: "batch_microphone_only") + } + guard let capability = backgroundBatchCapability else { + return .cloudListenStreaming(reason: "server_background_batch_capability_unavailable") + } + guard capability.enabled else { + return .cloudListenStreaming(reason: capability.reason ?? "server_background_batch_disabled") + } + guard let effectiveProvider = capability.effectiveProvider?.lowercased() else { + return .cloudListenStreaming(reason: capability.reason ?? "server_background_batch_provider_unavailable") + } + guard effectiveProvider == "assemblyai" || effectiveProvider == "deepgram" else { + return .cloudListenStreaming(reason: capability.reason ?? "server_background_batch_provider_unsupported") + } + return .cloudBatchAssembly + } + + func shouldFallbackToStreamingAfterBatchStartupFailure( + audioSource: AudioSource, + captureStarted: Bool + ) -> Bool { + audioSource == .microphone && !captureStarted + } +} diff --git a/desktop/Desktop/Sources/BackgroundTranscription/CloudBackgroundTranscriptionSession.swift b/desktop/Desktop/Sources/BackgroundTranscription/CloudBackgroundTranscriptionSession.swift new file mode 100644 index 00000000000..af7a7d5da99 --- /dev/null +++ b/desktop/Desktop/Sources/BackgroundTranscription/CloudBackgroundTranscriptionSession.swift @@ -0,0 +1,174 @@ +import Foundation + +struct BackgroundIngestResult: Equatable { + let enqueuedChunks: Int + let pendingChunkCount: Int + let acceptedInputBytes: Int + let didFinishInput: Bool + let isBackpressured: Bool +} + +struct BackgroundTranscriptionResult { + let chunk: BackgroundAudioChunk + let segments: [TranscriptionService.BackendSegment] +} + +struct BackgroundTranscriptSnapshot { + let pendingChunkCount: Int + let processedChunkCount: Int + let droppedChunkCount: Int + let isInputFinished: Bool + let segments: [TranscriptionService.BackendSegment] + let lastSpeechActivityDecision: SpeechActivityDecision? +} + +final class CloudBackgroundTranscriptionSession { + typealias TranscribeHandler = (BackgroundAudioChunk) async throws -> [TranscriptionService + .BackendSegment] + + private let configuration: BackgroundTranscriptionConfiguration + private let transcribe: TranscribeHandler + private let speechActivityDetector: SpeechActivityDetector + private var chunker: BackgroundAudioChunker + private var pendingChunks: [BackgroundAudioChunk] = [] + private var pendingChunkAttempts: [Int] = [] + private var processedChunkCount = 0 + private var droppedChunkCount = 0 + private var isInputFinished = false + private var processedSegments: [TranscriptionService.BackendSegment] = [] + private var lastSpeechActivityDecision: SpeechActivityDecision? + + init( + configuration: BackgroundTranscriptionConfiguration = BackgroundTranscriptionConfiguration(), + transcribe: @escaping TranscribeHandler + ) { + self.configuration = configuration + self.transcribe = transcribe + self.speechActivityDetector = SpeechActivityDetector( + configuration: configuration.speechActivityDetection, + sampleRate: configuration.sampleRate, + bytesPerSample: configuration.bytesPerSample + ) + self.chunker = BackgroundAudioChunker(configuration: configuration) + } + + var pendingChunkCount: Int { + pendingChunks.count + } + + var isBackpressured: Bool { + pendingChunks.count >= configuration.maxPendingChunks + } + + func append(pcmData: Data, startTime: Double) -> BackgroundIngestResult { + guard !isInputFinished else { + return BackgroundIngestResult( + enqueuedChunks: 0, + pendingChunkCount: pendingChunks.count, + acceptedInputBytes: 0, + didFinishInput: true, + isBackpressured: isBackpressured + ) + } + guard !isBackpressured else { + return BackgroundIngestResult( + enqueuedChunks: 0, + pendingChunkCount: pendingChunks.count, + acceptedInputBytes: 0, + didFinishInput: false, + isBackpressured: true + ) + } + + let chunks = chunker.append(pcmData: pcmData, startTime: startTime) + var enqueuedChunks = 0 + for chunk in chunks where pendingChunks.count < configuration.maxPendingChunks { + guard shouldUpload(chunk) else { + droppedChunkCount += 1 + continue + } + pendingChunks.append(chunk) + pendingChunkAttempts.append(0) + enqueuedChunks += 1 + } + return BackgroundIngestResult( + enqueuedChunks: enqueuedChunks, + pendingChunkCount: pendingChunks.count, + acceptedInputBytes: pcmData.count, + didFinishInput: false, + isBackpressured: isBackpressured + ) + } + + func finishInput() -> BackgroundIngestResult { + guard !isInputFinished else { + return BackgroundIngestResult( + enqueuedChunks: 0, + pendingChunkCount: pendingChunks.count, + acceptedInputBytes: 0, + didFinishInput: true, + isBackpressured: isBackpressured + ) + } + + var enqueuedChunks = 0 + for chunk in chunker.finishInput() { + guard shouldUpload(chunk) else { + droppedChunkCount += 1 + continue + } + pendingChunks.append(chunk) + pendingChunkAttempts.append(0) + enqueuedChunks += 1 + } + isInputFinished = true + return BackgroundIngestResult( + enqueuedChunks: enqueuedChunks, + pendingChunkCount: pendingChunks.count, + acceptedInputBytes: 0, + didFinishInput: true, + isBackpressured: isBackpressured + ) + } + + func transcribeNext() async throws -> BackgroundTranscriptionResult? { + guard !pendingChunks.isEmpty else { return nil } + let chunk = pendingChunks[0] + do { + let segments = try await transcribe(chunk) + pendingChunks.removeFirst() + pendingChunkAttempts.removeFirst() + processedChunkCount += 1 + processedSegments.append(contentsOf: segments) + return BackgroundTranscriptionResult(chunk: chunk, segments: segments) + } catch { + let attempts = (pendingChunkAttempts.first ?? 0) + 1 + if attempts >= configuration.maxChunkTranscriptionAttempts { + pendingChunks.removeFirst() + pendingChunkAttempts.removeFirst() + droppedChunkCount += 1 + } else { + pendingChunkAttempts[0] = attempts + } + throw error + } + } + + func snapshot() -> BackgroundTranscriptSnapshot { + BackgroundTranscriptSnapshot( + pendingChunkCount: pendingChunks.count, + processedChunkCount: processedChunkCount, + droppedChunkCount: droppedChunkCount, + isInputFinished: isInputFinished, + segments: processedSegments, + lastSpeechActivityDecision: lastSpeechActivityDecision + ) + } + + private func shouldUpload(_ chunk: BackgroundAudioChunk) -> Bool { + guard configuration.requiresSpeechBeforeUpload else { return true } + let decision = speechActivityDetector.evaluate(pcmData: chunk.pcmData) + lastSpeechActivityDecision = decision + return decision.shouldUpload + } +} diff --git a/desktop/Desktop/Sources/BackgroundTranscription/SpeakerSegmentReducer.swift b/desktop/Desktop/Sources/BackgroundTranscription/SpeakerSegmentReducer.swift new file mode 100644 index 00000000000..27021183615 --- /dev/null +++ b/desktop/Desktop/Sources/BackgroundTranscription/SpeakerSegmentReducer.swift @@ -0,0 +1,84 @@ +import Foundation + +struct SpeakerSegmentReducer { + struct ApplyResult: Equatable { + var added: Int = 0 + var updated: Int = 0 + var totalSegmentCount: Int = 0 + var totalWordCount: Int = 0 + } + + private(set) var segments: [SpeakerSegment] = [] + private(set) var totalSegmentCount: Int = 0 + private(set) var totalWordCount: Int = 0 + var maxInMemorySegments: Int + + init(maxInMemorySegments: Int = 400) { + self.maxInMemorySegments = maxInMemorySegments + } + + mutating func reset() { + segments = [] + totalSegmentCount = 0 + totalWordCount = 0 + } + + mutating func replaceSegments(_ replacement: [SpeakerSegment]) { + segments = replacement + totalSegmentCount = replacement.count + totalWordCount = replacement.reduce(0) { $0 + wordCount($1.text) } + } + + mutating func apply(_ incomingSegments: [TranscriptionService.BackendSegment]) -> ApplyResult { + apply(incomingSegments.map(Self.speakerSegment(from:))) + } + + mutating func apply(_ incomingSegments: [SpeakerSegment]) -> ApplyResult { + var result = ApplyResult() + + for incoming in incomingSegments where !incoming.text.isEmpty { + if let segId = incoming.segmentId, + let existingIdx = segments.firstIndex(where: { $0.segmentId == segId }) + { + let oldWords = wordCount(segments[existingIdx].text) + var updated = incoming + if updated.translations.isEmpty && !segments[existingIdx].translations.isEmpty { + updated.translations = segments[existingIdx].translations + } + segments[existingIdx] = updated + totalWordCount += wordCount(updated.text) - oldWords + result.updated += 1 + } else { + segments.append(incoming) + totalSegmentCount += 1 + totalWordCount += wordCount(incoming.text) + result.added += 1 + } + } + + if segments.count > maxInMemorySegments { + segments.removeFirst(segments.count - maxInMemorySegments) + } + + result.totalSegmentCount = totalSegmentCount + result.totalWordCount = totalWordCount + return result + } + + private static func speakerSegment(from segment: TranscriptionService.BackendSegment) -> SpeakerSegment { + SpeakerSegment( + segmentId: segment.id, + speaker: segment.speaker_id ?? 0, + text: segment.text, + start: segment.start, + end: segment.end, + isUser: segment.is_user, + personId: segment.person_id, + translations: (segment.translations ?? []).map { SegmentTranslation(lang: $0.lang, text: $0.text) } + ) + } + + private func wordCount(_ text: String) -> Int { + text.split(separator: " ").count + } +} diff --git a/desktop/Desktop/Sources/BackgroundTranscription/SpeechActivityDetector.swift b/desktop/Desktop/Sources/BackgroundTranscription/SpeechActivityDetector.swift new file mode 100644 index 00000000000..4908c07ca5e --- /dev/null +++ b/desktop/Desktop/Sources/BackgroundTranscription/SpeechActivityDetector.swift @@ -0,0 +1,183 @@ +import Foundation + +struct SpeechActivityDetectionConfiguration: Equatable { + var windowDuration: TimeInterval = 0.02 + var minimumSpeechDuration: TimeInterval = 0.25 + var peakAmplitudeThreshold: Int = 512 + var rmsAmplitudeThreshold: Int = 64 + var maximumSpeechZeroCrossingRate: Double = 0.35 +} + +struct SpeechActivityDecision: Equatable { + enum Reason: String, Equatable { + case speechDetected + case emptyAudio + case insufficientSpeech + case energeticNonSpeech + } + + let shouldUpload: Bool + let reason: Reason + let totalWindows: Int + let energeticWindows: Int + let speechLikeWindows: Int + let rejectedHighZeroCrossingWindows: Int + let maxPeakAmplitude: Int + let maxRMSAmplitude: Double +} + +struct SpeechActivityDetector { + private let configuration: SpeechActivityDetectionConfiguration + private let sampleRate: Int + private let bytesPerSample: Int + + init( + configuration: SpeechActivityDetectionConfiguration, + sampleRate: Int, + bytesPerSample: Int = 2 + ) { + self.configuration = configuration + self.sampleRate = sampleRate + self.bytesPerSample = bytesPerSample + } + + func evaluate(pcmData: Data) -> SpeechActivityDecision { + guard !pcmData.isEmpty else { + return SpeechActivityDecision( + shouldUpload: false, + reason: .emptyAudio, + totalWindows: 0, + energeticWindows: 0, + speechLikeWindows: 0, + rejectedHighZeroCrossingWindows: 0, + maxPeakAmplitude: 0, + maxRMSAmplitude: 0 + ) + } + + let windowBytes = max(bytesPerSample, alignedByteCount(for: configuration.windowDuration)) + let requiredSpeechWindows = max( + 1, + Int(ceil(configuration.minimumSpeechDuration / configuration.windowDuration)) + ) + + var totalWindows = 0 + var energeticWindows = 0 + var speechLikeWindows = 0 + var consecutiveSpeechLikeWindows = 0 + var rejectedHighZeroCrossingWindows = 0 + var maxPeakAmplitude = 0 + var maxRMSAmplitude = 0.0 + + var offset = 0 + while offset < pcmData.count { + let endOffset = min(offset + windowBytes, pcmData.count) + let window = analyzeWindow(pcmData, start: offset, end: endOffset) + totalWindows += 1 + maxPeakAmplitude = max(maxPeakAmplitude, window.peakAmplitude) + maxRMSAmplitude = max(maxRMSAmplitude, window.rmsAmplitude) + + if window.isEnergetic( + peakThreshold: configuration.peakAmplitudeThreshold, + rmsThreshold: configuration.rmsAmplitudeThreshold + ) { + energeticWindows += 1 + if window.zeroCrossingRate <= configuration.maximumSpeechZeroCrossingRate { + speechLikeWindows += 1 + consecutiveSpeechLikeWindows += 1 + if consecutiveSpeechLikeWindows >= requiredSpeechWindows { + return SpeechActivityDecision( + shouldUpload: true, + reason: .speechDetected, + totalWindows: totalWindows, + energeticWindows: energeticWindows, + speechLikeWindows: speechLikeWindows, + rejectedHighZeroCrossingWindows: rejectedHighZeroCrossingWindows, + maxPeakAmplitude: maxPeakAmplitude, + maxRMSAmplitude: maxRMSAmplitude + ) + } + } else { + consecutiveSpeechLikeWindows = 0 + rejectedHighZeroCrossingWindows += 1 + } + } else { + consecutiveSpeechLikeWindows = 0 + } + + offset += windowBytes + } + + let reason: SpeechActivityDecision.Reason = + energeticWindows > 0 && rejectedHighZeroCrossingWindows >= energeticWindows + ? .energeticNonSpeech + : .insufficientSpeech + + return SpeechActivityDecision( + shouldUpload: false, + reason: reason, + totalWindows: totalWindows, + energeticWindows: energeticWindows, + speechLikeWindows: speechLikeWindows, + rejectedHighZeroCrossingWindows: rejectedHighZeroCrossingWindows, + maxPeakAmplitude: maxPeakAmplitude, + maxRMSAmplitude: maxRMSAmplitude + ) + } + + private func analyzeWindow(_ pcmData: Data, start: Int, end: Int) -> WindowActivity { + var peak = 0 + var sumSquares = 0.0 + var sampleCount = 0 + var zeroCrossings = 0 + var previousSample: Int16? + + for offset in stride(from: start, to: end, by: bytesPerSample) { + guard offset + 1 < pcmData.count else { break } + let sample = sampleValue(in: pcmData, at: offset) + let amplitude = abs(Int(sample)) + peak = max(peak, amplitude) + sumSquares += Double(amplitude * amplitude) + if let previousSample, crossesZero(previousSample, sample) { + zeroCrossings += 1 + } + previousSample = sample + sampleCount += 1 + } + + guard sampleCount > 0 else { + return WindowActivity(peakAmplitude: 0, rmsAmplitude: 0, zeroCrossingRate: 0) + } + let rms = sqrt(sumSquares / Double(sampleCount)) + let denominator = max(1, sampleCount - 1) + return WindowActivity( + peakAmplitude: peak, + rmsAmplitude: rms, + zeroCrossingRate: Double(zeroCrossings) / Double(denominator) + ) + } + + private func alignedByteCount(for duration: TimeInterval) -> Int { + (Int(duration * Double(sampleRate)) * bytesPerSample).alignedToSample + } + + private func sampleValue(in pcmData: Data, at offset: Int) -> Int16 { + let low = UInt16(pcmData[offset]) + let high = UInt16(pcmData[offset + 1]) << 8 + return Int16(bitPattern: low | high) + } + + private func crossesZero(_ lhs: Int16, _ rhs: Int16) -> Bool { + (lhs < 0 && rhs > 0) || (lhs > 0 && rhs < 0) + } +} + +private struct WindowActivity { + let peakAmplitude: Int + let rmsAmplitude: Double + let zeroCrossingRate: Double + + func isEnergetic(peakThreshold: Int, rmsThreshold: Int) -> Bool { + peakAmplitude >= peakThreshold || rmsAmplitude >= Double(rmsThreshold) + } +} diff --git a/desktop/Desktop/Sources/MainWindow/Components/SpeakerBubbleView.swift b/desktop/Desktop/Sources/MainWindow/Components/SpeakerBubbleView.swift index 4d2d7652d79..14010e72024 100644 --- a/desktop/Desktop/Sources/MainWindow/Components/SpeakerBubbleView.swift +++ b/desktop/Desktop/Sources/MainWindow/Components/SpeakerBubbleView.swift @@ -27,7 +27,7 @@ struct SpeakerBubbleView: View { private var speakerLabel: String { if isUser { return "You" } if let name = personName { return name } - return "Speaker \(segment.speakerId)" + return "Speaker \(segment.displaySpeakerSuffix)" } private var avatarInitial: String { @@ -35,7 +35,7 @@ struct SpeakerBubbleView: View { if let name = personName, let first = name.first { return String(first).uppercased() } - return String(segment.speakerId) + return segment.speaker == nil ? "?" : String(segment.speakerId) } var body: some View { diff --git a/desktop/Desktop/Sources/MainWindow/Pages/ConversationDetailView.swift b/desktop/Desktop/Sources/MainWindow/Pages/ConversationDetailView.swift index 8d59d6eb43d..3a4697a01dc 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/ConversationDetailView.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/ConversationDetailView.swift @@ -717,7 +717,15 @@ struct ConversationDetailView: View { personId: isUser ? nil : personId, start: oldSegment.start, end: oldSegment.end, - translations: oldSegment.translations + translations: oldSegment.translations, + sttProvider: oldSegment.sttProvider, + sttModel: oldSegment.sttModel, + providerClusterId: oldSegment.providerClusterId, + providerSpeakerLabel: oldSegment.providerSpeakerLabel, + speakerIdentityState: oldSegment.speakerIdentityState, + speakerIdentityConfidence: oldSegment.speakerIdentityConfidence, + speakerIdentitySource: oldSegment.speakerIdentitySource, + speakerIdentityVersion: oldSegment.speakerIdentityVersion ) } loadedConversation = updatedConversation diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index ebb5b42ff06..fe5128e0e57 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -377,6 +377,7 @@ struct SettingsContentView: View { @AppStorage("dev_anthropic_api_key") private var devAnthropicKey: String = "" @AppStorage("dev_openai_api_key") private var devOpenAIKey: String = "" @AppStorage("dev_deepgram_api_key") private var devDeepgramKey: String = "" + @AppStorage("dev_assemblyai_api_key") private var devAssemblyAIKey: String = "" @State private var byokKeyStatuses: [BYOKProvider: BYOKValidator.Status] = [:] @State private var byokActivationError: String? @@ -1069,7 +1070,7 @@ struct SettingsContentView: View { } .buttonStyle(.plain) - // Single Language option + // Single Language option (picker must stay outside Button — nested controls freeze SwiftUI) Button(action: { transcriptionAutoDetect = false AssistantSettings.shared.transcriptionAutoDetect = false @@ -1090,33 +1091,6 @@ struct SettingsContentView: View { Text("Best for speaking in one specific language") .scaledFont(size: 12) .foregroundColor(OmiColors.textTertiary) - - // Language picker (only shown when single language is selected) - if !transcriptionAutoDetect { - HStack { - Text("Language:") - .scaledFont(size: 12) - .foregroundColor(OmiColors.textTertiary) - - Picker("", selection: $transcriptionLanguage) { - ForEach(languageOptions, id: \.0) { option in - Text(option.1).tag(option.0) - } - } - .pickerStyle(.menu) - .frame(width: 180) - .onChange(of: transcriptionLanguage) { _, newValue in - AssistantSettings.shared.transcriptionLanguage = newValue - let supportsMulti = AssistantSettings.supportsAutoDetect(newValue) - transcriptionAutoDetect = supportsMulti - AssistantSettings.shared.transcriptionAutoDetect = supportsMulti - updateTranscriptionPreferences(singleLanguageMode: !supportsMulti) - updateLanguage(newValue) - restartTranscriptionIfNeeded() - } - } - .padding(.top, 4) - } } Spacer() @@ -1136,6 +1110,33 @@ struct SettingsContentView: View { } .buttonStyle(.plain) + if !transcriptionAutoDetect { + HStack(spacing: 12) { + Text("Language:") + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + + Picker("", selection: $transcriptionLanguage) { + ForEach(languageOptions, id: \.0) { option in + Text(option.1).tag(option.0) + } + } + .pickerStyle(.menu) + .frame(width: 180) + .onChange(of: transcriptionLanguage) { _, newValue in + applyTranscriptionLanguageChange(newValue) + } + + Spacer() + } + .padding(.horizontal, 12) + .padding(.vertical, 8) + .background( + RoundedRectangle(cornerRadius: 8) + .fill(OmiColors.backgroundSecondary) + ) + } + // Info about language support HStack(spacing: 8) { Image(systemName: "info.circle") @@ -1303,16 +1304,22 @@ struct SettingsContentView: View { updateTranscriptionPreferences(vocabulary: vocabularyList.joined(separator: ", ")) } + private func applyTranscriptionLanguageChange(_ newValue: String) { + AssistantSettings.shared.transcriptionLanguage = newValue + let supportsMulti = AssistantSettings.supportsAutoDetect(newValue) + transcriptionAutoDetect = supportsMulti + AssistantSettings.shared.transcriptionAutoDetect = supportsMulti + updateTranscriptionPreferences(singleLanguageMode: !supportsMulti) + updateLanguage(newValue) + restartTranscriptionIfNeeded() + } + /// Restart transcription if currently running to apply new settings private func restartTranscriptionIfNeeded() { - guard appState.isTranscribing else { return } + guard appState.isTranscribing || appState.isStartingTranscription else { return } - // Stop and restart to apply new language settings - appState.stopTranscription() - - // Wait a moment for cleanup, then restart - DispatchQueue.main.asyncAfter(deadline: .now() + 1.0) { - self.appState.startTranscription() + Task { + await appState.restartTranscriptionAfterSettingsChange() } } @@ -2138,7 +2145,7 @@ struct SettingsContentView: View { .foregroundColor(OmiColors.textPrimary) Text( APIKeyService.isByokActive - ? "You're using your own OpenAI, Anthropic, Gemini, and Deepgram keys. No subscription." + ? "You're using your own OpenAI, Anthropic, Gemini, and Deepgram keys. No subscription. Add an optional AssemblyAI key for async transcription when enabled on our servers." : "Provide your own OpenAI, Anthropic, Gemini, and Deepgram keys to skip the subscription entirely." ) .scaledFont(size: 12) @@ -5318,11 +5325,20 @@ struct SettingsContentView: View { developerKeyField( provider: .deepgram, title: "Deepgram API Key", - subtitle: "For live transcription.", + subtitle: "For live transcription and prerecorded fallback.", settingId: "advanced.devkeys.deepgram", value: $devDeepgramKey ) + developerKeyField( + provider: .assemblyai, + title: "AssemblyAI API Key", + subtitle: + "Optional. For async transcription (sync, background, postprocess) when enabled server-side. Live listening still uses Deepgram.", + settingId: "advanced.devkeys.assemblyai", + value: $devAssemblyAIKey + ) + if let byokActivationError { settingsCard(settingId: "advanced.devkeys.error") { HStack(spacing: 10) { @@ -5354,11 +5370,12 @@ struct SettingsContentView: View { .onChange(of: devAnthropicKey) { _, _ in refreshBYOKActivation() } .onChange(of: devGeminiKey) { _, _ in refreshBYOKActivation() } .onChange(of: devDeepgramKey) { _, _ in refreshBYOKActivation() } + .onChange(of: devAssemblyAIKey) { _, _ in refreshBYOKActivation() } } private var hasAnyBYOKKey: Bool { !devOpenAIKey.isEmpty || !devAnthropicKey.isEmpty || !devGeminiKey.isEmpty - || !devDeepgramKey.isEmpty + || !devDeepgramKey.isEmpty || !devAssemblyAIKey.isEmpty } private var hasAllBYOKKeys: Bool { @@ -5378,8 +5395,8 @@ struct SettingsContentView: View { .foregroundColor(OmiColors.textPrimary) Text( hasAllBYOKKeys - ? "You're paying your own providers. Omi skips the subscription charge. Keys stay on this Mac." - : "Provide all four keys (OpenAI, Anthropic, Gemini, Deepgram) to switch to the free plan. Keys stay on this Mac — we never store them on our servers." + ? "You're paying your own providers. Omi skips the subscription charge. Keys stay on this Mac. AssemblyAI is optional for async transcription." + : "Provide all four keys (OpenAI, Anthropic, Gemini, Deepgram) to switch to the free plan. AssemblyAI is optional. Keys stay on this Mac — we never store them on our servers." ) .scaledFont(size: 12) .foregroundColor(OmiColors.textTertiary) @@ -5394,6 +5411,7 @@ struct SettingsContentView: View { devAnthropicKey = "" devGeminiKey = "" devDeepgramKey = "" + devAssemblyAIKey = "" Task { try? await APIClient.shared.deactivateBYOK() } @@ -5402,33 +5420,39 @@ struct SettingsContentView: View { private func refreshBYOKActivation() { Task { if APIKeyService.isByokActive { - // Validate before flipping the backend flag — otherwise we'd put the - // user on the free plan with dead keys and every chat would 401. - let snapshot = APIKeyService.byokSnapshot.reduce(into: [BYOKProvider: String]()) { - acc, entry in acc[entry.key] = entry.value.key - } - let results = await BYOKValidator.validateAll(snapshot) - let allOk = results.allSatisfy { - if case .ok = $0.value { return true } + // Validate required four before flipping the backend flag. + let requiredSnapshot = BYOKProvider.requiredForFreePlan.reduce(into: [BYOKProvider: String]()) { + acc, provider in + if let key = APIKeyService.byokKey(provider) { + acc[provider] = key + } + } + var keysToValidate = requiredSnapshot + if let assemblyKey = APIKeyService.byokKey(.assemblyai) { + keysToValidate[.assemblyai] = assemblyKey + } + let results = await BYOKValidator.validateAll(keysToValidate) + let requiredOk = BYOKProvider.requiredForFreePlan.allSatisfy { + if case .ok = results[$0] ?? .notChecked { return true } return false } - if allOk { - let fingerprints = APIKeyService.byokSnapshot.reduce(into: [String: String]()) { - acc, entry in acc[entry.key.rawValue] = entry.value.fingerprint - } - try? await APIClient.shared.activateBYOK(fingerprints: fingerprints) + if requiredOk { + try? await APIClient.shared.activateBYOK(fingerprints: APIKeyService.byokActivationFingerprints) await FloatingBarUsageLimiter.shared.fetchPlan() await MainActor.run { - // Clear any sticky paywall flag from a prior `freemium_threshold_reached` - // event — once all 4 BYOK keys validate, the user is on the free BYOK - // plan and shouldn't be locked out of capture/transcription anymore. AppState.current?.isPaywalled = false byokKeyStatuses = results - byokActivationError = nil + if case .failed(let msg) = results[.assemblyai] ?? .notChecked { + byokActivationError = + "Optional AssemblyAI key was rejected: \(msg). Free plan is still active with your four required keys." + } else { + byokActivationError = nil + } } } else { - let failed = results.filter { - if case .ok = $0.value { return false } + let failed = results.filter { provider, status in + guard BYOKProvider.requiredForFreePlan.contains(provider) else { return false } + if case .ok = status { return false } return true } let names = failed.keys.map(\.displayName).sorted().joined(separator: ", ") @@ -5437,7 +5461,7 @@ struct SettingsContentView: View { await MainActor.run { byokKeyStatuses = results byokActivationError = - "Rejected by provider: \(names). Free plan stays off until all 4 keys authenticate." + "Rejected by provider: \(names). Free plan stays off until all 4 required keys authenticate." } } } else { diff --git a/desktop/Desktop/Sources/OnboardingBYOKStepView.swift b/desktop/Desktop/Sources/OnboardingBYOKStepView.swift index e4f1cb8dab8..6c895c1a11e 100644 --- a/desktop/Desktop/Sources/OnboardingBYOKStepView.swift +++ b/desktop/Desktop/Sources/OnboardingBYOKStepView.swift @@ -138,7 +138,7 @@ struct OnboardingBYOKStepView: View { .gemini: geminiKey, .deepgram: deepgramKey, ] - for provider in BYOKProvider.allCases { + for provider in keysToCheck.keys { keyStatuses[provider] = .checking } let results = await BYOKValidator.validateAll(keysToCheck) @@ -156,12 +156,7 @@ struct OnboardingBYOKStepView: View { // Step 2: all four authenticate — flip the backend flag. do { - try await APIClient.shared.activateBYOK(fingerprints: BYOKProvider.allCases.reduce(into: [:]) { - acc, provider in - if let key = APIKeyService.byokKey(provider) { - acc[provider.rawValue] = APIKeyService.byokFingerprint(key) - } - }) + try await APIClient.shared.activateBYOK(fingerprints: APIKeyService.byokActivationFingerprints) // Refresh the in-memory quota snapshot — otherwise the client keeps // blocking chat against the stale basic-tier 30-message cap. await FloatingBarUsageLimiter.shared.fetchPlan() diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift b/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift index 38d8bf229fc..cc7a876607f 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift @@ -29,7 +29,7 @@ class AssistantSettings { private let defaultTranscriptionAutoDetect = true private let defaultTranscriptionVocabulary: [String] = [] private let defaultVadGateEnabled = false - private let defaultBatchTranscriptionEnabled = false + private let defaultBatchTranscriptionEnabled = true private init() { // Register defaults @@ -156,6 +156,12 @@ class AssistantSettings { return transcriptionLanguage } + /// AssemblyAI prerecorded language detection is brittle on short or low-speech chunks. + /// Cloud batch uses the selected single language even when streaming auto-detect is enabled. + var effectiveBatchTranscriptionLanguage: String { + return transcriptionLanguage + } + /// Custom vocabulary for improved transcription accuracy /// Array of words/terms that DeepGram should recognize (Nova-3 limit: 500 tokens total) var transcriptionVocabulary: [String] { @@ -183,20 +189,20 @@ class AssistantSettings { } } - /// Whether batch transcription mode is enabled (transcribes audio in chunks at silence boundaries) - var batchTranscriptionEnabled: Bool { - get { UserDefaults.standard.bool(forKey: batchTranscriptionEnabledKey) } + /// Whether local VAD gate is enabled to skip silence and reduce Deepgram usage + var vadGateEnabled: Bool { + get { UserDefaults.standard.bool(forKey: vadGateEnabledKey) } set { - UserDefaults.standard.set(newValue, forKey: batchTranscriptionEnabledKey) + UserDefaults.standard.set(newValue, forKey: vadGateEnabledKey) NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) } } - /// Whether local VAD gate is enabled to skip silence and reduce Deepgram usage - var vadGateEnabled: Bool { - get { UserDefaults.standard.bool(forKey: vadGateEnabledKey) } + /// Whether cloud batch transcription mode is enabled for microphone background audio. + var batchTranscriptionEnabled: Bool { + get { true } set { - UserDefaults.standard.set(newValue, forKey: vadGateEnabledKey) + UserDefaults.standard.set(newValue, forKey: batchTranscriptionEnabledKey) NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) } } diff --git a/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift b/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift index 1035bb86c3b..e9614ea3203 100644 --- a/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift +++ b/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift @@ -2156,6 +2156,19 @@ actor RewindDatabase { } } + migrator.registerMigration("addCanonicalSpeakerMetadata") { db in + try db.alter(table: "transcription_segments") { t in + t.add(column: "sttProvider", .text) + t.add(column: "sttModel", .text) + t.add(column: "providerClusterId", .text) + t.add(column: "providerSpeakerLabel", .text) + t.add(column: "speakerIdentityState", .text) + t.add(column: "speakerIdentityConfidence", .double) + t.add(column: "speakerIdentitySource", .text) + t.add(column: "speakerIdentityVersion", .text) + } + } + try migrator.migrate(queue) } diff --git a/desktop/Desktop/Sources/Rewind/Core/TranscriptionModels.swift b/desktop/Desktop/Sources/Rewind/Core/TranscriptionModels.swift index 9fe1a3f19cb..011f79beb47 100644 --- a/desktop/Desktop/Sources/Rewind/Core/TranscriptionModels.swift +++ b/desktop/Desktop/Sources/Rewind/Core/TranscriptionModels.swift @@ -193,6 +193,14 @@ struct TranscriptionSegmentRecord: Codable, FetchableRecord, PersistableRecord, var isUser: Bool // Whether this segment is from the user var personId: String? // Associated person ID (if identified) var translationsJson: String? // JSON-encoded [TranscriptTranslation] + var sttProvider: String? + var sttModel: String? + var providerClusterId: String? + var providerSpeakerLabel: String? + var speakerIdentityState: String? + var speakerIdentityConfidence: Double? + var speakerIdentitySource: String? + var speakerIdentityVersion: String? static let databaseTableName = "transcription_segments" @@ -212,7 +220,15 @@ struct TranscriptionSegmentRecord: Codable, FetchableRecord, PersistableRecord, speakerLabel: String? = nil, isUser: Bool = false, personId: String? = nil, - translationsJson: String? = nil + translationsJson: String? = nil, + sttProvider: String? = nil, + sttModel: String? = nil, + providerClusterId: String? = nil, + providerSpeakerLabel: String? = nil, + speakerIdentityState: String? = nil, + speakerIdentityConfidence: Double? = nil, + speakerIdentitySource: String? = nil, + speakerIdentityVersion: String? = nil ) { self.id = id self.sessionId = sessionId @@ -228,6 +244,14 @@ struct TranscriptionSegmentRecord: Codable, FetchableRecord, PersistableRecord, self.isUser = isUser self.personId = personId self.translationsJson = translationsJson + self.sttProvider = sttProvider + self.sttModel = sttModel + self.providerClusterId = providerClusterId + self.providerSpeakerLabel = providerSpeakerLabel + self.speakerIdentityState = speakerIdentityState + self.speakerIdentityConfidence = speakerIdentityConfidence + self.speakerIdentitySource = speakerIdentitySource + self.speakerIdentityVersion = speakerIdentityVersion } // MARK: - Persistence Callbacks @@ -429,7 +453,15 @@ extension TranscriptionSegmentRecord { speakerLabel: segment.speaker, isUser: segment.isUser, personId: segment.personId, - translationsJson: translationsJson + translationsJson: translationsJson, + sttProvider: segment.sttProvider, + sttModel: segment.sttModel, + providerClusterId: segment.providerClusterId, + providerSpeakerLabel: segment.providerSpeakerLabel, + speakerIdentityState: segment.speakerIdentityState, + speakerIdentityConfidence: segment.speakerIdentityConfidence, + speakerIdentitySource: segment.speakerIdentitySource, + speakerIdentityVersion: segment.speakerIdentityVersion ) } @@ -448,7 +480,15 @@ extension TranscriptionSegmentRecord { personId: personId, start: startTime, end: endTime, - translations: translations + translations: translations, + sttProvider: sttProvider, + sttModel: sttModel, + providerClusterId: providerClusterId, + providerSpeakerLabel: providerSpeakerLabel, + speakerIdentityState: speakerIdentityState, + speakerIdentityConfidence: speakerIdentityConfidence, + speakerIdentitySource: speakerIdentitySource, + speakerIdentityVersion: speakerIdentityVersion ) } } diff --git a/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift b/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift index 7b585fd49c6..c0c0ece5eba 100644 --- a/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift +++ b/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift @@ -299,7 +299,15 @@ actor TranscriptionStorage { isUser: Bool = false, personId: String? = nil, speakerLabel: String? = nil, - translationsJson: String? = nil + translationsJson: String? = nil, + sttProvider: String? = nil, + sttModel: String? = nil, + providerClusterId: String? = nil, + providerSpeakerLabel: String? = nil, + speakerIdentityState: String? = nil, + speakerIdentityConfidence: Double? = nil, + speakerIdentitySource: String? = nil, + speakerIdentityVersion: String? = nil ) async throws -> Int64 { let db = try await ensureInitialized() @@ -311,10 +319,22 @@ actor TranscriptionStorage { UPDATE transcription_segments SET text = ?, speaker = ?, startTime = ?, endTime = ?, isUser = ?, personId = ?, speakerLabel = COALESCE(?, speakerLabel), - translationsJson = COALESCE(?, translationsJson) + translationsJson = COALESCE(?, translationsJson), + sttProvider = COALESCE(?, sttProvider), + sttModel = COALESCE(?, sttModel), + providerClusterId = COALESCE(?, providerClusterId), + providerSpeakerLabel = COALESCE(?, providerSpeakerLabel), + speakerIdentityState = COALESCE(?, speakerIdentityState), + speakerIdentityConfidence = COALESCE(?, speakerIdentityConfidence), + speakerIdentitySource = COALESCE(?, speakerIdentitySource), + speakerIdentityVersion = COALESCE(?, speakerIdentityVersion) WHERE sessionId = ? AND segmentId = ? """, - arguments: [text, speaker, startTime, endTime, isUser, personId, speakerLabel, translationsJson, sessionId, segId] + arguments: [ + text, speaker, startTime, endTime, isUser, personId, speakerLabel, translationsJson, + sttProvider, sttModel, providerClusterId, providerSpeakerLabel, speakerIdentityState, + speakerIdentityConfidence, speakerIdentitySource, speakerIdentityVersion, sessionId, segId + ] ) return database.changesCount > 0 } @@ -343,7 +363,15 @@ actor TranscriptionStorage { speakerLabel: speakerLabel, isUser: isUser, personId: personId, - translationsJson: translationsJson + translationsJson: translationsJson, + sttProvider: sttProvider, + sttModel: sttModel, + providerClusterId: providerClusterId, + providerSpeakerLabel: providerSpeakerLabel, + speakerIdentityState: speakerIdentityState, + speakerIdentityConfidence: speakerIdentityConfidence, + speakerIdentitySource: speakerIdentitySource, + speakerIdentityVersion: speakerIdentityVersion ) let record = try await db.write { database in diff --git a/desktop/Desktop/Sources/TranscriptionService.swift b/desktop/Desktop/Sources/TranscriptionService.swift index 33e30c91220..cdfd728f7ff 100644 --- a/desktop/Desktop/Sources/TranscriptionService.swift +++ b/desktop/Desktop/Sources/TranscriptionService.swift @@ -1,4 +1,5 @@ import Foundation +import CryptoKit /// Service for real-time speech-to-text transcription. /// Conversation capture: Python backend `/v4/listen` WebSocket (speech profiles, speaker assignment, memory events). @@ -37,6 +38,14 @@ class TranscriptionService { let start: Double let end: Double let translations: [BackendTranslation]? + let stt_provider: String? + let stt_model: String? + let provider_cluster_id: String? + let provider_speaker_label: String? + let speaker_identity_state: String? + let speaker_identity_confidence: Double? + let speaker_identity_source: String? + let speaker_identity_version: String? } /// Message event (from `/v4/listen` only — not used by PTT transcribe-stream) @@ -105,7 +114,10 @@ class TranscriptionService { /// (Cloud Run), which does not have /v2/voice-message/* or /v4/listen endpoints. private static let pythonBackendBaseURL: String = DesktopBackendEnvironment.pythonBaseURL() - private static func sanitizedContextKeywords(_ keywords: [String]) -> [String] { + private static func sanitizedContextKeywords( + _ keywords: [String], + includeDefaultOmi: Bool = true + ) -> [String] { let stopWords: Set = [ "about", "after", "again", "all", "also", "and", "app", "are", "ask", "back", "browser", "but", "can", "chat", "code", "done", "each", "for", "from", "get", "has", "have", "help", "here", "home", "how", @@ -115,7 +127,8 @@ class TranscriptionService { ] var seen = Set() var result: [String] = [] - for keyword in ["Omi", "OMI"] + keywords { + let sourceKeywords = (includeDefaultOmi ? ["Omi", "OMI"] : []) + keywords + for keyword in sourceKeywords { let normalized = keyword .replacingOccurrences(of: #"\s+"#, with: " ", options: .regularExpression) .trimmingCharacters(in: .whitespacesAndNewlines) @@ -635,6 +648,117 @@ extension TranscriptionService { return transcript } + /// Transcribe a background recording chunk through the Python backend `/v2/desktop/background-transcribe`. + /// This is used by cloud batch background transcription only; PTT stays on `batchTranscribe`. + static func batchTranscribeSegments( + audioData: Data, + conversationId: String, + chunkStartMs: Int, + language: String, + contextKeywords: [String] = [] + ) async throws -> [BackendSegment] { + let authService = await MainActor.run { AuthService.shared } + let authHeader = try await authService.getAuthHeader() + let baseURLString = "\(pythonBackendBaseURL)v2/desktop/background-transcribe" + + guard var components = URLComponents(string: baseURLString) else { + throw TranscriptionError.connectionFailed(NSError(domain: "Invalid backend URL", code: -1)) + } + let sanitizedKeywords = sanitizedContextKeywords(contextKeywords, includeDefaultOmi: false) + var queryItems = [ + URLQueryItem(name: "conversation_id", value: conversationId), + URLQueryItem( + name: "chunk_id", + value: backgroundChunkId( + conversationId: conversationId, + chunkStartMs: chunkStartMs, + audioData: audioData + ) + ), + URLQueryItem(name: "chunk_start_ms", value: String(chunkStartMs)), + URLQueryItem(name: "language", value: language), + URLQueryItem(name: "sample_rate", value: "16000"), + URLQueryItem(name: "encoding", value: "linear16"), + URLQueryItem(name: "channels", value: "1"), + ] + if !sanitizedKeywords.isEmpty { + queryItems.append(URLQueryItem(name: "keywords", value: sanitizedKeywords.joined(separator: ","))) + } + components.queryItems = queryItems + + guard let url = components.url else { + throw TranscriptionError.connectionFailed(NSError(domain: "Invalid URL", code: -1)) + } + + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.setValue(authHeader, forHTTPHeaderField: "Authorization") + request.setValue("application/octet-stream", forHTTPHeaderField: "Content-Type") + request.setValue("desktop", forHTTPHeaderField: "X-App-Platform") + for (provider, entry) in APIKeyService.byokSnapshot { + request.setValue(entry.key, forHTTPHeaderField: provider.headerName) + } + request.httpBody = audioData + + var lastError: Error? + for attempt in 0..<3 { + let data: Data + let response: URLResponse + do { + log("TranscriptionService: Background batch transcribing \(audioData.count) bytes at \(chunkStartMs)ms, attempt=\(attempt + 1), contextKeywords=\(sanitizedKeywords.count)") + (data, response) = try await URLSession.shared.data(for: request) + } catch { + lastError = error + if attempt == 2 { + break + } + let delayNs = UInt64(pow(2.0, Double(attempt)) * 250_000_000) + try await Task.sleep(nanoseconds: delayNs) + continue + } + + guard let httpResponse = response as? HTTPURLResponse else { + throw TranscriptionError.invalidResponse + } + + if httpResponse.statusCode == 200 { + let decoded = try JSONDecoder().decode(BackgroundTranscribeResponse.self, from: data) + let provider = decoded.provider ?? "unknown" + log("TranscriptionService: Background batch transcription completed provider=\(provider), run_id=\(decoded.run_id ?? "nil"), segments=\(decoded.segments.count), chunkStartMs=\(chunkStartMs)") + if provider != "assemblyai" { + log("TranscriptionService: Background batch expected AssemblyAI but backend returned provider=\(provider)") + } + return decoded.segments + } + + let body = String(data: data, encoding: .utf8) ?? "no body" + logError("TranscriptionService: Background batch transcription failed with status \(httpResponse.statusCode): \(body)", error: nil) + if httpResponse.statusCode == 413 { + throw TranscriptionError.payloadTooLarge + } + if !(500...599).contains(httpResponse.statusCode) { + throw TranscriptionError.invalidResponse + } + + lastError = TranscriptionError.invalidResponse + let delayNs = UInt64(pow(2.0, Double(attempt)) * 250_000_000) + try await Task.sleep(nanoseconds: delayNs) + } + + throw lastError ?? TranscriptionError.invalidResponse + } + + static func backgroundChunkId( + conversationId: String, + chunkStartMs: Int, + audioData: Data + ) -> String { + let digest = SHA256.hash(data: audioData) + let hashPrefix = digest.prefix(12).map { String(format: "%02x", $0) }.joined() + return "\(conversationId)-\(chunkStartMs)-\(audioData.count)-\(hashPrefix)" + .replacingOccurrences(of: #"[^A-Za-z0-9_-]"#, with: "_", options: .regularExpression) + } + } /// Response model for Python backend `/v2/voice-message/transcribe` (batch PTT) @@ -642,3 +766,11 @@ private struct PythonTranscribeResponse: Decodable { let transcript: String let language: String? } + +/// Response model for Python backend `/v2/desktop/background-transcribe`. +struct BackgroundTranscribeResponse: Decodable { + let segments: [TranscriptionService.BackendSegment] + let language: String? + let provider: String? + let run_id: String? +} diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index 788c9cd0b4e..3e20e88c204 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -299,6 +299,22 @@ final class APIClientRoutingTests: XCTestCase { label: "deleteConversation") } + func testFinishBackgroundConversationRoutesToExplicitPythonConversation() async { + let client = await makeTestClient() + _ = try? await client.finishBackgroundConversation(conversationId: "batch-123") as ServerConversation + assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v2/desktop/background-conversation/batch-123/finish", method: "POST", + label: "finishBackgroundConversation") + } + + func testDesktopCapabilitiesRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.getDesktopCapabilities() as DesktopCapabilitiesResponse + assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v2/desktop/capabilities", method: "GET", + label: "getDesktopCapabilities") + } + // -- Conversations: manual URL(string: baseURL + ...) paths (PATCH → Python) -- func testSetConversationStarredRoutesToPython() async { diff --git a/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift new file mode 100644 index 00000000000..bc07eaa9109 --- /dev/null +++ b/desktop/Desktop/Tests/BackgroundTranscriptionTests.swift @@ -0,0 +1,777 @@ +import XCTest + +@testable import Omi_Computer + +final class BackgroundTranscriptionTests: XCTestCase { + func testChunkerCutsAtSilenceAndRetainsOverlap() { + var chunker = BackgroundAudioChunker( + configuration: BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 3.0, + minChunkDuration: 1.0, + overlapDuration: 0.5, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 4 + ) + ) + + let chunks = chunker.append( + pcmData: pcm(samples: Array(repeating: 1_000, count: 10) + [0, 0, 0]), startTime: 0) + + XCTAssertEqual(chunks.count, 1) + XCTAssertEqual(chunks[0].startTime, 0) + XCTAssertEqual(sampleCount(chunks[0].pcmData), 10) + XCTAssertFalse(chunks[0].isFinal) + + let final = chunker.finishInput() + XCTAssertEqual(final.count, 1) + XCTAssertEqual(final[0].startTime, 0.5, accuracy: 0.001) + XCTAssertEqual(sampleCount(final[0].pcmData), 8) + XCTAssertTrue(final[0].isFinal) + } + + func testChunkerHardCutsAtMaxDurationWithoutSilence() { + var chunker = BackgroundAudioChunker( + configuration: BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 1.0, + minChunkDuration: 0.5, + overlapDuration: 0.2, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 4 + ) + ) + + let chunks = chunker.append( + pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 3) + + XCTAssertEqual(chunks.count, 1) + XCTAssertEqual(chunks[0].startTime, 3) + XCTAssertEqual(sampleCount(chunks[0].pcmData), 10) + XCTAssertEqual(sampleCount(chunker.finishInput()[0].pcmData), 4) + } + + func testSessionBackpressuresWhenPendingQueueIsFull() async throws { + let configuration = BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 1.0, + minChunkDuration: 0.5, + overlapDuration: 0, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 1 + ) + var transcribedStarts: [Double] = [] + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + transcribedStarts.append(chunk.startTime) + return [ + Self.backendSegment( + id: "chunk-\(chunk.startTime)", text: "hello", start: chunk.startTime, + end: chunk.startTime + 1) + ] + } + + let first = session.append( + pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 0) + let second = session.append( + pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 1.2) + + XCTAssertEqual(first.enqueuedChunks, 1) + XCTAssertEqual(first.pendingChunkCount, 1) + XCTAssertTrue(first.isBackpressured) + XCTAssertEqual(second.enqueuedChunks, 0) + XCTAssertEqual(second.pendingChunkCount, 1) + XCTAssertEqual(second.acceptedInputBytes, 0) + XCTAssertTrue(second.isBackpressured) + XCTAssertEqual(session.pendingChunkCount, 1) + XCTAssertTrue(transcribedStarts.isEmpty) + } + + func testChunkerDoesNotLoopWhenOverlapEqualsMinimumCut() { + var chunker = BackgroundAudioChunker( + configuration: BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 3.0, + minChunkDuration: 1.0, + overlapDuration: 1.0, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 4 + ) + ) + + let chunks = chunker.append( + pcmData: pcm(samples: Array(repeating: 1_000, count: 10) + [0, 0, 0]), startTime: 0) + + XCTAssertEqual(chunks.count, 1) + XCTAssertEqual(sampleCount(chunks[0].pcmData), 11) + let final = chunker.finishInput() + XCTAssertEqual(final.count, 1) + XCTAssertEqual(sampleCount(final[0].pcmData), 12) + } + + func testFifteenSecondContinuousSpeechProducesChunkThroughSession() async throws { + let configuration = BackgroundTranscriptionConfiguration.cloudBatch + var transcribedStarts: [Double] = [] + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + transcribedStarts.append(chunk.startTime) + return [ + Self.backendSegment( + id: "chunk-\(chunk.startTime)", + text: "continuous speech", + start: chunk.startTime, + end: chunk.startTime + 1 + ) + ] + } + + let samplesPerFrame = configuration.sampleRate / 10 + let frame = pcm(samples: Array(repeating: 1_000, count: samplesPerFrame)) + var enqueuedChunks = 0 + + for frameIndex in 0..<160 { + let result = session.append(pcmData: frame, startTime: Double(frameIndex) / 10.0) + enqueuedChunks += result.enqueuedChunks + } + + XCTAssertEqual(enqueuedChunks, 1) + XCTAssertEqual(session.pendingChunkCount, 1) + + let first = try await session.transcribeNext() + XCTAssertNotNil(first) + XCTAssertEqual(first!.chunk.startTime, 0, accuracy: 0.001) + XCTAssertEqual(sampleCount(first!.chunk.pcmData), configuration.sampleRate * 15) + XCTAssertEqual(transcribedStarts, [0]) + XCTAssertEqual(session.snapshot().processedChunkCount, 1) + XCTAssertEqual(session.snapshot().segments.count, 1) + } + + func testCloudBatchDefaultKeepsFixedFifteenSecondFallback() { + let configuration = BackgroundTranscriptionConfiguration.cloudBatch + var chunker = BackgroundAudioChunker(configuration: configuration) + + XCTAssertFalse(configuration.usesSilenceAwareChunking) + XCTAssertEqual(configuration.minChunkDuration, 15.0) + XCTAssertEqual(configuration.maxChunkDuration, 15.0) + + let samplesPerFrame = configuration.sampleRate / 10 + let speechFrame = pcm(samples: Array(repeating: 1_000, count: samplesPerFrame)) + let silenceFrame = pcm(samples: Array(repeating: 0, count: samplesPerFrame)) + var chunks: [BackgroundAudioChunk] = [] + + for frameIndex in 0..<70 { + let frame = frameIndex < 65 ? speechFrame : silenceFrame + chunks.append(contentsOf: chunker.append(pcmData: frame, startTime: Double(frameIndex) / 10.0)) + } + XCTAssertTrue(chunks.isEmpty) + + for frameIndex in 70..<160 { + chunks.append(contentsOf: chunker.append(pcmData: speechFrame, startTime: Double(frameIndex) / 10.0)) + } + + XCTAssertEqual(chunks.count, 1) + XCTAssertEqual(chunks[0].startTime, 0, accuracy: 0.001) + XCTAssertEqual(sampleCount(chunks[0].pcmData), configuration.sampleRate * 15) + } + + func testSilenceAwareCloudBatchCandidateFlushesOnSilenceBeforeHardCap() { + let configuration = BackgroundTranscriptionConfiguration.silenceAwareCloudBatchCandidate + var chunker = BackgroundAudioChunker(configuration: configuration) + + XCTAssertTrue(configuration.usesSilenceAwareChunking) + XCTAssertEqual(configuration.maxChunkDuration, 15.0) + XCTAssertLessThan(configuration.minChunkDuration, configuration.maxChunkDuration) + + let speechSamples = Array(repeating: Int16(1_000), count: Int(Double(configuration.sampleRate) * 6.5)) + let silenceSamples = Array(repeating: Int16(0), count: Int(Double(configuration.sampleRate) * 0.5)) + let chunks = chunker.append(pcmData: pcm(samples: speechSamples + silenceSamples), startTime: 30) + + XCTAssertEqual(chunks.count, 1) + XCTAssertEqual(chunks[0].startTime, 30, accuracy: 0.001) + XCTAssertLessThan(sampleCount(chunks[0].pcmData), configuration.sampleRate * 15) + XCTAssertGreaterThanOrEqual(sampleCount(chunks[0].pcmData), Int(Double(configuration.sampleRate) * 6.5)) + XCTAssertFalse(chunks[0].isFinal) + } + + func testSilenceAwareCloudBatchCandidateStillHardCapsContinuousSpeech() { + let configuration = BackgroundTranscriptionConfiguration.silenceAwareCloudBatchCandidate + var chunker = BackgroundAudioChunker(configuration: configuration) + + let samples = Array(repeating: Int16(1_000), count: configuration.sampleRate * 16) + let chunks = chunker.append(pcmData: pcm(samples: samples), startTime: 12) + + XCTAssertEqual(chunks.count, 1) + XCTAssertEqual(chunks[0].startTime, 12, accuracy: 0.001) + XCTAssertEqual(sampleCount(chunks[0].pcmData), configuration.sampleRate * 15) + } + + func testCloudBatchDropsSilentChunksBeforeUpload() async throws { + let configuration = BackgroundTranscriptionConfiguration.cloudBatch + var uploadCount = 0 + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + uploadCount += 1 + return [ + Self.backendSegment( + id: "chunk-\(chunk.startTime)", + text: "should not upload", + start: chunk.startTime, + end: chunk.startTime + 1 + ) + ] + } + + let samplesPerFrame = configuration.sampleRate / 10 + let frame = pcm(samples: Array(repeating: 0, count: samplesPerFrame)) + var enqueuedChunks = 0 + + for frameIndex in 0..<160 { + let result = session.append(pcmData: frame, startTime: Double(frameIndex) / 10.0) + enqueuedChunks += result.enqueuedChunks + } + + XCTAssertEqual(enqueuedChunks, 0) + XCTAssertEqual(session.pendingChunkCount, 0) + let uploaded = try await session.transcribeNext() + XCTAssertNil(uploaded) + XCTAssertEqual(uploadCount, 0) + XCTAssertEqual(session.snapshot().droppedChunkCount, 1) + XCTAssertEqual(session.snapshot().lastSpeechActivityDecision?.reason, .insufficientSpeech) + } + + func testCloudBatchUploadsChunkWithMinimumSpeechEnergy() async throws { + let configuration = BackgroundTranscriptionConfiguration.cloudBatch + var uploadCount = 0 + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + uploadCount += 1 + return [ + Self.backendSegment( + id: "chunk-\(chunk.startTime)", + text: "speech", + start: chunk.startTime, + end: chunk.startTime + 1 + ) + ] + } + + let speechSamples = Array( + repeating: Int16(1_000), count: Int(Double(configuration.sampleRate) * 1.0)) + let silenceSamples = Array( + repeating: Int16(0), + count: configuration.sampleRate * 16 - speechSamples.count + ) + let result = session.append( + pcmData: pcm(samples: speechSamples + silenceSamples), + startTime: 0 + ) + + XCTAssertEqual(result.enqueuedChunks, 1) + XCTAssertEqual(session.pendingChunkCount, 1) + + let uploaded = try await session.transcribeNext() + XCTAssertEqual(uploaded?.chunk.startTime, 0) + XCTAssertEqual(uploadCount, 1) + XCTAssertEqual(session.snapshot().lastSpeechActivityDecision?.reason, .speechDetected) + } + + func testCloudBatchRejectsEnergeticNonSpeechNoiseBeforeUpload() async throws { + let configuration = BackgroundTranscriptionConfiguration.cloudBatch + var uploadCount = 0 + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + uploadCount += 1 + return [ + Self.backendSegment( + id: "chunk-\(chunk.startTime)", + text: "noise", + start: chunk.startTime, + end: chunk.startTime + 1 + ) + ] + } + + let samples = (0..<(configuration.sampleRate * 16)).map { index -> Int16 in + index.isMultiple(of: 2) ? 2_000 : -2_000 + } + let result = session.append(pcmData: pcm(samples: samples), startTime: 0) + + XCTAssertEqual(result.enqueuedChunks, 0) + XCTAssertEqual(session.pendingChunkCount, 0) + let uploaded = try await session.transcribeNext() + XCTAssertNil(uploaded) + XCTAssertEqual(uploadCount, 0) + XCTAssertEqual(session.snapshot().droppedChunkCount, 1) + XCTAssertEqual(session.snapshot().lastSpeechActivityDecision?.reason, .energeticNonSpeech) + XCTAssertGreaterThan( + session.snapshot().lastSpeechActivityDecision?.rejectedHighZeroCrossingWindows ?? 0, + 0 + ) + } + + func testSilenceAwareCloudBatchCandidateDropsShortSpeechBeforeUpload() async throws { + let configuration = BackgroundTranscriptionConfiguration.silenceAwareCloudBatchCandidate + var uploadCount = 0 + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + uploadCount += 1 + return [ + Self.backendSegment( + id: "short-speech", + text: "short", + start: chunk.startTime, + end: chunk.startTime + 1 + ) + ] + } + + let shortSpeech = Array( + repeating: Int16(1_000), count: Int(Double(configuration.sampleRate) * 0.2)) + let silence = Array( + repeating: Int16(0), count: Int(Double(configuration.sampleRate) * 6.5)) + let result = session.append(pcmData: pcm(samples: shortSpeech + silence), startTime: 0) + + XCTAssertEqual(result.enqueuedChunks, 0) + XCTAssertEqual(session.pendingChunkCount, 0) + let shortSpeechUpload = try await session.transcribeNext() + XCTAssertNil(shortSpeechUpload) + XCTAssertEqual(uploadCount, 0) + XCTAssertEqual(session.snapshot().droppedChunkCount, 1) + XCTAssertEqual(session.snapshot().lastSpeechActivityDecision?.reason, .insufficientSpeech) + } + + func testCloudBatchFinishSplitsLongTailAtFifteenSecondWindows() { + let configuration = BackgroundTranscriptionConfiguration.cloudBatch + var chunker = BackgroundAudioChunker(configuration: configuration) + + let samples = Array(repeating: Int16(1_000), count: configuration.sampleRate * 31) + let chunks = chunker.append(pcmData: pcm(samples: samples), startTime: 0) + let final = chunker.finishInput() + + XCTAssertEqual(chunks.count, 1) + XCTAssertEqual(sampleCount(chunks[0].pcmData), configuration.sampleRate * 15) + XCTAssertEqual(final.count, 2) + XCTAssertEqual(sampleCount(final[0].pcmData), configuration.sampleRate * 15) + XCTAssertLessThanOrEqual(sampleCount(final[1].pcmData), configuration.sampleRate * 2) + XCTAssertTrue(final[1].isFinal) + } + + func testSessionFinishEnqueuesTailEvenWhenLiveQueueIsFull() async throws { + let configuration = BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 1.0, + minChunkDuration: 1.0, + overlapDuration: 0, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 1 + ) + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + [ + Self.backendSegment( + id: "chunk-\(chunk.startTime)", text: "chunk", start: chunk.startTime, + end: chunk.startTime + 1) + ] + } + + let append = session.append( + pcmData: pcm(samples: Array(repeating: 1_000, count: 16)), startTime: 0) + let finish = session.finishInput() + + XCTAssertEqual(append.enqueuedChunks, 1) + XCTAssertEqual(finish.enqueuedChunks, 1) + XCTAssertEqual(finish.pendingChunkCount, 2) + XCTAssertTrue(finish.isBackpressured) + + let first = try await session.transcribeNext() + let tail = try await session.transcribeNext() + XCTAssertEqual(first?.chunk.startTime, 0) + XCTAssertEqual(tail!.chunk.startTime, 1, accuracy: 0.001) + XCTAssertTrue(tail?.chunk.isFinal ?? false) + } + + func testSessionFinishFlushesTail() async throws { + let configuration = BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 10, + minChunkDuration: 1, + overlapDuration: 0, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 4 + ) + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + [ + Self.backendSegment( + id: "final", text: chunk.isFinal ? "final" : "not final", start: chunk.startTime, end: 1) + ] + } + + XCTAssertEqual( + session.append(pcmData: pcm(samples: [1_000, 1_000, 1_000]), startTime: 2).enqueuedChunks, + 0) + let finish = session.finishInput() + + XCTAssertTrue(finish.didFinishInput) + XCTAssertEqual(finish.enqueuedChunks, 1) + let result = try await session.transcribeNext() + XCTAssertEqual(result?.chunk.startTime, 2) + XCTAssertTrue(result?.chunk.isFinal ?? false) + } + + func testSessionRetainsFailedChunkForRetry() async throws { + let configuration = BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 1.0, + minChunkDuration: 0.5, + overlapDuration: 0, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 4, + maxChunkTranscriptionAttempts: 3 + ) + var shouldFail = true + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + if shouldFail { + shouldFail = false + throw NSError(domain: "test", code: 1) + } + return [ + Self.backendSegment( + id: "retry", text: "retried", start: chunk.startTime, end: chunk.startTime + 1) + ] + } + + XCTAssertEqual( + session.append(pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 0) + .enqueuedChunks, 1) + + do { + _ = try await session.transcribeNext() + XCTFail("Expected first transcription attempt to fail") + } catch { + XCTAssertEqual(session.pendingChunkCount, 1) + XCTAssertEqual(session.snapshot().droppedChunkCount, 0) + } + + let retried = try await session.transcribeNext() + XCTAssertEqual(retried?.chunk.startTime, 0) + XCTAssertEqual(session.pendingChunkCount, 0) + XCTAssertEqual(session.snapshot().processedChunkCount, 1) + XCTAssertEqual(session.snapshot().droppedChunkCount, 0) + } + + func testSessionRetriesFailedChunkBeforeDroppingSoDrainCanContinue() async throws { + let configuration = BackgroundTranscriptionConfiguration( + sampleRate: 10, + maxChunkDuration: 1.0, + minChunkDuration: 0.5, + overlapDuration: 0, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 10, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 20, + maxPendingChunks: 4, + maxChunkTranscriptionAttempts: 2 + ) + let session = CloudBackgroundTranscriptionSession(configuration: configuration) { chunk in + if chunk.startTime == 0 { + throw NSError(domain: "test", code: 1) + } + return [ + Self.backendSegment( + id: "retry", text: "retried", start: chunk.startTime, end: chunk.startTime + 1) + ] + } + + XCTAssertEqual( + session.append(pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 0) + .enqueuedChunks, 1) + XCTAssertEqual( + session.append(pcmData: pcm(samples: Array(repeating: 1_000, count: 12)), startTime: 1.2) + .enqueuedChunks, 1) + + do { + _ = try await session.transcribeNext() + XCTFail("Expected first transcription attempt to fail") + } catch { + XCTAssertEqual(session.pendingChunkCount, 2) + XCTAssertEqual(session.snapshot().droppedChunkCount, 0) + } + + do { + _ = try await session.transcribeNext() + XCTFail("Expected second transcription attempt to fail and drop the chunk") + } catch { + XCTAssertEqual(session.pendingChunkCount, 1) + XCTAssertEqual(session.snapshot().droppedChunkCount, 1) + } + + let next = try await session.transcribeNext() + XCTAssertEqual(next?.chunk.startTime, 1.0) + XCTAssertEqual(session.pendingChunkCount, 0) + XCTAssertEqual(session.snapshot().processedChunkCount, 1) + XCTAssertEqual(session.snapshot().droppedChunkCount, 1) + } + + func testBackgroundChunkIdIsStableAndPayloadSensitive() { + let first = TranscriptionService.backgroundChunkId( + conversationId: "conv.123", + chunkStartMs: 15_000, + audioData: Data([1, 2, 3, 4]) + ) + let retry = TranscriptionService.backgroundChunkId( + conversationId: "conv.123", + chunkStartMs: 15_000, + audioData: Data([1, 2, 3, 4]) + ) + let changedPayload = TranscriptionService.backgroundChunkId( + conversationId: "conv.123", + chunkStartMs: 15_000, + audioData: Data([1, 2, 3, 5]) + ) + + XCTAssertEqual(first, retry) + XCTAssertNotEqual(first, changedPayload) + XCTAssertTrue(first.hasPrefix("conv_123-15000-4-")) + XCTAssertNil(first.range(of: #"[^A-Za-z0-9_-]"#, options: .regularExpression)) + } + + func testTranscriptMergerDeduplicatesAndMergesOverlap() { + var merger = BackgroundTranscriptMerger() + let first = Self.backendSegment(id: "a", text: "hello world", speakerId: 0, start: 0, end: 2) + let duplicate = Self.backendSegment( + id: nil, text: "hello world", speakerId: 0, start: 0.5, end: 1.8) + let overlap = Self.backendSegment( + id: "b", text: "world again", speakerId: 0, start: 1.5, end: 3) + + XCTAssertEqual(merger.merge([first]).count, 1) + XCTAssertEqual(merger.merge([duplicate]).count, 0) + let changed = merger.merge([overlap]) + let merged = merger.segments + + XCTAssertEqual(changed.count, 1) + XCTAssertEqual(merged.count, 1) + XCTAssertEqual(merged[0].text, "hello world again") + XCTAssertEqual(merged[0].start, 0) + XCTAssertEqual(merged[0].end, 3) + } + + func testSpeakerSegmentReducerUpdatesAndPreservesTranslations() { + var reducer = SpeakerSegmentReducer(maxInMemorySegments: 10) + let original = SpeakerSegment( + segmentId: "seg-1", + speaker: 1, + text: "hello", + start: 0, + end: 1, + translations: [SegmentTranslation(lang: "es", text: "hola")] + ) + _ = reducer.apply([original]) + + let update = Self.backendSegment( + id: "seg-1", text: "hello again", speakerId: 1, start: 0, end: 2) + let result = reducer.apply([update]) + + XCTAssertEqual(result.added, 0) + XCTAssertEqual(result.updated, 1) + XCTAssertEqual(result.totalSegmentCount, 1) + XCTAssertEqual(result.totalWordCount, 2) + XCTAssertEqual(reducer.segments[0].translations.first?.text, "hola") + } + + func testRoutingGuardUsesCloudBatchForMicrophoneWhenCapabilityUsable() { + let guardrail = BackgroundTranscriptionRoutingGuard() + + XCTAssertEqual( + guardrail.decide( + backgroundBatchCapability: Self.backgroundBatchCapability( + enabled: true, effectiveProvider: "assemblyai"), + audioSource: .microphone), + .cloudBatchAssembly + ) + XCTAssertEqual( + guardrail.decide( + backgroundBatchCapability: Self.backgroundBatchCapability( + enabled: true, effectiveProvider: "assemblyai"), + audioSource: .bleDevice), + .cloudListenStreaming(reason: "batch_microphone_only") + ) + XCTAssertEqual( + guardrail.decide( + backgroundBatchCapability: Self.backgroundBatchCapability( + enabled: false, effectiveProvider: nil, reason: "missing_assemblyai_api_key"), + audioSource: .microphone), + .cloudListenStreaming(reason: "missing_assemblyai_api_key") + ) + XCTAssertEqual( + guardrail.decide( + backgroundBatchCapability: nil, + audioSource: .microphone), + .cloudListenStreaming(reason: "server_background_batch_capability_unavailable") + ) + XCTAssertEqual( + guardrail.decide( + backgroundBatchCapability: Self.backgroundBatchCapability( + enabled: true, effectiveProvider: nil, reason: "no_usable_batch_provider"), + audioSource: .microphone), + .cloudListenStreaming(reason: "no_usable_batch_provider") + ) + XCTAssertEqual( + guardrail.decide( + backgroundBatchCapability: Self.backgroundBatchCapability( + enabled: true, effectiveProvider: "unknown-provider"), + audioSource: .microphone), + .cloudListenStreaming(reason: "server_background_batch_provider_unsupported") + ) + XCTAssertTrue( + guardrail.shouldFallbackToStreamingAfterBatchStartupFailure( + audioSource: .microphone, + captureStarted: false + ) + ) + XCTAssertFalse( + guardrail.shouldFallbackToStreamingAfterBatchStartupFailure( + audioSource: .microphone, + captureStarted: true + ) + ) + XCTAssertFalse( + guardrail.shouldFallbackToStreamingAfterBatchStartupFailure( + audioSource: .bleDevice, + captureStarted: false + ) + ) + } + + func testDesktopCapabilitiesDecodeEffectiveProviderFields() throws { + let json = Data( + """ + { + "background_batch": { + "enabled": true, + "mode": "deepgram_fallback", + "provider": "assemblyai", + "primary_provider": "assemblyai", + "effective_provider": "deepgram", + "fallback_provider": "deepgram", + "fallback_enabled": true, + "fallback_available": true, + "workload": "background", + "reason": "fallback_deepgram_available", + "sample_rate": 16000, + "channels": 1, + "encoding": "linear16", + "max_chunk_seconds": 15 + } + } + """.utf8) + + let decoded = try JSONDecoder().decode(DesktopCapabilitiesResponse.self, from: json) + + XCTAssertTrue(decoded.backgroundBatch.enabled) + XCTAssertEqual(decoded.backgroundBatch.mode, "deepgram_fallback") + XCTAssertEqual(decoded.backgroundBatch.primaryProvider, "assemblyai") + XCTAssertEqual(decoded.backgroundBatch.effectiveProvider, "deepgram") + XCTAssertEqual(decoded.backgroundBatch.fallbackProvider, "deepgram") + XCTAssertEqual(decoded.backgroundBatch.reason, "fallback_deepgram_available") + } + + func testOlderDesktopCapabilitiesDecodeWithoutEffectiveProviderFields() throws { + let json = Data( + """ + { + "background_batch": { + "enabled": false, + "provider": "assemblyai", + "sample_rate": 16000, + "channels": 1, + "encoding": "linear16", + "max_chunk_seconds": 15 + } + } + """.utf8) + + let decoded = try JSONDecoder().decode(DesktopCapabilitiesResponse.self, from: json) + + XCTAssertFalse(decoded.backgroundBatch.enabled) + XCTAssertNil(decoded.backgroundBatch.effectiveProvider) + XCTAssertEqual( + BackgroundTranscriptionRoutingGuard().decide( + backgroundBatchCapability: decoded.backgroundBatch, + audioSource: .microphone), + .cloudListenStreaming(reason: "server_background_batch_disabled") + ) + } + + private static func backgroundBatchCapability( + enabled: Bool, + effectiveProvider: String?, + reason: String? = nil + ) -> DesktopBackgroundBatchCapability { + DesktopBackgroundBatchCapability( + enabled: enabled, + mode: enabled ? "assemblyai_primary" : "disabled", + provider: "assemblyai", + primaryProvider: "assemblyai", + effectiveProvider: effectiveProvider, + fallbackProvider: "deepgram", + fallbackEnabled: true, + fallbackAvailable: true, + reason: reason, + sampleRate: 16000, + channels: 1, + encoding: "linear16", + maxChunkSeconds: 15 + ) + } + + private static func backendSegment( + id: String?, + text: String, + speakerId: Int = 0, + start: Double, + end: Double + ) -> TranscriptionService.BackendSegment { + TranscriptionService.BackendSegment( + id: id, + text: text, + speaker: "SPEAKER_\(String(format: "%02d", speakerId))", + speaker_id: speakerId, + is_user: false, + person_id: nil, + start: start, + end: end, + translations: nil, + stt_provider: "assemblyai", + stt_model: "universal-2", + provider_cluster_id: nil, + provider_speaker_label: nil, + speaker_identity_state: nil, + speaker_identity_confidence: nil, + speaker_identity_source: nil, + speaker_identity_version: nil + ) + } + + private func pcm(samples: [Int16]) -> Data { + var samples = samples + return samples.withUnsafeMutableBufferPointer { Data(buffer: $0) } + } + + private func sampleCount(_ data: Data) -> Int { + data.count / MemoryLayout.size + } +} diff --git a/desktop/Desktop/Tests/ListenProtocolTests.swift b/desktop/Desktop/Tests/ListenProtocolTests.swift index 6ef3cddaa2a..406f5c2f197 100644 --- a/desktop/Desktop/Tests/ListenProtocolTests.swift +++ b/desktop/Desktop/Tests/ListenProtocolTests.swift @@ -42,6 +42,23 @@ final class ListenProtocolTests: XCTestCase { XCTAssertFalse(seg.is_user) } + func testDecodeSegmentWithProviderIdentityMetadata() throws { + let json = """ + [{"id":"seg-1","text":"hello","speaker":null,"speaker_id":null,"is_user":false,"person_id":null,"start":0.0,"end":1.0,"provider_cluster_id":"speaker-alpha","speaker_identity_state":"unknown","stt_provider":"provider-a","stt_model":"async-large"}] + """ + let data = json.data(using: .utf8)! + let segments = try JSONDecoder().decode([TranscriptionService.BackendSegment].self, from: data) + + XCTAssertEqual(segments.count, 1) + let seg = segments[0] + XCTAssertNil(seg.speaker) + XCTAssertNil(seg.speaker_id) + XCTAssertEqual(seg.provider_cluster_id, "speaker-alpha") + XCTAssertEqual(seg.speaker_identity_state, "unknown") + XCTAssertEqual(seg.stt_provider, "provider-a") + XCTAssertEqual(seg.stt_model, "async-large") + } + func testDecodeMultipleSegments() throws { let json = """ [ diff --git a/desktop/Desktop/Tests/TranscriptSpeakerAssignmentTests.swift b/desktop/Desktop/Tests/TranscriptSpeakerAssignmentTests.swift index 3b733688c71..52c012d9068 100644 --- a/desktop/Desktop/Tests/TranscriptSpeakerAssignmentTests.swift +++ b/desktop/Desktop/Tests/TranscriptSpeakerAssignmentTests.swift @@ -205,6 +205,42 @@ final class TranscriptSpeakerAssignmentTests: XCTestCase { XCTAssertTrue(segment.translations.isEmpty, "Translations should default to empty array when not present in JSON") } + func testTranscriptSegmentDecodesProviderIdentityMetadata() throws { + let json = """ + { + "id": "seg_provider", + "text": "Hello", + "speaker": null, + "is_user": false, + "person_id": null, + "start": 0.0, + "end": 1.0, + "stt_provider": "provider-a", + "stt_model": "async-large", + "provider_cluster_id": "speaker-alpha", + "provider_speaker_label": null, + "speaker_identity_state": "unknown", + "speaker_identity_confidence": 0.42, + "speaker_identity_source": "omi_speaker_embedding", + "speaker_identity_version": "v1" + } + """.data(using: .utf8)! + + let segment = try JSONDecoder().decode(TranscriptSegment.self, from: json) + + XCTAssertNil(segment.speaker) + XCTAssertEqual(segment.speakerId, 0) + XCTAssertEqual(segment.sttProvider, "provider-a") + XCTAssertEqual(segment.sttModel, "async-large") + XCTAssertEqual(segment.providerClusterId, "speaker-alpha") + XCTAssertEqual(segment.speakerIdentityState, "unknown") + XCTAssertEqual(segment.speakerIdentityConfidence, 0.42) + XCTAssertEqual(segment.speakerIdentitySource, "omi_speaker_embedding") + XCTAssertEqual(segment.speakerIdentityVersion, "v1") + XCTAssertTrue(segment.hasExplicitUnknownSpeakerIdentity) + XCTAssertEqual(segment.displaySpeakerSuffix, "speaker-alpha") + } + func testSpeakerSegmentTranslationsPreserved() { let translations = [ SegmentTranslation(lang: "en", text: "Hello"), @@ -315,6 +351,40 @@ final class TranscriptSpeakerAssignmentTests: XCTestCase { XCTAssertEqual(segment.translations[1].text, "Hola") } + func testTranscriptionSegmentRecordRoundTripWithProviderIdentityMetadata() { + let source = TranscriptSegment( + id: "seg_provider_rt", + backendId: "seg_provider_rt", + text: "Hello", + speaker: nil, + isUser: false, + personId: nil, + start: 0, + end: 1, + sttProvider: "provider-a", + sttModel: "async-large", + providerClusterId: "speaker-alpha", + providerSpeakerLabel: nil, + speakerIdentityState: "unknown", + speakerIdentityConfidence: 0.42, + speakerIdentitySource: "omi_speaker_embedding", + speakerIdentityVersion: "v1" + ) + + let record = TranscriptionSegmentRecord.from(source, sessionId: 1, segmentOrder: 0) + let segment = record.toTranscriptSegment() + + XCTAssertNil(segment.speaker) + XCTAssertEqual(segment.displaySpeakerSuffix, "speaker-alpha") + XCTAssertEqual(segment.sttProvider, "provider-a") + XCTAssertEqual(segment.sttModel, "async-large") + XCTAssertEqual(segment.providerClusterId, "speaker-alpha") + XCTAssertEqual(segment.speakerIdentityState, "unknown") + XCTAssertEqual(segment.speakerIdentityConfidence, 0.42) + XCTAssertEqual(segment.speakerIdentitySource, "omi_speaker_embedding") + XCTAssertEqual(segment.speakerIdentityVersion, "v1") + } + func testTranscriptionSegmentRecordRoundTripNilTranslationsJson() { let record = TranscriptionSegmentRecord( sessionId: 1, diff --git a/desktop/run.sh b/desktop/run.sh index 90d9a489ef5..bff3a311df9 100755 --- a/desktop/run.sh +++ b/desktop/run.sh @@ -46,7 +46,9 @@ fi # ─── YOLO mode: use prod backend, zero local setup ─────────────────── # WARNING: Temporary shortcut while desktop dev setup is being cleaned up. # Will be removed once all desktop slop is fixed. +YOLO_MODE=0 if [ "$1" = "--yolo" ]; then + YOLO_MODE=1 echo "" echo "==========================================" echo " YOLO MODE — using production backend" @@ -280,6 +282,14 @@ fi if [ -f "$BACKEND_DIR/.env" ]; then set -a; source "$BACKEND_DIR/.env"; set +a fi +# YOLO must win over Backend-Rust/.env (e.g. OMI_PYTHON_API_URL=http://localhost:8080) +if [ "$YOLO_MODE" = "1" ]; then + export OMI_SKIP_BACKEND=1 + export OMI_SKIP_TUNNEL=1 + export OMI_DESKTOP_API_URL="https://desktop-backend-hhibjajaja-uc.a.run.app" + export OMI_PYTHON_API_URL="https://api.omi.me" + export FIREBASE_API_KEY="AIzaSyD9dzBdglc7IO9pPDIOvqnCoTis_xKkkC8" +fi # Read backend PORT from env (default: 10201, never use 8080) BACKEND_PORT="${PORT:-10201}" @@ -510,18 +520,20 @@ if ! grep -q "^FIREBASE_API_KEY=" "$APP_BUNDLE/Contents/Resources/.env"; then fi # Bootstrap OMI_PYTHON_API_URL — main Omi Python backend (auth, subscriptions, payments, transcription) # Do NOT fall back to OMI_DESKTOP_API_URL — that's the Rust desktop-backend which doesn't serve these routes -if ! grep -q "^OMI_PYTHON_API_URL=" "$APP_BUNDLE/Contents/Resources/.env"; then - PYTHON_API_URL="${OMI_PYTHON_API_URL:-}" - if [ -z "$PYTHON_API_URL" ] && [ -f "$BACKEND_DIR/.env" ]; then - PYTHON_API_URL=$(grep "^OMI_PYTHON_API_URL=" "$BACKEND_DIR/.env" | head -1 | cut -d= -f2-) - fi - if [ -z "$PYTHON_API_URL" ]; then - PYTHON_API_URL="https://api.omi.me" - substep "OMI_PYTHON_API_URL not set — defaulting to production: $PYTHON_API_URL" - fi +PYTHON_API_URL="${OMI_PYTHON_API_URL:-}" +if [ -z "$PYTHON_API_URL" ] && [ -f "$BACKEND_DIR/.env" ]; then + PYTHON_API_URL=$(grep "^OMI_PYTHON_API_URL=" "$BACKEND_DIR/.env" | head -1 | cut -d= -f2-) +fi +if [ -z "$PYTHON_API_URL" ]; then + PYTHON_API_URL="https://api.omi.me" + substep "OMI_PYTHON_API_URL not set — defaulting to production: $PYTHON_API_URL" +fi +if grep -q "^OMI_PYTHON_API_URL=" "$APP_BUNDLE/Contents/Resources/.env"; then + sed -i '' "s|^OMI_PYTHON_API_URL=.*|OMI_PYTHON_API_URL=$PYTHON_API_URL|" "$APP_BUNDLE/Contents/Resources/.env" +else echo "OMI_PYTHON_API_URL=$PYTHON_API_URL" >> "$APP_BUNDLE/Contents/Resources/.env" - substep "Set OMI_PYTHON_API_URL=$PYTHON_API_URL" fi +substep "OMI_PYTHON_API_URL=$PYTHON_API_URL" substep "Copying app icon" cp -f omi_icon.icns "$APP_BUNDLE/Contents/Resources/OmiIcon.icns" 2>/dev/null || true diff --git a/docs/doc/developer/backend/assemblyai_background_rollout.mdx b/docs/doc/developer/backend/assemblyai_background_rollout.mdx new file mode 100644 index 00000000000..32d2f10ab6b --- /dev/null +++ b/docs/doc/developer/backend/assemblyai_background_rollout.mdx @@ -0,0 +1,450 @@ +--- +title: "AssemblyAI Background Transcription Rollout" +icon: "waveform" +description: "Rollout gates, feature flags, instrumentation, and rollback for the AssemblyAI async background transcription MVP." +--- + +# AssemblyAI Background Transcription Rollout + +## Scope + +AssemblyAI is the default provider for passive async background transcription. +In this document, "background" means latency-tolerant prerecorded work: +`sync`, desktop `background`, and `postprocess`. These jobs run after audio has +already been captured and can tolerate async provider polling, retries, and +Deepgram fallback. + +Current saved-provider benchmark evidence still favors Deepgram on some +speaker-quality fixtures: Deepgram reached 99.79% speaker purity, while +AssemblyAI reached 95.91%. Current public pay-as-you-go pricing checked on +2026-05-25 makes AssemblyAI Universal-2 plus diarization about `$0.170/hour` +and Deepgram Nova-3 prerecorded plus diarization about `$0.408/hour` for +monolingual audio or `$0.468/hour` for multilingual audio. Older experiment +summaries used stale Deepgram rates that omitted the diarization add-on and made +AssemblyAI look more expensive; do not use those numbers for rollout decisions. +The production stance is: AssemblyAI first, conservative speaker safety, +privacy-safe monitoring, and Deepgram fallback. + +Deepgram remains the provider for mobile/BLE `/v4/listen`, realtime assistant +streaming, Hold-to-Talk streaming, voice-message finalize semantics, and +background fallback. + +## Current Default Decision + +As of the TICKET-028 readiness gate, AssemblyAI is the default for eligible +passive background workloads when credentials and workload flags allow it. +Deepgram remains the fallback and explicit rollback provider. + +The latest offline command was: + +```bash +cd backend +python3 scripts/stt/provider_comparison_gate.py \ + --manifest tests/fixtures/stt_provider_eval/manifest.json \ + --output-md /tmp/stt-provider-eval.md \ + --output-json /tmp/stt-provider-eval.json +``` + +It passed 10 fixture cases with no failures or warnings and reported all +required production strategies: + +| Strategy | Provider | Speaker purity | Empty transcript | Fallback | Timeout/error | Cost/hour | +| --- | --- | ---: | ---: | ---: | ---: | ---: | +| `always_deepgram` | Deepgram | 100.0% | 10.0% | 0.0% | 0.0% | `$0.408` | +| `always_assemblyai` | AssemblyAI | 98.9% | 10.0% | 1.0% | 0.0% | `$0.170` | +| `current_policy` | AssemblyAI | 98.9% | 10.0% | 1.0% | 0.0% | `$0.170` | +| `shadow_only` | Deepgram | 100.0% | 10.0% | 0.0% | 0.0% | `$0.408` | + +The gate still marked AssemblyAI as `limited` because the +`saved_real_policy_router_outputs` scenario had 95.5% AssemblyAI speaker-word +purity versus 100.0% for Deepgram. The mitigation remains to keep +split-before-identity enabled, keep unresolved provider clusters chunk-scoped, +and monitor real-session metrics that include both purity and fragmentation +budgets. + +Selected rollout decision: AssemblyAI is the passive background default with +Deepgram fallback and hard rollback thresholds. + +Eligible AssemblyAI workloads are: + +- `sync` +- `background` +- `postprocess` + +Ineligible latency-sensitive workloads stay on Deepgram: + +- `realtime` +- streaming `ptt` +- `voice_message` + +## Feature Flags And Environment + +AssemblyAI is enabled by default for eligible passive background prerecorded +workloads when credentials are configured. Desktop background has its own +provider mode so it can be rolled back independently from `sync` and +`postprocess`. + +| Variable | Default | Purpose | +| --- | --- | --- | +| `ASSEMBLYAI_API_KEY` | unset | Required before any AssemblyAI request can run. | +| `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE` | `assemblyai` | Desktop background policy mode: `assemblyai`, `deepgram`, or `shadow_only`. `assemblyai` is the default. `deepgram` is the rollback override. `shadow_only` is retained only as an explicit diagnostic/rollback mode. | +| `ASSEMBLYAI_PRERECORDED_STT_ENABLED` | `true` | Main rollout switch for passive prerecorded workloads. | +| `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS` | `sync,background,postprocess` | Comma-separated eligible prerecorded workloads. Unknown or ineligible values are ignored. | +| `ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED` | `true` | Use Deepgram when AssemblyAI is unavailable or fails and Deepgram credentials are usable. | +| `ASSEMBLYAI_STT_MODEL` | `universal-2` | AssemblyAI model used when a Deepgram `nova-*` model request is routed to AssemblyAI. | +| `ASSEMBLYAI_BASE_URL` | `https://api.assemblyai.com` | Provider API base URL. | +| `ASSEMBLYAI_POLL_INTERVAL_SECONDS` | `3` | Poll interval for async transcript completion. | +| `ASSEMBLYAI_MAX_POLL_SECONDS` | `900` | Max wait before the async run times out and can fall back. | +| `ASSEMBLYAI_SMOKE_AUDIO_URL` | unset | Optional live smoke-test audio URL. | + +## Routing And Fallback + +`backend/utils/stt/providers.py` owns provider selection, and +`backend/utils/stt/provider_service.py` resolves request-level credentials and +fallbacks. Eligible passive background workloads use AssemblyAI when the main +flag is enabled and the workload is in `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS`; +otherwise they use Deepgram. + +Desktop background routing is controlled by +`ASSEMBLYAI_BACKGROUND_PROVIDER_MODE`: + +| Mode | Runtime provider | Use | +| --- | --- | --- | +| `assemblyai` | AssemblyAI when credentials and workload gates allow it; Deepgram fallback when configured and usable. | Default mode. | +| `deepgram` | Deepgram | Kill switch or explicit opt-out without a code deploy. | +| `shadow_only` | Deepgram | Explicit diagnostic/rollback mode only. | + +Desktop Audio Recording uses the background workload through: + +- `POST /v2/desktop/background-conversation/start` +- `POST /v2/desktop/background-transcribe` +- `POST /v1/conversations` for finalization + +The chunk endpoint accepts raw PCM, wraps linear16 PCM as WAV before provider +upload, applies the chunk timeline offset, maps provider speaker clusters to +stable numeric `speaker_id` values per conversation, and appends segments to the +in-progress conversation when `persist=true`. + +Desktop chunking is currently config-gated to the fixed 15s safety profile: +`BackgroundTranscriptionConfiguration.cloudBatch` aliases the fixed-15s +configuration with 0.5s overlap and client-side speech-activity suppression. +The Swift chunker also has a named silence-aware candidate configuration that +can flush after the minimum speech/audio window and a 350ms silence gap, while +still enforcing the same 15s hard cap. Do not switch the runtime background path +to that candidate until Deepgram-vs-AssemblyAI replay evals show no speaker +quality regression on fast-turn, overlap, sparse, and variable-density fixtures. + +Deepgram is the prerecorded fallback provider. If AssemblyAI fails, times out, +or exhausts retries for an eligible prerecorded workload, the failed AssemblyAI +run is finalized in the provider ledger and the request retries through +Deepgram with fallback metadata. If AssemblyAI credentials are missing before +the request starts, provider resolution skips directly to Deepgram when +fallback is enabled and Deepgram credentials are usable. If fallback is disabled +or no Deepgram key is usable, the AssemblyAI request fails closed and records +the failed provider run. + +Rollback for background is to set `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=deepgram` +or `shadow_only`. The broader kill switch is +`ASSEMBLYAI_PRERECORDED_STT_ENABLED=false` or removing the affected workload +from `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS`. No client deploy is required for +rollback. + +## BYOK (Bring Your Own Keys) + +Desktop BYOK users still require four keys (OpenAI, Anthropic, Gemini, Deepgram) +for the free plan. AssemblyAI is an **optional fifth** key (`X-BYOK-AssemblyAI`). + +**Deploy backend routing before enabling Assembly for BYOK cohorts.** `resolve_prerecorded_provider_for_request()` in `backend/utils/stt/provider_service.py` applies: + +| Flags select Assembly? | Request headers | Provider used | +| --- | --- | --- | +| No | any | Deepgram (BYOK or server) | +| Yes | `x-byok-assemblyai` | AssemblyAI with user key | +| Yes | `x-byok-deepgram` only | Deepgram BYOK (avoids billing Omi's Assembly key) | +| Yes | neither | AssemblyAI with server `ASSEMBLYAI_API_KEY` | + +If a user enrolls an optional Assembly fingerprint in Firestore, they must send +`x-byok-assemblyai` on requests that send other BYOK headers. Clearing the +Assembly key requires re-activation with fingerprints that omit `assemblyai`. + +Listen/realtime STT is unchanged for mobile/BLE (Deepgram only). Assembly BYOK +does not affect `/v4/listen`. + +## Storage And Identity + +Provider output is normalized into `ProviderTranscriptResult` and reconstructed +into canonical `TranscriptSegment` records before conversation storage. +Provider-local speaker labels are stored as `provider_cluster_id` and +`provider_speaker_label`; they do not become canonical identity. + +Canonical speaker identity is assigned after normalization with the Omi speaker +embedding service. Low-confidence clusters remain explicitly unknown rather +than being mapped to a fake person or to `speaker_id=0`. + +### Speaker Detection Contract + +AssemblyAI and Deepgram speaker labels are provider-local diarization hints. +They are not stable people, and they are not safe to compare across prerecorded +jobs. AssemblyAI labels such as `A`/`B` and Deepgram labels such as `0`/`1` +can reset on every uploaded chunk. + +Desktop background transcription therefore uses a split-then-reconcile flow: + +1. Normalize provider words and utterances. +2. Reconstruct `TranscriptSegment` records with provider cluster metadata. +3. Split repeated non-contiguous provider clusters inside the chunk into local + contiguous groups before identity matching. +4. Match each local group to known Omi user/person embeddings when clean + samples exist. +5. Assign final app-visible `speaker_id` values from durable identity when + available, otherwise from a chunk-scoped provider key. + +The chunk-scoped unresolved key is: + +```text +provider:{provider}:chunk:{chunk_id}:cluster:{provider_cluster_id} +``` + +Known Omi identities intentionally ignore the chunk scope: + +```text +identity:user +identity:person:{person_id} +``` + +This is deliberate. False speaker merges are more damaging than temporary +anonymous-speaker fragmentation. The product should prefer splitting uncertain +provider clusters first, then merging only with Omi identity evidence, +high-confidence embeddings, or explicit user corrections. + +### Benchmark Learnings + +The AssemblyAI rollout benchmark found: + +- Fixed 15s batches expose provider-label resets; raw `A`/`B` labels must not + carry across chunks. +- Longer AssemblyAI batches reduced request count but often collapsed speakers + in fast-turn or overlap-heavy synthetic sessions. +- Batch-size, silence-gap, speech-duration, embedding-span, overlap-tolerance, + and match-threshold sweeps did not beat the chunk-scoped baseline. +- The strongest improvement came from splitting non-contiguous same-cluster + groups inside a chunk. That fixed rapid turn-taking cases where one provider + cluster contained multiple true speakers. +- Graph clustering and anonymous reconciliation are useful as reconciliation + structures, but they cannot repair a provider cluster that already contains + multiple speakers unless the cluster is split first. + +The current landed rule is: provider clusters are local hints, repeated +non-contiguous groups are split, and stable app identity comes from Omi +evidence rather than provider labels. + +Provider run ledgers live in `transcription_provider_runs`, and daily rollups +live in `transcription_provider_usage_daily`. Ledger payload guards reject raw +audio bytes, transcript text, words, utterances, API keys, tokens, secrets, and +similar high-risk payloads. + +Privacy-safe provider run fields include: + +- `provider`, `model`, `workload`, and `status` +- `raw_audio_seconds`, `speech_active_seconds`, `billable_seconds`, and `chunk_duration_seconds` +- `retry_count`, `fallback_count`, and `fallback.reason` +- `timing.latency_ms` +- `estimated_cost_usd` +- `speaker_cluster_count`, `split_count`, `identified_speaker_cluster_count`, and `identity_match_count` +- `provider_speaker_count`, `mapped_speaker_count`, `mapped_person_count`, `unmapped_speaker_count` +- `unknown_speaker_count` and `unknown_speaker_duration_seconds` +- `identity_confidence_summary` buckets + +Do not add raw audio, full transcript text, provider `words`, provider +`utterances`, user secrets, or provider API responses to these ledgers. Store +only opaque artifact references such as `provider_result_id` when a separate +retention-controlled artifact is required. + +## Dashboards And Reports + +Expected rollout dashboards should use `/metrics` counters and histograms for: + +- `transcription_provider_requests_total` +- `transcription_provider_latency_seconds` +- `transcription_provider_retries_total` +- `transcription_provider_fallbacks_total` +- `transcription_provider_audio_seconds_total` +- `transcription_provider_billable_seconds_total` +- `transcription_provider_speaker_clusters_total` +- `transcription_provider_identity_confidence_total` + +Expected cost and quality review should inspect: + +- `transcription_provider_usage_daily` by provider, model, workload, and day. +- `transcription_provider_runs` for failures, fallback direction, fallback reason, retry counts, latency, chunk duration, billable seconds, estimated cost, split count, identity match count, and unknown speaker fields. +- `backend/scripts/stt/provider_comparison_gate.py` reports for transcript drift, speaker cluster stability, identity quality, fallback rate, failure rate, and economics. + +## Default And Rollback Stages + +Default state: + +- Passive background workloads use AssemblyAI: + `ASSEMBLYAI_PRERECORDED_STT_ENABLED=true` and + `ASSEMBLYAI_PRERECORDED_STT_WORKLOADS=sync,background,postprocess`. +- Desktop background uses AssemblyAI: + `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=assemblyai`. +- Deepgram credentials remain available for fallback. +- Fallback runs must record `fallback.from_provider=assemblyai`, + `fallback.to_provider=deepgram`, and a classified `fallback.reason`. + +Owner/on-call action: + +- The backend on-call owns first response for AssemblyAI default regressions. +- Roll back immediately when a hard threshold below trips. +- Open a follow-up ticket when any warning threshold is close to tripping or + when support reports speaker/transcript regressions not visible in aggregate + metrics. +- Product/business owner approval is required before accepting any sustained + cost or speaker-quality regression versus Deepgram. + +Default health criteria: + +- Offline provider gate continues to pass with no failures: + `provider_comparison_gate.py --manifest tests/fixtures/stt_provider_eval/manifest.json`. +- AssemblyAI background error rate is not more than 1 percentage point above + the Deepgram baseline and remains below 3% absolute. +- Fallback rate stays at or below 5% and does not double the previous healthy + window. +- Timeout rate stays at or below 1%. +- Empty transcript rate is not more than 1 percentage point above Deepgram. +- Unknown or low-confidence speaker rate is not more than 5 percentage points + above Deepgram, and unknown speaker duration does not create visible speaker + degradation. +- App-visible speaker inflation stays within the replay budget and does not + rise without a matching purity improvement. +- Cost/hour stays within the approved margin. The current offline gate should + show AssemblyAI below Deepgram for diarized prerecorded work, roughly + `$0.170/hour` versus `$0.408/hour` with monolingual Deepgram Nova-3 pricing. + If the gate shows AssemblyAI above Deepgram, first audit stale fixture ledgers, + omitted Deepgram diarization cost, and billable-duration inflation. +- No sustained increase in user corrections, self-voice review failures, + support complaints, or billing surprises. + +Rollback command/config: + +```bash +ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=deepgram +``` + +After rollback, confirm new background runs show `provider=deepgram`, verify +`transcription_provider_fallbacks_total` is not continuing to rise from +AssemblyAI failures, and keep `ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED=true` +until the incident review is complete. + +## Health Thresholds + +Roll back to Deepgram or open an incident follow-up when any condition is true +for the review window: + +- AssemblyAI background error rate is higher than Deepgram by more than 1 + percentage point, or exceeds 3% absolute. +- Fallback rate exceeds 5% or doubles the previous healthy window. +- Timeout rate exceeds 1% or any timeout spike correlates with support reports. +- Empty transcript rate exceeds the Deepgram baseline by more than 1 percentage + point. +- Cost/hour unexpectedly exceeds the approved provider margin. The current + public pricing baseline for diarized prerecorded work is about `$0.170/hour` + for AssemblyAI Universal-2, `$0.408/hour` for Deepgram Nova-3 monolingual, and + `$0.468/hour` for Deepgram Nova-3 multilingual. A report showing AssemblyAI as + more expensive is a stale-pricing or billable-duration investigation until + proven otherwise. +- Unknown or low-confidence speaker rate exceeds the Deepgram baseline by more + than 5 percentage points, or `unknown_speaker_duration_seconds` grows enough + to make speaker labels visibly worse. +- App-visible speaker inflation or `split_count` rises without a matching + speaker purity improvement in replay reports. +- User correction/self-voice review rate rises above the Deepgram baseline when + that signal is available. +- Support complaints mention missing transcripts, wrong speakers, degraded + background transcription, or account/billing surprises. + +## Operations Runbook + +Enable AssemblyAI for passive background workloads: + +```bash +ASSEMBLYAI_API_KEY= +DEEPGRAM_API_KEY= +ASSEMBLYAI_PRERECORDED_STT_ENABLED=true +ASSEMBLYAI_PRERECORDED_STT_WORKLOADS=sync,background,postprocess +ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED=true +ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=assemblyai +``` + +Disable AssemblyAI for background: + +```bash +ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=deepgram +``` + +Disable AssemblyAI for all passive prerecorded workloads: + +```bash +ASSEMBLYAI_PRERECORDED_STT_ENABLED=false +``` + +Validate after each config change: + +1. Confirm `/metrics` has background samples for + `transcription_provider_requests_total`, + `transcription_provider_latency_seconds`, + `transcription_provider_fallbacks_total`, and + `transcription_provider_audio_seconds_total`. +2. Query `transcription_provider_usage_daily` for `provider=assemblyai`, + `model=universal-2`, and `workload=background`; compute cost/hour as + `estimated_cost_usd / (billable_seconds / 3600)`. +3. Inspect recent `transcription_provider_runs` without opening transcript or + audio artifacts. Check `status`, `fallback.reason`, `retry_count`, + `timing.latency_ms`, `chunk_duration_seconds`, `estimated_cost_usd`, + `split_count`, `identity_match_count`, `unknown_speaker_count`, and + `unknown_speaker_duration_seconds`. +4. Run the offline gate during rollout checks: + `python3 scripts/stt/provider_comparison_gate.py --manifest tests/fixtures/stt_provider_eval/manifest.json`. +5. If a rollback threshold trips, set `ASSEMBLYAI_BACKGROUND_PROVIDER_MODE=deepgram`, + verify new background runs show `provider=deepgram`, and leave + `ASSEMBLYAI_PRERECORDED_STT_FALLBACK_ENABLED=true` until the incident review + is complete. + +Debugging checklist: + +- `missing_assemblyai_api_key`: verify `ASSEMBLYAI_API_KEY` is present only in + secret storage and never in logs, ledgers, or dispatch artifacts. +- `provider_failure`: inspect provider status, timeout settings + `ASSEMBLYAI_POLL_INTERVAL_SECONDS` and `ASSEMBLYAI_MAX_POLL_SECONDS`, retry + counts, and Deepgram fallback health. +- High cost/hour: compare `billable_seconds`, `chunk_duration_seconds`, and + request count against the fixed-15s background profile before changing + chunking. +- Speaker regressions: compare `split_count`, `identity_match_count`, + `unknown_speaker_count`, `unknown_speaker_duration_seconds`, and replay + speaker-purity reports. Do not reduce split-before-match protections unless + replay gates show no regression. + +## Data Handling + +AssemblyAI receives uploaded audio only for eligible prerecorded workloads after +the routing flags above allow it. User consent and BYOK behavior must be +consistent with the desktop background transcription setting and the optional +`X-BYOK-AssemblyAI` header. Regional data handling follows the configured +AssemblyAI account and endpoint; do not enable a region or cohort where vendor +retention, residency, or customer terms are not approved. + +Secrets are configuration-only. Never place `ASSEMBLYAI_API_KEY`, user BYOK +headers, provider bearer tokens, raw audio, full transcript text, provider +`words`, or provider `utterances` in metrics, logs, Firestore ledgers, CAR +artifacts, screenshots, or support notes. + +Fixture validation: + +```bash +cd backend +python3 scripts/stt/provider_comparison_gate.py --manifest tests/fixtures/stt_provider_eval/manifest.json +``` + +Live validation requires `DEEPGRAM_API_KEY`, `ASSEMBLYAI_API_KEY`, and manifest +cases with `audio_url` or `audio_file`. diff --git a/docs/doc/developer/backend/listen_pusher_pipeline.mdx b/docs/doc/developer/backend/listen_pusher_pipeline.mdx index 8c8ef372b37..2d35c9fb5be 100644 --- a/docs/doc/developer/backend/listen_pusher_pipeline.mdx +++ b/docs/doc/developer/backend/listen_pusher_pipeline.mdx @@ -6,11 +6,11 @@ description: "Sequence diagrams for the /v4/listen WebSocket and Pusher processi # Listen + Pusher Pipeline — Sequence Diagrams -> Last updated: 2026-03-28 (PR #6061 — remove local fallback) +> Last updated: 2026-05-21 (AssemblyAI prerecorded STT MVP) > > These diagrams document the real behavior observed during E2E testing with -> live services (backend, pusher, Deepgram, embedding API). Update when the -> pipeline changes. +> live services (backend, pusher, Deepgram, AssemblyAI, embedding API). Update +> when the pipeline changes. ## 1. Connection + Streaming + Transcription @@ -226,6 +226,89 @@ sequenceDiagram Note over Pusher: In LOCAL_DEVELOPMENT mode,
GCS upload fails (no prod creds).
Conversation lifecycle still works. ``` +## 5.1 Background / Sync Upload Transcription + +Realtime listen and Hold-to-Talk streaming remain Deepgram paths. Background +prerecorded transcription routes through `utils/stt/provider_service.py`, which +selects AssemblyAI by default for eligible prerecorded workloads when +`ASSEMBLYAI_PRERECORDED_STT_ENABLED=true` and the workload is listed in +`ASSEMBLYAI_PRERECORDED_STT_WORKLOADS`. Deepgram remains the fallback when +AssemblyAI is disabled, unavailable, or fails and fallback is enabled. + +Desktop Audio Recording can use the batch background path instead of `/v4/listen` +when the desktop batch setting is enabled and the client is pointed at a backend +with AssemblyAI background support. The desktop client creates a stub +conversation with `POST /v2/desktop/background-conversation/start`, sends raw +mono PCM chunks to `POST /v2/desktop/background-transcribe`, and finalizes that +explicit conversation with +`POST /v2/desktop/background-conversation/{conversation_id}/finish`. Mobile +listen, BLE listen, and PTT continue to use their existing Deepgram paths. + +For cloud batch, the desktop client currently uses the fixed 15s safety profile: +15s PCM chunks with 0.5s overlap, client-side sustained-speech checks before +uploading, and backpressure instead of stopping recording when AssemblyAI is +slower than realtime. A silence-aware client configuration exists for eval and +can flush earlier after sustained speech plus a silence gap, but it remains off +the runtime path until replay gates show no speaker-quality regression versus +fixed 15s. AssemblyAI batch requests use the selected single language rather +than `multi` to avoid provider-side language-detection failures on low-speech +audio. + +Eligible AssemblyAI workloads are `sync`, `background`, and `postprocess`. +Latency-critical workloads (`realtime`, streaming `ptt`, and +`voice_message`) continue to use Deepgram and the existing websocket/finalize +semantics. + +```mermaid +sequenceDiagram + participant Client + participant Backend as Backend + participant ProviderService as Provider Service + participant AssemblyAI + participant Deepgram + participant EmbeddingAPI as Embedding API + participant Firestore + participant Metrics as /metrics + + Client->>Backend: Upload sync audio URL or desktop PCM background chunk + Backend->>ProviderService: transcribe_url/transcribe_bytes workload=sync|background|postprocess + alt AssemblyAI enabled for workload + ProviderService->>AssemblyAI: Submit async prerecorded transcript request + AssemblyAI-->>ProviderService: Completed transcript with utterances, words, speaker labels + else AssemblyAI disabled or workload ineligible + ProviderService->>Deepgram: Prerecorded transcript request + Deepgram-->>ProviderService: Transcript with words and speaker labels + end + + alt AssemblyAI fails, times out, or exhausts retries + ProviderService->>Firestore: Finalize failed AssemblyAI provider run + ProviderService->>Deepgram: Fallback prerecorded transcript request + Deepgram-->>ProviderService: Transcript result + end + + Note over ProviderService: Missing AssemblyAI credentials skip directly
to Deepgram when fallback credentials are usable. + +ProviderService->>ProviderService: Normalize provider result and reconstruct canonical segments + ProviderService->>ProviderService: Split repeated non-contiguous provider clusters inside each chunk + ProviderService->>EmbeddingAPI: Extract cluster samples for Omi speaker identity + ProviderService->>Firestore: Store canonical transcript segments and provider run ledger + ProviderService->>Firestore: Increment daily provider usage rollup + ProviderService->>Metrics: Emit provider request, latency, retry, fallback, audio seconds, clusters, identity confidence +``` + +AssemblyAI speaker labels are provider/session-local cluster IDs. Canonical +speaker identity is applied after provider normalization with the Omi embedding +service. Low-confidence matches remain explicit unknown speakers; provider +cluster metadata is preserved separately from `person_id` and `is_user`. +For desktop background chunks, unresolved provider clusters are scoped by +provider and `chunk_id`; repeated non-contiguous uses of the same provider +cluster inside one chunk are split before embedding identity matching. This +prevents rapid turn-taking from assigning multiple real speakers to one +app-visible speaker just because the provider reused `A`, `B`, `0`, or `1`. + +Optional desktop BYOK for AssemblyAI applies only to async prerecorded workloads +above; listen/realtime streaming in this pipeline remains Deepgram-only. + ## 6. Event Wire Protocol ### Server → Client (JSON over WS text frames) @@ -283,4 +366,3 @@ Both `transcribe.py` and `pusher.py` use an `asyncio.wait(FIRST_COMPLETED)` supe - Finite task normal completion (e.g. `process_pending_conversations`, `speaker_identification_task`) After supervisor exit, remaining tasks drain with `BG_DRAIN_TIMEOUT` (30s) before force-cancel. The connection gauge (`inc`/`dec`) is always paired in `try`/`finally`. - diff --git a/docs/docs.json b/docs/docs.json index fa8805c7abf..6644a1692f2 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -35,6 +35,7 @@ "doc/developer/backend/chat_system", "doc/developer/backend/StoringConversations", "doc/developer/backend/transcription", + "doc/developer/backend/assemblyai_background_rollout", "doc/developer/backend/listen_pusher_pipeline" ], "icon": "server" diff --git a/scripts/ASSEMBLYAI_BACKGROUND_E2E_AGENT_PROMPT.md b/scripts/ASSEMBLYAI_BACKGROUND_E2E_AGENT_PROMPT.md new file mode 100644 index 00000000000..99ee1787553 --- /dev/null +++ b/scripts/ASSEMBLYAI_BACKGROUND_E2E_AGENT_PROMPT.md @@ -0,0 +1,168 @@ +# Agent prompt: AssemblyAI background batch E2E (no desktop app) + +## Context + +Branch work adds **desktop always-on Audio Recording via AssemblyAI batch chunks** instead of `/v4/listen` (Deepgram streaming). Implementation is committed on the current branch. + +**Goal for you:** Build and run the strongest **non-desktop-app** E2E proof that the pipeline works end-to-end. Do **not** launch Omi Dev or use agent-swift unless absolutely necessary. + +## What exists today + +### Backend (working — verified manually) + +- `POST /v2/desktop/background-conversation/start` → `{ conversation_id }` +- `POST /v2/desktop/background-transcribe` → `{ segments, provider, run_id }` using `STTWorkload.background` → AssemblyAI +- Helpers: [`backend/utils/conversations/desktop_background.py`](backend/utils/conversations/desktop_background.py) +- Router: [`backend/routers/desktop_background.py`](backend/routers/desktop_background.py) +- Unit tests: [`backend/tests/unit/test_desktop_background_transcribe.py`](backend/tests/unit/test_desktop_background_transcribe.py) + +### Script (partial E2E — single-chunk only) + +[`scripts/desktop_assemblyai_e2e.py`](scripts/desktop_assemblyai_e2e.py): + +```bash +# Requires: local backend on :8080, ASSEMBLYAI_PRERECORDED_STT_ENABLED=true, Omi Dev signed in +python3 scripts/desktop_assemblyai_e2e.py --background-chunk --api http://127.0.0.1:8080 +``` + +Currently uploads **one full sample MP3 as a single PCM blob** (~154s). Proves backend + AssemblyAI once; does **not** simulate desktop chunking cadence or multi-chunk timeline. + +### Desktop (implemented but not in scope for your E2E) + +- Chunker/session in [`desktop/Desktop/Sources/BackgroundTranscription/`](desktop/Desktop/Sources/BackgroundTranscription/) +- AppState wiring in [`desktop/Desktop/Sources/AppState.swift`](desktop/Desktop/Sources/AppState.swift) +- Swift unit tests: [`desktop/Desktop/Tests/BackgroundTranscriptionTests.swift`](desktop/Desktop/Tests/BackgroundTranscriptionTests.swift) — chunker/merger/reducer only; **does not** exercise `AudioMixer` → AppState path + +### Known gap (why desktop failed in manual test) + +Live desktop recording started batch mode (`background-conversation/start`) but **never POSTed chunks** for 90+ seconds. Suspected cause: `AudioMixer` callback invokes `handleMixedBackgroundAudio` off MainActor. Your E2E should **not depend on fixing this** — but optionally add a Swift integration test that would catch it (see below). + +--- + +## Your mission + +Extend automated testing to get **as close as possible** to proving the full batch background path works **without running the desktop app**. + +### Tier 1 — Must deliver (Python, runnable in CI-like env) + +Extend [`scripts/desktop_assemblyai_e2e.py`](scripts/desktop_assemblyai_e2e.py) (or add `scripts/desktop_assemblyai_e2e_batch.py`) with a **`--background-batch`** mode that: + +1. **Prerequisites check** (exit with clear message if missing): + - Backend reachable at `--api` (default `http://127.0.0.1:8080`) + - Firebase token from `defaults read com.omi.desktop-dev auth_idToken` OR `--token` flag for CI + - Optional: grep backend `/docs` or a health endpoint; document required env vars + +2. **Simulate desktop chunking in Python** (mirror [`BackgroundTranscriptionConfiguration`](desktop/Desktop/Sources/BackgroundTranscription/BackgroundTranscriptionConfiguration.swift)): + - 16 kHz mono int16 PCM + - Split sample audio into **15s max chunks** with **1s overlap** (same as desktop chunker defaults) + - Use silence-boundary cuts if easy; hard-cut at 15s is acceptable for v1 + +3. **Full conversation lifecycle:** + ``` + POST /v2/desktop/background-conversation/start + for each chunk i: + POST /v2/desktop/background-transcribe + ?conversation_id=... + &chunk_start_ms= + &language=es # or en + POST /v1/conversations # force-process finalize + GET /v1/conversations?... # verify segments exist on conversation + ``` + +4. **Assertions (fail loudly):** + - Every chunk returns HTTP 200 + - Every chunk has `provider == "assemblyai"` + - `segments` non-empty for at least one chunk (speech sample) + - `chunk_start_ms` offsets applied: segment `start` values increase across chunks, no regression + - Multi-chunk: at least **2 chunks** uploaded from sample audio + - Finalize returns 200 or expected 404 if already processed + - Optional: fetch conversation from Firestore/API and assert `transcript_segments` length > 0 + +5. **CLI UX:** + ```bash + python3 scripts/desktop_assemblyai_e2e.py --background-batch [--api URL] [--language es] [--token TOKEN] + ``` + Print summary: chunk count, total segments, conversation_id, sample transcript snippet. + +6. **Document** in script header how to run locally: + ```bash + cd backend && DYLD_FALLBACK_LIBRARY_PATH="/opt/homebrew/lib" ./run-local.sh + # .env: ASSEMBLYAI_PRERECORDED_STT_ENABLED=true, ASSEMBLYAI_API_KEY=... + python3 scripts/desktop_assemblyai_e2e.py --background-batch + ``` + +### Tier 2 — Backend pytest integration (no live AssemblyAI if possible) + +Add [`backend/tests/integration/test_desktop_background_batch_e2e.py`](backend/tests/integration/test_desktop_background_batch_e2e.py) OR extend unit tests: + +- Mock `transcribe_bytes` for most tests (fast, no API key) +- One optional `@pytest.mark.integration` test that calls real AssemblyAI when `ASSEMBLYAI_API_KEY` set (skip otherwise) +- Test multi-chunk append: 3 chunks with offsets → Firestore segments merged correctly +- Test `provider_cluster_id` → distinct `speaker_id` when mock returns two clusters + +Run: `cd backend && python3 -m pytest tests/unit/test_desktop_background_transcribe.py -v` + +### Tier 3 — Swift integration test (no desktop app launch) + +Add test in [`desktop/Desktop/Tests/BackgroundTranscriptionTests.swift`](desktop/Desktop/Tests/BackgroundTranscriptionTests.swift) or new file: + +- **`testMixerCallbackDispatchesChunksToSession`**: Feed 16+ seconds of synthetic PCM through `BackgroundAudioChunker` the same way AppState would after receiving mixer output; mock HTTP handler returns fake segments; assert `transcribeNext` called and segments merged. +- **`testFifteenSecondContinuousSpeechProducesChunk`**: Append PCM at 100ms frames for 16s → assert at least one chunk enqueued without silence gaps. + +Run: `cd desktop && xcrun swift test -c debug --package-path Desktop --filter BackgroundTranscription` + +This catches the class of bug where audio never reaches the session (won't fix MainActor alone but validates chunker + session + drain loop). + +### Tier 4 — Optional: chunker parity test (Python ↔ Swift) + +Export chunk boundaries from Python chunker and assert Swift `BackgroundAudioChunker` produces same cut points for a fixture PCM file. Low priority unless easy. + +--- + +## Out of scope + +- Launching Omi Dev / agent-swift / live mic +- PTT path (`/v2/voice-message/transcribe`) +- `/v4/listen` WebSocket +- Prod Helm rollout + +--- + +## Success criteria + +You are done when: + +1. `python3 scripts/desktop_assemblyai_e2e.py --background-batch` passes against local backend with AssemblyAI enabled, printing conversation_id + segment count. +2. `cd backend && python3 -m pytest tests/unit/test_desktop_background_transcribe.py -v` passes (0 failures). +3. Swift `BackgroundTranscription` tests pass including at least one **multi-chunk / 15s boundary** test. +4. README or script docstring explains how to run without desktop app. + +--- + +## Reference files + +| File | Purpose | +|------|---------| +| [`backend/routers/desktop_background.py`](backend/routers/desktop_background.py) | HTTP endpoints | +| [`scripts/desktop_assemblyai_e2e.py`](scripts/desktop_assemblyai_e2e.py) | Extend this | +| [`desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift`](desktop/Desktop/Sources/BackgroundTranscription/BackgroundAudioChunker.swift) | Match chunk sizes | +| [`.cursor/plans/assemblyai_background_listen_098fc011.plan.md`](.cursor/plans/assemblyai_background_listen_098fc011.plan.md) | Full architecture plan | + +## Environment + +```bash +# backend/.env (required) +ASSEMBLYAI_PRERECORDED_STT_ENABLED=true +ASSEMBLYAI_PRERECORDED_STT_WORKLOADS=sync,background,postprocess +ASSEMBLYAI_API_KEY= +LOCAL_DEVELOPMENT=true +``` + +Auth: `LOCAL_DEVELOPMENT=true` maps failed token verify → uid `123`. + +--- + +## Notes from manual debugging + +- Single-chunk `--background-chunk` **passed** (AssemblyAI returned transcript for NBC sample). +- Desktop live path: conversation created but **no chunk POSTs** — fix is likely `Task { @MainActor in handleMixedBackgroundAudio(...) }` in mixer callback; **separate task**, not required for your Python E2E. diff --git a/scripts/desktop_assemblyai_e2e.py b/scripts/desktop_assemblyai_e2e.py new file mode 100755 index 00000000000..b2656a3ecdb --- /dev/null +++ b/scripts/desktop_assemblyai_e2e.py @@ -0,0 +1,575 @@ +#!/usr/bin/env python3 +"""Desktop-adjacent AssemblyAI E2E: uses Omi Dev auth + local Python backend. + +By default this script exercises the same backend path as mobile offline sync +(POST /v2/sync-local-files -> STTWorkload.sync). With --background-chunk it +exercises desktop batch listen (POST /v2/desktop/background-transcribe -> +STTWorkload.background) using raw PCM bytes. With --background-batch it +simulates desktop background chunking without launching the desktop app. + +Usage: + cd backend && DYLD_FALLBACK_LIBRARY_PATH="/opt/homebrew/lib" ./run-local.sh + # backend/.env: + # LOCAL_DEVELOPMENT=true + # ASSEMBLYAI_PRERECORDED_STT_ENABLED=true + # ASSEMBLYAI_PRERECORDED_STT_WORKLOADS=sync,background,postprocess + # ASSEMBLYAI_API_KEY=... + python3 scripts/desktop_assemblyai_e2e.py [--api http://127.0.0.1:8080] + python3 scripts/desktop_assemblyai_e2e.py --background-chunk [--api http://127.0.0.1:8080] + python3 scripts/desktop_assemblyai_e2e.py --background-batch [--api http://127.0.0.1:8080] [--language en] [--token TOKEN] +""" +from __future__ import annotations + +import argparse +import hashlib +import json +import os +import struct +import subprocess +import sys +import time +import urllib.error +import urllib.parse +import urllib.request +import wave +from pathlib import Path + +DEFAULTS_DOMAIN = "com.omi.desktop-dev" +SAMPLE_MP3_URL = "https://storage.googleapis.com/aai-docs-samples/nbc.mp3" +SAMPLE_RATE = 16000 +CHANNELS = 1 +BYTES_PER_SAMPLE = 2 +BACKGROUND_CHUNK_SECONDS = 15 +BACKGROUND_OVERLAP_SECONDS = 1 + + +def read_desktop_auth_token() -> str: + result = subprocess.run( + ["defaults", "read", DEFAULTS_DOMAIN, "auth_idToken"], + capture_output=True, + text=True, + check=False, + ) + token = (result.stdout or "").strip() + if result.returncode != 0 or not token: + raise SystemExit( + "No Omi Dev auth token found. Sign in via ./run.sh --yolo first, " + f"then retry (defaults domain: {DEFAULTS_DOMAIN})." + ) + return token + + +def read_backend_admin_key() -> str | None: + admin_key = os.getenv("ADMIN_KEY") + if admin_key: + return admin_key + + env_path = Path(__file__).resolve().parents[1] / "backend" / ".env" + if not env_path.exists(): + return None + + for line in env_path.read_text().splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("#") or not stripped.startswith("ADMIN_KEY="): + continue + return stripped.split("=", 1)[1].strip().strip('"').strip("'") or None + return None + + +def resolve_auth_token(args: argparse.Namespace) -> str: + if args.token: + return args.token + + if args.background_chunk or args.background_batch: + if args.use_desktop_auth: + print("Using Omi Dev desktop auth; this will persist sample transcripts to the signed-in local account.") + return read_desktop_auth_token() + + admin_key = read_backend_admin_key() + if not admin_key: + raise SystemExit( + "Background e2e persists transcript segments. Set ADMIN_KEY in backend/.env, pass --token, " + "or pass --use-desktop-auth to explicitly use the signed-in Omi Dev account." + ) + print(f"Using isolated local e2e uid={args.e2e_uid}.") + return f"{admin_key}{args.e2e_uid}" + + return read_desktop_auth_token() + + +def require_backend_reachable(api_base: str) -> None: + url = f"{api_base.rstrip('/')}/docs" + try: + req = urllib.request.Request(url, method="GET") + with urllib.request.urlopen(req, timeout=10) as resp: + if resp.status >= 500: + raise RuntimeError(f"Backend returned HTTP {resp.status}") + except (urllib.error.URLError, TimeoutError, RuntimeError) as exc: + raise SystemExit( + f"Backend is not reachable at {api_base}. Start the local backend first, for example:\n" + ' cd backend && DYLD_FALLBACK_LIBRARY_PATH="/opt/homebrew/lib" ./run-local.sh\n' + "Required backend env includes LOCAL_DEVELOPMENT=true, " + "ASSEMBLYAI_PRERECORDED_STT_ENABLED=true, and ASSEMBLYAI_API_KEY.\n" + f"Reachability error: {exc}" + ) + + +def mp3_to_pcm_bytes(mp3_path: Path, wav_path: Path, sample_rate: int = 16000) -> bytes: + subprocess.run( + [ + "ffmpeg", + "-y", + "-i", + str(mp3_path), + "-ar", + str(sample_rate), + "-ac", + "1", + "-f", + "wav", + str(wav_path), + ], + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + with wave.open(str(wav_path), "rb") as wav_file: + return wav_file.readframes(wav_file.getnframes()) + + +def write_length_prefixed_pcm_bin(pcm: bytes, bin_path: Path, sample_rate: int = 16000) -> None: + # Length-prefixed PCM frames (~100ms) for sync-local-files decoder + frame_samples = sample_rate // 10 + frame_bytes = frame_samples * 2 + with bin_path.open("wb") as out: + for offset in range(0, len(pcm), frame_bytes): + chunk = pcm[offset : offset + frame_bytes] + if not chunk: + continue + out.write(struct.pack(" None: + wav_path = bin_path.with_suffix(".wav") + pcm = mp3_to_pcm_bytes(mp3_path, wav_path, sample_rate) + write_length_prefixed_pcm_bin(pcm, bin_path, sample_rate) + wav_path.unlink(missing_ok=True) + + +def ensure_sample_mp3(workdir: Path) -> Path: + workdir.mkdir(parents=True, exist_ok=True) + mp3_path = workdir / "sample.mp3" + if not mp3_path.exists(): + print("Downloading sample audio...") + urllib.request.urlretrieve(SAMPLE_MP3_URL, mp3_path) + return mp3_path + + +def ensure_sample_bin(workdir: Path) -> Path: + mp3_path = ensure_sample_mp3(workdir) + bin_path = workdir / "desktop_e2e_sample.bin" + if not bin_path.exists(): + print("Converting sample to sync .bin format...") + mp3_to_pcm_bin(mp3_path, bin_path) + return bin_path + + +def ensure_sample_pcm(workdir: Path) -> Path: + mp3_path = ensure_sample_mp3(workdir) + pcm_path = workdir / "desktop_e2e_sample.raw.pcm" + if not pcm_path.exists(): + print("Converting sample to raw PCM format...") + wav_path = workdir / "desktop_e2e_sample.raw.wav" + pcm_path.write_bytes(mp3_to_pcm_bytes(mp3_path, wav_path)) + wav_path.unlink(missing_ok=True) + return pcm_path + + +def split_background_pcm_chunks(pcm: bytes) -> list[tuple[int, bytes]]: + """Mirror desktop's 15s chunks with 1s retained overlap. + + Returns (chunk_start_ms, pcm_bytes). Hard cuts are intentional here; the + Swift chunker may cut earlier at silence, but hard-cut parity is enough to + prove multi-chunk backend persistence and offset handling. + """ + bytes_per_second = SAMPLE_RATE * CHANNELS * BYTES_PER_SAMPLE + chunk_bytes = BACKGROUND_CHUNK_SECONDS * bytes_per_second + overlap_bytes = BACKGROUND_OVERLAP_SECONDS * bytes_per_second + stride_bytes = chunk_bytes - overlap_bytes + if stride_bytes <= 0: + raise ValueError("background chunk stride must be positive") + + chunks: list[tuple[int, bytes]] = [] + offset = 0 + while offset < len(pcm): + chunk = pcm[offset : offset + chunk_bytes] + if not chunk: + break + start_ms = int(offset / bytes_per_second * 1000) + chunks.append((start_ms, chunk)) + if offset + chunk_bytes >= len(pcm): + break + offset += stride_bytes + return chunks + + +def json_request( + url: str, + token: str, + *, + method: str = "POST", + payload: dict | None = None, + timeout: int = 120, +) -> dict: + data = json.dumps(payload or {}).encode() + req = urllib.request.Request( + url, + data=data, + method=method, + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + ) + with urllib.request.urlopen(req, timeout=timeout) as resp: + return json.loads(resp.read().decode()) + + +def get_json_request(url: str, token: str, *, timeout: int = 60) -> dict: + req = urllib.request.Request( + url, + method="GET", + headers={"Authorization": f"Bearer {token}"}, + ) + with urllib.request.urlopen(req, timeout=timeout) as resp: + return json.loads(resp.read().decode()) + + +def multipart_upload(url: str, token: str, bin_path: Path, filename: str) -> dict: + boundary = "----omiAssemblyAIe2e" + body = bytearray() + body.extend(f"--{boundary}\r\n".encode()) + body.extend(f'Content-Disposition: form-data; name="files"; filename="{filename}"\r\n'.encode()) + body.extend(b"Content-Type: application/octet-stream\r\n\r\n") + body.extend(bin_path.read_bytes()) + body.extend(f"\r\n--{boundary}--\r\n".encode()) + + req = urllib.request.Request( + url, + data=bytes(body), + method="POST", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": f"multipart/form-data; boundary={boundary}", + }, + ) + with urllib.request.urlopen(req, timeout=120) as resp: + return json.loads(resp.read().decode()) + + +def background_chunk_upload(api_base: str, token: str, pcm_path: Path) -> dict: + start_url = f"{api_base.rstrip('/')}/v2/desktop/background-conversation/start" + started = json_request(start_url, token, payload={"language": "en", "source": "desktop"}) + conversation_id = started.get("conversation_id") + if not conversation_id: + raise RuntimeError(f"Unexpected background-conversation response: {started}") + + transcribe_url = ( + f"{api_base.rstrip('/')}/v2/desktop/background-transcribe" + f"?conversation_id={conversation_id}&chunk_start_ms=0&sample_rate=16000&channels=1" + ) + req = urllib.request.Request( + transcribe_url, + data=pcm_path.read_bytes(), + method="POST", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/octet-stream", + }, + ) + with urllib.request.urlopen(req, timeout=180) as resp: + result = json.loads(resp.read().decode()) + result["_conversation_id"] = conversation_id + return result + + +def background_transcribe_chunk( + api_base: str, + token: str, + conversation_id: str, + chunk_start_ms: int, + chunk: bytes, + language: str, +) -> dict: + chunk_hash = hashlib.sha256(chunk).hexdigest() + chunk_id = f"{conversation_id}-{chunk_start_ms}-{len(chunk)}-{chunk_hash[:16]}" + transcribe_url = ( + f"{api_base.rstrip('/')}/v2/desktop/background-transcribe" + f"?conversation_id={conversation_id}" + f"&chunk_id={urllib.parse.quote(chunk_id)}" + f"&chunk_start_ms={chunk_start_ms}" + f"&sample_rate={SAMPLE_RATE}" + f"&channels={CHANNELS}" + f"&language={urllib.parse.quote(language)}" + ) + req = urllib.request.Request( + transcribe_url, + data=chunk, + method="POST", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/octet-stream", + }, + ) + with urllib.request.urlopen(req, timeout=180) as resp: + return json.loads(resp.read().decode()) + + +def background_batch_upload(api_base: str, token: str, pcm_path: Path, language: str) -> dict: + require_backend_reachable(api_base) + api_base = api_base.rstrip("/") + + start_url = f"{api_base}/v2/desktop/background-conversation/start" + started = json_request(start_url, token, payload={"language": language, "source": "desktop"}) + conversation_id = started.get("conversation_id") + if not conversation_id: + raise RuntimeError(f"Unexpected background-conversation response: {started}") + + chunks = split_background_pcm_chunks(pcm_path.read_bytes()) + if len(chunks) < 2: + raise RuntimeError(f"Expected sample to produce at least 2 chunks, got {len(chunks)}") + + total_segments = 0 + non_empty_chunks = 0 + previous_first_segment_start = None + snippet_parts: list[str] = [] + chunk_summaries: list[dict] = [] + + for index, (chunk_start_ms, chunk) in enumerate(chunks, start=1): + duration_ms = int(len(chunk) / (SAMPLE_RATE * CHANNELS * BYTES_PER_SAMPLE) * 1000) + print( + f"Posting chunk {index}/{len(chunks)} " + f"start_ms={chunk_start_ms} duration_ms={duration_ms} bytes={len(chunk)} ..." + ) + result = background_transcribe_chunk(api_base, token, conversation_id, chunk_start_ms, chunk, language) + provider = result.get("provider") + if provider != "assemblyai": + raise RuntimeError(f"Chunk {index} expected provider=assemblyai, got {provider!r}: {result}") + + segments = result.get("segments") or [] + for segment in segments: + if segment.get("start") is None or segment.get("end") is None: + raise RuntimeError(f"Chunk {index} returned segment without start/end: {segment}") + if segment["start"] + 0.001 < chunk_start_ms / 1000.0: + raise RuntimeError( + f"Chunk {index} segment start {segment['start']} regressed before offset {chunk_start_ms / 1000.0}" + ) + + first_segment_start = next((segment["start"] for segment in segments if segment.get("start") is not None), None) + if first_segment_start is not None: + if previous_first_segment_start is not None and first_segment_start + 0.001 < previous_first_segment_start: + raise RuntimeError( + f"Chunk {index} first segment start regressed: " + f"{first_segment_start} < {previous_first_segment_start}" + ) + previous_first_segment_start = first_segment_start + + if segments: + non_empty_chunks += 1 + snippet_parts.extend(segment.get("text", "") for segment in segments[:2]) + total_segments += len(segments) + chunk_summaries.append( + { + "index": index, + "chunk_start_ms": chunk_start_ms, + "chunk_duration_ms": result.get("chunk_duration_ms"), + "segments": len(segments), + "run_id": result.get("run_id"), + } + ) + + if non_empty_chunks == 0 or total_segments == 0: + raise RuntimeError("Expected non-empty AssemblyAI segments from at least one chunk.") + + before_finalize = get_json_request(f"{api_base}/v1/conversations/{conversation_id}", token) + persisted_before = before_finalize.get("transcript_segments") or [] + if not persisted_before: + raise RuntimeError(f"Expected persisted transcript_segments before finalize: {before_finalize}") + + finalize_status = None + finalize_body: dict | str | None = None + try: + finalize_body = json_request(f"{api_base}/v1/conversations", token, payload={}, timeout=240) + finalize_status = 200 + except urllib.error.HTTPError as exc: + finalize_status = exc.code + body = exc.read().decode(errors="replace") + try: + finalize_body = json.loads(body) + except json.JSONDecodeError: + finalize_body = body + if exc.code != 404: + raise RuntimeError(f"Finalize failed: HTTP {exc.code}\n{body}") from exc + + after_finalize = get_json_request(f"{api_base}/v1/conversations/{conversation_id}", token) + persisted_after = after_finalize.get("transcript_segments") or [] + if not persisted_after: + raise RuntimeError(f"Expected transcript_segments on conversation after finalize: {after_finalize}") + + return { + "conversation_id": conversation_id, + "chunks": chunk_summaries, + "total_segments_returned": total_segments, + "persisted_segments_before_finalize": len(persisted_before), + "persisted_segments_after_finalize": len(persisted_after), + "finalize_status": finalize_status, + "finalize_body": finalize_body, + "snippet": " ".join(part.strip() for part in snippet_parts if part.strip())[:300], + } + + +def poll_job(api_base: str, token: str, job_id: str, timeout_s: int = 900) -> dict: + url = f"{api_base.rstrip('/')}/v2/sync-local-files/{job_id}" + deadline = time.time() + timeout_s + while time.time() < deadline: + req = urllib.request.Request( + url, + headers={"Authorization": f"Bearer {token}"}, + ) + with urllib.request.urlopen(req, timeout=60) as resp: + job = json.loads(resp.read().decode()) + status = job.get("status") + stage = job.get("stage") + print(f" job {job_id}: status={status} stage={stage}") + if status in {"completed", "partial_failure", "failed"}: + return job + time.sleep(3) + raise TimeoutError(f"Timed out waiting for job {job_id}") + + +def main() -> int: + parser = argparse.ArgumentParser(description="Desktop AssemblyAI E2E via sync-local-files") + parser.add_argument("--api", default="http://127.0.0.1:8080", help="Local Python backend base URL") + parser.add_argument("--workdir", default="/tmp/omi-assemblyai-e2e", help="Temp dir for sample audio") + parser.add_argument("--language", default="en", help="Language passed to background transcription") + parser.add_argument( + "--token", + help=( + "Firebase ID token or ADMIN_KEY-prefixed uid. Background modes default to an isolated local e2e uid; " + "sync mode defaults to Omi Dev auth_idToken from macOS defaults." + ), + ) + parser.add_argument( + "--e2e-uid", + default="desktop-assemblyai-e2e", + help="Isolated local uid used by background modes when ADMIN_KEY is available and --token is omitted.", + ) + parser.add_argument( + "--use-desktop-auth", + action="store_true", + help="For background modes, explicitly persist sample transcripts to the signed-in Omi Dev account.", + ) + parser.add_argument( + "--background-chunk", + action="store_true", + help="Exercise /v2/desktop/background-transcribe with raw PCM instead of sync-local-files", + ) + parser.add_argument( + "--background-batch", + action="store_true", + help="Exercise desktop background batch lifecycle with 15s/1s-overlap raw PCM chunks", + ) + args = parser.parse_args() + + if args.background_chunk and args.background_batch: + print("Choose only one of --background-chunk or --background-batch.", file=sys.stderr) + return 2 + + token = resolve_auth_token(args) + + if args.background_chunk: + pcm_path = ensure_sample_pcm(Path(args.workdir)) + print(f"Posting raw PCM chunk {pcm_path.name} to {args.api.rstrip('/')}/v2/desktop/background-transcribe ...") + try: + result = background_chunk_upload(args.api, token, pcm_path) + except urllib.error.HTTPError as exc: + body = exc.read().decode(errors="replace") + print(f"Background chunk failed: HTTP {exc.code}\n{body}", file=sys.stderr) + return 1 + except RuntimeError as exc: + print(str(exc), file=sys.stderr) + return 1 + + print(json.dumps(result, indent=2)) + if result.get("provider") != "assemblyai": + print(f"Expected provider=assemblyai, got {result.get('provider')!r}.", file=sys.stderr) + return 1 + if not result.get("segments"): + print("Expected non-empty segments from background chunk.", file=sys.stderr) + return 1 + + print("\nBackground chunk succeeded with provider=assemblyai.") + print(f"Conversation: {result.get('_conversation_id')}") + return 0 + + if args.background_batch: + pcm_path = ensure_sample_pcm(Path(args.workdir)) + print(f"Posting simulated background batch from {pcm_path.name} to {args.api.rstrip('/')} ...") + try: + summary = background_batch_upload(args.api, token, pcm_path, args.language) + except urllib.error.HTTPError as exc: + body = exc.read().decode(errors="replace") + print(f"Background batch failed: HTTP {exc.code}\n{body}", file=sys.stderr) + return 1 + except (RuntimeError, SystemExit) as exc: + print(str(exc), file=sys.stderr) + return 1 + + print("\nBackground batch succeeded with provider=assemblyai.") + print(f"Conversation: {summary['conversation_id']}") + print(f"Chunks uploaded: {len(summary['chunks'])}") + print(f"Segments returned by chunks: {summary['total_segments_returned']}") + print(f"Segments persisted before finalize: {summary['persisted_segments_before_finalize']}") + print(f"Segments persisted after finalize: {summary['persisted_segments_after_finalize']}") + print(f"Finalize status: HTTP {summary['finalize_status']}") + if summary["snippet"]: + print(f"Transcript snippet: {summary['snippet']}") + print("\nChunk summary:") + print(json.dumps(summary["chunks"], indent=2)) + return 0 + + bin_path = ensure_sample_bin(Path(args.workdir)) + ts = int(time.time()) + # Filename must include _pcm16_{sampleRate}_ so sync decode uses PCM path (not Opus). + filename = f"audio_desktop_pcm16_16000_1_fs160_{ts}.bin" + + upload_url = f"{args.api.rstrip('/')}/v2/sync-local-files" + print(f"Uploading {bin_path.name} to {upload_url} ...") + try: + queued = multipart_upload(upload_url, token, bin_path, filename) + except urllib.error.HTTPError as exc: + body = exc.read().decode(errors="replace") + print(f"Upload failed: HTTP {exc.code}\n{body}", file=sys.stderr) + return 1 + + job_id = queued.get("job_id") + if not job_id: + print(f"Unexpected response: {queued}", file=sys.stderr) + return 1 + + print(f"Queued job_id={job_id}; polling...") + final = poll_job(args.api, token, job_id) + print(json.dumps(final, indent=2)) + + if final.get("status") not in {"completed", "partial_failure"}: + print("Job did not complete successfully.", file=sys.stderr) + return 1 + + print("\nCheck backend logs for provider=assemblyai on workload=sync.") + print("Firestore: transcription_provider_runs collection for your uid.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/web/frontend/package.json b/web/frontend/package.json index c55ff317aaf..33016c30ac7 100644 --- a/web/frontend/package.json +++ b/web/frontend/package.json @@ -6,7 +6,7 @@ "dev": "next dev --turbo", "build": "next build", "start": "next start", - "lint": "next lint", + "lint": "eslint ./src --ext .jsx,.js,.ts,.tsx --quiet --ignore-path ./.eslintignore", "lint:fix": "eslint ./src --ext .jsx,.js,.ts,.tsx --quiet --fix --ignore-path ./.eslintignore", "lint:format": "prettier --loglevel warn --write \"./**/*.{js,jsx,ts,tsx,css,md,json}\" " }, diff --git a/web/frontend/src/app/globals.css b/web/frontend/src/app/globals.css index 45b289a6989..deea1d98934 100644 --- a/web/frontend/src/app/globals.css +++ b/web/frontend/src/app/globals.css @@ -35,5 +35,6 @@ } .font-system-ui { - font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif; + font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, + 'Helvetica Neue', Arial, sans-serif; }