Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions embedding_cluster/ai_naming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from __future__ import annotations

import logging

import litellm

logger = logging.getLogger(__name__)

# Alias for testability (easy to mock)
litellm_completion = litellm.completion

SYSTEM_PROMPT_TOP_LEVEL = (
"Your role is to find a very short (max 5 words), concise name "
"for a group of items, one name to rule them all. "
"The user will provide a list of item names. Do your best."
)

SYSTEM_PROMPT_SUB_CLUSTER = (
"Your role is to find a very short (max 5 words), concise name "
"for a sub-group of items within a larger group called "
'"{parent_name}". '
"The name should distinguish this sub-group from its siblings "
"while relating to the parent theme. The user will provide a "
"list of item names. Do your best."
)


def _call_llm(
messages: list[dict[str, str]],
api_key: str,
model: str,
base_url: str | None = None,
temperature: float = 0.5,
) -> str:
"""Call LiteLLM and return the response content."""
kwargs: dict[str, object] = {
"model": model,
"messages": messages,
"api_key": api_key,
"temperature": temperature,
}
if base_url:
kwargs["api_base"] = base_url

response = litellm_completion(**kwargs)
content: str = response.choices[0].message.content or ""
return (content[:30] + "..") if len(content) > 30 else content


def get_cluster_name(
item_names: list[str],
api_key: str,
model: str,
base_url: str | None = None,
temperature: float = 0.5,
) -> str:
"""Generate a short name for a cluster of items."""
user_content = "\n".join(f"name: {name}" for name in item_names)
messages = [
{"role": "system", "content": SYSTEM_PROMPT_TOP_LEVEL},
{"role": "user", "content": user_content},
]
return _call_llm(messages, api_key, model, base_url, temperature)


def get_sub_cluster_name(
item_names: list[str],
api_key: str,
model: str,
base_url: str | None = None,
temperature: float = 0.5,
parent_cluster_name: str | None = None,
) -> str:
"""Generate a short name for a sub-cluster."""
if parent_cluster_name:
system_content = SYSTEM_PROMPT_SUB_CLUSTER.format(
parent_name=parent_cluster_name,
)
else:
system_content = SYSTEM_PROMPT_TOP_LEVEL

user_content = "\n".join(f"name: {name}" for name in item_names)
messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": user_content},
]
return _call_llm(messages, api_key, model, base_url, temperature)


def test_connection(
api_key: str,
model: str,
base_url: str | None = None,
) -> tuple[bool, str | None]:
"""Test LLM connection. Returns (success, error)."""
try:
kwargs: dict[str, object] = {
"model": model,
"messages": [{"role": "user", "content": "Say hello"}],
"api_key": api_key,
"max_tokens": 5,
}
if base_url:
kwargs["api_base"] = base_url
litellm_completion(**kwargs)
return True, None
except Exception as exc:
error_msg = str(exc)
if api_key and api_key in error_msg:
error_msg = error_msg.replace(api_key, "***")
return False, error_msg
55 changes: 4 additions & 51 deletions embedding_cluster/scatter_plot.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from __future__ import annotations

import logging
import random
from typing import TYPE_CHECKING, Any

import chromadb
import numpy as np
import plotly.graph_objects as go
from dash import Dash, Input, Output, callback, dcc, html, no_update
from openai import OpenAI
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
Expand Down Expand Up @@ -71,31 +69,6 @@ def reduce_dimensions(
return result


def gpt_get_cluster_name(info: str, settings: Settings) -> str:
openai_client = OpenAI()
messages: list[dict[str, str]] = [
{
"role": "system",
"content": (
"Your role is to find a very short (max 5 words), concise "
"name for a group of items, one name to rule them all. "
"the user will provide a list of item names. do your best"
),
},
{
"role": "user",
"content": info,
},
]
completion = openai_client.chat.completions.create(
model=settings.gpt_default_model,
temperature=settings.gpt_default_temperature,
messages=messages, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
)
content = completion.choices[0].message.content or ""
return (content[:30] + "..") if len(content) > 30 else content


def load_chromadb_collection(settings: Settings) -> Any:
chromadb_client: ClientAPI = chromadb.PersistentClient(path="./chromadb")
collection = chromadb_client.get_or_create_collection(
Expand Down Expand Up @@ -201,38 +174,18 @@ def generate_cluster_props(
num_clusters: int,
pred_arr: Any,
collection_content_text_display: list[str],
settings: Settings,
num_products_for_cluster_name: int = 10,
) -> tuple[list[list[int]], list[str]]:
_ = (collection_content_text_display, num_products_for_cluster_name)
clusters_indices: list[list[int]] = []
cluster_names: list[str] = []
group_index = 1
for cluster_i in range(num_clusters):
curr_cluster_indices = [i for i, x in enumerate(pred_arr) if x == cluster_i]
clusters_indices.append(curr_cluster_indices)
logger.info("Generating cluster %d names ...", cluster_i)
if settings.gpt_generate_cluster_name is True:
random_product_indexes = random.sample(
range(0, len(curr_cluster_indices)),
min(
num_products_for_cluster_name,
len(curr_cluster_indices),
),
)
curr_descriptions = ""
for product_index in random_product_indexes:
idx = curr_cluster_indices[product_index]
item = (
collection_content_text_display[idx]
if idx < len(collection_content_text_display)
else f"Item {idx}"
)
curr_descriptions += f"name: {item} \n"
cluster_name = gpt_get_cluster_name(curr_descriptions, settings)
cluster_names.append(cluster_name)
else:
cluster_names.append(f"Group {group_index}")
group_index += 1
cluster_names.append(f"Group {group_index}")
group_index += 1
return clusters_indices, cluster_names


Expand Down Expand Up @@ -288,7 +241,7 @@ def compute_plot_data(settings: Settings) -> dict[str, Any]:
)

clusters_indices, cluster_names = generate_cluster_props(
num_clusters, pred_arr, collection_content_text_display, settings
num_clusters, pred_arr, collection_content_text_display
)

# Build structured point data
Expand Down
2 changes: 2 additions & 0 deletions embedding_cluster/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles

from embedding_cluster.server.routes.ai import router as ai_router
from embedding_cluster.server.routes.annotations import (
router as annotations_router,
)
Expand Down Expand Up @@ -43,6 +44,7 @@ def create_app() -> FastAPI:
async def health_check() -> dict[str, str]:
return {"status": "ok"}

app.include_router(ai_router)
app.include_router(collections_router)
app.include_router(csv_router)
app.include_router(index_router)
Expand Down
39 changes: 36 additions & 3 deletions embedding_cluster/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ class PlotRequest(BaseModel):
num_clusters: int = 10
text_display_fields: list[str] | None = None
image_field: str | None = None
gpt_generate_cluster_name: bool = False
gpt_default_model: str = "gpt-3.5-turbo"
gpt_default_temperature: float = 0.51
reduction_algorithm: Literal["tsne", "umap", "pca"] = "tsne"
tsne_perplexity: float = 30.0
tsne_learning_rate: str = "auto"
Expand Down Expand Up @@ -199,6 +196,7 @@ class SubClusterInfo(BaseModel):
index: int
count: int
color: str
name: str | None = None


class SubClusterResponse(BaseModel):
Expand Down Expand Up @@ -248,3 +246,38 @@ class ClusterAnnotation(BaseModel):
class AnnotationsResponse(BaseModel):
job_id: str
clusters: dict[str, ClusterAnnotation]


class AiNamingRequest(BaseModel):
job_id: str
cluster_indices: list[int]
api_key: str
model: str
base_url: str | None = None
temperature: float = 0.5


class AiNamingResponse(BaseModel):
names: dict[str, str]


class AiSubClusterNamingRequest(BaseModel):
job_id: str
point_ids: list[str]
sub_cluster_labels: list[int]
api_key: str
model: str
base_url: str | None = None
temperature: float = 0.5
parent_cluster_name: str | None = None


class AiTestConnectionRequest(BaseModel):
api_key: str
model: str
base_url: str | None = None


class AiTestConnectionResponse(BaseModel):
success: bool
error: str | None = None
Loading
Loading