Skip to content
111 changes: 99 additions & 12 deletions embedding_cluster/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Awaitable, Callable

from chromadb.api import ClientAPI
from chromadb.api.models.Collection import Collection
Expand All @@ -29,12 +29,23 @@

logger = logging.getLogger(__name__)

PROGRESS_UPDATE_INTERVAL = 10


async def main_indexer(
settings: Settings,
on_progress: Callable[[dict[str, Any]], None] | None = None,
on_log: Callable[[str, str, str], Awaitable[None]] | None = None,
cancel_event: asyncio.Event | None = None,
) -> None:
async def _emit_log(
message: str,
level: str = "info",
verbosity: str = "low",
) -> None:
if on_log is not None:
await on_log(message, level, verbosity)

chromadb_client: ClientAPI = chromadb.PersistentClient(path="./chromadb")
chromadb_docs_collections: dict[str, ChromaDocsCollection] = (
init_chroma_docs_collection(settings)
Expand All @@ -53,27 +64,53 @@ async def main_indexer(
and len(settings.image_embedding_fields) > 0
):
logger.info("Loading image model: %s", settings.image_model_name)
image_model = CLIPModel.from_pretrained(settings.image_model_name).to(
settings.process_unit_device
)
image_model_processor = CLIPProcessor.from_pretrained(settings.image_model_name)
await _emit_log(f"Loading image model: {settings.image_model_name}...")
try:
image_model = await asyncio.to_thread(
lambda: CLIPModel.from_pretrained(settings.image_model_name).to(
settings.process_unit_device
)
)
image_model_processor = await asyncio.to_thread(
CLIPProcessor.from_pretrained, settings.image_model_name
)
await _emit_log("Image model loaded successfully")
except Exception as exc:
await _emit_log(
f"Failed to load image model: {exc}",
level="error",
)
raise

if (
settings.text_embedding_fields is not None
and len(settings.text_embedding_fields) > 0
):
logger.info("Loading text model: %s", settings.text_model_name)
text_model_transformer = SentenceTransformer(settings.text_model_name).to(
settings.process_unit_device
)
await _emit_log(f"Loading text model: {settings.text_model_name}...")
try:
text_model_transformer = await asyncio.to_thread(
lambda: SentenceTransformer(settings.text_model_name).to(
settings.process_unit_device
)
)
await _emit_log("Text model loaded successfully")
except Exception as exc:
await _emit_log(
f"Failed to load text model: {exc}",
level="error",
)
raise

start_time = time.perf_counter()

await _emit_log("Loading CSV file...")
with open(settings.local_csv_filename) as csv_file:
csv_iter = csv.DictReader(csv_file)
await _emit_log("CSV file opened, reading rows...")
rows_read = 0
curr_rows: list[dict[str, Any]] = []

batch_num = 0
skipped_rows = 0
if settings.index_start_line is not None:
skipped_rows = 1
Expand All @@ -84,16 +121,42 @@ async def main_indexer(

for row in csv_iter:
if cancel_event is not None and cancel_event.is_set():
logger.info("Indexing cancelled at row %d", rows_read + skipped_rows)
logger.info(
"Indexing cancelled at row %d",
rows_read + skipped_rows,
)
await _emit_log(
f"Indexing cancelled at row {rows_read + skipped_rows}",
level="warning",
)
break
rows_read += 1
curr_rows.append(row)
if on_progress is not None and rows_read % PROGRESS_UPDATE_INTERVAL == 0:
on_progress(
{
"rows_indexed": rows_read,
"total_rows": None,
"errors": 0,
"elapsed_seconds": (time.perf_counter() - start_time),
}
)
await _emit_log(
f"Reading row {rows_read}...",
verbosity="high",
)
if (
settings.index_end_line is not None
and settings.index_end_line == rows_read + skipped_rows
):
break
if len(curr_rows) == settings.index_bulk_size:
batch_num += 1
batch_start = rows_read - len(curr_rows) + 1
await _emit_log(
f"Processing batch {batch_num} ({batch_start}-{rows_read})...",
verbosity="medium",
)
await _handle_batch(
settings=settings,
rows=curr_rows,
Expand All @@ -104,6 +167,10 @@ async def main_indexer(
chromadb_docs_collections=chromadb_docs_collections,
chromadb_collections=chromadb_collections,
)
await _emit_log(
f"Batch {batch_num} complete",
verbosity="medium",
)
curr_rows = []
chromadb_docs_collections = init_chroma_docs_collection(settings)
if on_progress is not None:
Expand All @@ -112,15 +179,25 @@ async def main_indexer(
"rows_indexed": rows_read,
"total_rows": None,
"errors": 0,
"elapsed_seconds": time.perf_counter() - start_time,
"elapsed_seconds": (time.perf_counter() - start_time),
}
)
await _emit_log(
f"Indexed {rows_read} rows so far",
verbosity="medium",
)
logger.info(
"Indexed %d rows. [%d]",
rows_read,
skipped_rows + rows_read,
)
if len(curr_rows) > 0:
batch_num += 1
batch_start = rows_read - len(curr_rows) + 1
await _emit_log(
f"Processing batch {batch_num} ({batch_start}-{rows_read})...",
verbosity="medium",
)
await _handle_batch(
settings=settings,
rows=curr_rows,
Expand All @@ -131,16 +208,26 @@ async def main_indexer(
chromadb_docs_collections=chromadb_docs_collections,
chromadb_collections=chromadb_collections,
)
await _emit_log(
f"Batch {batch_num} complete",
verbosity="medium",
)
if on_progress is not None:
on_progress(
{
"rows_indexed": rows_read,
"total_rows": None,
"errors": 0,
"elapsed_seconds": time.perf_counter() - start_time,
"elapsed_seconds": (time.perf_counter() - start_time),
}
)

elapsed = time.perf_counter() - start_time
await _emit_log(
f"Indexing complete: {rows_read} rows in {elapsed:.1f}s",
level="success",
)


async def _handle_batch(
settings: Settings,
Expand Down
106 changes: 90 additions & 16 deletions embedding_cluster/server/routes/index.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import asyncio
import contextlib
import logging
import time
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -39,8 +41,22 @@ def resolve_csv_path(csv_filename: str) -> Path:
return Path("./uploads") / candidate


def _get_collection_names(settings: Settings) -> list[str]:
"""Build collection names from settings."""
names: list[str] = []
prefix = settings.chromadb_collection_prefix
if settings.image_embedding_fields:
for field in settings.image_embedding_fields:
names.append(f"{prefix}{field}")
if settings.text_embedding_fields:
for field in settings.text_embedding_fields:
names.append(f"{prefix}{field}")
return names


async def _run_indexing(task_state: TaskState, request: IndexRequest) -> None:
"""Run indexing in background, updating task state and broadcasting progress."""
heartbeat_task: asyncio.Task[None] | None = None
try:
# Construct Settings from IndexRequest
try:
Expand All @@ -67,6 +83,7 @@ async def _run_indexing(task_state: TaskState, request: IndexRequest) -> None:

# Update task status to RUNNING
task_state.status = TaskStatus.RUNNING
start_time = time.monotonic()

# Define progress callback
def on_progress(progress_data: dict[str, Any]) -> None:
Expand All @@ -79,20 +96,33 @@ def on_progress(progress_data: dict[str, Any]) -> None:
# ruff: noqa: RUF006
asyncio.create_task(ws_manager.broadcast(task_state.job_id, progress_data))

rows_indexed = progress_data.get("rows_indexed")
if isinstance(rows_indexed, int) and rows_indexed > 0:
# ruff: noqa: RUF006
asyncio.create_task(
ws_manager.broadcast(
task_state.job_id,
{
"type": "log",
"level": "info",
"message": f"Indexed {rows_indexed} rows",
},
)
# Define log callback
async def on_log(message: str, level: str, verbosity: str) -> None:
await ws_manager.broadcast(
task_state.job_id,
{
"type": "log",
"level": level,
"message": message,
"verbosity": verbosity,
},
)

# Heartbeat background task
async def _heartbeat() -> None:
while True:
await asyncio.sleep(3)
elapsed = time.monotonic() - start_time
await ws_manager.broadcast(
task_state.job_id,
{
"type": "heartbeat",
"elapsed_seconds": elapsed,
},
)

heartbeat_task = asyncio.create_task(_heartbeat())

total_rows = request.total_rows
on_progress(
{
Expand All @@ -103,13 +133,51 @@ def on_progress(progress_data: dict[str, Any]) -> None:
}
)

# Run indexer with callback and cancel event
# Run indexer with callbacks and cancel event
await main_indexer(
settings, on_progress=on_progress, cancel_event=task_state.cancel_event
settings,
on_progress=on_progress,
on_log=on_log,
cancel_event=task_state.cancel_event,
)

# Success
task_state.status = TaskStatus.COMPLETED
elapsed = time.monotonic() - start_time
rows_indexed = task_state.progress.get("rows_indexed", 0)

# Check if cancelled (cancel_event set by cancel endpoint)
if task_state.status == TaskStatus.CANCELLED:
logger.info("Indexing cancelled for job %s", task_state.job_id)
# ruff: noqa: RUF006
asyncio.create_task(
ws_manager.broadcast(
task_state.job_id,
{
"type": "cancelled",
"status": "cancelled",
"progress": task_state.progress,
"total_indexed": rows_indexed,
"elapsed_seconds": elapsed,
},
)
)
else:
# Success — send completion message
task_state.status = TaskStatus.COMPLETED
collection_names = _get_collection_names(settings)
# ruff: noqa: RUF006
asyncio.create_task(
ws_manager.broadcast(
task_state.job_id,
{
"type": "completed",
"status": "completed",
"progress": task_state.progress,
"total_indexed": rows_indexed,
"collection_names": collection_names,
"elapsed_seconds": elapsed,
},
)
)
except Exception as e:
logger.exception("Indexing failed for job %s", task_state.job_id)
task_state.status = TaskStatus.FAILED
Expand All @@ -119,12 +187,18 @@ def on_progress(progress_data: dict[str, Any]) -> None:
ws_manager.broadcast(
task_state.job_id,
{
"type": "error",
"status": task_state.status.value,
"error": task_state.error,
"message": str(e),
"progress": task_state.progress,
},
)
)
finally:
if heartbeat_task is not None:
with contextlib.suppress(RuntimeError):
heartbeat_task.cancel()


@router.post("/start", response_model=IndexStartResponse)
Expand Down
Loading