diff --git a/.gitignore b/.gitignore index 86be558..96b7c89 100644 --- a/.gitignore +++ b/.gitignore @@ -161,5 +161,6 @@ cython_debug/ #.idea/ chromadb +annotations/ .worktrees/ diff --git a/README.md b/README.md index 10e3656..23661f9 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,44 @@ visualization in a few steps: focus on specific groups. Use the optional GPT-powered naming to label each cluster automatically. +## Cluster Drill-Down and Annotation + +After generating a plot you can inspect, subdivide, and annotate +individual clusters directly from the web UI. + +### Cluster Detail Panel + +Click a cluster name in the legend to open a side panel listing every +item in that cluster. Items are sorted by distance to the centroid so +the most representative points appear first. The panel supports +pagination, displays item metadata, and shows image thumbnails when +an image field is available. + +### Sub-Clustering + +Inside the detail panel, toggle **Sub-cluster** to re-run k-means +within a single cluster. The result is rendered as a mini 3D scatter +plot (PCA-reduced) so you can explore hierarchical structure without +leaving the page. + +### Annotations + +Each cluster can be renamed, tagged, and annotated with free-form +notes. Changes are saved automatically (debounced) and persisted as +JSON sidecar files in the `annotations/` directory. Annotations +survive page reloads and are scoped per plot job. + +### API Endpoints + +The feature exposes the following REST endpoints under `/api`: + +- `GET /plot/{job_id}/cluster/{index}` -- paginated cluster detail +- `POST /plot/{job_id}/cluster/{index}/sub-cluster` -- sub-cluster + a single cluster with configurable k +- `GET /annotations/{job_id}` -- fetch all annotations for a job +- `PUT /annotations/{job_id}` -- update annotations +- `DELETE /annotations/{job_id}` -- delete annotations + ```text CSV --> Select Fields --> Download Model --> Embed & Store --> Configure Plot --> 3D Visualization --> Search & Explore diff --git a/embedding_cluster/annotations.py b/embedding_cluster/annotations.py new file mode 100644 index 0000000..5e4aabc --- /dev/null +++ b/embedding_cluster/annotations.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import json +import logging +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + +_DEFAULT_BASE_DIR = Path("./annotations") + + +class AnnotationManager: + def __init__(self, base_dir: Path | None = None) -> None: + self._base_dir = base_dir or _DEFAULT_BASE_DIR + self._base_dir.mkdir(parents=True, exist_ok=True) + + def _file_path(self, job_id: str) -> Path: + return self._base_dir / f"{job_id}.json" + + def _read(self, job_id: str) -> dict[str, Any]: + path = self._file_path(job_id) + if not path.exists(): + return {"job_id": job_id, "clusters": {}} + data: dict[str, Any] = json.loads(path.read_text()) + return data + + def _write(self, job_id: str, data: dict[str, Any]) -> None: + path = self._file_path(job_id) + path.write_text(json.dumps(data, indent=2)) + + def get_annotations(self, job_id: str) -> dict[str, Any]: + return self._read(job_id) + + def update_annotation( + self, + job_id: str, + cluster_index: int, + name: str | None = None, + notes: str | None = None, + tags: list[str] | None = None, + ) -> dict[str, Any]: + data = self._read(job_id) + key = str(cluster_index) + if key not in data["clusters"]: + data["clusters"][key] = { + "name": None, + "notes": None, + "tags": None, + "updated_at": None, + } + cluster = data["clusters"][key] + if name is not None: + cluster["name"] = name + if notes is not None: + cluster["notes"] = notes + if tags is not None: + cluster["tags"] = tags + cluster["updated_at"] = datetime.now( + tz=timezone.utc # noqa: UP017 + ).isoformat() + self._write(job_id, data) + return data + + def delete_annotations(self, job_id: str) -> None: + path = self._file_path(job_id) + if path.exists(): + path.unlink() diff --git a/embedding_cluster/scatter_plot.py b/embedding_cluster/scatter_plot.py index 6400305..769422a 100644 --- a/embedding_cluster/scatter_plot.py +++ b/embedding_cluster/scatter_plot.py @@ -340,6 +340,9 @@ def compute_plot_data(settings: Settings) -> dict[str, Any]: "points": points, "clusters": clusters, "total_points": len(collection_content["ids"]), + "embeddings_standardized": embeddings_standardized.tolist(), + "cluster_labels": np.asarray(pred_arr).tolist(), + "point_ids": [p["id"] for p in points], } diff --git a/embedding_cluster/server/app.py b/embedding_cluster/server/app.py index 937c0c4..7a352da 100644 --- a/embedding_cluster/server/app.py +++ b/embedding_cluster/server/app.py @@ -8,6 +8,9 @@ from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles +from embedding_cluster.server.routes.annotations import ( + router as annotations_router, +) from embedding_cluster.server.routes.collections import ( router as collections_router, ) @@ -44,6 +47,7 @@ async def health_check() -> dict[str, str]: app.include_router(csv_router) app.include_router(index_router) app.include_router(plot_router) + app.include_router(annotations_router) app.include_router(search_router) if FRONTEND_DIR.is_dir(): diff --git a/embedding_cluster/server/models.py b/embedding_cluster/server/models.py index 8de9ca2..b53843e 100644 --- a/embedding_cluster/server/models.py +++ b/embedding_cluster/server/models.py @@ -156,3 +156,70 @@ class SearchRequest(BaseModel): class SearchResponse(BaseModel): results: list[SearchResult] + + +class ClusterItemResponse(BaseModel): + id: str + metadata: dict[str, object] + distance_to_centroid: float + + +class ClusterDetailResponse(BaseModel): + cluster_index: int + cluster_name: str + total_items: int + page: int + page_size: int + items: list[ClusterItemResponse] + + +class SubClusterRequest(BaseModel): + num_sub_clusters: int = 3 + + @field_validator("num_sub_clusters") + @classmethod + def validate_num_sub_clusters(cls, v: int) -> int: + if v < 2: + msg = "num_sub_clusters must be at least 2" + raise ValueError(msg) + return v + + +class SubClusterPoint(BaseModel): + id: str + x: float + y: float + z: float + sub_cluster: int + metadata: dict[str, object] + + +class SubClusterInfo(BaseModel): + index: int + count: int + color: str + + +class SubClusterResponse(BaseModel): + parent_cluster_index: int + points: list[SubClusterPoint] + sub_clusters: list[SubClusterInfo] + total_points: int + + +class AnnotationUpdate(BaseModel): + name: str | None = None + notes: str | None = None + tags: list[str] | None = None + + +class ClusterAnnotation(BaseModel): + name: str | None = None + notes: str | None = None + tags: list[str] | None = None + updated_at: str | None = None + + +class AnnotationsResponse(BaseModel): + job_id: str + clusters: dict[str, ClusterAnnotation] diff --git a/embedding_cluster/server/routes/annotations.py b/embedding_cluster/server/routes/annotations.py new file mode 100644 index 0000000..c6e17ad --- /dev/null +++ b/embedding_cluster/server/routes/annotations.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import logging +from pathlib import Path + +from fastapi import APIRouter + +from embedding_cluster.annotations import AnnotationManager +from embedding_cluster.server.models import ( + AnnotationsResponse, + AnnotationUpdate, + MessageResponse, +) + +logger = logging.getLogger(__name__) + +_DEFAULT_ANNOTATIONS_DIR = Path("./annotations") + +router = APIRouter(prefix="/api/annotations", tags=["annotations"]) + + +def _get_manager() -> AnnotationManager: + return AnnotationManager(base_dir=_DEFAULT_ANNOTATIONS_DIR) + + +@router.get("/{job_id}", response_model=AnnotationsResponse) +async def get_annotations(job_id: str) -> AnnotationsResponse: + manager = _get_manager() + data = manager.get_annotations(job_id) + return AnnotationsResponse(**data) + + +@router.put( + "/{job_id}/cluster/{cluster_index}", + response_model=AnnotationsResponse, +) +async def update_annotation( + job_id: str, + cluster_index: int, + body: AnnotationUpdate, +) -> AnnotationsResponse: + manager = _get_manager() + data = manager.update_annotation( + job_id, + cluster_index, + name=body.name, + notes=body.notes, + tags=body.tags, + ) + return AnnotationsResponse(**data) + + +@router.delete("/{job_id}", response_model=MessageResponse) +async def delete_annotations(job_id: str) -> MessageResponse: + manager = _get_manager() + manager.delete_annotations(job_id) + return MessageResponse(message="Annotations deleted") diff --git a/embedding_cluster/server/routes/plot.py b/embedding_cluster/server/routes/plot.py index cba9303..ba2c678 100644 --- a/embedding_cluster/server/routes/plot.py +++ b/embedding_cluster/server/routes/plot.py @@ -2,18 +2,25 @@ import asyncio import logging -from typing import cast +from typing import Any, cast from fastapi import APIRouter, HTTPException from embedding_cluster.scatter_plot import ( compute_plot_data, load_chromadb_embeddings, + reduce_dimensions, suggest_optimal_clusters, ) from embedding_cluster.server.models import ( + ClusterDetailResponse, + ClusterItemResponse, IndexStartResponse, PlotRequest, + SubClusterInfo, + SubClusterPoint, + SubClusterRequest, + SubClusterResponse, SuggestClustersRequest, ) from embedding_cluster.server.tasks import TaskState, TaskStatus, task_registry @@ -71,10 +78,14 @@ async def get_plot_data(job_id: str) -> dict[str, object]: return {"status": "failed", "error": task.error, "ready": False} # COMPLETED result = cast("dict[str, object]", task.result) + # Strip internal fields not meant for the frontend + internal_keys = ("embeddings_standardized", "cluster_labels", "point_ids") + frontend_result = {k: v for k, v in result.items() if k not in internal_keys} return { "status": "completed", "ready": True, - **result, + "job_id": job_id, + **frontend_result, } @@ -145,3 +156,176 @@ async def get_suggest_clusters_status(job_id: str) -> dict[str, object]: "ready": True, "result": result, } + + +@router.get( + "/{job_id}/cluster/{cluster_index}", + response_model=ClusterDetailResponse, +) +async def get_cluster_detail( + job_id: str, + cluster_index: int, + page: int = 1, + page_size: int = 50, +) -> ClusterDetailResponse: + task = task_registry.get(job_id) + if task is None: + raise HTTPException(status_code=404, detail="Job not found") + if task.status != TaskStatus.COMPLETED: + raise HTTPException(status_code=409, detail="Job not completed") + + result = cast("dict[str, object]", task.result) + clusters = cast("list[dict[str, object]]", result["clusters"]) + cluster_labels = cast("list[int]", result["cluster_labels"]) + embeddings = cast("list[list[float]]", result["embeddings_standardized"]) + points = cast("list[dict[str, object]]", result["points"]) + + # Validate cluster index + cluster_info: dict[str, Any] | None = None + for c in clusters: + if cast("int", c["index"]) == cluster_index: + cluster_info = cast("dict[str, Any]", c) + break + if cluster_info is None: + raise HTTPException(status_code=404, detail="Cluster not found") + + # Get indices belonging to this cluster + cluster_point_indices = [ + i for i, label in enumerate(cluster_labels) if label == cluster_index + ] + + # Compute centroid + import numpy as np + + cluster_embeddings = np.array([embeddings[i] for i in cluster_point_indices]) + centroid = cluster_embeddings.mean(axis=0) + + # Compute distances and build items + items_with_distance: list[tuple[float, dict[str, object]]] = [] + for idx in cluster_point_indices: + point_embedding = np.array(embeddings[idx]) + distance = float(np.linalg.norm(point_embedding - centroid)) + point = points[idx] + items_with_distance.append((distance, point)) + + # Sort by distance + items_with_distance.sort(key=lambda x: x[0]) + + # Paginate + total_items = len(items_with_distance) + start = (page - 1) * page_size + end = start + page_size + page_items = items_with_distance[start:end] + + return ClusterDetailResponse( + cluster_index=cluster_index, + cluster_name=cast("str", cluster_info["name"]), + total_items=total_items, + page=page, + page_size=page_size, + items=[ + ClusterItemResponse( + id=cast("str", point["id"]), + metadata=cast("dict[str, object]", point["metadata"]), + distance_to_centroid=dist, + ) + for dist, point in page_items + ], + ) + + +@router.post( + "/{job_id}/cluster/{cluster_index}/sub-cluster", + response_model=SubClusterResponse, +) +async def sub_cluster( + job_id: str, + cluster_index: int, + request: SubClusterRequest, +) -> SubClusterResponse: + task = task_registry.get(job_id) + if task is None: + raise HTTPException(status_code=404, detail="Job not found") + if task.status != TaskStatus.COMPLETED: + raise HTTPException(status_code=409, detail="Job not completed") + + result = cast("dict[str, object]", task.result) + clusters = cast("list[dict[str, object]]", result["clusters"]) + cluster_labels = cast("list[int]", result["cluster_labels"]) + embeddings = cast("list[list[float]]", result["embeddings_standardized"]) + points = cast("list[dict[str, object]]", result["points"]) + + # Validate cluster exists + cluster_exists = any(cast("int", c["index"]) == cluster_index for c in clusters) + if not cluster_exists: + raise HTTPException(status_code=404, detail="Cluster not found") + + # Get indices for this cluster + cluster_point_indices = [ + i for i, label in enumerate(cluster_labels) if label == cluster_index + ] + + num_sub = request.num_sub_clusters + if num_sub > len(cluster_point_indices): + raise HTTPException( + status_code=400, + detail=( + f"num_sub_clusters ({num_sub}) exceeds " + f"items in cluster ({len(cluster_point_indices)})" + ), + ) + + # Run k-means on cluster subset + import numpy as np + + cluster_embeddings = np.array([embeddings[i] for i in cluster_point_indices]) + + def _compute() -> SubClusterResponse: + from sklearn.cluster import KMeans + + kmeans = KMeans( + n_clusters=num_sub, + n_init="auto", + random_state=171, + max_iter=1000, + ) + sub_labels = kmeans.fit_predict(cluster_embeddings) + + # Reduce dimensions for visualization + reduced = reduce_dimensions( + cluster_embeddings, + algorithm="pca", + n_components=3, + ) + + sub_points: list[SubClusterPoint] = [] + for j, idx in enumerate(cluster_point_indices): + point = points[idx] + sub_points.append( + SubClusterPoint( + id=cast("str", point["id"]), + x=float(reduced[j, 0]), + y=float(reduced[j, 1]), + z=float(reduced[j, 2]), + sub_cluster=int(sub_labels[j]), + metadata=cast( + "dict[str, object]", + point["metadata"], + ), + ) + ) + + sub_cluster_infos: list[SubClusterInfo] = [] + for si in range(num_sub): + count = int(np.sum(sub_labels == si)) + color = f"hsl({si * 360 // num_sub}, 70%, 50%)" + sub_cluster_infos.append(SubClusterInfo(index=si, count=count, color=color)) + + return SubClusterResponse( + parent_cluster_index=cluster_index, + points=sub_points, + sub_clusters=sub_cluster_infos, + total_points=len(cluster_point_indices), + ) + + return await asyncio.to_thread(_compute) diff --git a/frontend/e2e/cluster-drawer.spec.ts b/frontend/e2e/cluster-drawer.spec.ts new file mode 100644 index 0000000..237b53f --- /dev/null +++ b/frontend/e2e/cluster-drawer.spec.ts @@ -0,0 +1,389 @@ +import { test, expect } from './fixtures' + +type PlotStoreSnapshot = { + selectedPointIds?: Set +} + +type PlotStoreWindow = Window & { + __plotStore?: { + getState: () => PlotStoreSnapshot + } +} + +test.describe('Cluster detail drawer', () => { + test('opens drawer when clicking cluster legend item', async ({ page, plotPage: _ }) => { + await expect( + page.getByRole('button', { name: 'Compute Plot' }) + ).toBeVisible({ timeout: 10_000 }) + + await page.getByRole('button', { name: 'Compute Plot' }).click() + + await expect(page.locator('canvas')).toBeVisible({ timeout: 120_000 }) + + const firstCluster = page.getByRole('button', { name: /Group 1\s+\d+ points/i }) + await expect(firstCluster).toBeVisible({ timeout: 10_000 }) + await firstCluster.click() + + await expect( + page.getByRole('button', { name: 'Close drawer' }) + ).toBeVisible({ timeout: 10_000 }) + }) + + test('opens even when plot data lacks job id', async ({ page, plotPage: _ }) => { + await page.route(/.*\/api\/plot\/data\/.*/, async (route) => { + const response = await route.fetch() + const body = await response.json() + if (typeof body === 'object' && body !== null && 'job_id' in body) { + delete body.job_id + } + + await route.fulfill({ + status: response.status(), + headers: { + ...response.headers(), + 'content-type': 'application/json', + }, + body: JSON.stringify(body), + }) + }) + + await expect( + page.getByRole('button', { name: 'Compute Plot' }) + ).toBeVisible({ timeout: 10_000 }) + + await page.getByRole('button', { name: 'Compute Plot' }).click() + + await expect(page.locator('canvas')).toBeVisible({ timeout: 120_000 }) + + const firstCluster = page.getByRole('button', { name: /Group 1\s+\d+ points/i }) + await expect(firstCluster).toBeVisible({ timeout: 10_000 }) + await firstCluster.click() + + await expect( + page.getByRole('button', { name: 'Close drawer' }) + ).toBeVisible({ timeout: 10_000 }) + }) + + test('reopens after closing without recompute', async ({ page, plotPage: _ }) => { + await expect( + page.getByRole('button', { name: 'Compute Plot' }) + ).toBeVisible({ timeout: 10_000 }) + + await page.getByRole('button', { name: 'Compute Plot' }).click() + + await expect(page.locator('canvas')).toBeVisible({ timeout: 120_000 }) + + const firstCluster = page.getByRole('button', { name: /Group 1\s+\d+ points/i }) + await expect(firstCluster).toBeVisible({ timeout: 10_000 }) + await firstCluster.click() + + const closeButton = page.getByRole('button', { name: 'Close drawer' }) + await expect(closeButton).toBeVisible({ timeout: 10_000 }) + await closeButton.click() + + await expect(closeButton).toBeHidden({ timeout: 10_000 }) + + await firstCluster.click() + await expect(closeButton).toBeVisible({ timeout: 10_000 }) + }) + + test('does not require multiple computes to open drawer', async ({ page, plotPage: _ }) => { + await expect( + page.getByRole('button', { name: 'Compute Plot' }) + ).toBeVisible({ timeout: 10_000 }) + + await page.getByRole('button', { name: 'Compute Plot' }).click() + await expect(page.locator('canvas')).toBeVisible({ timeout: 120_000 }) + + const firstCluster = page.getByRole('button', { name: /Group 1\s+\d+ points/i }) + await expect(firstCluster).toBeVisible({ timeout: 10_000 }) + await firstCluster.click() + + await expect( + page.getByRole('button', { name: 'Close drawer' }) + ).toBeVisible({ timeout: 10_000 }) + }) + + test('stale plot does not auto-render after visualize', async ({ page, plotPage: _ }) => { + await expect( + page.getByRole('button', { name: 'Compute Plot' }) + ).toBeVisible({ timeout: 10_000 }) + + await page.getByRole('button', { name: 'Compute Plot' }).click() + await expect(page.locator('canvas')).toBeVisible({ timeout: 120_000 }) + + await page.getByRole('link', { name: 'Home' }).click() + await expect(page.getByRole('link', { name: 'Home' })).toHaveAttribute('aria-current', 'page') + + await page.getByRole('button', { name: 'Visualize' }).first().click() + await page.waitForURL('**/plot?collection=*') + + await expect(page.getByRole('heading', { name: 'Clusters' })).toBeHidden({ timeout: 10_000 }) + }) + + test('opens and reopens drawer at 1472x838 viewport', async ({ page, plotPage: _ }) => { + await page.setViewportSize({ width: 1472, height: 838 }) + + await expect( + page.getByRole('button', { name: 'Compute Plot' }) + ).toBeVisible({ timeout: 10_000 }) + + await page.getByRole('button', { name: 'Compute Plot' }).click() + await expect(page.locator('canvas')).toBeVisible({ timeout: 120_000 }) + + const firstCluster = page.getByTestId('cluster-legend-name-0') + await expect(firstCluster).toBeVisible({ timeout: 10_000 }) + + await firstCluster.click() + const closeButton = page.getByRole('button', { name: 'Close drawer' }) + await expect(closeButton).toBeVisible({ timeout: 10_000 }) + + await closeButton.click() + await expect(closeButton).toBeHidden({ timeout: 10_000 }) + + await firstCluster.click() + await expect(closeButton).toBeVisible({ timeout: 10_000 }) + await expect(page.getByTestId('cluster-detail-drawer')).toBeVisible({ timeout: 10_000 }) + }) + + test('group selection shows visible fixed-width drawer at large viewport', async ({ page, plotPage: _ }) => { + await page.setViewportSize({ width: 1596, height: 958 }) + + await expect( + page.getByRole('button', { name: 'Compute Plot' }) + ).toBeVisible({ timeout: 10_000 }) + + await page.getByRole('button', { name: 'Compute Plot' }).click() + await expect(page.locator('canvas')).toBeVisible({ timeout: 120_000 }) + + await page.getByTestId('cluster-legend-name-3').click() + + const drawer = page.getByTestId('cluster-detail-drawer') + await expect(drawer).toBeVisible({ timeout: 10_000 }) + + const drawerWidth = await drawer.evaluate((el) => el.getBoundingClientRect().width) + expect(drawerWidth).toBeGreaterThanOrEqual(380) + + const drawerRect = await drawer.evaluate((el) => { + const rect = el.getBoundingClientRect() + return { + x: rect.x, + right: rect.right, + y: rect.y, + bottom: rect.bottom, + viewportW: window.innerWidth, + viewportH: window.innerHeight, + } + }) + expect(drawerRect.x).toBeGreaterThanOrEqual(0) + expect(drawerRect.right).toBeLessThanOrEqual(drawerRect.viewportW) + expect(drawerRect.y).toBeGreaterThanOrEqual(0) + expect(drawerRect.bottom).toBeLessThanOrEqual(drawerRect.viewportH) + + await expect(page.getByRole('button', { name: 'Close drawer' })).toBeVisible({ timeout: 10_000 }) + }) + + test('drawer stays in viewport on narrow layouts', async ({ page, plotPage: _ }) => { + await page.setViewportSize({ width: 980, height: 900 }) + + await expect( + page.getByRole('button', { name: 'Compute Plot' }) + ).toBeVisible({ timeout: 10_000 }) + + await page.getByRole('button', { name: 'Compute Plot' }).click() + await expect(page.locator('canvas')).toBeVisible({ timeout: 120_000 }) + + const firstCluster = page.getByTestId('cluster-legend-name-0') + await expect(firstCluster).toBeVisible({ timeout: 10_000 }) + await firstCluster.click() + + const drawer = page.getByTestId('cluster-detail-drawer') + await expect(drawer).toBeVisible({ timeout: 10_000 }) + + const drawerRect = await drawer.evaluate((el) => { + const rect = el.getBoundingClientRect() + return { + x: rect.x, + right: rect.right, + y: rect.y, + bottom: rect.bottom, + width: rect.width, + viewportW: window.innerWidth, + viewportH: window.innerHeight, + } + }) + + expect(drawerRect.x).toBeGreaterThanOrEqual(0) + expect(drawerRect.right).toBeLessThanOrEqual(drawerRect.viewportW) + expect(drawerRect.y).toBeGreaterThanOrEqual(0) + expect(drawerRect.bottom).toBeLessThanOrEqual(drawerRect.viewportH) + expect(drawerRect.width).toBeGreaterThan(260) + }) + + + + test('drawer item click toggles multi-select state', async ({ page, plotPage: _ }) => { + await expect( + page.getByRole('button', { name: 'Compute Plot' }) + ).toBeVisible({ timeout: 10_000 }) + + await page.getByRole('button', { name: 'Compute Plot' }).click() + await expect(page.locator('canvas')).toBeVisible({ timeout: 120_000 }) + + const firstCluster = page.getByTestId('cluster-legend-name-0') + await expect(firstCluster).toBeVisible({ timeout: 10_000 }) + await firstCluster.click() + + const drawer = page.getByTestId('cluster-detail-drawer') + await expect(drawer).toBeVisible({ timeout: 10_000 }) + + const rows = drawer.getByRole('button').filter({ hasText: /dist: /i }) + await expect(rows.first()).toBeVisible({ timeout: 10_000 }) + + await rows.nth(0).click() + await rows.nth(1).click() + + let size = await page.evaluate(() => { + const plotWindow = window as PlotStoreWindow + return plotWindow.__plotStore?.getState().selectedPointIds?.size + }) + expect(size).toBe(2) + + await rows.nth(0).click() + + size = await page.evaluate(() => { + const plotWindow = window as PlotStoreWindow + return plotWindow.__plotStore?.getState().selectedPointIds?.size + }) + expect(size).toBe(1) + }) + + test('drawer supports clear selected and select page', async ({ page, plotPage: _ }) => { + await expect( + page.getByRole('button', { name: 'Compute Plot' }) + ).toBeVisible({ timeout: 10_000 }) + + await page.getByRole('button', { name: 'Compute Plot' }).click() + await expect(page.locator('canvas')).toBeVisible({ timeout: 120_000 }) + + const firstCluster = page.getByTestId('cluster-legend-name-0') + await expect(firstCluster).toBeVisible({ timeout: 10_000 }) + await firstCluster.click() + + const drawer = page.getByTestId('cluster-detail-drawer') + await expect(drawer).toBeVisible({ timeout: 10_000 }) + + const rows = drawer.getByRole('button').filter({ hasText: /dist: /i }) + await expect(rows.first()).toBeVisible({ timeout: 10_000 }) + + await rows.nth(0).click() + await rows.nth(1).click() + + let size = await page.evaluate(() => { + const plotWindow = window as PlotStoreWindow + return plotWindow.__plotStore?.getState().selectedPointIds?.size + }) + expect(size).toBe(2) + + await page.getByRole('button', { name: 'Clear selected' }).click() + + size = await page.evaluate(() => { + const plotWindow = window as PlotStoreWindow + return plotWindow.__plotStore?.getState().selectedPointIds?.size + }) + expect(size).toBe(0) + + await page.getByRole('button', { name: 'Select page' }).click() + + size = await page.evaluate(() => { + const plotWindow = window as PlotStoreWindow + return plotWindow.__plotStore?.getState().selectedPointIds?.size + }) + expect(size).toBeGreaterThan(1) + }) + + test('selected items remain emphasized in particle mode', async ({ page, plotPage: _ }) => { + await expect( + page.getByRole('button', { name: 'Compute Plot' }) + ).toBeVisible({ timeout: 10_000 }) + + await page.getByRole('button', { name: 'Rendering' }).click() + await page.getByRole('radio', { name: 'Particles' }).check() + await page.getByRole('button', { name: 'Compute Plot' }).click() + await expect(page.locator('canvas')).toBeVisible({ timeout: 120_000 }) + + await page.getByTestId('cluster-legend-name-0').click() + const drawer = page.getByTestId('cluster-detail-drawer') + await expect(drawer).toBeVisible({ timeout: 10_000 }) + + const rows = drawer.getByRole('button').filter({ hasText: /dist: /i }) + await rows.nth(0).click() + await rows.nth(1).click() + + const selectedState = await page.evaluate(() => { + const plotWindow = window as PlotStoreWindow + return plotWindow.__plotStore?.getState().selectedPointIds?.size + }) + expect(selectedState).toBe(2) + + await expect(page.getByRole('heading', { name: 'Visualization Error' })).toHaveCount(0) + }) + + test('selected items persist in spheres and sprites modes', async ({ page, plotPage: _ }) => { + await expect( + page.getByRole('button', { name: 'Compute Plot' }) + ).toBeVisible({ timeout: 10_000 }) + + await page.getByRole('button', { name: 'Rendering' }).click() + await page.getByRole('radio', { name: 'Spheres' }).check() + await page.getByRole('button', { name: 'Compute Plot' }).click() + await expect(page.locator('canvas')).toBeVisible({ timeout: 120_000 }) + + await page.getByTestId('cluster-legend-name-0').click() + const drawer = page.getByTestId('cluster-detail-drawer') + await expect(drawer).toBeVisible({ timeout: 10_000 }) + + const rows = drawer.getByRole('button').filter({ hasText: /dist: /i }) + await rows.nth(0).click() + await rows.nth(1).click() + + let selectedState = await page.evaluate(() => { + const plotWindow = window as PlotStoreWindow + return plotWindow.__plotStore?.getState().selectedPointIds?.size + }) + expect(selectedState).toBe(2) + await expect(page.getByRole('heading', { name: 'Visualization Error' })).toHaveCount(0) + + await page.getByRole('button', { name: 'Rendering' }).click() + await page.getByRole('button', { name: 'Rendering' }).click() + await page.getByRole('radio', { name: 'Sprites' }).check() + + selectedState = await page.evaluate(() => { + const plotWindow = window as PlotStoreWindow + return plotWindow.__plotStore?.getState().selectedPointIds?.size + }) + expect(selectedState).toBe(2) + await expect(page.getByRole('heading', { name: 'Visualization Error' })).toHaveCount(0) + }) + + test('distance panel shows pairwise distances for selected products', async ({ page, plotPage: _ }) => { + await expect( + page.getByRole('button', { name: 'Compute Plot' }) + ).toBeVisible({ timeout: 10_000 }) + + await page.getByRole('button', { name: 'Compute Plot' }).click() + await expect(page.locator('canvas')).toBeVisible({ timeout: 120_000 }) + + await page.getByTestId('cluster-legend-name-0').click() + const drawer = page.getByTestId('cluster-detail-drawer') + await expect(drawer).toBeVisible({ timeout: 10_000 }) + + const rows = drawer.getByRole('button').filter({ hasText: /dist: /i }) + await rows.nth(0).click() + await rows.nth(1).click() + + await expect(page.getByRole('heading', { name: 'Selected distances' })).toBeVisible({ timeout: 10_000 }) + await expect(page.getByTestId('selected-distance-row').first()).toBeVisible({ timeout: 10_000 }) + await expect(page.getByTestId('selected-distance-row').first()).toContainText(/vs/i) + }) +}) diff --git a/frontend/e2e/fixtures.ts b/frontend/e2e/fixtures.ts index 0d46ab1..2aa10f1 100644 --- a/frontend/e2e/fixtures.ts +++ b/frontend/e2e/fixtures.ts @@ -12,7 +12,7 @@ export const test = base.extend<{ }>({ plotPage: async ({ page }, use) => { // Navigate with collection param so SearchBar renders after compute - await page.goto('/plot?collection=fashion_productDisplayName') + await page.goto('/plot?collection=fashionimageUrl') // Wait for the collection dropdown to be populated with options await expect( page.locator('select option:not([value=""])').first() diff --git a/frontend/e2e/search.spec.ts b/frontend/e2e/search.spec.ts index 511c17a..c2c7f30 100644 --- a/frontend/e2e/search.spec.ts +++ b/frontend/e2e/search.spec.ts @@ -183,4 +183,28 @@ test.describe('Semantic Search', () => { page.getByText('Results: 25') ).toBeVisible() }) + + test('hovering points does not trigger visualization error', async ({ page }) => { + const canvas = page.locator('canvas') + await expect(canvas).toBeVisible({ timeout: 10_000 }) + + const box = await canvas.boundingBox() + expect(box).not.toBeNull() + + if (!box) { + throw new Error('Canvas bounding box is null') + } + + for (let i = 0; i < 12; i += 1) { + const x = Math.floor(box.x + box.width * (0.25 + (i % 4) * 0.15)) + const y = Math.floor(box.y + box.height * (0.25 + Math.floor(i / 4) * 0.2)) + await page.mouse.move(x, y) + await page.mouse.move(Math.floor(box.x + 4), Math.floor(box.y + 4)) + } + + await expect( + page.getByRole('heading', { name: 'Visualization Error' }) + ).toHaveCount(0) + }) + }) diff --git a/frontend/e2e/tooltip-positioning.spec.ts b/frontend/e2e/tooltip-positioning.spec.ts new file mode 100644 index 0000000..6cfd326 --- /dev/null +++ b/frontend/e2e/tooltip-positioning.spec.ts @@ -0,0 +1,57 @@ +import { test, expect } from '@playwright/test' +import { computeTooltipPlacement } from '../src/components/plot/tooltipPositioning' + +test.describe('Tooltip placement algorithm', () => { + test('clamps and flips within viewport bounds', async () => { + const placement = { + centeredAbove: computeTooltipPlacement({ + x: 400, + y: 300, + width: 256, + height: 180, + viewportWidth: 1200, + viewportHeight: 900, + offset: 12, + margin: 8, + }), + flipBelow: computeTooltipPlacement({ + x: 400, + y: 40, + width: 256, + height: 180, + viewportWidth: 1200, + viewportHeight: 900, + offset: 12, + margin: 8, + }), + clampLeft: computeTooltipPlacement({ + x: 30, + y: 400, + width: 256, + height: 180, + viewportWidth: 1200, + viewportHeight: 900, + offset: 12, + margin: 8, + }), + clampRight: computeTooltipPlacement({ + x: 1180, + y: 400, + width: 256, + height: 180, + viewportWidth: 1200, + viewportHeight: 900, + offset: 12, + margin: 8, + }), + } + + expect(placement.centeredAbove.x).toBe(272) + expect(placement.centeredAbove.y).toBe(108) + + expect(placement.flipBelow.y).toBe(52) + + expect(placement.clampLeft.x).toBe(8) + expect(placement.clampRight.x).toBe(936) + }) +}) diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 493badd..c97f470 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,8 +1,10 @@ -import { QueryClient, QueryClientProvider } from '@tanstack/react-query' -import { BrowserRouter, Link, NavLink, Route, Routes } from 'react-router-dom' +import { QueryClient, QueryClientProvider, useQueryClient } from '@tanstack/react-query' +import { BrowserRouter, Link, NavLink, Route, Routes, useLocation } from 'react-router-dom' +import { useEffect, useRef } from 'react' import HomePage from './pages/HomePage' import IndexPage from './pages/IndexPage' import PlotPage from './pages/PlotPage' +import { usePlotStore } from './stores/plotStore' const queryClient = new QueryClient() @@ -41,12 +43,44 @@ function NavBar() { ) } +function PlotStateResetter() { + const location = useLocation() + const setPlotData = usePlotStore((state) => state.setPlotData) + const resetPlotJobId = usePlotStore((state) => state.resetPlotJobId) + const clearClusterDrillDown = usePlotStore((state) => state.clearClusterDrillDown) + const resetPlotCollectionName = usePlotStore((state) => state.resetPlotCollectionName) + const queryClient = useQueryClient() + const previousLocationRef = useRef(location) + + useEffect(() => { + const previousLocation = previousLocationRef.current + const leftPlot = previousLocation.pathname === '/plot' && location.pathname !== '/plot' + const enteredPlotWithNewQuery = + location.pathname === '/plot' && previousLocation.search !== location.search + const enteredPlot = + location.pathname === '/plot' && previousLocation.pathname !== '/plot' + + if (leftPlot || enteredPlotWithNewQuery || enteredPlot) { + setPlotData(null) + resetPlotJobId() + clearClusterDrillDown() + resetPlotCollectionName() + queryClient.removeQueries({ queryKey: ['plotData'] }) + } + + previousLocationRef.current = location + }, [clearClusterDrillDown, location, resetPlotCollectionName, resetPlotJobId, setPlotData, queryClient]) + + return null +} + export default function App() { return (
+
} /> diff --git a/frontend/src/api/plot.ts b/frontend/src/api/plot.ts index da9de6e..d24a44a 100644 --- a/frontend/src/api/plot.ts +++ b/frontend/src/api/plot.ts @@ -1,4 +1,18 @@ -import type { IndexStartResponse, PlotRequest, PlotResponse, SearchRequest, SearchResponse, SuggestClustersRequest, SuggestClustersStatusResponse } from "../types"; +import type { + AnnotationUpdate, + AnnotationsResponse, + ClusterDetailResponse, + IndexStartResponse, + MessageResponse, + PlotRequest, + PlotResponse, + SearchRequest, + SearchResponse, + SubClusterRequest, + SubClusterResponse, + SuggestClustersRequest, + SuggestClustersStatusResponse, +} from "../types"; import { apiFetch, apiPost } from "./client"; export async function startPlotCompute( @@ -34,3 +48,54 @@ export async function getSuggestClustersStatus( `/plot/suggest-clusters/${jobId}`, ); } + +export async function getClusterDetail( + jobId: string, + clusterIndex: number, + page = 1, + pageSize = 50, +): Promise { + return apiFetch( + `/plot/${jobId}/cluster/${clusterIndex}?page=${page}&page_size=${pageSize}`, + ); +} + +export async function subCluster( + jobId: string, + clusterIndex: number, + request: SubClusterRequest, +): Promise { + return apiPost( + `/plot/${jobId}/cluster/${clusterIndex}/sub-cluster`, + request, + ); +} + +export async function getAnnotations( + jobId: string, +): Promise { + return apiFetch(`/annotations/${jobId}`); +} + +export async function updateAnnotation( + jobId: string, + clusterIndex: number, + body: AnnotationUpdate, +): Promise { + return apiFetch( + `/annotations/${jobId}/cluster/${clusterIndex}`, + { + method: "PUT", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(body), + }, + ); +} + +export async function deleteAnnotations( + jobId: string, +): Promise { + return apiFetch(`/annotations/${jobId}`, { + method: "DELETE", + }); +} diff --git a/frontend/src/components/plot/ClusterDetailDrawer.tsx b/frontend/src/components/plot/ClusterDetailDrawer.tsx new file mode 100644 index 0000000..5fecc28 --- /dev/null +++ b/frontend/src/components/plot/ClusterDetailDrawer.tsx @@ -0,0 +1,327 @@ +import { useState, useEffect, useCallback, useRef } from 'react' +import { usePlotStore, CLUSTER_COLORS } from '../../stores/plotStore' +import { getClusterDetail, updateAnnotation, getAnnotations } from '../../api/plot' +import type { ClusterDetailResponse } from '../../types' +import SubClusterView from './SubClusterView' +import SelectedPointsDistancePanel from './SelectedPointsDistancePanel' + +interface ClusterDetailDrawerProps { + jobId: string + imageField?: string +} + +export default function ClusterDetailDrawer({ jobId, imageField }: ClusterDetailDrawerProps) { + const selectedCluster = usePlotStore((s) => s.selectedCluster) + const clusterDetail = usePlotStore((s) => s.clusterDetail) + const annotations = usePlotStore((s) => s.annotations) + const isLoadingClusterDetail = usePlotStore((s) => s.isLoadingClusterDetail) + const plotData = usePlotStore((s) => s.plotData) + const selectedPointIds = usePlotStore((s) => s.selectedPointIds) + const setClusterDetail = usePlotStore((s) => s.setClusterDetail) + const setAnnotations = usePlotStore((s) => s.setAnnotations) + const setIsLoadingClusterDetail = usePlotStore((s) => s.setIsLoadingClusterDetail) + const setHighlightedIds = usePlotStore((s) => s.setHighlightedIds) + const clearSelectedPointIds = usePlotStore((s) => s.clearSelectedPointIds) + const setSelectedPointIds = usePlotStore((s) => s.setSelectedPointIds) + const clearClusterDrillDown = usePlotStore((s) => s.clearClusterDrillDown) + + const [page, setPage] = useState(1) + const [isEditingName, setIsEditingName] = useState(false) + const [editName, setEditName] = useState('') + const [notes, setNotes] = useState('') + const [tagsInput, setTagsInput] = useState('') + const [showSubCluster, setShowSubCluster] = useState(false) + const notesTimeoutRef = useRef>(undefined) + const tagsTimeoutRef = useRef>(undefined) + + const clusterIndex = selectedCluster + const cluster = plotData?.clusters.find((c) => c.index === clusterIndex) + const color = clusterIndex != null ? CLUSTER_COLORS[clusterIndex % CLUSTER_COLORS.length] : '#999' + const annotation = clusterIndex != null ? annotations?.clusters[String(clusterIndex)] : undefined + + // Load cluster detail when selected cluster changes + useEffect(() => { + if (clusterIndex == null) return + setPage(1) + setShowSubCluster(false) + setIsLoadingClusterDetail(true) + setClusterDetail(null) + + getClusterDetail(jobId, clusterIndex, 1) + .then((data: ClusterDetailResponse) => setClusterDetail(data)) + .catch(() => setClusterDetail(null)) + .finally(() => setIsLoadingClusterDetail(false)) + }, [jobId, clusterIndex, setClusterDetail, setIsLoadingClusterDetail]) + + // Load annotations + useEffect(() => { + getAnnotations(jobId) + .then(setAnnotations) + .catch(() => setAnnotations(null)) + }, [jobId, setAnnotations]) + + // Sync local notes/tags state with annotation + useEffect(() => { + setNotes(annotation?.notes ?? '') + setTagsInput(annotation?.tags?.join(', ') ?? '') + }, [annotation]) + + const handlePageChange = useCallback((newPage: number) => { + if (clusterIndex == null) return + setPage(newPage) + setIsLoadingClusterDetail(true) + getClusterDetail(jobId, clusterIndex, newPage) + .then((data: ClusterDetailResponse) => setClusterDetail(data)) + .catch(() => {/* keep previous data */}) + .finally(() => setIsLoadingClusterDetail(false)) + }, [jobId, clusterIndex, setClusterDetail, setIsLoadingClusterDetail]) + + const handleSaveName = useCallback(() => { + if (clusterIndex == null) return + setIsEditingName(false) + if (editName.trim()) { + updateAnnotation(jobId, clusterIndex, { name: editName.trim() }) + .then(setAnnotations) + .catch(() => {/* silent */}) + } + }, [jobId, clusterIndex, editName, setAnnotations]) + + const handleNotesChange = useCallback((value: string) => { + setNotes(value) + if (notesTimeoutRef.current) clearTimeout(notesTimeoutRef.current) + notesTimeoutRef.current = setTimeout(() => { + if (clusterIndex != null) { + updateAnnotation(jobId, clusterIndex, { notes: value }) + .then(setAnnotations) + .catch(() => {/* silent */}) + } + }, 800) + }, [jobId, clusterIndex, setAnnotations]) + + const handleTagsChange = useCallback((value: string) => { + setTagsInput(value) + if (tagsTimeoutRef.current) clearTimeout(tagsTimeoutRef.current) + tagsTimeoutRef.current = setTimeout(() => { + if (clusterIndex != null) { + const tags = value.split(',').map((t) => t.trim()).filter(Boolean) + updateAnnotation(jobId, clusterIndex, { tags }) + .then(setAnnotations) + .catch(() => {/* silent */}) + } + }, 800) + }, [jobId, clusterIndex, setAnnotations]) + + const handleItemClick = useCallback((id: string) => { + const nextSelected = new Set(selectedPointIds) + if (nextSelected.has(id)) { + nextSelected.delete(id) + } else { + nextSelected.add(id) + } + setSelectedPointIds(nextSelected) + setHighlightedIds(nextSelected) + }, [selectedPointIds, setHighlightedIds, setSelectedPointIds]) + + const handleClearSelected = useCallback(() => { + clearSelectedPointIds() + setHighlightedIds(new Set()) + }, [clearSelectedPointIds, setHighlightedIds]) + + const handleSelectPage = useCallback(() => { + if (!clusterDetail) return + const pageIds = new Set(clusterDetail.items.map((item) => item.id)) + setSelectedPointIds(pageIds) + setHighlightedIds(pageIds) + }, [clusterDetail, setHighlightedIds, setSelectedPointIds]) + + if (clusterIndex == null) return null + + const totalPages = clusterDetail ? Math.ceil(clusterDetail.total_items / clusterDetail.page_size) : 0 + const displayName = annotation?.name ?? cluster?.name ?? `Cluster ${clusterIndex}` + const selectedItems = clusterDetail?.items.filter((item) => selectedPointIds.has(item.id)) ?? [] + + return ( +
+ {/* Header */} +
+
+ + {isEditingName ? ( + setEditName(e.target.value)} + onBlur={handleSaveName} + onKeyDown={(e) => { + if (e.key === 'Enter') handleSaveName() + if (e.key === 'Escape') setIsEditingName(false) + }} + className="text-sm font-bold text-gray-900 border-b border-blue-500 outline-none bg-transparent w-full" + autoFocus + /> + ) : ( + + )} + {clusterDetail && ( + + ({clusterDetail.total_items} items) + + )} +
+ +
+ + {/* Sub-cluster toggle */} +
+ + + +
+ + {/* Sub-cluster view */} + {showSubCluster && ( +
+ +
+ )} + + {/* Items list */} +
+ {isLoadingClusterDetail && ( +
+
+
+ )} + + {clusterDetail && !isLoadingClusterDetail && ( +
+ {clusterDetail.items.map((item) => { + const isSelected = selectedPointIds.has(item.id) + return ( + + )})} +
+ )} +
+ + {/* Pagination */} + {totalPages > 1 && ( +
+ + + Page {page} / {totalPages} + + +
+ )} + + + + {/* Annotation section */} +
+
+ +