diff --git a/embedding_cluster/ai_naming.py b/embedding_cluster/ai_naming.py index ed62fd2..0af4e8f 100644 --- a/embedding_cluster/ai_naming.py +++ b/embedding_cluster/ai_naming.py @@ -25,6 +25,15 @@ ) +def _normalize_base_url(model: str, base_url: str | None) -> str | None: + """Strip /v1 suffix for Ollama models (litellm uses native API).""" + if base_url and model.startswith("ollama/"): + stripped = base_url.rstrip("/") + if stripped.endswith("/v1"): + return stripped[:-3] + return base_url + + def _call_llm( messages: list[dict[str, str]], api_key: str, @@ -36,11 +45,13 @@ def _call_llm( kwargs: dict[str, object] = { "model": model, "messages": messages, - "api_key": api_key, "temperature": temperature, } - if base_url: - kwargs["api_base"] = base_url + if api_key: + kwargs["api_key"] = api_key + resolved_url = _normalize_base_url(model, base_url) + if resolved_url: + kwargs["api_base"] = resolved_url response = litellm_completion(**kwargs) content: str = response.choices[0].message.content or "" @@ -97,11 +108,13 @@ def test_connection( kwargs: dict[str, object] = { "model": model, "messages": [{"role": "user", "content": "Say hello"}], - "api_key": api_key, "max_tokens": 5, } - if base_url: - kwargs["api_base"] = base_url + if api_key: + kwargs["api_key"] = api_key + resolved_url = _normalize_base_url(model, base_url) + if resolved_url: + kwargs["api_base"] = resolved_url litellm_completion(**kwargs) return True, None except Exception as exc: diff --git a/embedding_cluster/server/models.py b/embedding_cluster/server/models.py index 56dcd22..0d574c0 100644 --- a/embedding_cluster/server/models.py +++ b/embedding_cluster/server/models.py @@ -281,3 +281,18 @@ class AiTestConnectionRequest(BaseModel): class AiTestConnectionResponse(BaseModel): success: bool error: str | None = None + + +class OllamaModelsRequest(BaseModel): + base_url: str = "http://localhost:11434" + + +class OllamaModel(BaseModel): + name: str + size: int | None = None + parameter_size: str | None = None + family: str | None = None + + +class OllamaModelsResponse(BaseModel): + models: list[OllamaModel] diff --git a/embedding_cluster/server/routes/ai.py b/embedding_cluster/server/routes/ai.py index bd029fa..4cb0683 100644 --- a/embedding_cluster/server/routes/ai.py +++ b/embedding_cluster/server/routes/ai.py @@ -1,9 +1,12 @@ from __future__ import annotations +import asyncio import logging import random +from functools import partial from typing import Any, cast +import httpx from fastapi import APIRouter, HTTPException from embedding_cluster.ai_naming import ( @@ -19,6 +22,9 @@ AiSubClusterNamingRequest, AiTestConnectionRequest, AiTestConnectionResponse, + OllamaModel, + OllamaModelsRequest, + OllamaModelsResponse, ) from embedding_cluster.server.tasks import TaskStatus, task_registry @@ -103,20 +109,27 @@ def _get_item_names_for_sub_cluster( @router.post("/name-clusters", response_model=AiNamingResponse) async def name_clusters(request: AiNamingRequest) -> AiNamingResponse: result = _get_completed_job(request.job_id) + loop = asyncio.get_running_loop() - names: dict[str, str] = {} - for cluster_index in request.cluster_indices: + async def _name_one(cluster_index: int) -> tuple[str, str]: item_names = _get_item_names_for_cluster(result, cluster_index) - name = get_cluster_name( - item_names=item_names, - api_key=request.api_key, - model=request.model, - base_url=request.base_url, - temperature=request.temperature, + name = await loop.run_in_executor( + None, + partial( + get_cluster_name, + item_names=item_names, + api_key=request.api_key, + model=request.model, + base_url=request.base_url, + temperature=request.temperature, + ), ) - names[str(cluster_index)] = name + return str(cluster_index), name - return AiNamingResponse(names=names) + results = await asyncio.gather( + *(_name_one(idx) for idx in request.cluster_indices), + ) + return AiNamingResponse(names=dict(results)) @router.post("/name-sub-clusters", response_model=AiNamingResponse) @@ -124,28 +137,35 @@ async def name_sub_clusters( request: AiSubClusterNamingRequest, ) -> AiNamingResponse: result = _get_completed_job(request.job_id) + loop = asyncio.get_running_loop() unique_labels = sorted(set(request.sub_cluster_labels)) - names: dict[str, str] = {} - for label in unique_labels: + async def _name_one(label: int) -> tuple[str, str]: item_names = _get_item_names_for_sub_cluster( result, request.point_ids, request.sub_cluster_labels, label, ) - name = get_sub_cluster_name( - item_names=item_names, - api_key=request.api_key, - model=request.model, - base_url=request.base_url, - temperature=request.temperature, - parent_cluster_name=request.parent_cluster_name, + name = await loop.run_in_executor( + None, + partial( + get_sub_cluster_name, + item_names=item_names, + api_key=request.api_key, + model=request.model, + base_url=request.base_url, + temperature=request.temperature, + parent_cluster_name=request.parent_cluster_name, + ), ) - names[str(label)] = name + return str(label), name - return AiNamingResponse(names=names) + results = await asyncio.gather( + *(_name_one(lbl) for lbl in unique_labels), + ) + return AiNamingResponse(names=dict(results)) @router.post("/test-connection", response_model=AiTestConnectionResponse) @@ -158,3 +178,48 @@ async def test_connection( base_url=request.base_url, ) return AiTestConnectionResponse(success=success, error=error) + + +@router.post("/ollama/models", response_model=OllamaModelsResponse) +async def list_ollama_models( + request: OllamaModelsRequest, +) -> OllamaModelsResponse: + """Proxy to Ollama /api/tags to list locally installed models.""" + stripped = request.base_url.rstrip("/") + if stripped.endswith("/v1"): + stripped = stripped[:-3] + url = stripped + "/api/tags" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get(url) + resp.raise_for_status() + except httpx.ConnectError: + raise HTTPException( + status_code=502, + detail=f"Cannot connect to Ollama at {request.base_url}", + ) from None + except httpx.HTTPStatusError as exc: + raise HTTPException( + status_code=502, + detail=f"Ollama returned {exc.response.status_code}", + ) from None + except httpx.TimeoutException: + raise HTTPException( + status_code=504, + detail="Ollama request timed out", + ) from None + + data = resp.json() + raw_models: list[dict[str, Any]] = data.get("models", []) + models = [ + OllamaModel( + name=m.get("name", ""), + size=m.get("size"), + parameter_size=(m.get("details") or {}).get( + "parameter_size", + ), + family=(m.get("details") or {}).get("family"), + ) + for m in raw_models + ] + return OllamaModelsResponse(models=models) diff --git a/frontend/src/api/ai.ts b/frontend/src/api/ai.ts index 59f57a9..8f424e7 100644 --- a/frontend/src/api/ai.ts +++ b/frontend/src/api/ai.ts @@ -4,13 +4,23 @@ import type { AiSubClusterNamingRequest, AiTestConnectionRequest, AiTestConnectionResponse, + OllamaModelsResponse, } from "../types"; import { apiPost } from "./client"; const AI_SETTINGS_KEY = "ai-cluster-naming-settings"; +export const AI_PROVIDERS = [ + { value: "openai", label: "OpenAI", defaultBaseUrl: "" }, + { value: "google", label: "Google", defaultBaseUrl: "" }, + { value: "anthropic", label: "Anthropic", defaultBaseUrl: "" }, + { value: "ollama", label: "Ollama", defaultBaseUrl: "http://localhost:11434" }, +] as const; + +export type AiProvider = (typeof AI_PROVIDERS)[number]["value"]; + export interface StoredAiSettings { - provider: string; + provider: AiProvider; model: string; apiKey: string; baseUrl: string; @@ -56,3 +66,11 @@ export async function nameAiSubClusters( ): Promise { return apiPost("/ai/name-sub-clusters", request); } + +export async function fetchOllamaModels( + baseUrl: string = "http://localhost:11434", +): Promise { + return apiPost("/ai/ollama/models", { + base_url: baseUrl, + }); +} diff --git a/frontend/src/components/plot/ClusterDetailDrawer.tsx b/frontend/src/components/plot/ClusterDetailDrawer.tsx index 7db58b3..b33a36d 100644 --- a/frontend/src/components/plot/ClusterDetailDrawer.tsx +++ b/frontend/src/components/plot/ClusterDetailDrawer.tsx @@ -109,7 +109,7 @@ export default function ClusterDetailDrawer({ jobId, imageField }: ClusterDetail parentName: string | undefined, ) => { const settings = loadAiSettings() - if (!settings.apiKey || !jobId) return + if ((!settings.apiKey && settings.provider !== 'ollama') || !jobId) return setIsNamingSubClusters(true) try { diff --git a/frontend/src/components/plot/ClusterLegend.tsx b/frontend/src/components/plot/ClusterLegend.tsx index f836cc5..b034ea1 100644 --- a/frontend/src/components/plot/ClusterLegend.tsx +++ b/frontend/src/components/plot/ClusterLegend.tsx @@ -36,7 +36,7 @@ export default function ClusterLegend() { if (!plotData || !plotJobId || isNamingClusters) return const settings = loadAiSettings() - if (!settings.apiKey) { + if (!settings.apiKey && settings.provider !== 'ollama') { setNamingError('Configure AI settings first (Settings page)') return } diff --git a/frontend/src/components/plot/PlotControls.tsx b/frontend/src/components/plot/PlotControls.tsx index 92a0e69..f71c628 100644 --- a/frontend/src/components/plot/PlotControls.tsx +++ b/frontend/src/components/plot/PlotControls.tsx @@ -316,7 +316,6 @@ export default function PlotControls({ onCompute, isComputing }: PlotControlsPro {/* Rendering (Render Mode + Point Size) */}
-
{(['particles', 'sprites', 'spheres'] as const).map((mode) => (
- handleChange('model', e.target.value)} - className="w-full border-gray-300 rounded-md shadow-sm focus:border-blue-500 focus:ring-blue-500 sm:text-sm p-2 border" - placeholder="e.g. gpt-4o-mini" - /> + {settings.provider === 'ollama' && ollamaModels.length > 0 ? ( + + ) : ( + handleChange('model', e.target.value)} + className="w-full border-gray-300 rounded-md shadow-sm focus:border-blue-500 focus:ring-blue-500 sm:text-sm p-2 border" + placeholder="e.g. gpt-4o-mini" + /> + )} + {settings.provider === 'ollama' && ollamaModelsLoading && ( +

Loading models...

+ )} + {settings.provider === 'ollama' && ollamaModelsError && ( +

{ollamaModelsError}

+ )}
@@ -146,7 +216,7 @@ export default function SettingsPage() {