Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,6 @@ cython_debug/
#.idea/

chromadb
annotations/

.worktrees/
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,44 @@ visualization in a few steps:
focus on specific groups. Use the optional GPT-powered naming
to label each cluster automatically.

## Cluster Drill-Down and Annotation

After generating a plot you can inspect, subdivide, and annotate
individual clusters directly from the web UI.

### Cluster Detail Panel

Click a cluster name in the legend to open a side panel listing every
item in that cluster. Items are sorted by distance to the centroid so
the most representative points appear first. The panel supports
pagination, displays item metadata, and shows image thumbnails when
an image field is available.

### Sub-Clustering

Inside the detail panel, toggle **Sub-cluster** to re-run k-means
within a single cluster. The result is rendered as a mini 3D scatter
plot (PCA-reduced) so you can explore hierarchical structure without
leaving the page.

### Annotations

Each cluster can be renamed, tagged, and annotated with free-form
notes. Changes are saved automatically (debounced) and persisted as
JSON sidecar files in the `annotations/` directory. Annotations
survive page reloads and are scoped per plot job.

### API Endpoints

The feature exposes the following REST endpoints under `/api`:

- `GET /plot/{job_id}/cluster/{index}` -- paginated cluster detail
- `POST /plot/{job_id}/cluster/{index}/sub-cluster` -- sub-cluster
a single cluster with configurable k
- `GET /annotations/{job_id}` -- fetch all annotations for a job
- `PUT /annotations/{job_id}` -- update annotations
- `DELETE /annotations/{job_id}` -- delete annotations

```text
CSV --> Select Fields --> Download Model --> Embed & Store
--> Configure Plot --> 3D Visualization --> Search & Explore
Expand Down
69 changes: 69 additions & 0 deletions embedding_cluster/annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

logger = logging.getLogger(__name__)

_DEFAULT_BASE_DIR = Path("./annotations")


class AnnotationManager:
def __init__(self, base_dir: Path | None = None) -> None:
self._base_dir = base_dir or _DEFAULT_BASE_DIR
self._base_dir.mkdir(parents=True, exist_ok=True)

def _file_path(self, job_id: str) -> Path:
return self._base_dir / f"{job_id}.json"

def _read(self, job_id: str) -> dict[str, Any]:
path = self._file_path(job_id)
if not path.exists():
return {"job_id": job_id, "clusters": {}}
data: dict[str, Any] = json.loads(path.read_text())
return data

def _write(self, job_id: str, data: dict[str, Any]) -> None:
path = self._file_path(job_id)
path.write_text(json.dumps(data, indent=2))

def get_annotations(self, job_id: str) -> dict[str, Any]:
return self._read(job_id)

def update_annotation(
self,
job_id: str,
cluster_index: int,
name: str | None = None,
notes: str | None = None,
tags: list[str] | None = None,
) -> dict[str, Any]:
data = self._read(job_id)
key = str(cluster_index)
if key not in data["clusters"]:
data["clusters"][key] = {
"name": None,
"notes": None,
"tags": None,
"updated_at": None,
}
cluster = data["clusters"][key]
if name is not None:
cluster["name"] = name
if notes is not None:
cluster["notes"] = notes
if tags is not None:
cluster["tags"] = tags
cluster["updated_at"] = datetime.now(
tz=timezone.utc # noqa: UP017
).isoformat()
self._write(job_id, data)
return data

def delete_annotations(self, job_id: str) -> None:
path = self._file_path(job_id)
if path.exists():
path.unlink()
3 changes: 3 additions & 0 deletions embedding_cluster/scatter_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ def compute_plot_data(settings: Settings) -> dict[str, Any]:
"points": points,
"clusters": clusters,
"total_points": len(collection_content["ids"]),
"embeddings_standardized": embeddings_standardized.tolist(),
"cluster_labels": np.asarray(pred_arr).tolist(),
"point_ids": [p["id"] for p in points],
}


Expand Down
4 changes: 4 additions & 0 deletions embedding_cluster/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles

from embedding_cluster.server.routes.annotations import (
router as annotations_router,
)
from embedding_cluster.server.routes.collections import (
router as collections_router,
)
Expand Down Expand Up @@ -44,6 +47,7 @@ async def health_check() -> dict[str, str]:
app.include_router(csv_router)
app.include_router(index_router)
app.include_router(plot_router)
app.include_router(annotations_router)
app.include_router(search_router)

if FRONTEND_DIR.is_dir():
Expand Down
67 changes: 67 additions & 0 deletions embedding_cluster/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,70 @@ class SearchRequest(BaseModel):

class SearchResponse(BaseModel):
results: list[SearchResult]


class ClusterItemResponse(BaseModel):
id: str
metadata: dict[str, object]
distance_to_centroid: float


class ClusterDetailResponse(BaseModel):
cluster_index: int
cluster_name: str
total_items: int
page: int
page_size: int
items: list[ClusterItemResponse]


class SubClusterRequest(BaseModel):
num_sub_clusters: int = 3

@field_validator("num_sub_clusters")
@classmethod
def validate_num_sub_clusters(cls, v: int) -> int:
if v < 2:
msg = "num_sub_clusters must be at least 2"
raise ValueError(msg)
return v


class SubClusterPoint(BaseModel):
id: str
x: float
y: float
z: float
sub_cluster: int
metadata: dict[str, object]


class SubClusterInfo(BaseModel):
index: int
count: int
color: str


class SubClusterResponse(BaseModel):
parent_cluster_index: int
points: list[SubClusterPoint]
sub_clusters: list[SubClusterInfo]
total_points: int


class AnnotationUpdate(BaseModel):
name: str | None = None
notes: str | None = None
tags: list[str] | None = None


class ClusterAnnotation(BaseModel):
name: str | None = None
notes: str | None = None
tags: list[str] | None = None
updated_at: str | None = None


class AnnotationsResponse(BaseModel):
job_id: str
clusters: dict[str, ClusterAnnotation]
57 changes: 57 additions & 0 deletions embedding_cluster/server/routes/annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import logging
from pathlib import Path

from fastapi import APIRouter

from embedding_cluster.annotations import AnnotationManager
from embedding_cluster.server.models import (
AnnotationsResponse,
AnnotationUpdate,
MessageResponse,
)

logger = logging.getLogger(__name__)

_DEFAULT_ANNOTATIONS_DIR = Path("./annotations")

router = APIRouter(prefix="/api/annotations", tags=["annotations"])


def _get_manager() -> AnnotationManager:
return AnnotationManager(base_dir=_DEFAULT_ANNOTATIONS_DIR)


@router.get("/{job_id}", response_model=AnnotationsResponse)
async def get_annotations(job_id: str) -> AnnotationsResponse:
manager = _get_manager()
data = manager.get_annotations(job_id)
return AnnotationsResponse(**data)


@router.put(
"/{job_id}/cluster/{cluster_index}",
response_model=AnnotationsResponse,
)
async def update_annotation(
job_id: str,
cluster_index: int,
body: AnnotationUpdate,
) -> AnnotationsResponse:
manager = _get_manager()
data = manager.update_annotation(
job_id,
cluster_index,
name=body.name,
notes=body.notes,
tags=body.tags,
)
return AnnotationsResponse(**data)


@router.delete("/{job_id}", response_model=MessageResponse)
async def delete_annotations(job_id: str) -> MessageResponse:
manager = _get_manager()
manager.delete_annotations(job_id)
return MessageResponse(message="Annotations deleted")
Loading
Loading