Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 30 additions & 26 deletions embedding_cluster/scatter_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ def compute_plot_data(settings: Settings) -> dict[str, Any]:
)

# Build structured point data
# Points are built in original data order so that points[i] aligns
# with cluster_labels[i] and embeddings_standardized[i]. The
# cluster-detail and sub-cluster endpoints rely on this alignment.
points: list[dict[str, Any]] = []
clusters: list[dict[str, Any]] = []

Expand All @@ -308,33 +311,34 @@ def compute_plot_data(settings: Settings) -> dict[str, Any]:
}
)

for idx in clusters_indices[cluster_i]:
metadata: dict[str, Any] = {}
if idx < len(collection_content["metadatas"]):
raw_metadata = dict(collection_content["metadatas"][idx])
if display_fields:
metadata = {
key: value
for key, value in raw_metadata.items()
if key in display_fields
}
else:
metadata = raw_metadata
point_id = (
collection_content["ids"][idx]
if idx < len(collection_content["ids"])
else str(idx)
)
points.append(
{
"x": float(reduced[idx, 0]),
"y": float(reduced[idx, 1]),
"z": float(reduced[idx, 2]),
"cluster": cluster_i,
"metadata": metadata,
"id": point_id,
for idx in range(len(collection_content["ids"])):
cluster_i = int(pred_arr[idx])
metadata: dict[str, Any] = {}
if idx < len(collection_content["metadatas"]):
raw_metadata = dict(collection_content["metadatas"][idx])
if display_fields:
metadata = {
key: value
for key, value in raw_metadata.items()
if key in display_fields
}
)
else:
metadata = raw_metadata
point_id = (
collection_content["ids"][idx]
if idx < len(collection_content["ids"])
else str(idx)
)
points.append(
{
"x": float(reduced[idx, 0]),
"y": float(reduced[idx, 1]),
"z": float(reduced[idx, 2]),
"cluster": cluster_i,
"metadata": metadata,
"id": point_id,
}
)

return {
"points": points,
Expand Down
98 changes: 98 additions & 0 deletions tests/test_cluster_detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,101 @@ async def test_cluster_detail_job_not_ready(
status.HTTP_409_CONFLICT,
status.HTTP_200_OK,
)


def _fake_compute_interleaved_clusters(
_settings: object,
) -> dict[str, object]:
"""Fake compute with interleaved cluster labels.

Regression fixture: simulates real KMeans output where labels are NOT
grouped (e.g. [0, 1, 0, 1, 0, 1]). The old code built points grouped
by cluster, so points[i] and cluster_labels[i] referred to different
items — causing the cluster-detail endpoint to return wrong products.
"""
n = 6
rng = np.random.default_rng(99)
embeddings = rng.random((n, 4))
labels = [0, 1, 0, 1, 0, 1]

points = []
for i in range(n):
points.append(
{
"x": float(i),
"y": float(i),
"z": float(i),
"cluster": labels[i],
"metadata": {"name": f"item{i}"},
"id": str(i),
}
)

return {
"points": points,
"clusters": [
{
"index": 0,
"name": "Group 1",
"color": "hsl(0, 70%, 50%)",
"count": 3,
},
{
"index": 1,
"name": "Group 2",
"color": "hsl(180, 70%, 50%)",
"count": 3,
},
],
"total_points": n,
"embeddings_standardized": embeddings.tolist(),
"cluster_labels": labels,
"point_ids": [str(i) for i in range(n)],
}


@pytest.fixture
def mock_compute_interleaved() -> Iterator[None]:
with patch(
"embedding_cluster.server.routes.plot.compute_plot_data",
side_effect=_fake_compute_interleaved_clusters,
):
yield


@pytest.mark.asyncio
async def test_cluster_detail_returns_correct_items_with_interleaved_labels(
app: FastAPI, mock_compute_interleaved: None
) -> None:
"""Regression: cluster-detail must return items belonging to the
requested cluster when cluster_labels are interleaved (not grouped).
"""
_ = mock_compute_interleaved
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
start = await client.post(
"/api/plot/compute",
json={"chromadb_collection_name": "test"},
)
job_id = cast("str", start.json()["job_id"])
await asyncio.sleep(0.2)

resp_c0 = await client.get(f"/api/plot/{job_id}/cluster/0")
resp_c1 = await client.get(f"/api/plot/{job_id}/cluster/1")

assert resp_c0.status_code == status.HTTP_200_OK
assert resp_c1.status_code == status.HTTP_200_OK

c0_data = resp_c0.json()
c1_data = resp_c1.json()

c0_ids = {item["id"] for item in c0_data["items"]}
c1_ids = {item["id"] for item in c1_data["items"]}

assert c0_ids == {"0", "2", "4"}, (
f"Cluster 0 should contain items 0,2,4 but got {c0_ids}"
)
assert c1_ids == {"1", "3", "5"}, (
f"Cluster 1 should contain items 1,3,5 but got {c1_ids}"
)
29 changes: 29 additions & 0 deletions tests/test_scatter_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,3 +681,32 @@ def test_point_ids_match_collection_ids(self) -> None:
point_ids = result["point_ids"]
assert isinstance(point_ids, list)
assert sorted(point_ids) == ["0", "1", "2", "3"]

def test_points_aligned_with_cluster_labels(self) -> None:
"""points[i].cluster must equal cluster_labels[i] for every i.

Regression: compute_plot_data previously grouped points by cluster
instead of preserving original-data order, causing the cluster-detail
endpoint to return items from the wrong cluster.
"""
result = self._run_compute(self._make_settings(), n_points=6)
points = result["points"]
labels = result["cluster_labels"]
assert len(points) == len(labels)
for i, (point, label) in enumerate(zip(points, labels, strict=True)):
assert point["cluster"] == label, (
f"points[{i}].cluster={point['cluster']} != cluster_labels[{i}]={label}"
)

def test_points_preserve_original_id_order(self) -> None:
"""points[i].id must equal the i-th collection id.

Regression: when points were grouped by cluster the id ordering
no longer matched the original collection order, breaking the
index-based lookup in cluster-detail and sub-cluster endpoints.
"""
result = self._run_compute(self._make_settings(), n_points=6)
points = result["points"]
expected_ids = [str(i) for i in range(6)]
actual_ids = [p["id"] for p in points]
assert actual_ids == expected_ids
Loading