Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions backend/database/redis_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,11 @@ def has_silent_user_notification_been_sent(uid: str) -> bool:
return r.exists(f'users:{uid}:silent_notification_sent')


def try_acquire_byok_llm_error_notification_lock(uid: str, provider: str, reason: str, ttl: int = 60 * 60 * 24) -> bool:
"""Return True once per BYOK provider/error reason per TTL window."""
return bool(r.set(f'users:{uid}:byok_llm_error:{provider}:{reason}', '1', ex=ttl, nx=True))


# ******************************************************
# ******* IMPORTANT CONVERSATION NOTIFICATIONS *********
# ******************************************************
Expand Down Expand Up @@ -636,8 +641,7 @@ def remove_conversation_summary_app_id(app_id: str) -> bool:
# Lua script: atomic increment + TTL in a single round-trip.
# Returns [current_count, ttl_remaining]. Sets TTL on first hit
# and self-heals any key that lost its TTL (prevents permanent buckets).
_RATE_LIMIT_LUA = r.register_script(
"""
_RATE_LIMIT_LUA = r.register_script("""
local key = KEYS[1]
local window = tonumber(ARGV[1])
local current = redis.call('INCR', key)
Expand All @@ -650,8 +654,7 @@ def remove_conversation_summary_app_id(app_id: str) -> bool:
ttl = window
end
return {current, ttl}
"""
)
""")


def check_rate_limit(key: str, policy: str, max_requests: int, window: int) -> tuple[bool, int, int]:
Expand Down Expand Up @@ -680,8 +683,7 @@ def check_rate_limit(key: str, policy: str, max_requests: int, window: int) -> t
# Burst uses a sorted set keyed by timestamp-ms for sliding-window accuracy,
# trimmed on every call (O(log n)). Daily char counter auto-expires at midnight
# UTC (caller passes seconds_until_midnight_utc as the TTL).
_TTS_RATE_LIMIT_LUA = r.register_script(
"""
_TTS_RATE_LIMIT_LUA = r.register_script("""
local burst_key = KEYS[1]
local daily_key = KEYS[2]
local now_ms = tonumber(ARGV[1])
Expand Down Expand Up @@ -709,8 +711,7 @@ def check_rate_limit(key: str, policy: str, max_requests: int, window: int) -> t
redis.call('EXPIRE', daily_key, daily_ttl)
end
return {0, 0}
"""
)
""")


def _seconds_until_midnight_utc() -> int:
Expand Down
3 changes: 2 additions & 1 deletion backend/routers/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
trigger_external_integrations,
)
from utils.conversations.location import async_get_google_maps_location
from utils.byok import set_byok_keys
from utils.byok import set_byok_keys, set_byok_uid
from utils.conversations.process_conversation import process_conversation
from utils.executors import storage_executor
from utils.webhooks import (
Expand Down Expand Up @@ -79,6 +79,7 @@ async def _process_conversation_task(
"""
if byok_keys:
set_byok_keys(byok_keys)
set_byok_uid(uid)
try:
conversation_data = conversations_db.get_conversation(uid, conversation_id)
if not conversation_data:
Expand Down
4 changes: 3 additions & 1 deletion backend/routers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
)

from utils import encryption
from utils.byok import get_byok_keys, set_byok_keys
from utils.byok import get_byok_keys, set_byok_keys, set_byok_uid
from utils.log_sanitizer import sanitize
from utils.stt.pre_recorded import deepgram_prerecorded, get_deepgram_model_for_language, postprocess_words
from utils.stt.vad import vad_is_empty
Expand Down Expand Up @@ -1359,6 +1359,7 @@ def _run_full_pipeline_background(
Moved ALL heavy processing here so the v2 endpoint returns 202 immediately.
"""
set_byok_keys(byok_keys or {})
set_byok_uid(uid if byok_keys else None)
segmented_paths = set()
wav_paths = []
stage_timings = {}
Expand Down Expand Up @@ -1580,6 +1581,7 @@ def _process_one_segment(path):
pass
finally:
set_byok_keys({})
set_byok_uid(None)
_cleanup_files(list(segmented_paths))
_cleanup_files(wav_paths)
try:
Expand Down
4 changes: 2 additions & 2 deletions backend/routers/users.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import re
import threading
import uuid
from typing import List, Dict, Any, Union, Optional
import hashlib
Expand Down Expand Up @@ -62,6 +61,7 @@
PricingOption,
PhoneCallQuota,
)
from utils.executors import storage_executor, submit_with_context
from utils.phone_calls import get_quota_snapshot as get_phone_call_quota_snapshot
from utils.apps import get_available_app_by_id
from utils.subscription import (
Expand Down Expand Up @@ -165,7 +165,7 @@ def delete_account(

# 3. Wipe Firestore subcollections in the background — can take minutes
# for heavy users and would otherwise time out at the load balancer.
threading.Thread(target=_background_wipe_user_data, args=(uid,), daemon=True).start()
submit_with_context(storage_executor, _background_wipe_user_data, uid)

return {'status': 'ok', 'message': 'Account deletion started'}
except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions backend/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ pytest tests/unit/test_thread_join_elimination.py -v
pytest tests/unit/test_async_http_infrastructure.py -v
pytest tests/unit/test_clean_sweep_migrations.py -v
pytest tests/unit/test_omi_qos_tiers.py -v
pytest tests/unit/test_byok_llm_errors.py -v
pytest tests/unit/test_byok_security.py -v
pytest tests/unit/test_paywall_reconnect_gate.py -v
pytest tests/unit/test_vertex_ai_system_role.py -v
Expand Down
2 changes: 2 additions & 0 deletions backend/tests/unit/test_action_item_date_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def _load_module_from_file(module_name, file_path):
if mod_name not in sys.modules:
_stub_module(mod_name)

sys.modules["database.auth"].get_user_name = MagicMock(return_value="TestUser")

# Stub database.action_items
action_items_db = _stub_module("database.action_items")
action_items_db.create_action_item = MagicMock(return_value="test-item-id")
Expand Down
6 changes: 6 additions & 0 deletions backend/tests/unit/test_async_app_integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
"vector_db",
"apps",
"llm_usage",
"user_usage",
"chat",
"goals",
"announcements",
]:
mod = types.ModuleType(f"database.{submod}")
sys.modules.setdefault(f"database.{submod}", mod)
Expand All @@ -53,6 +55,10 @@
sys.modules["database.redis_db"].get_proactive_noti_sent_at_ttl = MagicMock(return_value=0)
sys.modules["database.redis_db"].incr_daily_notification_count = MagicMock()
sys.modules["database.redis_db"].get_daily_notification_count = MagicMock(return_value=0)
sys.modules["database.redis_db"].delete_generic_cache = MagicMock()
sys.modules["database.user_usage"].get_monthly_chat_usage = MagicMock(return_value={})
sys.modules["database.user_usage"].get_monthly_usage_stats_since = MagicMock(return_value={})
sys.modules["database.announcements"].compare_versions = MagicMock(return_value=0)
sys.modules["database.vector_db"].query_vectors_by_metadata = MagicMock(return_value=[])
sys.modules["database.apps"].record_app_usage = MagicMock()
sys.modules["database.llm_usage"].record_llm_usage = MagicMock()
Expand Down
2 changes: 2 additions & 0 deletions backend/tests/unit/test_available_plans_resilience.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _compare_versions(a, b):


_announcements_mod._compare_versions = _compare_versions
_announcements_mod.compare_versions = _compare_versions

# database.users needs the functions payment.py imports by name
_users_mod = sys.modules["database.users"]
Expand Down Expand Up @@ -126,6 +127,7 @@ def _compare_versions(a, b):

_endpoints_mod = sys.modules["utils.other.endpoints"]
_endpoints_mod.get_current_user_uid = lambda: "test-user"
_endpoints_mod.get_current_user_uid_no_byok_validation = lambda: "test-user"

# Ensure utils.other has endpoints attr for `from utils.other import endpoints`
sys.modules["utils.other"].endpoints = _endpoints_mod
Expand Down
4 changes: 4 additions & 0 deletions backend/tests/unit/test_batch_upload_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@

import os
import sys
import types
from unittest.mock import MagicMock, patch

os.environ.setdefault("ENCRYPTION_SECRET", "omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv")

# Mock heavy dependencies at sys.modules level before importing storage
sys.modules.setdefault("database._client", MagicMock())
subscription_mod = types.ModuleType("utils.subscription")
subscription_mod.get_default_basic_subscription = MagicMock()
sys.modules.setdefault("utils.subscription", subscription_mod)

_mock_gcs_storage = MagicMock()
_mock_gcs_client_instance = MagicMock()
Expand Down
112 changes: 112 additions & 0 deletions backend/tests/unit/test_byok_llm_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os
import sys
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

os.environ.setdefault('OPENAI_API_KEY', 'sk-test-fake-for-unit-tests')
os.environ.setdefault('ANTHROPIC_API_KEY', 'ant-test-fake-for-unit-tests')
os.environ.setdefault('ENCRYPTION_SECRET', 'omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv')

sys.modules.setdefault('database._client', MagicMock())


class _HTTPError(Exception):
def __init__(self, message: str, status_code: int):
super().__init__(message)
self.status_code = status_code


def test_classify_byok_llm_error_authentication():
from utils.llm.byok_errors import classify_byok_llm_error

assert classify_byok_llm_error(_HTTPError("bad api key", 401)) == 'invalid'


def test_classify_byok_llm_error_permission():
from utils.llm.byok_errors import classify_byok_llm_error

assert classify_byok_llm_error(_HTTPError("project denied", 403)) == 'permission'


def test_classify_byok_llm_error_insufficient_quota():
from utils.llm.byok_errors import classify_byok_llm_error

assert classify_byok_llm_error(_HTTPError("insufficient_quota", 429)) == 'quota'


def test_classify_byok_llm_error_ignores_transient_rate_limit():
from utils.llm.byok_errors import classify_byok_llm_error

assert classify_byok_llm_error(_HTTPError("rate limit reached, retry later", 429)) is None


@patch('utils.llm.byok_errors.messaging.send_each')
@patch('utils.llm.byok_errors.notification_db.get_all_tokens', return_value=['token-1'])
@patch('utils.llm.byok_errors.try_acquire_byok_llm_error_notification_lock', return_value=True)
@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1')
@patch('utils.llm.byok_errors.get_byok_key', return_value='sk-user')
def test_handle_llm_error_notifies_actionable_byok_error(
mock_get_key,
mock_get_uid,
mock_lock,
mock_get_tokens,
mock_send_each,
):
from utils.llm.byok_errors import handle_llm_error

mock_send_each.return_value = SimpleNamespace(responses=[SimpleNamespace(success=True, exception=None)])

handle_llm_error(_HTTPError("insufficient_quota", 429), 'openai', feature='memories', model='gpt-test')

mock_lock.assert_called_once_with('user-1', 'openai', 'quota')
mock_get_tokens.assert_called_once_with('user-1')
mock_send_each.assert_called_once()
message = mock_send_each.call_args.args[0][0]
assert message.data == {'type': 'byok_llm_error', 'provider': 'openai', 'reason': 'quota'}


@patch('utils.llm.byok_errors.messaging.send_each')
@patch('utils.llm.byok_errors.try_acquire_byok_llm_error_notification_lock')
@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1')
@patch('utils.llm.byok_errors.get_byok_key', return_value=None)
def test_handle_llm_error_does_not_notify_platform_error(
mock_get_key,
mock_get_uid,
mock_lock,
mock_send_each,
):
from utils.llm.byok_errors import handle_llm_error

handle_llm_error(_HTTPError("insufficient_quota", 429), 'openai', feature='memories', model='gpt-test')

mock_lock.assert_not_called()
mock_send_each.assert_not_called()


def test_validate_byok_request_records_current_uid():
from utils.byok import get_byok_uid, validate_byok_request

with patch('utils.byok._check_byok_validity', return_value=None):
validate_byok_request('user-1')

assert get_byok_uid() == 'user-1'


def test_anthropic_proxy_constructs_default_client_lazily():
from utils.llm.clients import _AnthropicClientProxy

created = []

def _fake_client(**kwargs):
created.append(kwargs)
return object()

proxy = _AnthropicClientProxy()

with patch('utils.llm.clients.get_byok_key', return_value=None), patch(
'utils.llm.clients.anthropic.AsyncAnthropic', side_effect=_fake_client
):
assert created == []
proxy._resolve()

assert created == [{'timeout': 120.0, 'max_retries': 1}]
1 change: 1 addition & 0 deletions backend/tests/unit/test_byok_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
sys.modules.setdefault('database._client', MagicMock())
sys.modules.setdefault('database.redis_db', MagicMock())
sys.modules.setdefault('database.users', MagicMock())
sys.modules.setdefault('database.notifications', MagicMock())
sys.modules.setdefault('database.user_usage', MagicMock())
sys.modules.setdefault('database.llm_usage', MagicMock())
sys.modules.setdefault('database.announcements', MagicMock())
Expand Down
Loading
Loading