diff --git a/backend/database/redis_db.py b/backend/database/redis_db.py index 00f76372ada..e33447a3b45 100644 --- a/backend/database/redis_db.py +++ b/backend/database/redis_db.py @@ -586,6 +586,11 @@ def has_silent_user_notification_been_sent(uid: str) -> bool: return r.exists(f'users:{uid}:silent_notification_sent') +def try_acquire_byok_llm_error_notification_lock(uid: str, provider: str, reason: str, ttl: int = 60 * 60 * 24) -> bool: + """Return True once per BYOK provider/error reason per TTL window.""" + return bool(r.set(f'users:{uid}:byok_llm_error:{provider}:{reason}', '1', ex=ttl, nx=True)) + + # ****************************************************** # ******* IMPORTANT CONVERSATION NOTIFICATIONS ********* # ****************************************************** @@ -636,8 +641,7 @@ def remove_conversation_summary_app_id(app_id: str) -> bool: # Lua script: atomic increment + TTL in a single round-trip. # Returns [current_count, ttl_remaining]. Sets TTL on first hit # and self-heals any key that lost its TTL (prevents permanent buckets). -_RATE_LIMIT_LUA = r.register_script( - """ +_RATE_LIMIT_LUA = r.register_script(""" local key = KEYS[1] local window = tonumber(ARGV[1]) local current = redis.call('INCR', key) @@ -650,8 +654,7 @@ def remove_conversation_summary_app_id(app_id: str) -> bool: ttl = window end return {current, ttl} -""" -) +""") def check_rate_limit(key: str, policy: str, max_requests: int, window: int) -> tuple[bool, int, int]: @@ -680,8 +683,7 @@ def check_rate_limit(key: str, policy: str, max_requests: int, window: int) -> t # Burst uses a sorted set keyed by timestamp-ms for sliding-window accuracy, # trimmed on every call (O(log n)). Daily char counter auto-expires at midnight # UTC (caller passes seconds_until_midnight_utc as the TTL). -_TTS_RATE_LIMIT_LUA = r.register_script( - """ +_TTS_RATE_LIMIT_LUA = r.register_script(""" local burst_key = KEYS[1] local daily_key = KEYS[2] local now_ms = tonumber(ARGV[1]) @@ -709,8 +711,7 @@ def check_rate_limit(key: str, policy: str, max_requests: int, window: int) -> t redis.call('EXPIRE', daily_key, daily_ttl) end return {0, 0} -""" -) +""") def _seconds_until_midnight_utc() -> int: diff --git a/backend/routers/pusher.py b/backend/routers/pusher.py index 2e4625076df..f1967ea3fc2 100644 --- a/backend/routers/pusher.py +++ b/backend/routers/pusher.py @@ -25,7 +25,7 @@ trigger_external_integrations, ) from utils.conversations.location import async_get_google_maps_location -from utils.byok import set_byok_keys +from utils.byok import set_byok_keys, set_byok_uid from utils.conversations.process_conversation import process_conversation from utils.executors import storage_executor from utils.webhooks import ( @@ -79,6 +79,7 @@ async def _process_conversation_task( """ if byok_keys: set_byok_keys(byok_keys) + set_byok_uid(uid) try: conversation_data = conversations_db.get_conversation(uid, conversation_id) if not conversation_data: diff --git a/backend/routers/sync.py b/backend/routers/sync.py index 671deba0881..ace4806c7f0 100644 --- a/backend/routers/sync.py +++ b/backend/routers/sync.py @@ -49,7 +49,7 @@ ) from utils import encryption -from utils.byok import get_byok_keys, set_byok_keys +from utils.byok import get_byok_keys, set_byok_keys, set_byok_uid from utils.log_sanitizer import sanitize from utils.stt.pre_recorded import deepgram_prerecorded, get_deepgram_model_for_language, postprocess_words from utils.stt.vad import vad_is_empty @@ -1359,6 +1359,7 @@ def _run_full_pipeline_background( Moved ALL heavy processing here so the v2 endpoint returns 202 immediately. """ set_byok_keys(byok_keys or {}) + set_byok_uid(uid if byok_keys else None) segmented_paths = set() wav_paths = [] stage_timings = {} @@ -1580,6 +1581,7 @@ def _process_one_segment(path): pass finally: set_byok_keys({}) + set_byok_uid(None) _cleanup_files(list(segmented_paths)) _cleanup_files(wav_paths) try: diff --git a/backend/routers/users.py b/backend/routers/users.py index 0d78de5722f..df19f2c1af1 100644 --- a/backend/routers/users.py +++ b/backend/routers/users.py @@ -1,6 +1,5 @@ import json import re -import threading import uuid from typing import List, Dict, Any, Union, Optional import hashlib @@ -62,6 +61,7 @@ PricingOption, PhoneCallQuota, ) +from utils.executors import storage_executor, submit_with_context from utils.phone_calls import get_quota_snapshot as get_phone_call_quota_snapshot from utils.apps import get_available_app_by_id from utils.subscription import ( @@ -165,7 +165,7 @@ def delete_account( # 3. Wipe Firestore subcollections in the background โ€” can take minutes # for heavy users and would otherwise time out at the load balancer. - threading.Thread(target=_background_wipe_user_data, args=(uid,), daemon=True).start() + submit_with_context(storage_executor, _background_wipe_user_data, uid) return {'status': 'ok', 'message': 'Account deletion started'} except Exception as e: diff --git a/backend/test.sh b/backend/test.sh index 6ec08b6d4d9..1da0e2ffcb9 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -111,6 +111,7 @@ pytest tests/unit/test_thread_join_elimination.py -v pytest tests/unit/test_async_http_infrastructure.py -v pytest tests/unit/test_clean_sweep_migrations.py -v pytest tests/unit/test_omi_qos_tiers.py -v +pytest tests/unit/test_byok_llm_errors.py -v pytest tests/unit/test_byok_security.py -v pytest tests/unit/test_paywall_reconnect_gate.py -v pytest tests/unit/test_vertex_ai_system_role.py -v diff --git a/backend/tests/unit/test_action_item_date_validation.py b/backend/tests/unit/test_action_item_date_validation.py index 5a36a8455c2..cd42eade7d9 100644 --- a/backend/tests/unit/test_action_item_date_validation.py +++ b/backend/tests/unit/test_action_item_date_validation.py @@ -82,6 +82,8 @@ def _load_module_from_file(module_name, file_path): if mod_name not in sys.modules: _stub_module(mod_name) +sys.modules["database.auth"].get_user_name = MagicMock(return_value="TestUser") + # Stub database.action_items action_items_db = _stub_module("database.action_items") action_items_db.create_action_item = MagicMock(return_value="test-item-id") diff --git a/backend/tests/unit/test_async_app_integrations.py b/backend/tests/unit/test_async_app_integrations.py index 796b029f200..d1210d63140 100644 --- a/backend/tests/unit/test_async_app_integrations.py +++ b/backend/tests/unit/test_async_app_integrations.py @@ -37,8 +37,10 @@ "vector_db", "apps", "llm_usage", + "user_usage", "chat", "goals", + "announcements", ]: mod = types.ModuleType(f"database.{submod}") sys.modules.setdefault(f"database.{submod}", mod) @@ -53,6 +55,10 @@ sys.modules["database.redis_db"].get_proactive_noti_sent_at_ttl = MagicMock(return_value=0) sys.modules["database.redis_db"].incr_daily_notification_count = MagicMock() sys.modules["database.redis_db"].get_daily_notification_count = MagicMock(return_value=0) +sys.modules["database.redis_db"].delete_generic_cache = MagicMock() +sys.modules["database.user_usage"].get_monthly_chat_usage = MagicMock(return_value={}) +sys.modules["database.user_usage"].get_monthly_usage_stats_since = MagicMock(return_value={}) +sys.modules["database.announcements"].compare_versions = MagicMock(return_value=0) sys.modules["database.vector_db"].query_vectors_by_metadata = MagicMock(return_value=[]) sys.modules["database.apps"].record_app_usage = MagicMock() sys.modules["database.llm_usage"].record_llm_usage = MagicMock() diff --git a/backend/tests/unit/test_available_plans_resilience.py b/backend/tests/unit/test_available_plans_resilience.py index ef19495fdfc..49ff8b629c5 100644 --- a/backend/tests/unit/test_available_plans_resilience.py +++ b/backend/tests/unit/test_available_plans_resilience.py @@ -73,6 +73,7 @@ def _compare_versions(a, b): _announcements_mod._compare_versions = _compare_versions +_announcements_mod.compare_versions = _compare_versions # database.users needs the functions payment.py imports by name _users_mod = sys.modules["database.users"] @@ -126,6 +127,7 @@ def _compare_versions(a, b): _endpoints_mod = sys.modules["utils.other.endpoints"] _endpoints_mod.get_current_user_uid = lambda: "test-user" +_endpoints_mod.get_current_user_uid_no_byok_validation = lambda: "test-user" # Ensure utils.other has endpoints attr for `from utils.other import endpoints` sys.modules["utils.other"].endpoints = _endpoints_mod diff --git a/backend/tests/unit/test_batch_upload_storage.py b/backend/tests/unit/test_batch_upload_storage.py index 08045593a20..e9b13d8139a 100644 --- a/backend/tests/unit/test_batch_upload_storage.py +++ b/backend/tests/unit/test_batch_upload_storage.py @@ -11,12 +11,16 @@ import os import sys +import types from unittest.mock import MagicMock, patch os.environ.setdefault("ENCRYPTION_SECRET", "omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv") # Mock heavy dependencies at sys.modules level before importing storage sys.modules.setdefault("database._client", MagicMock()) +subscription_mod = types.ModuleType("utils.subscription") +subscription_mod.get_default_basic_subscription = MagicMock() +sys.modules.setdefault("utils.subscription", subscription_mod) _mock_gcs_storage = MagicMock() _mock_gcs_client_instance = MagicMock() diff --git a/backend/tests/unit/test_byok_llm_errors.py b/backend/tests/unit/test_byok_llm_errors.py new file mode 100644 index 00000000000..619a9f36500 --- /dev/null +++ b/backend/tests/unit/test_byok_llm_errors.py @@ -0,0 +1,112 @@ +import os +import sys +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +os.environ.setdefault('OPENAI_API_KEY', 'sk-test-fake-for-unit-tests') +os.environ.setdefault('ANTHROPIC_API_KEY', 'ant-test-fake-for-unit-tests') +os.environ.setdefault('ENCRYPTION_SECRET', 'omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv') + +sys.modules.setdefault('database._client', MagicMock()) + + +class _HTTPError(Exception): + def __init__(self, message: str, status_code: int): + super().__init__(message) + self.status_code = status_code + + +def test_classify_byok_llm_error_authentication(): + from utils.llm.byok_errors import classify_byok_llm_error + + assert classify_byok_llm_error(_HTTPError("bad api key", 401)) == 'invalid' + + +def test_classify_byok_llm_error_permission(): + from utils.llm.byok_errors import classify_byok_llm_error + + assert classify_byok_llm_error(_HTTPError("project denied", 403)) == 'permission' + + +def test_classify_byok_llm_error_insufficient_quota(): + from utils.llm.byok_errors import classify_byok_llm_error + + assert classify_byok_llm_error(_HTTPError("insufficient_quota", 429)) == 'quota' + + +def test_classify_byok_llm_error_ignores_transient_rate_limit(): + from utils.llm.byok_errors import classify_byok_llm_error + + assert classify_byok_llm_error(_HTTPError("rate limit reached, retry later", 429)) is None + + +@patch('utils.llm.byok_errors.messaging.send_each') +@patch('utils.llm.byok_errors.notification_db.get_all_tokens', return_value=['token-1']) +@patch('utils.llm.byok_errors.try_acquire_byok_llm_error_notification_lock', return_value=True) +@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1') +@patch('utils.llm.byok_errors.get_byok_key', return_value='sk-user') +def test_handle_llm_error_notifies_actionable_byok_error( + mock_get_key, + mock_get_uid, + mock_lock, + mock_get_tokens, + mock_send_each, +): + from utils.llm.byok_errors import handle_llm_error + + mock_send_each.return_value = SimpleNamespace(responses=[SimpleNamespace(success=True, exception=None)]) + + handle_llm_error(_HTTPError("insufficient_quota", 429), 'openai', feature='memories', model='gpt-test') + + mock_lock.assert_called_once_with('user-1', 'openai', 'quota') + mock_get_tokens.assert_called_once_with('user-1') + mock_send_each.assert_called_once() + message = mock_send_each.call_args.args[0][0] + assert message.data == {'type': 'byok_llm_error', 'provider': 'openai', 'reason': 'quota'} + + +@patch('utils.llm.byok_errors.messaging.send_each') +@patch('utils.llm.byok_errors.try_acquire_byok_llm_error_notification_lock') +@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1') +@patch('utils.llm.byok_errors.get_byok_key', return_value=None) +def test_handle_llm_error_does_not_notify_platform_error( + mock_get_key, + mock_get_uid, + mock_lock, + mock_send_each, +): + from utils.llm.byok_errors import handle_llm_error + + handle_llm_error(_HTTPError("insufficient_quota", 429), 'openai', feature='memories', model='gpt-test') + + mock_lock.assert_not_called() + mock_send_each.assert_not_called() + + +def test_validate_byok_request_records_current_uid(): + from utils.byok import get_byok_uid, validate_byok_request + + with patch('utils.byok._check_byok_validity', return_value=None): + validate_byok_request('user-1') + + assert get_byok_uid() == 'user-1' + + +def test_anthropic_proxy_constructs_default_client_lazily(): + from utils.llm.clients import _AnthropicClientProxy + + created = [] + + def _fake_client(**kwargs): + created.append(kwargs) + return object() + + proxy = _AnthropicClientProxy() + + with patch('utils.llm.clients.get_byok_key', return_value=None), patch( + 'utils.llm.clients.anthropic.AsyncAnthropic', side_effect=_fake_client + ): + assert created == [] + proxy._resolve() + + assert created == [{'timeout': 120.0, 'max_retries': 1}] diff --git a/backend/tests/unit/test_byok_security.py b/backend/tests/unit/test_byok_security.py index 5353218fce5..55376e5544e 100644 --- a/backend/tests/unit/test_byok_security.py +++ b/backend/tests/unit/test_byok_security.py @@ -26,6 +26,7 @@ sys.modules.setdefault('database._client', MagicMock()) sys.modules.setdefault('database.redis_db', MagicMock()) sys.modules.setdefault('database.users', MagicMock()) +sys.modules.setdefault('database.notifications', MagicMock()) sys.modules.setdefault('database.user_usage', MagicMock()) sys.modules.setdefault('database.llm_usage', MagicMock()) sys.modules.setdefault('database.announcements', MagicMock()) diff --git a/backend/tests/unit/test_chat_quota.py b/backend/tests/unit/test_chat_quota.py index ef779bb9aef..88472ded661 100644 --- a/backend/tests/unit/test_chat_quota.py +++ b/backend/tests/unit/test_chat_quota.py @@ -23,9 +23,12 @@ def _compare_versions(a, b): _announcements_mod._compare_versions = _compare_versions +_announcements_mod.compare_versions = _compare_versions # Create stubs for database modules used by get_chat_quota_snapshot -_db_users_mod = types.SimpleNamespace(get_user_valid_subscription=MagicMock()) +_db_users_mod = types.SimpleNamespace( + get_user_valid_subscription=MagicMock(), is_byok_active=MagicMock(return_value=False) +) _db_user_usage_mod = types.SimpleNamespace(get_monthly_chat_usage=MagicMock()) sys.modules.setdefault("database._client", types.SimpleNamespace(db=MagicMock())) @@ -319,8 +322,8 @@ def test_enforcement_allowed(self, monkeypatch): ): sub_mod.enforce_chat_quota("uid123") # no exception - def test_enforcement_exceeded_raises_402(self, monkeypatch): - """When user exceeds quota, raises HTTPException 402.""" + def test_enforcement_exceeded_basic_raises_402(self, monkeypatch): + """When a basic user exceeds quota, raises HTTPException 402.""" from fastapi import HTTPException sub_mod = _reload_subscription_module() @@ -330,10 +333,10 @@ def test_enforcement_exceeded_raises_402(self, monkeypatch): "get_chat_quota_snapshot", return_value={ 'allowed': False, - 'plan': PlanType.unlimited, + 'plan': PlanType.basic, 'unit': 'questions', - 'used': 2001, - 'limit': 2000, + 'used': 31, + 'limit': 30, 'reset_at': _RESET_AT, }, ): @@ -342,17 +345,15 @@ def test_enforcement_exceeded_raises_402(self, monkeypatch): assert exc_info.value.status_code == 402 assert exc_info.value.detail['error'] == 'quota_exceeded' - assert exc_info.value.detail['plan'] == 'Neo' - assert exc_info.value.detail['plan_type'] == 'unlimited' + assert exc_info.value.detail['plan'] == 'Free' + assert exc_info.value.detail['plan_type'] == 'basic' assert exc_info.value.detail['unit'] == 'questions' - assert exc_info.value.detail['used'] == 2001 - assert exc_info.value.detail['limit'] == 2000 + assert exc_info.value.detail['used'] == 31 + assert exc_info.value.detail['limit'] == 30 assert exc_info.value.detail['reset_at'] == _RESET_AT - def test_enforcement_402_operator_plan(self, monkeypatch): - """Operator plan shows correct display name in 402 detail.""" - from fastapi import HTTPException - + def test_enforcement_operator_overage_allowed(self, monkeypatch): + """Operator users over quota are allowed and handled as overage.""" sub_mod = _reload_subscription_module() with patch.object( @@ -367,16 +368,10 @@ def test_enforcement_402_operator_plan(self, monkeypatch): 'reset_at': _RESET_AT, }, ): - with pytest.raises(HTTPException) as exc_info: - sub_mod.enforce_chat_quota("uid123") - - assert exc_info.value.status_code == 402 - assert exc_info.value.detail['plan'] == 'Operator' - - def test_enforcement_402_architect_cost_based(self, monkeypatch): - """Architect plan shows cost_usd unit in 402 detail.""" - from fastapi import HTTPException + sub_mod.enforce_chat_quota("uid123") + def test_enforcement_architect_cost_overage_allowed(self, monkeypatch): + """Architect users over cost cap are allowed and handled as overage.""" sub_mod = _reload_subscription_module() with patch.object( @@ -391,9 +386,4 @@ def test_enforcement_402_architect_cost_based(self, monkeypatch): 'reset_at': _RESET_AT, }, ): - with pytest.raises(HTTPException) as exc_info: - sub_mod.enforce_chat_quota("uid123") - - assert exc_info.value.status_code == 402 - assert exc_info.value.detail['unit'] == 'cost_usd' - assert exc_info.value.detail['used'] == 400.5 + sub_mod.enforce_chat_quota("uid123") diff --git a/backend/tests/unit/test_daily_summary_race_condition.py b/backend/tests/unit/test_daily_summary_race_condition.py index 3289eb8e45b..11ea99b8db2 100644 --- a/backend/tests/unit/test_daily_summary_race_condition.py +++ b/backend/tests/unit/test_daily_summary_race_condition.py @@ -73,6 +73,7 @@ def try_acquire_daily_summary_lock(uid: str, date: str, ttl: int = 60 * 60 * 2) for name in [ "utils.llm.external_integrations", "utils.notifications", + "utils.subscription", "utils.webhooks", "utils.conversations", "utils.conversations.factory", @@ -97,6 +98,9 @@ def try_acquire_daily_summary_lock(uid: str, date: str, ttl: int = 60 * 60 * 2) utils_notifications.send_bulk_notification = MagicMock() utils_notifications.send_notification = MagicMock() +utils_subscription = sys.modules["utils.subscription"] +utils_subscription.is_trial_paywalled = MagicMock(return_value=False) + utils_webhooks = sys.modules["utils.webhooks"] utils_webhooks.day_summary_webhook = MagicMock() diff --git a/backend/tests/unit/test_desktop_migration.py b/backend/tests/unit/test_desktop_migration.py index 289d88811e9..a3070d6caf6 100644 --- a/backend/tests/unit/test_desktop_migration.py +++ b/backend/tests/unit/test_desktop_migration.py @@ -69,6 +69,9 @@ def _stub_package(name): if mod_name not in sys.modules: _stub_module(mod_name) +redis_stub = sys.modules["database.redis_db"] +redis_stub.try_acquire_user_platform_write_lock = MagicMock(return_value=True) + # Stub google.cloud.firestore sentinels firestore_stub = sys.modules["google.cloud.firestore"] firestore_stub.Increment = lambda x: f"__increment_{x}__" @@ -1694,7 +1697,7 @@ def test_returns_title(self, mock_update): session_id='s1', messages=[TitleMessageInput(text='hi', sender='human'), TitleMessageInput(text='hello', sender='ai')], ) - with patch.dict('sys.modules', {'utils.llm.clients': MagicMock(llm_mini=mock_llm)}): + with patch.dict('sys.modules', {'utils.llm.clients': MagicMock(get_llm=MagicMock(return_value=mock_llm))}): result = generate_session_title(request, uid='u1') assert result == {'title': 'Project Discussion'} @@ -1713,7 +1716,7 @@ def test_empty_response_defaults_to_new_chat(self, mock_update): session_id='s1', messages=[TitleMessageInput(text='hi', sender='human')], ) - with patch.dict('sys.modules', {'utils.llm.clients': MagicMock(llm_mini=mock_llm)}): + with patch.dict('sys.modules', {'utils.llm.clients': MagicMock(get_llm=MagicMock(return_value=mock_llm))}): result = generate_session_title(request, uid='u1') assert result == {'title': 'New Chat'} diff --git a/backend/tests/unit/test_desktop_transcribe.py b/backend/tests/unit/test_desktop_transcribe.py index 3b750a4f180..c992b6942fe 100644 --- a/backend/tests/unit/test_desktop_transcribe.py +++ b/backend/tests/unit/test_desktop_transcribe.py @@ -27,6 +27,7 @@ 'transcript_segment', 'chat', 'conversation', + 'conversation_enums', 'notification_message', 'app', 'memory', @@ -113,8 +114,6 @@ 'utils.llm.chat', 'utils.llm.goals', 'utils.llm.usage_tracker', - 'utils.conversations', - 'utils.conversations.process_conversation', 'utils.notifications', 'utils.other.storage', 'utils.other.chat_file', @@ -131,6 +130,24 @@ ]: sys.modules.setdefault(_ufull, MagicMock()) +_utils_pkg = sys.modules.get('utils') +_conv_pkg = ModuleType('utils.conversations') +_conv_pkg.__path__ = ['utils/conversations'] +_conv_pkg.__package__ = 'utils.conversations' +sys.modules['utils.conversations'] = _conv_pkg +if _utils_pkg is not None: + setattr(_utils_pkg, 'conversations', _conv_pkg) + +_conv_factory = ModuleType('utils.conversations.factory') +_conv_factory.deserialize_conversation = MagicMock() +sys.modules['utils.conversations.factory'] = _conv_factory +setattr(_conv_pkg, 'factory', _conv_factory) + +_conv_process = ModuleType('utils.conversations.process_conversation') +_conv_process.process_conversation = MagicMock() +sys.modules['utils.conversations.process_conversation'] = _conv_process +setattr(_conv_pkg, 'process_conversation', _conv_process) + # Force-import real models.chat (has no project deps, needed for FastAPI response_model) import importlib.util as _ilu @@ -430,14 +447,14 @@ class TestDeepgramPrerecordedFromBytesEdgeCases: @patch('utils.stt.pre_recorded._deepgram_client') def test_retry_raises_after_max_attempts(self, mock_client): - """After 3 failed attempts, should raise RuntimeError.""" + """After the configured retry is exhausted, should raise RuntimeError.""" mock_client.listen.rest.v.return_value.transcribe_file.side_effect = Exception('connection timeout') - with pytest.raises(RuntimeError, match='Deepgram transcription failed after 3 attempts'): + with pytest.raises(RuntimeError, match='Deepgram transcription failed after 2 attempts'): deepgram_prerecorded_from_bytes(b'\x00' * 100, encoding='linear16') - # Should have been called 3 times (attempts 0, 1, 2) - assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 3 + # Should have been called twice (initial + one retry) + assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 2 @patch('utils.stt.pre_recorded._deepgram_client') def test_return_language_empty_words_returns_detected_lang(self, mock_client): @@ -462,10 +479,10 @@ def test_no_channels_raises_and_retries(self, mock_client): mock_response.to_dict.return_value = {'results': {'channels': []}} mock_client.listen.rest.v.return_value.transcribe_file.return_value = mock_response - with pytest.raises(RuntimeError, match='Deepgram transcription failed after 3 attempts'): + with pytest.raises(RuntimeError, match='Deepgram transcription failed after 2 attempts'): deepgram_prerecorded_from_bytes(b'\x00' * 100) - assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 3 + assert mock_client.listen.rest.v.return_value.transcribe_file.call_count == 2 # --------------------------------------------------------------------------- @@ -509,6 +526,10 @@ def _stub_router_deps(): rdb = sys.modules.get('database.redis_db') if rdb: rdb.check_rate_limit = MagicMock(return_value=(True, 99, 0)) + subscription = sys.modules.get('utils.subscription') + if subscription: + subscription.enforce_chat_quota = MagicMock() + subscription.is_trial_paywalled = MagicMock(return_value=False) def _make_chat_client(): diff --git a/backend/tests/unit/test_dg_usage_batch.py b/backend/tests/unit/test_dg_usage_batch.py index c32a8362728..af948b0cbef 100644 --- a/backend/tests/unit/test_dg_usage_batch.py +++ b/backend/tests/unit/test_dg_usage_batch.py @@ -76,11 +76,14 @@ def setup_method(self): 'database.user_usage', 'database.conversations', 'firebase_admin', + 'firebase_admin.auth', 'firebase_admin.messaging', ]: if mod_name not in sys.modules: sys.modules[mod_name] = ModuleType(mod_name) + sys.modules['firebase_admin'].auth = sys.modules['firebase_admin.auth'] + sys.modules['firebase_admin'].messaging = sys.modules['firebase_admin.messaging'] sys.modules['database._client'].db = MagicMock() sys.modules['database.redis_db'].r = MagicMock() diff --git a/backend/tests/unit/test_fair_use_classifier.py b/backend/tests/unit/test_fair_use_classifier.py index 88bf803cdbb..0ed6467a669 100644 --- a/backend/tests/unit/test_fair_use_classifier.py +++ b/backend/tests/unit/test_fair_use_classifier.py @@ -1,6 +1,7 @@ """Tests for the LLM fair-use classifier (utils/llm/fair_use_classifier.py).""" import json +import os import sys import types from datetime import datetime, timedelta @@ -8,6 +9,8 @@ import pytest +os.environ.setdefault("OPENAI_API_KEY", "test-openai-key") + # --------------------------------------------------------------------------- # Stub heavy dependencies before importing the module under test # --------------------------------------------------------------------------- @@ -35,6 +38,11 @@ import utils.llm.fair_use_classifier as classifier_mod +def _mock_classifier_llm(**kwargs): + classifier_mod._classifier_llm = MagicMock() + classifier_mod._classifier_llm.ainvoke = AsyncMock(**kwargs) + + class TestSelectRecipes: """Test dynamic recipe selection based on conversation patterns.""" @@ -167,7 +175,7 @@ async def test_parses_llm_response_correctly(self): 'reasoning': 'Clear audiobook pattern', } ) - _llm_clients.llm_mini.ainvoke = AsyncMock(return_value=llm_response) + _mock_classifier_llm(return_value=llm_response) result = await classifier_mod.classify_user_purpose('user1') @@ -192,7 +200,7 @@ async def test_handles_markdown_code_block_response(self): llm_response = MagicMock() llm_response.content = '```json\n{"misuse_score": 0.1, "usage_type": "none", "confidence": 0.9, "evidence": [], "reasoning": "Normal"}\n```' - _llm_clients.llm_mini.ainvoke = AsyncMock(return_value=llm_response) + _mock_classifier_llm(return_value=llm_response) result = await classifier_mod.classify_user_purpose('user1') assert result['misuse_score'] == pytest.approx(0.1) @@ -215,7 +223,7 @@ async def test_clamps_score_to_valid_range(self): llm_response.content = json.dumps( {'misuse_score': 1.5, 'usage_type': 'none', 'confidence': -0.2, 'evidence': [], 'reasoning': ''} ) - _llm_clients.llm_mini.ainvoke = AsyncMock(return_value=llm_response) + _mock_classifier_llm(return_value=llm_response) result = await classifier_mod.classify_user_purpose('user1') assert result['misuse_score'] == 1.0 @@ -237,7 +245,7 @@ async def test_returns_default_on_json_parse_error(self): llm_response = MagicMock() llm_response.content = 'This is not JSON at all' - _llm_clients.llm_mini.ainvoke = AsyncMock(return_value=llm_response) + _mock_classifier_llm(return_value=llm_response) result = await classifier_mod.classify_user_purpose('user1') assert result['misuse_score'] == 0.0 @@ -256,7 +264,7 @@ async def test_returns_default_on_llm_error(self): 'created_at': now, } ] - _llm_clients.llm_mini.ainvoke = AsyncMock(side_effect=Exception('LLM timeout')) + _mock_classifier_llm(side_effect=Exception('LLM timeout')) result = await classifier_mod.classify_user_purpose('user1') assert result['misuse_score'] == 0.0 diff --git a/backend/tests/unit/test_firestore_read_ops_cache.py b/backend/tests/unit/test_firestore_read_ops_cache.py index 21d6c77edca..6235c5c279c 100644 --- a/backend/tests/unit/test_firestore_read_ops_cache.py +++ b/backend/tests/unit/test_firestore_read_ops_cache.py @@ -653,7 +653,7 @@ def test_schedule_completed_calls_invalidation(self): idx_scheduled = source.find("Scheduled upgrade completed for user") assert idx_scheduled > 0 # Find the invalidation call before the log line (it's called right after update) - section = source[idx_scheduled - 200 : idx_scheduled] + section = source[idx_scheduled - 500 : idx_scheduled] assert 'set_credits_invalidation_signal(uid)' in section def test_schedule_canceled_calls_invalidation(self): diff --git a/backend/tests/unit/test_geocoding_cache.py b/backend/tests/unit/test_geocoding_cache.py index 8e67f2b61cb..e89570432e6 100644 --- a/backend/tests/unit/test_geocoding_cache.py +++ b/backend/tests/unit/test_geocoding_cache.py @@ -31,6 +31,7 @@ sys.modules["utils.http_client"] = _http_mod _http_mod.get_maps_client = MagicMock() _http_mod.get_webhook_client = MagicMock() +_http_mod.get_maps_semaphore = MagicMock() from models.geolocation import Geolocation from utils.conversations.location import get_google_maps_location @@ -44,7 +45,7 @@ def test_3_decimal_rounding(self): # 37.78512 -> 37.785, -122.40932 -> -122.409 with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = {"status": "OK", "results": []} mock_req.get.return_value = mock_resp @@ -80,7 +81,7 @@ def test_cache_hit_returns_geolocation(self): } with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = json.dumps(cached) - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: result = get_google_maps_location(37.78512, -122.40932) # Should NOT call Google API @@ -96,7 +97,7 @@ def test_cache_hit_no_api_key_needed(self): with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = json.dumps(cached) with patch.dict("os.environ", {}, clear=True): - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: result = get_google_maps_location(37.785, -122.409) mock_req.get.assert_not_called() assert result is not None @@ -118,7 +119,7 @@ def test_cache_miss_calls_api_and_caches(self): } with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = api_response mock_req.get.return_value = mock_resp @@ -139,7 +140,7 @@ def test_cache_miss_calls_api_and_caches(self): def test_api_no_results_returns_none(self): with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = {"status": "OK", "results": []} mock_req.get.return_value = mock_resp @@ -165,7 +166,7 @@ def test_redis_read_failure_falls_through_to_api(self): } with patch("utils.conversations.location.r") as mock_r: mock_r.get.side_effect = ConnectionError("Redis down") - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = api_response mock_req.get.return_value = mock_resp @@ -191,7 +192,7 @@ def test_redis_write_failure_still_returns_result(self): with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None mock_r.set.side_effect = ConnectionError("Redis down") - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = api_response mock_req.get.return_value = mock_resp @@ -211,7 +212,7 @@ def test_api_status_not_ok_returns_none(self): """Non-OK status (e.g. ZERO_RESULTS, OVER_QUERY_LIMIT) returns None.""" with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = {"status": "ZERO_RESULTS", "results": []} mock_req.get.return_value = mock_resp @@ -225,7 +226,7 @@ def test_missing_place_id_returns_none(self): """Result with no place_id returns None.""" with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = { "status": "OK", @@ -241,7 +242,7 @@ def test_missing_place_id_key_returns_none(self): """Result with no place_id key at all returns None.""" with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = { "status": "OK", @@ -257,7 +258,7 @@ def test_empty_types_gives_none_location_type(self): """Result with no types gives location_type=None.""" with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = { "status": "OK", @@ -274,7 +275,7 @@ def test_missing_types_key_gives_none_location_type(self): """Result with no 'types' key at all gives location_type=None.""" with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = None - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = { "status": "OK", @@ -305,7 +306,7 @@ def test_invalid_json_falls_through_to_api(self): } with patch("utils.conversations.location.r") as mock_r: mock_r.get.return_value = "not-valid-json{{" - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = api_response mock_req.get.return_value = mock_resp @@ -332,7 +333,7 @@ def test_schema_mismatch_falls_through_to_api(self): with patch("utils.conversations.location.r") as mock_r: # Missing required 'latitude' and 'longitude' fields mock_r.get.return_value = json.dumps({"bad_field": "bad_value"}) - with patch("utils.conversations.location.requests") as mock_req: + with patch("utils.conversations.location.httpx") as mock_req: mock_resp = MagicMock() mock_resp.json.return_value = api_response mock_req.get.return_value = mock_resp diff --git a/backend/tests/unit/test_kg_user_type_mismatch.py b/backend/tests/unit/test_kg_user_type_mismatch.py index 0d20aa3273c..386c8611d81 100644 --- a/backend/tests/unit/test_kg_user_type_mismatch.py +++ b/backend/tests/unit/test_kg_user_type_mismatch.py @@ -63,6 +63,9 @@ def _stub_module(name: str) -> types.ModuleType: "upsert_memory_vector", "delete_memory_vector", "upsert_vector2", + "find_similar_action_items", + "upsert_action_item_vectors_batch", + "delete_action_item_vectors_batch", "update_vector_metadata", ]: setattr(vector_db_mod, attr, MagicMock()) @@ -98,6 +101,7 @@ def _stub_module(name: str) -> types.ModuleType: for name in [ "utils.apps", "utils.analytics", + "utils.subscription", "utils.llm.memories", "utils.llm.conversation_processing", "utils.llm.external_integrations", @@ -117,12 +121,15 @@ def _stub_module(name: str) -> types.ModuleType: sys.modules[name] = types.ModuleType(name) utils_apps = sys.modules["utils.apps"] -for attr in ["get_available_apps", "update_personas_async", "sync_update_persona_prompt"]: +for attr in ["get_available_apps", "update_personas_async", "update_persona_prompt", "sync_update_persona_prompt"]: setattr(utils_apps, attr, MagicMock()) utils_analytics = sys.modules["utils.analytics"] utils_analytics.record_usage = MagicMock() +utils_subscription = sys.modules["utils.subscription"] +utils_subscription.is_trial_paywalled = MagicMock(return_value=False) + llm_memories = sys.modules["utils.llm.memories"] for attr in ["resolve_memory_conflict", "extract_memories_from_text", "new_memories_extractor"]: setattr(llm_memories, attr, MagicMock()) diff --git a/backend/tests/unit/test_llm_usage_endpoints.py b/backend/tests/unit/test_llm_usage_endpoints.py index c7b1d6bb7a4..09ff34d496c 100644 --- a/backend/tests/unit/test_llm_usage_endpoints.py +++ b/backend/tests/unit/test_llm_usage_endpoints.py @@ -109,6 +109,8 @@ def _passthrough_decorator(func): "adapt_plans_for_legacy_client", "legacy_plan_features", "is_paid_plan", + "is_trial_paywalled", + "clear_trial_paywall_cache", ]: setattr(subscription_mod, attr, MagicMock()) subscription_mod.get_paid_plan_definitions = MagicMock(return_value=[]) diff --git a/backend/tests/unit/test_lock_bypass_fixes.py b/backend/tests/unit/test_lock_bypass_fixes.py index 04449e2f8c2..cac558fe0be 100644 --- a/backend/tests/unit/test_lock_bypass_fixes.py +++ b/backend/tests/unit/test_lock_bypass_fixes.py @@ -740,7 +740,9 @@ def test_scheduled_summary_excludes_locked(self): unlocked_conv = _make_conversation(locked=False, conversation_id='conv-2') conversations_db.get_conversations = MagicMock(return_value=[locked_conv, unlocked_conv]) - with patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True): + with patch('utils.other.notifications.try_acquire_daily_summary_lock', return_value=True), patch( + 'utils.other.notifications.is_trial_paywalled', return_value=False + ): with patch( 'utils.other.notifications.generate_comprehensive_daily_summary', return_value={'headline': 'Test', 'day_emoji': '๐Ÿ“…', 'overview': 'ok'}, @@ -1259,9 +1261,11 @@ def test_suggest_goal_filters_locked_memories(self): mock_track.__enter__ = MagicMock(return_value=None) mock_track.__exit__ = MagicMock(return_value=False) + mock_llm = MagicMock() + mock_llm.invoke.return_value = mock_llm_response + with patch('utils.llm.goals.track_usage', return_value=mock_track): - with patch('utils.llm.goals.llm_mini') as mock_llm: - mock_llm.invoke.return_value = mock_llm_response + with patch('utils.llm.goals.get_llm', return_value=mock_llm): from utils.llm.goals import suggest_goal diff --git a/backend/tests/unit/test_mentor_notifications.py b/backend/tests/unit/test_mentor_notifications.py index 3adb6842e0d..9f3e2b1492f 100644 --- a/backend/tests/unit/test_mentor_notifications.py +++ b/backend/tests/unit/test_mentor_notifications.py @@ -111,6 +111,8 @@ def _stub_module(name: str) -> types.ModuleType: llm_mod = _stub_module("utils.llm") if not hasattr(llm_mod, '__path__'): llm_mod.__path__ = [os.path.join(_backend_root, "utils", "llm")] +subscription_mod = _stub_module("utils.subscription") +subscription_mod.is_trial_paywalled = MagicMock(return_value=False) tracker_mod = _stub_module("utils.llm.usage_tracker") tracker_mod.get_usage_callback = MagicMock(return_value=[]) tracker_mod.track_usage = MagicMock() diff --git a/backend/tests/unit/test_process_conversation_usage_context.py b/backend/tests/unit/test_process_conversation_usage_context.py index 4f4695059d5..8d48cf7f234 100644 --- a/backend/tests/unit/test_process_conversation_usage_context.py +++ b/backend/tests/unit/test_process_conversation_usage_context.py @@ -53,6 +53,7 @@ def _stub_module(name: str) -> types.ModuleType: vector_db_mod = sys.modules["database.vector_db"] for attr in [ "find_similar_memories", + "find_similar_action_items", "upsert_memory_vector", "delete_memory_vector", "upsert_vector2", @@ -82,6 +83,7 @@ def _stub_module(name: str) -> types.ModuleType: for name in [ "utils.apps", "utils.analytics", + "utils.subscription", "utils.llm.memories", "utils.llm.conversation_processing", "utils.llm.external_integrations", @@ -106,6 +108,9 @@ def _stub_module(name: str) -> types.ModuleType: utils_analytics = sys.modules["utils.analytics"] utils_analytics.record_usage = MagicMock() +utils_subscription = sys.modules["utils.subscription"] +utils_subscription.is_trial_paywalled = MagicMock(return_value=False) + llm_memories = sys.modules["utils.llm.memories"] for attr in ["resolve_memory_conflict", "extract_memories_from_text", "new_memories_extractor"]: setattr(llm_memories, attr, MagicMock()) diff --git a/backend/tests/unit/test_prompt_cache_integration.py b/backend/tests/unit/test_prompt_cache_integration.py index 554ff5dc520..6055a227700 100644 --- a/backend/tests/unit/test_prompt_cache_integration.py +++ b/backend/tests/unit/test_prompt_cache_integration.py @@ -89,6 +89,7 @@ def _stub_module(name: str) -> types.ModuleType: clients_mod = _stub_module("utils.llm.clients") clients_mod.get_llm = MagicMock(return_value=mock_llm) clients_mod.get_model = MagicMock(return_value="gpt-4.1-mini") +clients_mod.get_openai_agent_llm = MagicMock(return_value=mock_llm) clients_mod.llm_mini = mock_llm clients_mod.llm_mini_stream = mock_llm clients_mod.llm_medium = mock_llm @@ -113,6 +114,8 @@ def _stub_module(name: str) -> types.ModuleType: tracker_mod.reset_usage_context = MagicMock() tracker_mod.Features = MagicMock() tracker_mod.track_usage = MagicMock() +byok_errors_mod = _stub_module("utils.llm.byok_errors") +byok_errors_mod.handle_llm_error = MagicMock() # --- LLMs/memory stubs --- llms_mod = _stub_module("utils.llms") @@ -611,8 +614,13 @@ def __init__(self, **kwargs): source = source.replace("from langchain_openai import ChatOpenAI, OpenAIEmbeddings", "") source = source.replace("import tiktoken", "") source = source.replace("import anthropic", "") + source = source.replace("from langchain_core.callbacks import BaseCallbackHandler", "") + source = source.replace("from langchain_core.language_models import BaseChatModel", "") source = source.replace("from langchain_core.output_parsers import PydanticOutputParser", "") - source = source.replace("from models.conversation import Structured", "") + source = source.replace("from langchain_google_genai import ChatGoogleGenerativeAI", "") + source = source.replace("from models.structured import Structured", "") + source = source.replace("from utils.byok import get_byok_key", "") + source = source.replace("from utils.llm.byok_errors import handle_llm_error", "") source = source.replace("from utils.llm.usage_tracker import get_usage_callback", "") # Create a fake anthropic module with AsyncAnthropic @@ -623,10 +631,15 @@ def __init__(self, **kwargs): "os": os, "ChatOpenAI": FakeChatOpenAI, "OpenAIEmbeddings": FakeOpenAIEmbeddings, + "ChatGoogleGenerativeAI": FakeChatOpenAI, + "BaseCallbackHandler": object, + "BaseChatModel": object, "tiktoken": fake_tiktoken, "anthropic": fake_anthropic, "PydanticOutputParser": MagicMock(), "Structured": MagicMock(), + "get_byok_key": MagicMock(return_value=None), + "handle_llm_error": MagicMock(), "get_usage_callback": MagicMock(return_value=[]), "List": list, } diff --git a/backend/tests/unit/test_rate_limiting.py b/backend/tests/unit/test_rate_limiting.py index c358050d7cd..90c736968f9 100644 --- a/backend/tests/unit/test_rate_limiting.py +++ b/backend/tests/unit/test_rate_limiting.py @@ -17,6 +17,7 @@ 'google.cloud.firestore', 'database.redis_db', 'database.auth', + 'database.users', ]: if mod_name not in sys.modules: sys.modules[mod_name] = types.ModuleType(mod_name) @@ -43,6 +44,7 @@ class _RedisError(Exception): redis_db_stub = sys.modules['database.redis_db'] redis_db_stub._RATE_LIMIT_LUA = MagicMock(return_value=[1, 3600]) redis_db_stub.try_acquire_listen_lock = MagicMock(return_value=True) +sys.modules['database.users'].record_user_platform = MagicMock() def _check_rate_limit(key, policy, max_requests, window): @@ -479,18 +481,23 @@ def setUpClass(cls): cls.mock_lua = mock_lua_callable + def _rate_limit_lua_source(self): + for call in self.real_module.r.register_script.call_args_list: + lua_source = call.args[0] + if 'local key = KEYS[1]' in lua_source and 'INCR' in lua_source: + return lua_source + self.fail('Rate-limit Lua script was not registered') + def test_lua_script_has_ttl_self_heal(self): """Verify the registered Lua script contains TTL self-heal logic.""" - # register_script was called with the Lua source - call_args = self.real_module.r.register_script.call_args - lua_source = call_args[0][0] + lua_source = self._rate_limit_lua_source() self.assertIn('TTL', lua_source) self.assertIn('ttl < 0', lua_source) self.assertIn('EXPIRE', lua_source) def test_lua_script_uses_incr(self): """Verify Lua uses INCR for atomic counter.""" - lua_source = self.real_module.r.register_script.call_args[0][0] + lua_source = self._rate_limit_lua_source() self.assertIn('INCR', lua_source) def test_real_check_rate_limit_key_format(self): diff --git a/backend/tests/unit/test_realtime_integrations_usage_tracking.py b/backend/tests/unit/test_realtime_integrations_usage_tracking.py index cfc6e4875af..fd9bd1d878a 100644 --- a/backend/tests/unit/test_realtime_integrations_usage_tracking.py +++ b/backend/tests/unit/test_realtime_integrations_usage_tracking.py @@ -41,10 +41,12 @@ def _stub_module(name: str) -> types.ModuleType: "vector_db", "apps", "llm_usage", + "user_usage", "_client", "chat", "goals", "auth", + "announcements", ]: mod = _stub_module(f"database.{submodule}") setattr(database_mod, submodule, mod) @@ -75,6 +77,14 @@ def _stub_module(name: str) -> types.ModuleType: redis_mod.incr_daily_notification_count = MagicMock() redis_mod.get_daily_notification_count = MagicMock(return_value=0) redis_mod.get_proactive_noti_sent_at_ttl = MagicMock(return_value=0) +redis_mod.delete_generic_cache = MagicMock() + +user_usage_mod = sys.modules["database.user_usage"] +user_usage_mod.get_monthly_chat_usage = MagicMock(return_value={}) +user_usage_mod.get_monthly_usage_stats_since = MagicMock(return_value={}) + +announcements_mod = sys.modules["database.announcements"] +announcements_mod.compare_versions = MagicMock(return_value=0) goals_mod = sys.modules["database.goals"] goals_mod.get_user_goals = MagicMock(return_value=[]) diff --git a/backend/tests/unit/test_speaker_sample_migration.py b/backend/tests/unit/test_speaker_sample_migration.py index a0ad4456dc7..3fe990520d7 100644 --- a/backend/tests/unit/test_speaker_sample_migration.py +++ b/backend/tests/unit/test_speaker_sample_migration.py @@ -14,6 +14,8 @@ sys.modules["utils.other.storage"] = MagicMock() sys.modules["utils.stt.pre_recorded"] = MagicMock() sys.modules["utils.stt.speaker_embedding"] = MagicMock() +sys.modules["firebase_admin"] = MagicMock() +sys.modules["firebase_admin.auth"] = MagicMock() sys.modules["stripe"] = MagicMock() diff --git a/backend/tests/unit/test_storage_opus_encoding.py b/backend/tests/unit/test_storage_opus_encoding.py index 5d5b6f43de5..f24315deb2b 100644 --- a/backend/tests/unit/test_storage_opus_encoding.py +++ b/backend/tests/unit/test_storage_opus_encoding.py @@ -13,6 +13,7 @@ import os import struct import sys +import types from unittest.mock import MagicMock, patch import pytest @@ -21,6 +22,9 @@ # Mock heavy dependencies at sys.modules level before importing storage sys.modules.setdefault("database._client", MagicMock()) +subscription_mod = types.ModuleType("utils.subscription") +subscription_mod.get_default_basic_subscription = MagicMock() +sys.modules.setdefault("utils.subscription", subscription_mod) _mock_gcs_storage = MagicMock() _mock_gcs_client_instance = MagicMock() diff --git a/backend/tests/unit/test_storage_upload_audio_chunk_data_protection.py b/backend/tests/unit/test_storage_upload_audio_chunk_data_protection.py index 9fedab43864..247bb26c49f 100644 --- a/backend/tests/unit/test_storage_upload_audio_chunk_data_protection.py +++ b/backend/tests/unit/test_storage_upload_audio_chunk_data_protection.py @@ -7,6 +7,7 @@ import os import sys +import types from unittest.mock import MagicMock, patch, call import pytest @@ -15,6 +16,9 @@ # Mock heavy dependencies at sys.modules level before importing storage sys.modules.setdefault("database._client", MagicMock()) +subscription_mod = types.ModuleType("utils.subscription") +subscription_mod.get_default_basic_subscription = MagicMock() +sys.modules.setdefault("utils.subscription", subscription_mod) # We need the real storage module but with mocked GCS client _mock_gcs_storage = MagicMock() @@ -73,29 +77,33 @@ def test_falls_back_to_db_when_level_not_provided(self, mock_users_db): @patch.object(storage_mod, 'users_db') def test_standard_level_uploads_unencrypted(self, mock_users_db): - """Standard protection level should upload .bin (no encryption).""" + """Standard protection level should upload encoded .opus data without encryption.""" _, mock_blob = self._setup_mock_bucket() + raw_chunk = b'\x00' * 100 path = storage_mod.upload_audio_chunk( - chunk_data=b'\x00' * 100, + chunk_data=raw_chunk, uid='test-uid', conversation_id='conv-1', timestamp=1234567890.123, data_protection_level='standard', ) - assert path.endswith('.bin') + assert path.endswith('.opus') mock_blob.upload_from_string.assert_called_once() + upload_data = mock_blob.upload_from_string.call_args.args[0] + assert upload_data != raw_chunk @patch.object(storage_mod, 'encryption') @patch.object(storage_mod, 'users_db') def test_enhanced_level_uploads_encrypted(self, mock_users_db, mock_encryption): - """Enhanced protection level should encrypt and upload .enc.""" + """Enhanced protection level should encrypt encoded Opus data and upload .enc.""" _, mock_blob = self._setup_mock_bucket() mock_encryption.encrypt_audio_chunk.return_value = b'\x01' * 120 + raw_chunk = b'\x00' * 100 path = storage_mod.upload_audio_chunk( - chunk_data=b'\x00' * 100, + chunk_data=raw_chunk, uid='test-uid', conversation_id='conv-1', timestamp=1234567890.123, @@ -103,7 +111,10 @@ def test_enhanced_level_uploads_encrypted(self, mock_users_db, mock_encryption): ) assert path.endswith('.enc') - mock_encryption.encrypt_audio_chunk.assert_called_once_with(b'\x00' * 100, 'test-uid') + mock_encryption.encrypt_audio_chunk.assert_called_once() + encrypted_input, encrypted_uid = mock_encryption.encrypt_audio_chunk.call_args.args + assert encrypted_input != raw_chunk + assert encrypted_uid == 'test-uid' @patch.object(storage_mod, 'users_db') def test_explicit_none_falls_back_to_db(self, mock_users_db): diff --git a/backend/tests/unit/test_subscription_plans.py b/backend/tests/unit/test_subscription_plans.py index c7db2277813..671bebca11b 100644 --- a/backend/tests/unit/test_subscription_plans.py +++ b/backend/tests/unit/test_subscription_plans.py @@ -1,8 +1,12 @@ import sys import types +_announcements_mod = types.ModuleType("database.announcements") +_announcements_mod.compare_versions = lambda a, b: 0 + sys.modules.setdefault("database.users", types.SimpleNamespace()) sys.modules.setdefault("database.user_usage", types.SimpleNamespace()) +sys.modules.setdefault("database.announcements", _announcements_mod) from models.users import PlanType from utils.subscription import get_plan_features, get_plan_limits, get_plan_type_from_price_id, is_paid_plan diff --git a/backend/tests/unit/test_subscription_restructure.py b/backend/tests/unit/test_subscription_restructure.py index f2c60f6d9b7..f2c4f44e12c 100644 --- a/backend/tests/unit/test_subscription_restructure.py +++ b/backend/tests/unit/test_subscription_restructure.py @@ -19,6 +19,7 @@ def _compare_versions(a, b): _announcements_mod._compare_versions = _compare_versions +_announcements_mod.compare_versions = _compare_versions sys.modules.setdefault("database.users", types.SimpleNamespace()) sys.modules.setdefault("database.user_usage", types.SimpleNamespace()) sys.modules.setdefault("database.announcements", _announcements_mod) diff --git a/backend/tests/unit/test_sync_fair_use_gate.py b/backend/tests/unit/test_sync_fair_use_gate.py index 7ad8f01c61e..8748bb1ab4b 100644 --- a/backend/tests/unit/test_sync_fair_use_gate.py +++ b/backend/tests/unit/test_sync_fair_use_gate.py @@ -17,11 +17,15 @@ 'database.user_usage', 'database.conversations', 'firebase_admin', + 'firebase_admin.auth', 'firebase_admin.messaging', ]: if mod_name not in sys.modules: sys.modules[mod_name] = ModuleType(mod_name) +sys.modules['firebase_admin'].auth = sys.modules['firebase_admin.auth'] +sys.modules['firebase_admin'].messaging = sys.modules['firebase_admin.messaging'] + # Stub redis_db.r _mock_redis = MagicMock() sys.modules['database.redis_db'].r = _mock_redis @@ -256,9 +260,13 @@ def _read_sync_source(): return f.read() def test_no_402_block(self): - """sync.py must not raise 402 (lock instead of block).""" + """Sync upload endpoints must lock credit-exhausted conversations instead of raising 402.""" source = self._read_sync_source() - assert 'status_code=402' not in source + for route in ['@router.post("/v1/sync-local-files"', '@router.post("/v2/sync-local-files"']: + start = source.index(route) + next_route = source.find('\n@router.', start + 1) + endpoint_source = source[start:] if next_route == -1 else source[start:next_route] + assert 'status_code=402' not in endpoint_source def test_should_lock_flag_exists(self): """sync.py must use should_lock flag for credit-exhausted locking.""" diff --git a/backend/tests/unit/test_sync_opus_decode.py b/backend/tests/unit/test_sync_opus_decode.py index 8563f274bfc..7284a93d74b 100644 --- a/backend/tests/unit/test_sync_opus_decode.py +++ b/backend/tests/unit/test_sync_opus_decode.py @@ -52,6 +52,9 @@ 'utils.subscription', 'utils.log_sanitizer', 'utils.executors', + 'utils.speaker_assignment', + 'utils.speaker_identification', + 'utils.stt.speaker_embedding', 'pydub', 'numpy', 'httpx', @@ -64,16 +67,20 @@ sys.modules['database._client'].db = MagicMock() sys.modules['utils.log_sanitizer'].sanitize = lambda x: x sys.modules['utils.log_sanitizer'].sanitize_pii = lambda x: x +sys.modules['utils.speaker_assignment'].process_speaker_assigned_segments = MagicMock() +sys.modules['utils.speaker_identification'].detect_speaker_from_text = MagicMock(return_value=None) +sys.modules['utils.stt.speaker_embedding'].extract_embedding_from_bytes = MagicMock(return_value=None) +sys.modules['utils.stt.speaker_embedding'].compare_embeddings = MagicMock(return_value=0) +sys.modules['utils.stt.speaker_embedding'].SPEAKER_MATCH_THRESHOLD = 0.75 from routers.sync import decode_opus_file_to_wav, decode_files_to_wav # noqa: E402 - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- #: One frame of fake Opus-encoded bytes (content doesn't matter โ€” decoder is mocked). -FAKE_OPUS_FRAME = b'\xAA\xBB\xCC' * 34 # 102 bytes +FAKE_OPUS_FRAME = b'\xaa\xbb\xcc' * 34 # 102 bytes #: PCM returned by the mocked decoder: 320 bytes = 160 mono samples at 16-bit. #: 100 such frames = 16 000 samples = 1.0 s at 16 kHz. @@ -279,7 +286,7 @@ def test_truncated_frame_data_stops_cleanly(self): f.write(FAKE_OPUS_FRAME) # Write a length prefix claiming 1000 bytes but only supply 10 f.write(struct.pack(' anthropic.AsyncAnthropic:') + end = source.find('\n def ', start + 1) + func_body = source[start:end] + assert 'anthropic.AsyncAnthropic(timeout=120.0, max_retries=1)' in func_body + assert '_default_anthropic_client = anthropic.AsyncAnthropic' not in source def test_anthropic_byok_has_timeout(self): source = self._read_clients_source() diff --git a/backend/tests/unit/test_task_sharing.py b/backend/tests/unit/test_task_sharing.py index 969c719200f..a7df888463c 100644 --- a/backend/tests/unit/test_task_sharing.py +++ b/backend/tests/unit/test_task_sharing.py @@ -80,6 +80,7 @@ def _stub_module(name): notif_mod.send_action_item_data_message = MagicMock() notif_mod.send_action_item_update_message = MagicMock() notif_mod.send_action_item_deletion_message = MagicMock() +notif_mod.send_action_items_batch_deletion_message = MagicMock() _stub_module("utils.task_sync") sys.modules["utils.task_sync"].auto_sync_action_item = MagicMock() diff --git a/backend/tests/unit/test_thread_join_elimination.py b/backend/tests/unit/test_thread_join_elimination.py index e7d173b1ad5..7979b858f03 100644 --- a/backend/tests/unit/test_thread_join_elimination.py +++ b/backend/tests/unit/test_thread_join_elimination.py @@ -236,8 +236,8 @@ async def test_async_vad_local_fallback(self): import importlib mod = importlib.import_module('utils.stt.vad') - # _local_vad should be called via run_in_executor(critical_executor, ...) - with patch.object(mod, '_local_vad', return_value=[]) as mock_local: + # _run_file_vad should be called via run_in_executor(critical_executor, ...) + with patch.object(mod, '_run_file_vad', return_value=[]) as mock_local: result = await mod.async_vad_is_empty('/tmp/nonexistent.wav') mock_local.assert_called_once_with('/tmp/nonexistent.wav') assert result is True # empty segments = True diff --git a/backend/tests/unit/test_users_add_sample_transaction.py b/backend/tests/unit/test_users_add_sample_transaction.py index 59cc08006e5..98e16d69f5b 100644 --- a/backend/tests/unit/test_users_add_sample_transaction.py +++ b/backend/tests/unit/test_users_add_sample_transaction.py @@ -10,6 +10,8 @@ # Mock the database client to avoid needing GCP credentials sys.modules["database._client"] = MagicMock() +sys.modules["firebase_admin"] = MagicMock() +sys.modules["firebase_admin.auth"] = MagicMock() sys.modules["stripe"] = MagicMock() diff --git a/backend/tests/unit/test_vad_onnx.py b/backend/tests/unit/test_vad_onnx.py index d3d6d6f1b1e..c87cd24a0ad 100644 --- a/backend/tests/unit/test_vad_onnx.py +++ b/backend/tests/unit/test_vad_onnx.py @@ -151,7 +151,7 @@ class TestVadIsEmptyHostedSuccess: """vad_is_empty() when HOSTED_VAD_API_URL is set and succeeds.""" @patch.dict(os.environ, {'HOSTED_VAD_API_URL': 'http://vad.test/v1/vad'}) - @patch('utils.stt.vad.requests.post') + @patch('utils.stt.vad.httpx.post') @patch.object(vad, 'redis_db') def test_hosted_returns_segments(self, mock_redis, mock_post, tmp_wav_dir): """Hosted VAD returns segments โ€” vad_is_empty returns False (not empty).""" @@ -169,7 +169,7 @@ def test_hosted_returns_segments(self, mock_redis, mock_post, tmp_wav_dir): mock_post.assert_called_once() @patch.dict(os.environ, {'HOSTED_VAD_API_URL': 'http://vad.test/v1/vad'}) - @patch('utils.stt.vad.requests.post') + @patch('utils.stt.vad.httpx.post') @patch.object(vad, 'redis_db') def test_hosted_returns_empty(self, mock_redis, mock_post, tmp_wav_dir): """Hosted VAD returns empty list โ€” vad_is_empty returns True.""" @@ -185,7 +185,7 @@ def test_hosted_returns_empty(self, mock_redis, mock_post, tmp_wav_dir): assert result is True @patch.dict(os.environ, {'HOSTED_VAD_API_URL': 'http://vad.test/v1/vad'}) - @patch('utils.stt.vad.requests.post') + @patch('utils.stt.vad.httpx.post') @patch.object(vad, 'redis_db') def test_hosted_return_segments_mode(self, mock_redis, mock_post, tmp_wav_dir): """return_segments=True returns the hosted segment list directly.""" @@ -211,7 +211,7 @@ class TestVadIsEmptyFallback: """vad_is_empty() falls back to local ONNX when hosted VAD fails.""" @patch.dict(os.environ, {'HOSTED_VAD_API_URL': 'http://vad.test/v1/vad'}) - @patch('utils.stt.vad.requests.post', side_effect=Exception('connection refused')) + @patch('utils.stt.vad.httpx.post', side_effect=Exception('connection refused')) @patch('utils.stt.vad._run_file_vad') @patch.object(vad, 'redis_db') def test_hosted_exception_falls_back(self, mock_redis, mock_local, mock_post, tmp_wav_dir): @@ -226,7 +226,7 @@ def test_hosted_exception_falls_back(self, mock_redis, mock_local, mock_post, tm mock_local.assert_called_once_with(wav_path) @patch.dict(os.environ, {'HOSTED_VAD_API_URL': 'http://vad.test/v1/vad'}) - @patch('utils.stt.vad.requests.post') + @patch('utils.stt.vad.httpx.post') @patch('utils.stt.vad._run_file_vad') @patch.object(vad, 'redis_db') def test_hosted_http_error_falls_back(self, mock_redis, mock_local, mock_post, tmp_wav_dir): diff --git a/backend/tests/unit/test_voice_message_language.py b/backend/tests/unit/test_voice_message_language.py index f9eea84fdf9..64f8c1e6b02 100644 --- a/backend/tests/unit/test_voice_message_language.py +++ b/backend/tests/unit/test_voice_message_language.py @@ -7,6 +7,8 @@ from unittest.mock import MagicMock sys.modules["database._client"] = MagicMock() +sys.modules["firebase_admin"] = MagicMock() +sys.modules["firebase_admin.auth"] = MagicMock() sys.modules["stripe"] = MagicMock() sys.modules["database.chat"] = MagicMock() sys.modules["database.notifications"] = MagicMock() diff --git a/backend/tests/unit/test_ws_auth_handshake.py b/backend/tests/unit/test_ws_auth_handshake.py index 91d14b25209..77353cb1950 100644 --- a/backend/tests/unit/test_ws_auth_handshake.py +++ b/backend/tests/unit/test_ws_auth_handshake.py @@ -7,9 +7,26 @@ """ import asyncio +import sys +import types import unittest from unittest.mock import patch, MagicMock +db_client_mod = types.ModuleType("database._client") +db_client_mod.db = MagicMock() +db_client_mod.document_id_from_seed = MagicMock(return_value="doc-id") +sys.modules.setdefault("database._client", db_client_mod) + +redis_db_mod = types.ModuleType("database.redis_db") +redis_db_mod.check_rate_limit = MagicMock(return_value=(True, 0, 0)) +redis_db_mod.try_acquire_listen_lock = MagicMock(return_value=True) +sys.modules.setdefault("database.redis_db", redis_db_mod) + +users_db_mod = types.ModuleType("database.users") +users_db_mod.record_user_platform = MagicMock() +users_db_mod.get_byok_state = MagicMock(return_value={}) +sys.modules.setdefault("database.users", users_db_mod) + from fastapi import FastAPI, WebSocket, WebSocketException, Depends from fastapi.testclient import TestClient from firebase_admin.auth import InvalidIdTokenError diff --git a/backend/utils/byok.py b/backend/utils/byok.py index 39d355273dd..6c3c9fdfa1d 100644 --- a/backend/utils/byok.py +++ b/backend/utils/byok.py @@ -73,6 +73,7 @@ def invalidate_byok_state_cache(uid: str) -> None: # Keys for the current request, if the client supplied them. # Default is None (not {}) to avoid sharing a mutable object across contexts. _byok_ctx: ContextVar[Optional[Dict[str, str]]] = ContextVar('byok_keys', default=None) +_byok_uid_ctx: ContextVar[Optional[str]] = ContextVar('byok_uid', default=None) def get_byok_keys() -> Dict[str, str]: @@ -87,6 +88,16 @@ def get_byok_key(provider: str) -> Optional[str]: return keys.get(provider) +def get_byok_uid() -> Optional[str]: + """Return the authenticated uid for the current request, when known.""" + return _byok_uid_ctx.get() + + +def set_byok_uid(uid: Optional[str]) -> None: + """Attach the authenticated uid to the current request context.""" + _byok_uid_ctx.set(uid) + + def has_byok_keys() -> bool: """True if the current request carries at least one BYOK header.""" keys = _byok_ctx.get() @@ -127,10 +138,12 @@ async def dispatch(self, request: Request, call_next): if value: keys[provider] = value token = _byok_ctx.set(keys) + uid_token = _byok_uid_ctx.set(None) try: return await call_next(request) finally: _byok_ctx.reset(token) + _byok_uid_ctx.reset(uid_token) # --------------------------------------------------------------------------- @@ -203,6 +216,7 @@ def validate_byok_request(uid: str) -> None: if error: logger.warning('BYOK validation failed uid=%s: %s', uid, error) raise HTTPException(status_code=403, detail=error) + set_byok_uid(uid) def validate_byok_websocket(uid: str) -> Optional[str]: @@ -215,4 +229,6 @@ def validate_byok_websocket(uid: str) -> Optional[str]: error = _check_byok_validity(uid) if error: logger.warning('BYOK WS validation failed uid=%s: %s', uid, error) + else: + set_byok_uid(uid) return error diff --git a/backend/utils/llm/byok_errors.py b/backend/utils/llm/byok_errors.py new file mode 100644 index 00000000000..24829ce7da4 --- /dev/null +++ b/backend/utils/llm/byok_errors.py @@ -0,0 +1,168 @@ +import logging +from typing import Optional + +from firebase_admin import messaging + +try: + import database.notifications as notification_db +except ImportError: + notification_db = None + +try: + from database.redis_db import try_acquire_byok_llm_error_notification_lock +except ImportError: + + def try_acquire_byok_llm_error_notification_lock( + uid: str, provider: str, reason: str, ttl: int = 60 * 60 * 24 + ) -> bool: + logger.error('BYOK LLM notification lock unavailable uid=%s provider=%s reason=%s', uid, provider, reason) + return False + + +from utils.byok import get_byok_key, get_byok_uid +from utils.log_sanitizer import sanitize + +logger = logging.getLogger(__name__) + +_PERMANENT_FAILURE_CODES = frozenset({'UNREGISTERED', 'INVALID_REGISTRATION_TOKEN', 'NOT_FOUND'}) +_QUOTA_ERROR_NAMES = frozenset({'RateLimitError'}) + + +def get_llm_error_source(provider: Optional[str]) -> str: + """Return platform/byok for the current request and provider.""" + if provider and get_byok_key(provider): + return 'byok' + return 'platform' + + +def classify_byok_llm_error(error: Exception) -> Optional[str]: + """Classify user-actionable BYOK failures for logging and notification.""" + status_code = _get_status_code(error) + error_name = type(error).__name__ + error_text = sanitize(str(error)).lower() + + if status_code == 401 or error_name == 'AuthenticationError': + return 'invalid' + if status_code == 403 or error_name == 'PermissionDeniedError': + return 'permission' + if status_code == 429 or error_name in _QUOTA_ERROR_NAMES: + if 'insufficient_quota' in error_text or 'quota' in error_text: + return 'quota' + return None + + +def handle_llm_error( + error: Exception, + provider: Optional[str], + feature: Optional[str] = None, + model: Optional[str] = None, + operation: str = 'chat', +) -> None: + """Log LLM failures with source context and notify users about broken BYOK keys.""" + source = get_llm_error_source(provider) + reason = classify_byok_llm_error(error) if source == 'byok' else None + uid = get_byok_uid() + status_code = _get_status_code(error) + + logger.error( + 'LLM error source=%s provider=%s feature=%s model=%s operation=%s uid=%s status_code=%s reason=%s ' + 'error_type=%s error=%s', + source, + provider or 'unknown', + feature or 'unknown', + model or 'unknown', + operation, + uid or 'unknown', + status_code or 'unknown', + reason or 'unknown', + type(error).__name__, + sanitize(str(error)), + ) + + if source == 'byok' and uid and provider and reason: + _send_byok_llm_error_notification(uid, provider, reason) + + +def _get_status_code(error: Exception) -> Optional[int]: + status_code = getattr(error, 'status_code', None) + if isinstance(status_code, int): + return status_code + + response = getattr(error, 'response', None) + response_status = getattr(response, 'status_code', None) + if isinstance(response_status, int): + return response_status + return None + + +def _send_byok_llm_error_notification(uid: str, provider: str, reason: str) -> None: + if notification_db is None: + logger.error('BYOK LLM notification database unavailable uid=%s provider=%s reason=%s', uid, provider, reason) + return + + provider_name = provider.capitalize() + if reason == 'quota': + body = f'Your {provider_name} BYOK key appears to be out of quota. Update it to restore AI features.' + elif reason == 'permission': + body = f'Your {provider_name} BYOK key was denied access. Check its project and permissions in Omi settings.' + else: + body = f'Your {provider_name} BYOK key was rejected. Update it in Omi settings to restore AI features.' + + try: + tokens = notification_db.get_all_tokens(uid) + except Exception as e: + logger.error( + 'BYOK LLM notification token lookup failed uid=%s provider=%s reason=%s: %s', uid, provider, reason, e + ) + return + + if not tokens: + logger.info('No tokens found for BYOK LLM notification uid=%s provider=%s reason=%s', uid, provider, reason) + return + + try: + acquired = try_acquire_byok_llm_error_notification_lock(uid, provider, reason) + except Exception as e: + logger.error('BYOK LLM notification lock failed uid=%s provider=%s reason=%s: %s', uid, provider, reason, e) + return + + if not acquired: + logger.info('BYOK LLM notification already sent recently uid=%s provider=%s reason=%s', uid, provider, reason) + return + + notification = messaging.Notification(title='omi', body=body) + data = {'type': 'byok_llm_error', 'provider': provider, 'reason': reason} + messages = [messaging.Message(token=token, notification=notification, data=data) for token in tokens] + + try: + response = messaging.send_each(messages) + except Exception as e: + logger.error('BYOK LLM notification send failed uid=%s provider=%s reason=%s: %s', uid, provider, reason, e) + return + + invalid_tokens = [] + success_count = 0 + for idx, result in enumerate(response.responses): + if result.success: + success_count += 1 + elif result.exception: + error_code = getattr(result.exception, 'code', None) + if error_code in _PERMANENT_FAILURE_CODES: + invalid_tokens.append(tokens[idx]) + else: + logger.error('BYOK LLM notification FCM send failed uid=%s error=%s', uid, result.exception) + + if invalid_tokens: + try: + notification_db.remove_bulk_tokens(invalid_tokens) + except Exception as e: + logger.error('BYOK LLM notification invalid token cleanup failed uid=%s: %s', uid, e) + + logger.info( + 'BYOK LLM notification sent uid=%s provider=%s reason=%s success=%s total=%s', + uid, + provider, + reason, + success_count, + len(tokens), + ) diff --git a/backend/utils/llm/clients.py b/backend/utils/llm/clients.py index 2d73a028f59..2d2fc94a86b 100644 --- a/backend/utils/llm/clients.py +++ b/backend/utils/llm/clients.py @@ -6,6 +6,7 @@ import anthropic import httpx from cachetools import TTLCache +from langchain_core.callbacks import BaseCallbackHandler from langchain_core.language_models import BaseChatModel from langchain_core.output_parsers import PydanticOutputParser from langchain_google_genai import ChatGoogleGenerativeAI @@ -14,12 +15,49 @@ from models.structured import Structured from utils.byok import get_byok_key +from utils.llm.byok_errors import handle_llm_error from utils.llm.usage_tracker import get_usage_callback logger = logging.getLogger(__name__) _usage_callback = get_usage_callback() + +class _LLMErrorCallback(BaseCallbackHandler): + """LangChain callback that tags provider errors with platform/BYOK source.""" + + def __init__(self, provider: str, model: str = '', feature: str = ''): + self.provider = provider + self.model = model + self.feature = feature + + def on_llm_error(self, error: BaseException, **kwargs) -> None: + if isinstance(error, Exception): + handle_llm_error(error, self.provider, feature=self.feature, model=self.model) + + +_llm_error_callbacks: Dict[Tuple[str, str, str], _LLMErrorCallback] = {} + + +def _get_llm_error_callback(provider: str, model: str = '', feature: str = '') -> _LLMErrorCallback: + key = (provider, model, feature) + if key not in _llm_error_callbacks: + _llm_error_callbacks[key] = _LLMErrorCallback(provider, model=model, feature=feature) + return _llm_error_callbacks[key] + + +def _with_llm_callbacks(kwargs: Dict[str, Any], provider: str, model: str = '', feature: str = '') -> Dict[str, Any]: + result = dict(kwargs) + callbacks = list(result.get('callbacks') or []) + if _usage_callback not in callbacks: + callbacks.append(_usage_callback) + error_callback = _get_llm_error_callback(provider, model=model, feature=feature) + if error_callback not in callbacks: + callbacks.append(error_callback) + result['callbacks'] = callbacks + return result + + # --------------------------------------------------------------------------- # BYOK (Bring Your Own Key) # @@ -39,14 +77,21 @@ class _AnthropicClientProxy: __slots__ = ('_default',) - def __init__(self, default: anthropic.AsyncAnthropic): + def __init__(self, default: Optional[anthropic.AsyncAnthropic] = None): object.__setattr__(self, '_default', default) + def _get_default(self) -> anthropic.AsyncAnthropic: + default = self._default + if default is None: + default = anthropic.AsyncAnthropic(timeout=120.0, max_retries=1) + object.__setattr__(self, '_default', default) + return default + def _resolve(self) -> anthropic.AsyncAnthropic: byok = get_byok_key('anthropic') if byok: return _cached_anthropic(byok) - return self._default + return self._get_default() def __getattr__(self, name: str): return getattr(self._resolve(), name) @@ -56,6 +101,7 @@ class _OpenAIEmbeddingsProxy: """Transparent proxy for OpenAIEmbeddings that uses BYOK OpenAI when set.""" __slots__ = ('_model', '_default', '_ctor_kwargs') + _METHODS_TO_WRAP = {'embed_documents', 'aembed_documents', 'embed_query', 'aembed_query'} def __init__(self, model: str, default: OpenAIEmbeddings, ctor_kwargs: Dict[str, Any]): object.__setattr__(self, '_model', model) @@ -74,7 +120,28 @@ def _resolve(self) -> OpenAIEmbeddings: return self._default def __getattr__(self, name: str): - return getattr(self._resolve(), name) + attr = getattr(self._resolve(), name) + if name not in self._METHODS_TO_WRAP or not callable(attr): + return attr + if name.startswith('a'): + + async def _wrapped_async(*args, **kwargs): + try: + return await attr(*args, **kwargs) + except Exception as e: + handle_llm_error(e, 'openai', feature='embeddings', model=self._model, operation=name) + raise + + return _wrapped_async + + def _wrapped(*args, **kwargs): + try: + return attr(*args, **kwargs) + except Exception as e: + handle_llm_error(e, 'openai', feature='embeddings', model=self._model, operation=name) + raise + + return _wrapped _BYOK_CACHE_MAX_SIZE = 256 @@ -111,7 +178,10 @@ def _create_byok_client( model: str, provider: str, byok_key: str, streaming: bool = False, feature: str = '' ) -> Optional[ChatOpenAI]: """Create a ChatOpenAI using the user's BYOK key. Returns None if BYOK not supported for this provider.""" - kwargs: Dict[str, Any] = {'callbacks': [_usage_callback], 'request_timeout': 120, 'max_retries': 1} + callback_provider = _effective_byok_provider(model, provider) + kwargs: Dict[str, Any] = _with_llm_callbacks( + {'request_timeout': 120, 'max_retries': 1}, callback_provider, model=model, feature=feature + ) if model == 'gpt-5.1': kwargs['extra_body'] = {"prompt_cache_retention": "24h"} if streaming: @@ -137,8 +207,7 @@ def _create_byok_client( # Anthropic client for chat agent (module-level, BYOK-aware) -_default_anthropic_client = anthropic.AsyncAnthropic(timeout=120.0, max_retries=1) -anthropic_client = _AnthropicClientProxy(_default_anthropic_client) +anthropic_client = _AnthropicClientProxy() def get_anthropic_client() -> anthropic.AsyncAnthropic: @@ -148,6 +217,7 @@ def get_anthropic_client() -> anthropic.AsyncAnthropic: def get_openai_chat(model: str, **kwargs) -> ChatOpenAI: """Explicit factory; equivalent to using the module-level proxies.""" + kwargs = _with_llm_callbacks(kwargs, 'openai', model=model) byok = get_byok_key('openai') if byok: return _cached_openai_chat(model, byok, kwargs) @@ -417,11 +487,9 @@ def _get_or_create_openai_llm(model_name: str, streaming: bool = False) -> ChatO """Get or create a cached ChatOpenAI for an OpenAI model.""" key = (model_name, streaming, 'openai') if key not in _llm_cache: - kwargs: Dict[str, Any] = { - 'callbacks': [_usage_callback], - 'request_timeout': 120, - 'max_retries': 1, - } + kwargs: Dict[str, Any] = _with_llm_callbacks( + {'request_timeout': 120, 'max_retries': 1}, 'openai', model=model_name + ) if model_name == 'gpt-5.1': kwargs['extra_body'] = {"prompt_cache_retention": "24h"} if streaming: @@ -447,10 +515,10 @@ def _get_or_create_openrouter_llm( 'api_key': os.environ.get('OPENROUTER_API_KEY'), 'base_url': "https://openrouter.ai/api/v1", 'default_headers': {"X-Title": "Omi Chat"}, - 'callbacks': [_usage_callback], 'request_timeout': 120, 'max_retries': 1, } + kwargs = _with_llm_callbacks(kwargs, 'openrouter', model=api_model) if temperature is not None: kwargs['temperature'] = temperature if streaming: @@ -478,7 +546,7 @@ def _get_or_create_gemini_llm(model_name: str, streaming: bool = False) -> BaseC use_vertex = os.environ.get('USE_VERTEX_AI', '').lower() == 'true' gcp_project = os.environ.get('GOOGLE_CLOUD_PROJECT', '') if use_vertex else '' gemini_key = os.environ.get('GEMINI_API_KEY', '') - kwargs: Dict[str, Any] = {'callbacks': [_usage_callback], 'timeout': 120, 'max_retries': 1} + kwargs: Dict[str, Any] = _with_llm_callbacks({'timeout': 120, 'max_retries': 1}, 'gemini', model=model_name) if streaming: kwargs['streaming'] = True @@ -587,6 +655,11 @@ def get_llm(feature: str, streaming: bool = False, cache_key: Optional[str] = No return result +def get_openai_agent_llm(streaming: bool = False) -> BaseChatModel: + """OpenAI-compatible agent model used when CHAT_PROVIDER=openai.""" + return get_llm('chat_graph', streaming=streaming) + + def get_qos_info() -> Dict[str, Dict[str, str]]: """Return full featureโ†’(model, provider) mapping for the active profile (debugging/monitoring).""" info: Dict[str, Dict[str, str]] = {} @@ -624,7 +697,10 @@ def get_qos_info() -> Dict[str, Dict[str, str]]: # Legacy module-level alias (kept for test compatibility). # Production code should use get_llm(feature) exclusively. # --------------------------------------------------------------------------- -llm_mini = ChatOpenAI(model='gpt-4.1-mini', callbacks=[_usage_callback], request_timeout=120, max_retries=1) +llm_mini = ChatOpenAI( + model='gpt-4.1-mini', + **_with_llm_callbacks({'request_timeout': 120, 'max_retries': 1}, 'openai', model='gpt-4.1-mini'), +) # --------------------------------------------------------------------------- # Embeddings, parser, utilities @@ -667,6 +743,10 @@ def gemini_embed_query(text: str) -> List[float]: 'taskType': 'RETRIEVAL_QUERY', } headers = {'x-goog-api-key': api_key, 'Content-Type': 'application/json'} - resp = httpx.post(url, json=payload, headers=headers, timeout=10) - resp.raise_for_status() - return resp.json()['embedding']['values'] + try: + resp = httpx.post(url, json=payload, headers=headers, timeout=10) + resp.raise_for_status() + return resp.json()['embedding']['values'] + except Exception as e: + handle_llm_error(e, 'gemini', feature='embeddings', model='embedding-001', operation='embed_query') + raise diff --git a/backend/utils/retrieval/agentic.py b/backend/utils/retrieval/agentic.py index 84c3697ffb4..e5709abe205 100644 --- a/backend/utils/retrieval/agentic.py +++ b/backend/utils/retrieval/agentic.py @@ -6,14 +6,21 @@ tool use API with streaming for real-time responses. """ -import uuid import asyncio import contextvars +import os import traceback +import uuid from typing import List, Optional, AsyncGenerator, Any, Tuple +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.runnables import RunnableConfig +try: + from langgraph.prebuilt import create_react_agent +except ImportError: + create_react_agent = None + # Context variable to store config for tools agent_config_context: contextvars.ContextVar[dict] = contextvars.ContextVar('agent_config', default=None) @@ -47,10 +54,12 @@ ) from utils.retrieval.tools.app_tools import load_app_tools, get_tool_status_message from utils.retrieval.safety import AgentSafetyGuard, SafetyGuardError -from utils.llm.clients import anthropic_client, ANTHROPIC_AGENT_MODEL +from utils.llm.byok_errors import handle_llm_error +from utils.llm.clients import anthropic_client, ANTHROPIC_AGENT_MODEL, get_openai_agent_llm from utils.llm.chat import _get_agentic_qa_prompt from utils.other.endpoints import timeit -from utils.observability.langsmith import is_langsmith_enabled +from utils.observability.langsmith import get_chat_tracer_callbacks, is_langsmith_enabled +from utils.observability.langsmith_prompts import get_prompt_metadata import logging # Import langsmith traceable if available @@ -67,6 +76,12 @@ def decorator(func): logger = logging.getLogger(__name__) +CHAT_PROVIDER = os.getenv('CHAT_PROVIDER', 'anthropic').strip().lower() +if CHAT_PROVIDER not in {'anthropic', 'openai'}: + logger.warning("Unsupported CHAT_PROVIDER=%s; falling back to anthropic", CHAT_PROVIDER) + CHAT_PROVIDER = 'anthropic' +logger.info("Chat provider: %s", CHAT_PROVIDER) + # PROMPT CACHE OPTIMIZATION: This list MUST stay fixed and in this exact order. # Anthropic caches the tools array as part of the request prefix. If the tool # definitions are identical across requests they are cached automatically. @@ -420,7 +435,7 @@ async def _run_anthropic_agent_stream( response = await stream.get_final_message() except Exception as e: - logger.error(f"Anthropic API error: {e}") + handle_llm_error(e, 'anthropic', feature='chat_agent', model=ANTHROPIC_AGENT_MODEL) await callback.put_data(f"\n\nSorry, I encountered an error. Please try again.") await callback.end() return @@ -502,12 +517,241 @@ async def _run_anthropic_agent_stream( await callback.end() +def _messages_to_langchain(messages: List[Message]) -> List: + """Convert chat messages to LangChain message objects.""" + result = [] + for msg in messages: + if msg.sender == "ai": + result.append(AIMessage(content=msg.text)) + else: + result.append(HumanMessage(content=msg.text)) + return result + + +def _append_openai_app_tool_prompt(system_prompt: str, app_tools: list) -> str: + if not app_tools: + return system_prompt + + app_tool_names = ", ".join(sorted(t.name for t in app_tools)) + return f"""{system_prompt} + + +You have access to additional tools from the user's connected apps. Use these tools when the user asks for actions or data from matching external services. + +Available app tool names: {app_tool_names} +""" + + +async def _run_openai_agent_stream( + agent, + messages: List, + config: dict, + callback: AsyncStreamingCallback, + full_response: List[str], +): + """Run the LangGraph ReAct agent and feed events into the callback queue.""" + safety_guard = config['configurable'].get('safety_guard') + + try: + async for event in agent.astream_events({"messages": messages}, config=config, version="v2"): + kind = event.get("event") + + if kind == "on_chat_model_stream": + chunk = event.get("data", {}).get("chunk") + token = getattr(chunk, "content", None) + if isinstance(token, str) and token: + full_response.append(token) + await callback.put_data(token) + + elif kind == "on_tool_start": + tool_name = event.get("name", "unknown") + tool_input = event.get("data", {}).get("input", {}) + if not isinstance(tool_input, dict): + tool_input = {} + + logger.info(f"Tool started: {tool_name}") + + app_id = _extract_app_id(tool_name) + tools_list = config.get('configurable', {}).get('tools', []) + tool_obj = next((t for t in tools_list if getattr(t, 'name', None) == tool_name), None) + await callback.put_thought(get_tool_display_name(tool_name, tool_obj), app_id=app_id) + + if safety_guard: + try: + safety_guard.validate_tool_call(tool_name, tool_input) + warning = safety_guard.should_warn_user() + if warning: + await callback.put_thought(warning) + except SafetyGuardError as e: + await callback.put_data(f"\n\n{str(e)}") + logger.error(f"Safety Guard blocked tool call: {e}") + await callback.end() + return + + elif kind == "on_tool_end": + tool_name = event.get("name", "unknown") + output_raw = event.get("data", {}).get("output", "") + output = str(getattr(output_raw, 'content', output_raw)) + + logger.info(f"Tool ended: {tool_name}") + await _emit_calendar_status(callback, tool_name, output) + + if safety_guard and output: + try: + safety_guard.check_context_size(output) + except SafetyGuardError as e: + await callback.put_data(f"\n\n{str(e)}") + logger.error(f"Safety Guard blocked due to context size: {e}") + await callback.end() + return + + elif kind == "on_tool_error": + logger.error(f"Tool error: {event.get('name', 'unknown')} - {event.get('data', {}).get('error', '')}") + elif kind == "on_chain_error": + logger.error(f"Chain error: {event.get('data', {}).get('error', '')}") + + if safety_guard: + logger.info(f"Safety Guard final stats: {safety_guard.get_stats()}") + + await callback.end() + + except SafetyGuardError as e: + await callback.put_data(f"\n\n{str(e)}") + logger.error(f"Safety Guard stopped execution: {e}") + await callback.end() + except Exception as e: + logger.error(f"Error in OpenAI agent stream: {e}") + traceback.print_exc() + await callback.end() + + +async def _execute_agentic_chat_stream_openai( + uid: str, + messages: List[Message], + app: Optional[App], + callback_data: dict, + chat_session: Optional[ChatSession], + context: Optional[PageContext], +) -> AsyncGenerator[str, None]: + """Execute agentic chat through LangGraph/OpenAI for self-hosted fallback.""" + if create_react_agent is None: + logger.error("CHAT_PROVIDER=openai but langgraph is not installed") + if callback_data is not None: + callback_data['error'] = 'langgraph is not installed' + yield "data: Sorry, I encountered an error. Please try again." + yield None + return + + system_prompt = _get_agentic_qa_prompt(uid, app, messages, context=context) + + prompt_name, prompt_commit, prompt_source = None, None, None + try: + prompt_name, prompt_commit, prompt_source = get_prompt_metadata() + except Exception as e: + logger.error(f"Could not get prompt metadata: {e}") + + tools = list(CORE_TOOLS) + try: + app_tools = load_app_tools(uid) + if app_tools: + tools.extend(app_tools) + logger.info(f"Added {len(app_tools)} app tools to OpenAI chat") + system_prompt = _append_openai_app_tool_prompt(system_prompt, app_tools) + except Exception as e: + logger.error(f"Error loading app tools: {e}") + + langchain_messages = [SystemMessage(content=system_prompt)] + langchain_messages.extend(_messages_to_langchain(messages)) + + conversations_collected = [] + safety_guard = AgentSafetyGuard(max_tool_calls=25, max_context_tokens=500000) + langsmith_run_id = str(uuid.uuid4()) + metadata = { + "uid": uid, + "app_id": app.id if app else None, + "app_name": app.name if app else None, + "chat_session_id": chat_session.id if chat_session else None, + "has_context": context is not None, + "context_type": context.type if context else None, + "num_tools": len(tools), + "prompt_name": prompt_name, + "prompt_commit": prompt_commit, + "provider": "openai", + } + tracer_callbacks = get_chat_tracer_callbacks( + run_id=langsmith_run_id, + run_name="chat.agentic.stream", + tags=["chat", "agentic", "streaming", "openai"], + metadata=metadata, + ) + config = { + "run_id": langsmith_run_id, + "callbacks": tracer_callbacks, + "run_name": "chat.agentic.stream", + "tags": ["chat", "agentic", "streaming", "openai"], + "metadata": metadata, + "configurable": { + "user_id": uid, + "thread_id": str(uuid.uuid4()), + "conversations_collected": conversations_collected, + "safety_guard": safety_guard, + "chat_session_id": chat_session.id if chat_session else None, + "tools": tools, + }, + } + agent_config_context.set(config) + + if callback_data is not None: + callback_data['langsmith_run_id'] = langsmith_run_id + callback_data['prompt_name'] = prompt_name + callback_data['prompt_commit'] = prompt_commit + + callback = AsyncStreamingCallback() + full_response = [] + tool_usage_count = 0 + agent = create_react_agent(model=get_openai_agent_llm(streaming=True), tools=tools) + task = asyncio.create_task(_run_openai_agent_stream(agent, langchain_messages, config, callback, full_response)) + + try: + while True: + chunk = await callback.queue.get() + if chunk is None: + break + + if chunk.startswith("think: "): + tool_usage_count += 1 + + yield chunk + + await task + + if callback_data is not None: + callback_data['answer'] = ''.join(full_response) + callback_data['memories_found'] = conversations_collected if conversations_collected else [] + callback_data['ask_for_nps'] = tool_usage_count > 0 + chart_data_from_config = config['configurable'].get('chart_data') + if chart_data_from_config: + callback_data['chart_data'] = chart_data_from_config + logger.info(f"Collected {len(callback_data['memories_found'])} conversations for citation") + + except asyncio.CancelledError: + task.cancel() + raise + except Exception as e: + logger.error(f"Error in execute_agentic_chat_stream openai: {e}") + traceback.print_exc() + if callback_data is not None: + callback_data['error'] = str(e) + + yield None + + # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- -@_traceable(name="chat.anthropic.stream", run_type="chain") +@_traceable(name="chat.agentic.stream", run_type="chain") async def execute_agentic_chat_stream( uid: str, messages: List[Message], @@ -520,14 +764,19 @@ async def execute_agentic_chat_stream( Yields formatted chunks with "data: " or "think: " prefixes. """ + if CHAT_PROVIDER == 'openai': + async for chunk in _execute_agentic_chat_stream_openai( + uid, messages, app, callback_data=callback_data, chat_session=chat_session, context=context + ): + yield chunk + return + # Build system prompt system_prompt = _get_agentic_qa_prompt(uid, app, messages, context=context) # Get prompt metadata for tracing/versioning prompt_name, prompt_commit, prompt_source = None, None, None try: - from utils.observability.langsmith_prompts import get_prompt_metadata - prompt_name, prompt_commit, prompt_source = get_prompt_metadata() except Exception as e: logger.error(f"Could not get prompt metadata: {e}") diff --git a/backend/utils/stt/vad.py b/backend/utils/stt/vad.py index 2fcc0d38f45..e9b345100eb 100644 --- a/backend/utils/stt/vad.py +++ b/backend/utils/stt/vad.py @@ -4,7 +4,7 @@ import numpy as np import onnxruntime as ort -import requests +import httpx from fastapi import HTTPException from pydub import AudioSegment @@ -111,7 +111,7 @@ def vad_is_empty(file_path, return_segments: bool = False, cache: bool = False): try: with open(file_path, 'rb') as file: files = {'file': (file_path.split('/')[-1], file, 'audio/wav')} - response = requests.post(hosted_vad_url, files=files, timeout=300) + response = httpx.post(hosted_vad_url, files=files, timeout=300.0) response.raise_for_status() segments = response.json() except Exception as e: