diff --git a/backend/routers/sync.py b/backend/routers/sync.py index b3907a11f22..d6061eb859f 100644 --- a/backend/routers/sync.py +++ b/backend/routers/sync.py @@ -205,6 +205,17 @@ def precache_conversation_audio_endpoint( if not audio_files: return {"status": "no_audio", "message": "No audio files in conversation"} + _precache_sem = threading.Semaphore(10) + + def _precache_all_parallel(): + logger.info(f"Pre-caching all {len(audio_files)} audio files for conversation {conversation_id} (parallel)") + + def _bounded_precache(af): + with _precache_sem: + return _precache_audio_file(uid, conversation_id, af) + + futures = [submit_with_context(storage_executor, _bounded_precache, af) for af in audio_files] + # Start background parallel pre-caching with bounded concurrency (#7387) def _precache_all_parallel(): logger.info(f"Pre-caching all {len(audio_files)} audio files for conversation {conversation_id} (parallel)") @@ -220,6 +231,7 @@ def _precache_all_parallel(): except Exception: _PRECACHE_FILE_SEM.release() raise + for future in futures: try: future.result() @@ -312,10 +324,17 @@ def get_audio_signed_urls_endpoint( ) uncached_files.append(af) - # Cache remaining files in background if uncached_files: + _uncached_sem = threading.Semaphore(10) def _cache_uncached_parallel(): + + def _bounded_precache(af): + with _uncached_sem: + return _precache_audio_file(uid, conversation_id, af) + + futures = [submit_with_context(storage_executor, _bounded_precache, af) for af in uncached_files] + futures = [] for af in uncached_files: _PRECACHE_FILE_SEM.acquire() @@ -328,6 +347,7 @@ def _cache_uncached_parallel(): except Exception: _PRECACHE_FILE_SEM.release() raise + for future in futures: try: future.result() diff --git a/backend/test.sh b/backend/test.sh index 29ab17f03a6..730d906b2dc 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -121,8 +121,12 @@ pytest tests/unit/test_trial_metadata.py -v pytest tests/unit/test_neo_desktop_grandfather.py -v pytest tests/unit/test_vertex_ai_system_role.py -v pytest tests/unit/test_tts.py -v + +pytest tests/unit/test_storage_fanout_limit.py -v + pytest tests/unit/test_webhook_auto_disable.py -v + # Fair-use integration tests (require Redis; skip gracefully if unavailable) if redis-cli ping >/dev/null 2>&1; then pytest tests/integration/test_fair_use_live.py -v diff --git a/backend/tests/unit/test_storage_fanout_limit.py b/backend/tests/unit/test_storage_fanout_limit.py new file mode 100644 index 00000000000..eb42d72d390 --- /dev/null +++ b/backend/tests/unit/test_storage_fanout_limit.py @@ -0,0 +1,130 @@ +"""Tests for bounded fan-out in storage executor submissions (issue #7387). + +Verifies that concurrent storage_executor submissions are capped by +threading.Semaphore to prevent queue spikes from unbounded fan-out. +""" + +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from unittest.mock import patch, MagicMock + +import pytest + + +class TestStorageFanoutSemaphore: + """Verify that _STORAGE_FANOUT_SEMAPHORE limits concurrent GCS operations.""" + + def test_semaphore_limits_concurrency(self): + """Fan-out of 50 tasks with semaphore(5) should never exceed 5 concurrent.""" + sem = threading.Semaphore(5) + max_concurrent = 0 + current = 0 + lock = threading.Lock() + + def work(i): + nonlocal max_concurrent, current + with sem: + with lock: + current += 1 + if current > max_concurrent: + max_concurrent = current + time.sleep(0.01) + with lock: + current -= 1 + return i + + executor = ThreadPoolExecutor(max_workers=20) + futures = [executor.submit(work, i) for i in range(50)] + for f in as_completed(futures): + f.result() + executor.shutdown(wait=True) + + assert max_concurrent <= 5, f"Max concurrent was {max_concurrent}, expected <= 5" + assert max_concurrent >= 2, "Semaphore should allow some parallelism" + + def test_semaphore_does_not_deadlock(self): + """Nested semaphore acquisition (precache -> chunk download) must not deadlock.""" + outer_sem = threading.Semaphore(3) + inner_sem = threading.Semaphore(3) + results = [] + lock = threading.Lock() + + def inner_work(val): + with inner_sem: + time.sleep(0.005) + return val * 2 + + def outer_work(i): + with outer_sem: + inner_executor = ThreadPoolExecutor(max_workers=2) + futs = [inner_executor.submit(inner_work, j) for j in range(3)] + vals = [f.result(timeout=5) for f in futs] + inner_executor.shutdown(wait=True) + with lock: + results.append(sum(vals)) + return i + + executor = ThreadPoolExecutor(max_workers=10) + futures = [executor.submit(outer_work, i) for i in range(6)] + for f in as_completed(futures): + f.result(timeout=10) + executor.shutdown(wait=True) + + assert len(results) == 6 + assert all(r == 6 for r in results) # 0*2 + 1*2 + 2*2 = 6 + + def test_storage_module_has_semaphore(self): + """The storage module must define a _STORAGE_FANOUT_SEMAPHORE. + + Uses ast parsing to avoid importing heavy native deps (opuslib, GCS). + """ + import ast + import pathlib + + storage_path = pathlib.Path(__file__).resolve().parents[2] / 'utils' / 'other' / 'storage.py' + tree = ast.parse(storage_path.read_text()) + + semaphore_found = False + semaphore_value = None + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == '_STORAGE_FANOUT_SEMAPHORE': + semaphore_found = True + if isinstance(node.value, ast.Call) and node.value.args: + arg = node.value.args[0] + if isinstance(arg, ast.Constant): + semaphore_value = arg.value + + assert semaphore_found, "_STORAGE_FANOUT_SEMAPHORE not found in storage.py" + assert semaphore_value is not None, "Could not determine semaphore value" + assert 5 <= semaphore_value <= 30, f"Semaphore value {semaphore_value} outside expected range [5, 30]" + + def test_bounded_fan_out_caps_queue_depth(self): + """Simulate the storage pattern: N tasks submitted, only K run concurrently.""" + sem = threading.Semaphore(10) + task_count = 100 + peak_queue = 0 + running = 0 + lock = threading.Lock() + + def bounded_work(i): + nonlocal peak_queue, running + with sem: + with lock: + running += 1 + if running > peak_queue: + peak_queue = running + time.sleep(0.005) + with lock: + running -= 1 + return i + + executor = ThreadPoolExecutor(max_workers=50) + futures = [executor.submit(bounded_work, i) for i in range(task_count)] + results = [f.result() for f in as_completed(futures)] + executor.shutdown(wait=True) + + assert len(results) == task_count + assert peak_queue <= 10, f"Peak concurrent {peak_queue} exceeded semaphore limit 10" diff --git a/backend/utils/llm/knowledge_graph.py b/backend/utils/llm/knowledge_graph.py index 3eec2a9006a..271f4272ea7 100644 --- a/backend/utils/llm/knowledge_graph.py +++ b/backend/utils/llm/knowledge_graph.py @@ -262,6 +262,14 @@ def process_memory(memory): all_nodes = [] all_edges = [] + _kg_sem = threading.Semaphore(10) + + def _bounded_process_memory(m): + with _kg_sem: + return process_memory(m) + + futures = [storage_executor.submit(_bounded_process_memory, m) for m in memories] + futures = [] for m in memories: _KG_REBUILD_SEM.acquire() @@ -272,6 +280,7 @@ def process_memory(memory): except Exception: _KG_REBUILD_SEM.release() raise + for future in as_completed(futures): try: result = future.result() diff --git a/backend/utils/other/storage.py b/backend/utils/other/storage.py index f12ad2742c4..3c893905c04 100644 --- a/backend/utils/other/storage.py +++ b/backend/utils/other/storage.py @@ -4,13 +4,18 @@ import os import struct import threading + + import time + import wave from typing import List from concurrent.futures import as_completed, wait, FIRST_COMPLETED from utils.executors import postprocess_executor, storage_executor +_STORAGE_FANOUT_SEMAPHORE = threading.Semaphore(20) + import opuslib from google.cloud import storage from google.oauth2 import service_account @@ -770,6 +775,24 @@ def download_single_chunk(timestamp: float) -> tuple[float, bytes | None]: chunk_results = {} individual_timestamps = [ts for ts in timestamps if round(ts, 3) not in ts_to_batch_path] + + unique_batch_paths = set(ts_to_batch_path.values()) + + def _bounded_download_single_chunk(ts): + with _STORAGE_FANOUT_SEMAPHORE: + return download_single_chunk(ts) + + def _bounded_download_and_decode_blob(path): + with _STORAGE_FANOUT_SEMAPHORE: + return _download_and_decode_blob(path) + + individual_futures = { + storage_executor.submit(_bounded_download_single_chunk, ts): ts for ts in individual_timestamps + } + batch_futures = { + storage_executor.submit(_bounded_download_and_decode_blob, path): path for path in unique_batch_paths + } + unique_batch_paths = list(set(ts_to_batch_path.values())) # Build unified job list: ('individual', ts) or ('batch', path) @@ -1053,6 +1076,12 @@ def _cache_single(af): except Exception as e: logger.error(f"[PRECACHE] Error caching audio file {af.get('id')}: {e}") + def _bounded_cache_single(af): + with _STORAGE_FANOUT_SEMAPHORE: + return _cache_single(af) + + futures = [storage_executor.submit(_bounded_cache_single, af) for af in audio_files] + futures = [] for af in audio_files: _PRECACHE_FILE_SEM.acquire() @@ -1063,6 +1092,7 @@ def _cache_single(af): except Exception: _PRECACHE_FILE_SEM.release() raise + for future in as_completed(futures): try: future.result()