Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion backend/routers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,17 @@ def precache_conversation_audio_endpoint(
if not audio_files:
return {"status": "no_audio", "message": "No audio files in conversation"}

_precache_sem = threading.Semaphore(10)

def _precache_all_parallel():
logger.info(f"Pre-caching all {len(audio_files)} audio files for conversation {conversation_id} (parallel)")

def _bounded_precache(af):
with _precache_sem:
return _precache_audio_file(uid, conversation_id, af)

futures = [submit_with_context(storage_executor, _bounded_precache, af) for af in audio_files]
Comment on lines +208 to +217
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Per-request semaphore provides no cross-request fan-out protection

_precache_sem is allocated fresh inside every HTTP handler invocation. Each concurrent request gets its own independent Semaphore(10), so N simultaneous requests each submitting M audio files contribute N×M tasks to storage_executor with no global ceiling. The same pattern applies to _uncached_sem at line 308. The _STORAGE_FANOUT_SEMAPHORE already defined at module level in storage.py is the correct shared throttle — it should be imported here (or mirrored at module level in sync.py) so that concurrency is bounded across all requests, not just within a single one.


# Start background parallel pre-caching with bounded concurrency (#7387)
def _precache_all_parallel():
logger.info(f"Pre-caching all {len(audio_files)} audio files for conversation {conversation_id} (parallel)")
Expand All @@ -220,6 +231,7 @@ def _precache_all_parallel():
except Exception:
_PRECACHE_FILE_SEM.release()
raise

for future in futures:
try:
future.result()
Expand Down Expand Up @@ -312,10 +324,17 @@ def get_audio_signed_urls_endpoint(
)
uncached_files.append(af)

# Cache remaining files in background
if uncached_files:
_uncached_sem = threading.Semaphore(10)

def _cache_uncached_parallel():

def _bounded_precache(af):
with _uncached_sem:
return _precache_audio_file(uid, conversation_id, af)

futures = [submit_with_context(storage_executor, _bounded_precache, af) for af in uncached_files]

futures = []
for af in uncached_files:
_PRECACHE_FILE_SEM.acquire()
Expand All @@ -328,6 +347,7 @@ def _cache_uncached_parallel():
except Exception:
_PRECACHE_FILE_SEM.release()
raise

for future in futures:
try:
future.result()
Expand Down
4 changes: 4 additions & 0 deletions backend/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,12 @@ pytest tests/unit/test_trial_metadata.py -v
pytest tests/unit/test_neo_desktop_grandfather.py -v
pytest tests/unit/test_vertex_ai_system_role.py -v
pytest tests/unit/test_tts.py -v

pytest tests/unit/test_storage_fanout_limit.py -v

pytest tests/unit/test_webhook_auto_disable.py -v


# Fair-use integration tests (require Redis; skip gracefully if unavailable)
if redis-cli ping >/dev/null 2>&1; then
pytest tests/integration/test_fair_use_live.py -v
Expand Down
130 changes: 130 additions & 0 deletions backend/tests/unit/test_storage_fanout_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Tests for bounded fan-out in storage executor submissions (issue #7387).

Verifies that concurrent storage_executor submissions are capped by
threading.Semaphore to prevent queue spikes from unbounded fan-out.
"""

import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from unittest.mock import patch, MagicMock

import pytest


class TestStorageFanoutSemaphore:
"""Verify that _STORAGE_FANOUT_SEMAPHORE limits concurrent GCS operations."""

def test_semaphore_limits_concurrency(self):
"""Fan-out of 50 tasks with semaphore(5) should never exceed 5 concurrent."""
sem = threading.Semaphore(5)
max_concurrent = 0
current = 0
lock = threading.Lock()

def work(i):
nonlocal max_concurrent, current
with sem:
with lock:
current += 1
if current > max_concurrent:
max_concurrent = current
time.sleep(0.01)
with lock:
current -= 1
return i

executor = ThreadPoolExecutor(max_workers=20)
futures = [executor.submit(work, i) for i in range(50)]
for f in as_completed(futures):
f.result()
executor.shutdown(wait=True)

assert max_concurrent <= 5, f"Max concurrent was {max_concurrent}, expected <= 5"
assert max_concurrent >= 2, "Semaphore should allow some parallelism"

def test_semaphore_does_not_deadlock(self):
"""Nested semaphore acquisition (precache -> chunk download) must not deadlock."""
outer_sem = threading.Semaphore(3)
inner_sem = threading.Semaphore(3)
results = []
lock = threading.Lock()

def inner_work(val):
with inner_sem:
time.sleep(0.005)
return val * 2

def outer_work(i):
with outer_sem:
inner_executor = ThreadPoolExecutor(max_workers=2)
futs = [inner_executor.submit(inner_work, j) for j in range(3)]
vals = [f.result(timeout=5) for f in futs]
inner_executor.shutdown(wait=True)
with lock:
results.append(sum(vals))
return i

executor = ThreadPoolExecutor(max_workers=10)
futures = [executor.submit(outer_work, i) for i in range(6)]
for f in as_completed(futures):
f.result(timeout=10)
executor.shutdown(wait=True)

assert len(results) == 6
assert all(r == 6 for r in results) # 0*2 + 1*2 + 2*2 = 6

def test_storage_module_has_semaphore(self):
"""The storage module must define a _STORAGE_FANOUT_SEMAPHORE.

Uses ast parsing to avoid importing heavy native deps (opuslib, GCS).
"""
import ast
import pathlib

storage_path = pathlib.Path(__file__).resolve().parents[2] / 'utils' / 'other' / 'storage.py'
tree = ast.parse(storage_path.read_text())

semaphore_found = False
semaphore_value = None
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == '_STORAGE_FANOUT_SEMAPHORE':
semaphore_found = True
if isinstance(node.value, ast.Call) and node.value.args:
Comment on lines +80 to +95
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Test validates a symbol that doesn't guard the main call sites

test_storage_module_has_semaphore confirms that _STORAGE_FANOUT_SEMAPHORE exists in storage.py and its value is in [5, 30]. However, the per-request semaphores introduced in sync.py (_precache_sem, _uncached_sem) and knowledge_graph.py (_kg_sem) are different objects created at call time, and no test verifies that those sites respect any global limit or share a semaphore. The suite therefore passes while the actual cross-request unbounded fan-out problem remains untested.

arg = node.value.args[0]
if isinstance(arg, ast.Constant):
semaphore_value = arg.value

assert semaphore_found, "_STORAGE_FANOUT_SEMAPHORE not found in storage.py"
assert semaphore_value is not None, "Could not determine semaphore value"
assert 5 <= semaphore_value <= 30, f"Semaphore value {semaphore_value} outside expected range [5, 30]"

def test_bounded_fan_out_caps_queue_depth(self):
"""Simulate the storage pattern: N tasks submitted, only K run concurrently."""
sem = threading.Semaphore(10)
task_count = 100
peak_queue = 0
running = 0
lock = threading.Lock()

def bounded_work(i):
nonlocal peak_queue, running
with sem:
with lock:
running += 1
if running > peak_queue:
peak_queue = running
time.sleep(0.005)
with lock:
running -= 1
return i

executor = ThreadPoolExecutor(max_workers=50)
futures = [executor.submit(bounded_work, i) for i in range(task_count)]
results = [f.result() for f in as_completed(futures)]
executor.shutdown(wait=True)

assert len(results) == task_count
assert peak_queue <= 10, f"Peak concurrent {peak_queue} exceeded semaphore limit 10"
9 changes: 9 additions & 0 deletions backend/utils/llm/knowledge_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ def process_memory(memory):
all_nodes = []
all_edges = []

_kg_sem = threading.Semaphore(10)

def _bounded_process_memory(m):
with _kg_sem:
return process_memory(m)

futures = [storage_executor.submit(_bounded_process_memory, m) for m in memories]
Comment on lines +265 to +271
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Per-call semaphore and thread-parking anti-pattern

_kg_sem is created fresh on every call to rebuild_knowledge_graph, so concurrent callers each get their own Semaphore(10) with no shared global limit. Additionally, all futures are submitted to storage_executor (96 threads) before any semaphore is acquired. When len(memories) > 10, up to 96 threads start executing _bounded_process_memory, and 86+ of them immediately block on _kg_sem, occupying worker slots while doing no useful work. This can starve unrelated tasks queued on the same executor.


futures = []
for m in memories:
_KG_REBUILD_SEM.acquire()
Expand All @@ -272,6 +280,7 @@ def process_memory(memory):
except Exception:
_KG_REBUILD_SEM.release()
raise

for future in as_completed(futures):
try:
result = future.result()
Expand Down
30 changes: 30 additions & 0 deletions backend/utils/other/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
import os
import struct
import threading


import time

import wave
from typing import List
from concurrent.futures import as_completed, wait, FIRST_COMPLETED

from utils.executors import postprocess_executor, storage_executor

_STORAGE_FANOUT_SEMAPHORE = threading.Semaphore(20)

import opuslib
Comment on lines +17 to 19
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Module-level object initialized between import blocks

_STORAGE_FANOUT_SEMAPHORE is instantiated between the internal-package imports (from utils.executors import ...) and the third-party library imports (import opuslib). Moving it after all imports are complete keeps the standard stdlib → third-party → local import ordering intact.

from google.cloud import storage
from google.oauth2 import service_account
Expand Down Expand Up @@ -770,6 +775,24 @@ def download_single_chunk(timestamp: float) -> tuple[float, bytes | None]:
chunk_results = {}

individual_timestamps = [ts for ts in timestamps if round(ts, 3) not in ts_to_batch_path]

unique_batch_paths = set(ts_to_batch_path.values())

def _bounded_download_single_chunk(ts):
with _STORAGE_FANOUT_SEMAPHORE:
return download_single_chunk(ts)

def _bounded_download_and_decode_blob(path):
with _STORAGE_FANOUT_SEMAPHORE:
return _download_and_decode_blob(path)

individual_futures = {
storage_executor.submit(_bounded_download_single_chunk, ts): ts for ts in individual_timestamps
}
batch_futures = {
storage_executor.submit(_bounded_download_and_decode_blob, path): path for path in unique_batch_paths
}

unique_batch_paths = list(set(ts_to_batch_path.values()))

# Build unified job list: ('individual', ts) or ('batch', path)
Expand Down Expand Up @@ -1053,6 +1076,12 @@ def _cache_single(af):
except Exception as e:
logger.error(f"[PRECACHE] Error caching audio file {af.get('id')}: {e}")

def _bounded_cache_single(af):
with _STORAGE_FANOUT_SEMAPHORE:
return _cache_single(af)

futures = [storage_executor.submit(_bounded_cache_single, af) for af in audio_files]

futures = []
for af in audio_files:
_PRECACHE_FILE_SEM.acquire()
Expand All @@ -1063,6 +1092,7 @@ def _cache_single(af):
except Exception:
_PRECACHE_FILE_SEM.release()
raise

for future in as_completed(futures):
try:
future.result()
Expand Down
Loading