From d0d7f1f33b54ad665f8246b57a119efd30ed5699 Mon Sep 17 00:00:00 2001 From: djgpp6 Date: Tue, 9 Jun 2026 04:48:47 +0530 Subject: [PATCH] Auto-discover models when adding custom model configs When setting up a custom model, query the provider's models endpoint after entering an endpoint and API key. Populate the model field from the response and auto-select when only one model is available. Fixes #204 --- API.md | 2 + server/remote_models.go | 291 ++++++++++++++++++++++++++++++ server/remote_models_test.go | 165 +++++++++++++++++ server/server.go | 1 + ui/src/components/ModelsModal.tsx | 107 ++++++++++- ui/src/services/api.ts | 26 +++ 6 files changed, 587 insertions(+), 5 deletions(-) create mode 100644 server/remote_models.go create mode 100644 server/remote_models_test.go diff --git a/API.md b/API.md index 7f4d610b..c596813d 100644 --- a/API.md +++ b/API.md @@ -212,6 +212,8 @@ fresh reset event. - `GET /api/tools` — registered tool definitions. - `GET/POST/PUT/DELETE /api/custom-models[/]` — custom model CRUD. - `POST /api/custom-models-test` — test a custom model config. +- `POST /api/custom-models-discover` — list models from a remote provider's + models endpoint (e.g. `/v1/models`). - `GET/POST/PUT/DELETE /api/notification-channels[/]`, `GET /api/notification-channel-types` — notification CRUD. diff --git a/server/remote_models.go b/server/remote_models.go new file mode 100644 index 00000000..dd8e348b --- /dev/null +++ b/server/remote_models.go @@ -0,0 +1,291 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +const remoteModelsTimeout = 15 * time.Second + +// DiscoverModelsRequest is the request body for listing models from a remote API. +type DiscoverModelsRequest struct { + ModelID string `json:"model_id,omitempty"` // If provided with empty api_key, use stored key + ProviderType string `json:"provider_type"` + Endpoint string `json:"endpoint"` + APIKey string `json:"api_key"` +} + +// RemoteModelOption is one model returned by a remote provider's models endpoint. +type RemoteModelOption struct { + ID string `json:"id"` + DisplayName string `json:"display_name,omitempty"` +} + +type openAIModelsResponse struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` +} + +type anthropicModelsResponse struct { + Data []struct { + ID string `json:"id"` + DisplayName string `json:"display_name"` + } `json:"data"` +} + +type geminiModelsResponse struct { + Models []struct { + Name string `json:"name"` + DisplayName string `json:"displayName"` + } `json:"models"` +} + +func (s *Server) handleDiscoverModels(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req DiscoverModelsRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest) + return + } + + if req.ModelID != "" && req.APIKey == "" { + model, err := s.db.GetModel(r.Context(), req.ModelID) + if err != nil { + http.Error(w, fmt.Sprintf("Model not found: %v", err), http.StatusNotFound) + return + } + req.APIKey = model.ApiKey + } + + if req.ProviderType == "" || req.Endpoint == "" || req.APIKey == "" { + http.Error(w, "provider_type, endpoint, and api_key are required", http.StatusBadRequest) + return + } + + ctx, cancel := context.WithTimeout(r.Context(), remoteModelsTimeout) + defer cancel() + + models, err := discoverRemoteModels(ctx, http.DefaultClient, req.ProviderType, req.Endpoint, req.APIKey) + w.Header().Set("Content-Type", "application/json") + if err != nil { + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": false, + "message": err.Error(), + "models": []RemoteModelOption{}, + }) + return + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": true, + "models": models, + }) +} + +func discoverRemoteModels(ctx context.Context, httpc *http.Client, providerType, endpoint, apiKey string) ([]RemoteModelOption, error) { + switch providerType { + case "openai", "openai-responses": + return discoverOpenAIModels(ctx, httpc, endpoint, apiKey) + case "anthropic": + return discoverAnthropicModels(ctx, httpc, endpoint, apiKey) + case "gemini": + return discoverGeminiModels(ctx, httpc, endpoint, apiKey) + default: + return nil, fmt.Errorf("unsupported provider_type %q", providerType) + } +} + +func discoverOpenAIModels(ctx context.Context, httpc *http.Client, endpoint, apiKey string) ([]RemoteModelOption, error) { + modelsURL, err := modelsListURL(endpoint) + if err != nil { + return nil, err + } + body, status, err := getRemoteJSON(ctx, httpc, modelsURL, openAIAuthHeaders(apiKey)) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("models endpoint returned HTTP %d", status) + } + var resp openAIModelsResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to parse models response: %w", err) + } + return remoteModelsFromIDs(resp.Data, func(item struct { + ID string `json:"id"` + }) (string, string) { + return item.ID, item.ID + }), nil +} + +func discoverAnthropicModels(ctx context.Context, httpc *http.Client, endpoint, apiKey string) ([]RemoteModelOption, error) { + modelsURL, err := modelsListURL(endpoint) + if err != nil { + return nil, err + } + body, status, err := getRemoteJSON(ctx, httpc, modelsURL, anthropicAuthHeaders(apiKey)) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("models endpoint returned HTTP %d", status) + } + var resp anthropicModelsResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to parse models response: %w", err) + } + return remoteModelsFromIDs(resp.Data, func(item struct { + ID string `json:"id"` + DisplayName string `json:"display_name"` + }) (string, string) { + displayName := item.DisplayName + if displayName == "" { + displayName = item.ID + } + return item.ID, displayName + }), nil +} + +func discoverGeminiModels(ctx context.Context, httpc *http.Client, endpoint, apiKey string) ([]RemoteModelOption, error) { + modelsURL, err := geminiModelsListURL(endpoint, apiKey) + if err != nil { + return nil, err + } + body, status, err := getRemoteJSON(ctx, httpc, modelsURL, nil) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("models endpoint returned HTTP %d", status) + } + var resp geminiModelsResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to parse models response: %w", err) + } + var out []RemoteModelOption + for _, model := range resp.Models { + id := strings.TrimPrefix(model.Name, "models/") + if id == "" { + continue + } + displayName := model.DisplayName + if displayName == "" { + displayName = id + } + out = append(out, RemoteModelOption{ID: id, DisplayName: displayName}) + } + return out, nil +} + +func remoteModelsFromIDs[T any](items []T, pick func(T) (id, displayName string)) []RemoteModelOption { + out := make([]RemoteModelOption, 0, len(items)) + for _, item := range items { + id, displayName := pick(item) + if id == "" { + continue + } + out = append(out, RemoteModelOption{ID: id, DisplayName: displayName}) + } + return out +} + +func modelsListURL(endpoint string) (string, error) { + endpoint = strings.TrimSpace(endpoint) + if endpoint == "" { + return "", fmt.Errorf("endpoint is required") + } + parsed, err := url.Parse(endpoint) + if err != nil { + return "", fmt.Errorf("invalid endpoint URL: %w", err) + } + if parsed.Scheme == "" || parsed.Host == "" { + return "", fmt.Errorf("endpoint must be an absolute URL") + } + + path := strings.TrimSuffix(parsed.Path, "/") + for _, suffix := range []string{"/chat/completions", "/responses", "/messages"} { + if strings.HasSuffix(path, suffix) { + path = strings.TrimSuffix(path, suffix) + break + } + } + switch { + case strings.HasSuffix(path, "/models"): + // already a models endpoint + case strings.HasSuffix(path, "/v1beta"): + path += "/models" + case strings.HasSuffix(path, "/v1"): + path += "/models" + case idx := strings.LastIndex(path, "/v1"); idx >= 0: + path = path[:idx+3] + "/models" + default: + path += "/models" + } + parsed.Path = path + parsed.RawQuery = "" + parsed.Fragment = "" + return parsed.String(), nil +} + +func geminiModelsListURL(endpoint, apiKey string) (string, error) { + u, err := modelsListURL(endpoint) + if err != nil { + return "", err + } + parsed, err := url.Parse(u) + if err != nil { + return "", err + } + q := parsed.Query() + q.Set("key", apiKey) + parsed.RawQuery = q.Encode() + return parsed.String(), nil +} + +func openAIAuthHeaders(apiKey string) http.Header { + h := make(http.Header) + h.Set("Authorization", "Bearer "+apiKey) + return h +} + +func anthropicAuthHeaders(apiKey string) http.Header { + h := make(http.Header) + h.Set("x-api-key", apiKey) + h.Set("anthropic-version", "2023-06-01") + return h +} + +func getRemoteJSON(ctx context.Context, httpc *http.Client, rawURL string, headers http.Header) ([]byte, int, error) { + if httpc == nil { + httpc = http.DefaultClient + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + if err != nil { + return nil, 0, err + } + for k, vals := range headers { + for _, v := range vals { + req.Header.Add(k, v) + } + } + resp, err := httpc.Do(req) + if err != nil { + return nil, 0, fmt.Errorf("failed to reach models endpoint: %w", err) + } + defer resp.Body.Close() + body, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) + if err != nil { + return nil, resp.StatusCode, fmt.Errorf("failed to read models response: %w", err) + } + return body, resp.StatusCode, nil +} \ No newline at end of file diff --git a/server/remote_models_test.go b/server/remote_models_test.go new file mode 100644 index 00000000..0f0e301c --- /dev/null +++ b/server/remote_models_test.go @@ -0,0 +1,165 @@ +package server + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestModelsListURL(t *testing.T) { + t.Parallel() + tests := []struct { + endpoint string + want string + }{ + {"https://api.openai.com/v1", "https://api.openai.com/v1/models"}, + {"https://api.anthropic.com/v1/messages", "https://api.anthropic.com/v1/models"}, + {"https://generativelanguage.googleapis.com/v1beta", "https://generativelanguage.googleapis.com/v1beta/models"}, + {"https://llm.example.com/v1/", "https://llm.example.com/v1/models"}, + {"https://llm.example.com/v1/models", "https://llm.example.com/v1/models"}, + {"https://llm.example.com/v1/chat/completions", "https://llm.example.com/v1/models"}, + } + for _, tc := range tests { + got, err := modelsListURL(tc.endpoint) + if err != nil { + t.Fatalf("modelsListURL(%q) error: %v", tc.endpoint, err) + } + if got != tc.want { + t.Errorf("modelsListURL(%q) = %q, want %q", tc.endpoint, got, tc.want) + } + } +} + +func TestDiscoverOpenAIModels(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Fatalf("unexpected auth header: %q", got) + } + json.NewEncoder(w).Encode(openAIModelsResponse{ + Data: []struct { + ID string `json:"id"` + }{ + {ID: "gpt-5.5"}, + {ID: "gpt-5.4"}, + }, + }) + })) + defer server.Close() + + models, err := discoverOpenAIModels(t.Context(), server.Client(), server.URL+"/v1", "test-key") + if err != nil { + t.Fatalf("discoverOpenAIModels failed: %v", err) + } + if len(models) != 2 { + t.Fatalf("expected 2 models, got %d", len(models)) + } + if models[0].ID != "gpt-5.5" || models[1].ID != "gpt-5.4" { + t.Fatalf("unexpected models: %+v", models) + } +} + +func TestDiscoverAnthropicModels(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + if got := r.Header.Get("x-api-key"); got != "anthropic-key" { + t.Fatalf("unexpected api key header: %q", got) + } + json.NewEncoder(w).Encode(anthropicModelsResponse{ + Data: []struct { + ID string `json:"id"` + DisplayName string `json:"display_name"` + }{ + {ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6"}, + }, + }) + })) + defer server.Close() + + models, err := discoverAnthropicModels(t.Context(), server.Client(), server.URL+"/v1/messages", "anthropic-key") + if err != nil { + t.Fatalf("discoverAnthropicModels failed: %v", err) + } + if len(models) != 1 || models[0].ID != "claude-sonnet-4-6" || models[0].DisplayName != "Claude Sonnet 4.6" { + t.Fatalf("unexpected models: %+v", models) + } +} + +func TestDiscoverGeminiModels(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1beta/models" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + if got := r.URL.Query().Get("key"); got != "gemini-key" { + t.Fatalf("unexpected key query: %q", got) + } + json.NewEncoder(w).Encode(geminiModelsResponse{ + Models: []struct { + Name string `json:"name"` + DisplayName string `json:"displayName"` + }{ + {Name: "models/gemini-3-pro-preview", DisplayName: "Gemini 3 Pro"}, + }, + }) + })) + defer server.Close() + + models, err := discoverGeminiModels(t.Context(), server.Client(), server.URL+"/v1beta", "gemini-key") + if err != nil { + t.Fatalf("discoverGeminiModels failed: %v", err) + } + if len(models) != 1 || models[0].ID != "gemini-3-pro-preview" || models[0].DisplayName != "Gemini 3 Pro" { + t.Fatalf("unexpected models: %+v", models) + } +} + +func TestHandleDiscoverModelsEndpoint(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(openAIModelsResponse{ + Data: []struct { + ID string `json:"id"` + }{{ID: "only-model"}}, + }) + })) + defer server.Close() + + h := NewTestHarness(t) + reqBody, err := json.Marshal(DiscoverModelsRequest{ + ProviderType: "openai", + Endpoint: server.URL + "/v1", + APIKey: "test-key", + }) + if err != nil { + t.Fatalf("marshal request: %v", err) + } + req := httptest.NewRequest(http.MethodPost, "/api/custom-models-discover", bytes.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.server.handleDiscoverModels(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + var result map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil { + t.Fatalf("parse response: %v", err) + } + if success, ok := result["success"].(bool); !ok || !success { + t.Fatalf("expected success=true, got %+v", result) + } + models, ok := result["models"].([]any) + if !ok || len(models) != 1 { + t.Fatalf("expected one model, got %+v", result["models"]) + } +} \ No newline at end of file diff --git a/server/server.go b/server/server.go index cf8ba2b5..68f44944 100644 --- a/server/server.go +++ b/server/server.go @@ -460,6 +460,7 @@ func (s *Server) RegisterRoutes(mux *http.ServeMux) { mux.Handle("/api/custom-models", http.HandlerFunc(s.handleCustomModels)) mux.Handle("/api/custom-models/", http.HandlerFunc(s.handleCustomModel)) mux.Handle("/api/custom-models-test", http.HandlerFunc(s.handleTestModel)) + mux.Handle("/api/custom-models-discover", http.HandlerFunc(s.handleDiscoverModels)) // Notification channels API mux.Handle("/api/notification-channels", http.HandlerFunc(s.handleNotificationChannels)) diff --git a/ui/src/components/ModelsModal.tsx b/ui/src/components/ModelsModal.tsx index 85fd6ec7..33e381a2 100644 --- a/ui/src/components/ModelsModal.tsx +++ b/ui/src/components/ModelsModal.tsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect, useCallback } from "react"; +import React, { useState, useEffect, useCallback, useMemo, useRef } from "react"; import Modal from "./Modal"; import { useI18n } from "../i18n"; import { @@ -167,6 +167,11 @@ function ModelsModal({ isOpen, onClose, onModelsChanged }: ModelsModalProps) { // Tooltip state const [showTagsTooltip, setShowTagsTooltip] = useState(false); + // Remote model discovery from /v1/models (or provider equivalent) + const [remoteModels, setRemoteModels] = useState<{ name: string; model_name: string }[]>([]); + const [loadingRemoteModels, setLoadingRemoteModels] = useState(false); + const discoverRequestId = useRef(0); + const loadModels = useCallback(async () => { try { setLoading(true); @@ -195,7 +200,95 @@ function ModelsModal({ isOpen, onClose, onModelsChanged }: ModelsModalProps) { } }, [isOpen, loadModels, setBuiltInFromModelList]); + const modelSuggestions = useMemo(() => { + const seen = new Set(); + const suggestions: { name: string; model_name: string }[] = []; + const add = (name: string, model_name: string) => { + if (!model_name || seen.has(model_name)) { + return; + } + seen.add(model_name); + suggestions.push({ name, model_name }); + }; + for (const preset of remoteModels) { + add(preset.name, preset.model_name); + } + for (const preset of DEFAULT_MODELS[form.provider_type]) { + add(preset.name, preset.model_name); + } + return suggestions; + }, [remoteModels, form.provider_type]); + + useEffect(() => { + if (!showForm) { + setRemoteModels([]); + return; + } + + const canDiscover = Boolean(form.endpoint && (form.api_key || editingModelId)); + if (!canDiscover) { + setRemoteModels([]); + return; + } + + const requestId = ++discoverRequestId.current; + const timer = window.setTimeout(async () => { + setLoadingRemoteModels(true); + try { + const result = await customModelsApi.discoverRemoteModels({ + model_id: editingModelId || undefined, + provider_type: form.provider_type, + endpoint: form.endpoint, + api_key: form.api_key, + }); + if (requestId !== discoverRequestId.current) { + return; + } + if (!result.success) { + setRemoteModels([]); + return; + } + const discovered = result.models.map((model) => ({ + name: model.display_name || model.id, + model_name: model.id, + })); + setRemoteModels(discovered); + if (discovered.length === 1) { + setForm((prev) => { + if (prev.model_name) { + return prev; + } + const only = discovered[0]; + return { + ...prev, + model_name: only.model_name, + display_name: prev.display_name || only.name, + }; + }); + } + } catch { + if (requestId === discoverRequestId.current) { + setRemoteModels([]); + } + } finally { + if (requestId === discoverRequestId.current) { + setLoadingRemoteModels(false); + } + } + }, 500); + + return () => window.clearTimeout(timer); + }, [ + showForm, + editingModelId, + form.provider_type, + form.endpoint, + form.api_key, + form.endpoint_custom, + ]); + const handleProviderChange = (provider: ProviderType) => { + setRemoteModels([]); setForm((prev) => ({ ...prev, provider_type: provider, @@ -204,6 +297,7 @@ function ModelsModal({ isOpen, onClose, onModelsChanged }: ModelsModalProps) { }; const handleEndpointModeChange = (custom: boolean) => { + setRemoteModels([]); setForm((prev) => ({ ...prev, endpoint_custom: custom, @@ -328,6 +422,7 @@ function ModelsModal({ isOpen, onClose, onModelsChanged }: ModelsModalProps) { setEditingModelId(null); setForm(emptyForm); setTestResult(null); + setRemoteModels([]); }; const handleAddNew = () => { @@ -335,6 +430,7 @@ function ModelsModal({ isOpen, onClose, onModelsChanged }: ModelsModalProps) { setForm(emptyForm); setShowForm(true); setTestResult(null); + setRemoteModels([]); }; const handleRefreshModels = async () => { @@ -460,9 +556,7 @@ function ModelsModal({ isOpen, onClose, onModelsChanged }: ModelsModalProps) { // If the user picked a known suggestion and the display // name is empty, pre-fill it from the preset's friendly // name. Never overwrite a non-empty display name. - const preset = DEFAULT_MODELS[prev.provider_type].find( - (p) => p.model_name === v, - ); + const preset = modelSuggestions.find((p) => p.model_name === v); return { ...prev, model_name: v, @@ -475,8 +569,11 @@ function ModelsModal({ isOpen, onClose, onModelsChanged }: ModelsModalProps) { list={`model-name-suggestions-${form.provider_type}`} autoComplete="off" /> + {loadingRemoteModels && ( +
{t("loadingModels")}
+ )} - {DEFAULT_MODELS[form.provider_type].map((preset) => ( + {modelSuggestions.map((preset) => ( diff --git a/ui/src/services/api.ts b/ui/src/services/api.ts index 76487254..29ad9809 100644 --- a/ui/src/services/api.ts +++ b/ui/src/services/api.ts @@ -711,6 +711,18 @@ export interface TestCustomModelRequest { reasoning_effort?: string; } +export interface DiscoverRemoteModelsRequest { + model_id?: string; // If provided with empty api_key, use stored key + provider_type: "anthropic" | "openai" | "openai-responses" | "gemini"; + endpoint: string; + api_key: string; +} + +export interface RemoteModelOption { + id: string; + display_name?: string; +} + class CustomModelsApi { private baseUrl = "/api"; @@ -787,6 +799,20 @@ class CustomModelsApi { } return response.json(); } + + async discoverRemoteModels( + request: DiscoverRemoteModelsRequest, + ): Promise<{ success: boolean; models: RemoteModelOption[]; message?: string }> { + const response = await fetch(`${this.baseUrl}/custom-models-discover`, { + method: "POST", + headers: this.postHeaders, + body: JSON.stringify(request), + }); + if (!response.ok) { + throw await responseError(response, "Failed to discover remote models"); + } + return response.json(); + } } export const customModelsApi = new CustomModelsApi();