Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions embedding_cluster/ai_naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@
)


def _normalize_base_url(model: str, base_url: str | None) -> str | None:
"""Strip /v1 suffix for Ollama models (litellm uses native API)."""
if base_url and model.startswith("ollama/"):
stripped = base_url.rstrip("/")
if stripped.endswith("/v1"):
return stripped[:-3]
return base_url


def _call_llm(
messages: list[dict[str, str]],
api_key: str,
Expand All @@ -36,11 +45,13 @@ def _call_llm(
kwargs: dict[str, object] = {
"model": model,
"messages": messages,
"api_key": api_key,
"temperature": temperature,
}
if base_url:
kwargs["api_base"] = base_url
if api_key:
kwargs["api_key"] = api_key
resolved_url = _normalize_base_url(model, base_url)
if resolved_url:
kwargs["api_base"] = resolved_url

response = litellm_completion(**kwargs)
content: str = response.choices[0].message.content or ""
Expand Down Expand Up @@ -97,11 +108,13 @@ def test_connection(
kwargs: dict[str, object] = {
"model": model,
"messages": [{"role": "user", "content": "Say hello"}],
"api_key": api_key,
"max_tokens": 5,
}
if base_url:
kwargs["api_base"] = base_url
if api_key:
kwargs["api_key"] = api_key
resolved_url = _normalize_base_url(model, base_url)
if resolved_url:
kwargs["api_base"] = resolved_url
litellm_completion(**kwargs)
return True, None
except Exception as exc:
Expand Down
15 changes: 15 additions & 0 deletions embedding_cluster/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,18 @@ class AiTestConnectionRequest(BaseModel):
class AiTestConnectionResponse(BaseModel):
success: bool
error: str | None = None


class OllamaModelsRequest(BaseModel):
base_url: str = "http://localhost:11434"


class OllamaModel(BaseModel):
name: str
size: int | None = None
parameter_size: str | None = None
family: str | None = None


class OllamaModelsResponse(BaseModel):
models: list[OllamaModel]
107 changes: 86 additions & 21 deletions embedding_cluster/server/routes/ai.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import asyncio
import logging
import random
from functools import partial
from typing import Any, cast

import httpx
from fastapi import APIRouter, HTTPException

from embedding_cluster.ai_naming import (
Expand All @@ -19,6 +22,9 @@
AiSubClusterNamingRequest,
AiTestConnectionRequest,
AiTestConnectionResponse,
OllamaModel,
OllamaModelsRequest,
OllamaModelsResponse,
)
from embedding_cluster.server.tasks import TaskStatus, task_registry

Expand Down Expand Up @@ -103,49 +109,63 @@ def _get_item_names_for_sub_cluster(
@router.post("/name-clusters", response_model=AiNamingResponse)
async def name_clusters(request: AiNamingRequest) -> AiNamingResponse:
result = _get_completed_job(request.job_id)
loop = asyncio.get_running_loop()

names: dict[str, str] = {}
for cluster_index in request.cluster_indices:
async def _name_one(cluster_index: int) -> tuple[str, str]:
item_names = _get_item_names_for_cluster(result, cluster_index)
name = get_cluster_name(
item_names=item_names,
api_key=request.api_key,
model=request.model,
base_url=request.base_url,
temperature=request.temperature,
name = await loop.run_in_executor(
None,
partial(
get_cluster_name,
item_names=item_names,
api_key=request.api_key,
model=request.model,
base_url=request.base_url,
temperature=request.temperature,
),
)
names[str(cluster_index)] = name
return str(cluster_index), name

return AiNamingResponse(names=names)
results = await asyncio.gather(
*(_name_one(idx) for idx in request.cluster_indices),
)
return AiNamingResponse(names=dict(results))


@router.post("/name-sub-clusters", response_model=AiNamingResponse)
async def name_sub_clusters(
request: AiSubClusterNamingRequest,
) -> AiNamingResponse:
result = _get_completed_job(request.job_id)
loop = asyncio.get_running_loop()

unique_labels = sorted(set(request.sub_cluster_labels))
names: dict[str, str] = {}

for label in unique_labels:
async def _name_one(label: int) -> tuple[str, str]:
item_names = _get_item_names_for_sub_cluster(
result,
request.point_ids,
request.sub_cluster_labels,
label,
)
name = get_sub_cluster_name(
item_names=item_names,
api_key=request.api_key,
model=request.model,
base_url=request.base_url,
temperature=request.temperature,
parent_cluster_name=request.parent_cluster_name,
name = await loop.run_in_executor(
None,
partial(
get_sub_cluster_name,
item_names=item_names,
api_key=request.api_key,
model=request.model,
base_url=request.base_url,
temperature=request.temperature,
parent_cluster_name=request.parent_cluster_name,
),
)
names[str(label)] = name
return str(label), name

return AiNamingResponse(names=names)
results = await asyncio.gather(
*(_name_one(lbl) for lbl in unique_labels),
)
return AiNamingResponse(names=dict(results))


@router.post("/test-connection", response_model=AiTestConnectionResponse)
Expand All @@ -158,3 +178,48 @@ async def test_connection(
base_url=request.base_url,
)
return AiTestConnectionResponse(success=success, error=error)


@router.post("/ollama/models", response_model=OllamaModelsResponse)
async def list_ollama_models(
request: OllamaModelsRequest,
) -> OllamaModelsResponse:
"""Proxy to Ollama /api/tags to list locally installed models."""
stripped = request.base_url.rstrip("/")
if stripped.endswith("/v1"):
stripped = stripped[:-3]
url = stripped + "/api/tags"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(url)
resp.raise_for_status()
except httpx.ConnectError:
raise HTTPException(
status_code=502,
detail=f"Cannot connect to Ollama at {request.base_url}",
) from None
except httpx.HTTPStatusError as exc:
raise HTTPException(
status_code=502,
detail=f"Ollama returned {exc.response.status_code}",
) from None
except httpx.TimeoutException:
raise HTTPException(
status_code=504,
detail="Ollama request timed out",
) from None

data = resp.json()
raw_models: list[dict[str, Any]] = data.get("models", [])
models = [
OllamaModel(
name=m.get("name", ""),
size=m.get("size"),
parameter_size=(m.get("details") or {}).get(
"parameter_size",
),
family=(m.get("details") or {}).get("family"),
)
for m in raw_models
]
return OllamaModelsResponse(models=models)
20 changes: 19 additions & 1 deletion frontend/src/api/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,23 @@ import type {
AiSubClusterNamingRequest,
AiTestConnectionRequest,
AiTestConnectionResponse,
OllamaModelsResponse,
} from "../types";
import { apiPost } from "./client";

const AI_SETTINGS_KEY = "ai-cluster-naming-settings";

export const AI_PROVIDERS = [
{ value: "openai", label: "OpenAI", defaultBaseUrl: "" },
{ value: "google", label: "Google", defaultBaseUrl: "" },
{ value: "anthropic", label: "Anthropic", defaultBaseUrl: "" },
{ value: "ollama", label: "Ollama", defaultBaseUrl: "http://localhost:11434" },
] as const;

export type AiProvider = (typeof AI_PROVIDERS)[number]["value"];

export interface StoredAiSettings {
provider: string;
provider: AiProvider;
model: string;
apiKey: string;
baseUrl: string;
Expand Down Expand Up @@ -56,3 +66,11 @@ export async function nameAiSubClusters(
): Promise<AiNamingResponse> {
return apiPost<AiNamingResponse>("/ai/name-sub-clusters", request);
}

export async function fetchOllamaModels(
baseUrl: string = "http://localhost:11434",
): Promise<OllamaModelsResponse> {
return apiPost<OllamaModelsResponse>("/ai/ollama/models", {
base_url: baseUrl,
});
}
2 changes: 1 addition & 1 deletion frontend/src/components/plot/ClusterDetailDrawer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ export default function ClusterDetailDrawer({ jobId, imageField }: ClusterDetail
parentName: string | undefined,
) => {
const settings = loadAiSettings()
if (!settings.apiKey || !jobId) return
if ((!settings.apiKey && settings.provider !== 'ollama') || !jobId) return

setIsNamingSubClusters(true)
try {
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/components/plot/ClusterLegend.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export default function ClusterLegend() {
if (!plotData || !plotJobId || isNamingClusters) return

const settings = loadAiSettings()
if (!settings.apiKey) {
if (!settings.apiKey && settings.provider !== 'ollama') {
setNamingError('Configure AI settings first (Settings page)')
return
}
Expand Down
1 change: 0 additions & 1 deletion frontend/src/components/plot/PlotControls.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,6 @@ export default function PlotControls({ onCompute, isComputing }: PlotControlsPro
{/* Rendering (Render Mode + Point Size) */}
<CollapsibleSection title="Rendering" defaultOpen={false}>
<div className="space-y-1.5">
<label className="block text-xs font-medium text-gray-600">Render Mode</label>
<div className="flex space-x-2">
{(['particles', 'sprites', 'spheres'] as const).map((mode) => (
<label key={mode} className="flex items-center text-xs cursor-pointer">
Expand Down
Loading
Loading