From c8e782f654a9e23d75c378856264b6ef371646f7 Mon Sep 17 00:00:00 2001 From: aGallea Date: Tue, 17 Mar 2026 09:31:08 +0200 Subject: [PATCH 1/2] fix(plot): align points array order with cluster labels Points were built grouped-by-cluster while cluster_labels stayed in original data order, causing get_cluster_detail and sub_cluster endpoints to return wrong items for each cluster group. Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus --- embedding_cluster/scatter_plot.py | 56 +++++++++++++++++-------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/embedding_cluster/scatter_plot.py b/embedding_cluster/scatter_plot.py index 769422a..532d1f1 100644 --- a/embedding_cluster/scatter_plot.py +++ b/embedding_cluster/scatter_plot.py @@ -292,6 +292,9 @@ def compute_plot_data(settings: Settings) -> dict[str, Any]: ) # Build structured point data + # Points are built in original data order so that points[i] aligns + # with cluster_labels[i] and embeddings_standardized[i]. The + # cluster-detail and sub-cluster endpoints rely on this alignment. points: list[dict[str, Any]] = [] clusters: list[dict[str, Any]] = [] @@ -308,33 +311,34 @@ def compute_plot_data(settings: Settings) -> dict[str, Any]: } ) - for idx in clusters_indices[cluster_i]: - metadata: dict[str, Any] = {} - if idx < len(collection_content["metadatas"]): - raw_metadata = dict(collection_content["metadatas"][idx]) - if display_fields: - metadata = { - key: value - for key, value in raw_metadata.items() - if key in display_fields - } - else: - metadata = raw_metadata - point_id = ( - collection_content["ids"][idx] - if idx < len(collection_content["ids"]) - else str(idx) - ) - points.append( - { - "x": float(reduced[idx, 0]), - "y": float(reduced[idx, 1]), - "z": float(reduced[idx, 2]), - "cluster": cluster_i, - "metadata": metadata, - "id": point_id, + for idx in range(len(collection_content["ids"])): + cluster_i = int(pred_arr[idx]) + metadata: dict[str, Any] = {} + if idx < len(collection_content["metadatas"]): + raw_metadata = dict(collection_content["metadatas"][idx]) + if display_fields: + metadata = { + key: value + for key, value in raw_metadata.items() + if key in display_fields } - ) + else: + metadata = raw_metadata + point_id = ( + collection_content["ids"][idx] + if idx < len(collection_content["ids"]) + else str(idx) + ) + points.append( + { + "x": float(reduced[idx, 0]), + "y": float(reduced[idx, 1]), + "z": float(reduced[idx, 2]), + "cluster": cluster_i, + "metadata": metadata, + "id": point_id, + } + ) return { "points": points, From e1e365c846cd26240f9c7bf123d94794159a8c58 Mon Sep 17 00:00:00 2001 From: aGallea Date: Tue, 17 Mar 2026 09:32:46 +0200 Subject: [PATCH 2/2] test(plot): add regression tests for cluster detail ordering Verify points array aligns with cluster_labels and original ID order. Integration test uses interleaved labels to catch index mismatch between points and cluster_labels. Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus --- tests/test_cluster_detail.py | 98 ++++++++++++++++++++++++++++++++++++ tests/test_scatter_plot.py | 29 +++++++++++ 2 files changed, 127 insertions(+) diff --git a/tests/test_cluster_detail.py b/tests/test_cluster_detail.py index 37ce460..9908719 100644 --- a/tests/test_cluster_detail.py +++ b/tests/test_cluster_detail.py @@ -203,3 +203,101 @@ async def test_cluster_detail_job_not_ready( status.HTTP_409_CONFLICT, status.HTTP_200_OK, ) + + +def _fake_compute_interleaved_clusters( + _settings: object, +) -> dict[str, object]: + """Fake compute with interleaved cluster labels. + + Regression fixture: simulates real KMeans output where labels are NOT + grouped (e.g. [0, 1, 0, 1, 0, 1]). The old code built points grouped + by cluster, so points[i] and cluster_labels[i] referred to different + items — causing the cluster-detail endpoint to return wrong products. + """ + n = 6 + rng = np.random.default_rng(99) + embeddings = rng.random((n, 4)) + labels = [0, 1, 0, 1, 0, 1] + + points = [] + for i in range(n): + points.append( + { + "x": float(i), + "y": float(i), + "z": float(i), + "cluster": labels[i], + "metadata": {"name": f"item{i}"}, + "id": str(i), + } + ) + + return { + "points": points, + "clusters": [ + { + "index": 0, + "name": "Group 1", + "color": "hsl(0, 70%, 50%)", + "count": 3, + }, + { + "index": 1, + "name": "Group 2", + "color": "hsl(180, 70%, 50%)", + "count": 3, + }, + ], + "total_points": n, + "embeddings_standardized": embeddings.tolist(), + "cluster_labels": labels, + "point_ids": [str(i) for i in range(n)], + } + + +@pytest.fixture +def mock_compute_interleaved() -> Iterator[None]: + with patch( + "embedding_cluster.server.routes.plot.compute_plot_data", + side_effect=_fake_compute_interleaved_clusters, + ): + yield + + +@pytest.mark.asyncio +async def test_cluster_detail_returns_correct_items_with_interleaved_labels( + app: FastAPI, mock_compute_interleaved: None +) -> None: + """Regression: cluster-detail must return items belonging to the + requested cluster when cluster_labels are interleaved (not grouped). + """ + _ = mock_compute_interleaved + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + start = await client.post( + "/api/plot/compute", + json={"chromadb_collection_name": "test"}, + ) + job_id = cast("str", start.json()["job_id"]) + await asyncio.sleep(0.2) + + resp_c0 = await client.get(f"/api/plot/{job_id}/cluster/0") + resp_c1 = await client.get(f"/api/plot/{job_id}/cluster/1") + + assert resp_c0.status_code == status.HTTP_200_OK + assert resp_c1.status_code == status.HTTP_200_OK + + c0_data = resp_c0.json() + c1_data = resp_c1.json() + + c0_ids = {item["id"] for item in c0_data["items"]} + c1_ids = {item["id"] for item in c1_data["items"]} + + assert c0_ids == {"0", "2", "4"}, ( + f"Cluster 0 should contain items 0,2,4 but got {c0_ids}" + ) + assert c1_ids == {"1", "3", "5"}, ( + f"Cluster 1 should contain items 1,3,5 but got {c1_ids}" + ) diff --git a/tests/test_scatter_plot.py b/tests/test_scatter_plot.py index e98e58f..6e8128e 100644 --- a/tests/test_scatter_plot.py +++ b/tests/test_scatter_plot.py @@ -681,3 +681,32 @@ def test_point_ids_match_collection_ids(self) -> None: point_ids = result["point_ids"] assert isinstance(point_ids, list) assert sorted(point_ids) == ["0", "1", "2", "3"] + + def test_points_aligned_with_cluster_labels(self) -> None: + """points[i].cluster must equal cluster_labels[i] for every i. + + Regression: compute_plot_data previously grouped points by cluster + instead of preserving original-data order, causing the cluster-detail + endpoint to return items from the wrong cluster. + """ + result = self._run_compute(self._make_settings(), n_points=6) + points = result["points"] + labels = result["cluster_labels"] + assert len(points) == len(labels) + for i, (point, label) in enumerate(zip(points, labels, strict=True)): + assert point["cluster"] == label, ( + f"points[{i}].cluster={point['cluster']} != cluster_labels[{i}]={label}" + ) + + def test_points_preserve_original_id_order(self) -> None: + """points[i].id must equal the i-th collection id. + + Regression: when points were grouped by cluster the id ordering + no longer matched the original collection order, breaking the + index-based lookup in cluster-detail and sub-cluster endpoints. + """ + result = self._run_compute(self._make_settings(), n_points=6) + points = result["points"] + expected_ids = [str(i) for i in range(6)] + actual_ids = [p["id"] for p in points] + assert actual_ids == expected_ids