From a5ee7e25173f136360bb726578e41338c060aab8 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Sun, 3 May 2026 16:46:25 -0500 Subject: [PATCH] Fix provider-prefixed model routing --- .gitignore | 3 + autorouter.go | 7 +- autorouter_test.go | 2 + detection.go | 2 +- detection_test.go | 1 + interceptors/billing_test.go | 1 + models_live_test.go | 524 +++++++++++++++++++++ providers/anthropic/parser.go | 6 +- providers/anthropic/parser_test.go | 13 + providers/googleai/parser.go | 7 +- providers/googleai/parser_test.go | 33 ++ providers/googleai/resolver.go | 4 + providers/openai_compatible/parser_test.go | 18 + providers/perplexity/provider_test.go | 17 + providers/perplexity/resolver.go | 11 +- 15 files changed, 640 insertions(+), 9 deletions(-) create mode 100644 models_live_test.go diff --git a/.gitignore b/.gitignore index 1521c8b..ec14841 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ dist +.env +out +api.json diff --git a/autorouter.go b/autorouter.go index 7744431..f6ebb53 100644 --- a/autorouter.go +++ b/autorouter.go @@ -553,6 +553,7 @@ var knownProviderPrefixes = map[string]bool{ "perplexity": true, "bedrock": true, "azure": true, + "mistral": true, } func stripProviderPrefix(model string) (stripped string, hasPrefix bool) { @@ -562,7 +563,11 @@ func stripProviderPrefix(model string) (stripped string, hasPrefix bool) { } prefix := model[:idx] if knownProviderPrefixes[prefix] { - return model[idx+1:], true + stripped = model[idx+1:] + if strings.HasPrefix(stripped, prefix+"/") { + stripped = strings.TrimPrefix(stripped, prefix+"/") + } + return stripped, true } return model, false } diff --git a/autorouter_test.go b/autorouter_test.go index 4f06b66..5af8a68 100644 --- a/autorouter_test.go +++ b/autorouter_test.go @@ -292,8 +292,10 @@ func TestStripProviderPrefix(t *testing.T) { {"fireworks prefix", "fireworks/accounts/fireworks/models/llama", "accounts/fireworks/models/llama", true}, {"xai prefix", "xai/grok-1", "grok-1", true}, {"perplexity prefix", "perplexity/sonar-small", "sonar-small", true}, + {"repeated perplexity prefix", "perplexity/perplexity/sonar", "sonar", true}, {"bedrock prefix", "bedrock/anthropic.claude-3", "anthropic.claude-3", true}, {"azure prefix", "azure/gpt-4-deployment", "gpt-4-deployment", true}, + {"mistral prefix", "mistral/codestral-2508", "codestral-2508", true}, {"multiple slashes preserved", "openai/gpt-4/turbo", "gpt-4/turbo", true}, {"empty string", "", "", false}, {"slash only - not a provider", "/", "/", false}, diff --git a/detection.go b/detection.go index 879b0ed..28609f8 100644 --- a/detection.go +++ b/detection.go @@ -84,7 +84,7 @@ func DetectProviderFromModel(model string) string { if idx := strings.Index(model, "/"); idx >= 0 { prefix := model[:idx] switch prefix { - case "openai", "anthropic", "googleai", "groq", "fireworks", "xai", "perplexity", "bedrock", "azure": + case "openai", "anthropic", "googleai", "groq", "fireworks", "xai", "perplexity", "bedrock", "azure", "mistral": return prefix } } diff --git a/detection_test.go b/detection_test.go index b4571b2..3157aff 100644 --- a/detection_test.go +++ b/detection_test.go @@ -138,6 +138,7 @@ func TestDetectProviderFromModel(t *testing.T) { {"perplexity/sonar prefix", "perplexity/sonar-small", "perplexity"}, {"bedrock/claude prefix", "bedrock/anthropic.claude-3", "bedrock"}, {"azure/gpt-4 prefix", "azure/gpt-4", "azure"}, + {"mistral/codestral prefix", "mistral/codestral-2508", "mistral"}, {"unknown/ prefix returns unknown", "unknown/model", ""}, {"single slash only", "/", ""}, } diff --git a/interceptors/billing_test.go b/interceptors/billing_test.go index 47bf7e0..4837ac8 100644 --- a/interceptors/billing_test.go +++ b/interceptors/billing_test.go @@ -285,6 +285,7 @@ func TestDetectProvider(t *testing.T) { {"gemini-1.5-flash", "googleai"}, {"llama-3-70b", "openai_compatible"}, {"mixtral-8x7b", "openai_compatible"}, + {"mistral/codestral-2508", "mistral"}, {"unknown-model", ""}, } diff --git a/models_live_test.go b/models_live_test.go new file mode 100644 index 0000000..814bad6 --- /dev/null +++ b/models_live_test.go @@ -0,0 +1,524 @@ +//go:build integration + +package llmproxy + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "slices" + "strconv" + "strings" + "testing" + "time" +) + +const defaultLiveModelsURL = "https://aigateway-usw.agentuity.cloud/models" + +func TestLiveModelsSmoke(t *testing.T) { + if os.Getenv("LLMPROXY_LIVE_MODEL_SMOKE") != "1" { + t.Skip("set LLMPROXY_LIVE_MODEL_SMOKE=1 to run live model smoke tests") + } + + modelsURL := envOrDefault("LLMPROXY_MODELS_URL", defaultLiveModelsURL) + client := &http.Client{Timeout: envDuration("LLMPROXY_LIVE_MODEL_TIMEOUT", 60*time.Second)} + + models, err := fetchLiveModels(client, modelsURL) + if err != nil { + t.Fatalf("fetch models: %v", err) + } + if len(models) == 0 { + t.Fatal("models API returned no models") + } + + if limit := envInt("LLMPROXY_LIVE_MODEL_LIMIT", 0); limit > 0 && limit < len(models) { + models = models[:limit] + } + if allowlist := envSet("LLMPROXY_LIVE_MODEL_IDS"); len(allowlist) > 0 { + models = filterLiveModels(models, allowlist) + if len(models) == 0 { + t.Fatalf("no models matched LLMPROXY_LIVE_MODEL_IDS") + } + } + + envByProvider, missing := liveProviderEnv(models) + if len(missing) > 0 { + t.Fatalf("missing provider env vars:\n%s", strings.Join(missing, "\n")) + } + + concurrency := envInt("LLMPROXY_LIVE_MODEL_CONCURRENCY", 3) + if concurrency < 1 { + concurrency = 1 + } + sem := make(chan struct{}, concurrency) + + for _, model := range models { + model := model + t.Run(model.ID, func(t *testing.T) { + t.Parallel() + if reason := liveSkipReason(model); reason != "" { + t.Skip(reason) + } + sem <- struct{}{} + defer func() { <-sem }() + + apiKey := envByProvider[model.ProviderName] + if apiKey == "" { + t.Fatalf("no API key resolved for provider %q", model.ProviderName) + } + + req, err := newLiveModelRequest(t.Context(), model, apiKey) + if err != nil { + t.Fatalf("build request: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read response: %v", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + t.Fatalf("status %d: %s", resp.StatusCode, truncateForLog(body)) + } + if err := validateLiveModelResponseShape(model, body); err != nil { + t.Fatalf("invalid response shape: %v\nbody: %s", err, truncateForLog(body)) + } + }) + } +} + +type liveModelsResponse struct { + Data map[string][]liveModel `json:"data"` +} + +type liveModel struct { + ID string `json:"id"` + API string `json:"api"` + InputModalities []string `json:"input_modalities"` + OutputModalities []string `json:"output_modalities"` + Provider liveModelProvider `json:"provider"` + ProviderName string `json:"-"` +} + +type liveModelProvider struct { + Env []string `json:"env"` + API string `json:"api"` +} + +func fetchLiveModels(client *http.Client, modelsURL string) ([]liveModel, error) { + req, err := http.NewRequest(http.MethodGet, modelsURL, nil) + if err != nil { + return nil, err + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("status %d: %s", resp.StatusCode, truncateForLog(body)) + } + + var parsed liveModelsResponse + if err := json.Unmarshal(body, &parsed); err != nil { + return nil, err + } + + var models []liveModel + for provider, providerModels := range parsed.Data { + for _, model := range providerModels { + model.ProviderName = provider + models = append(models, model) + } + } + slices.SortFunc(models, func(a, b liveModel) int { + return strings.Compare(a.ID, b.ID) + }) + return models, nil +} + +func liveProviderEnv(models []liveModel) (map[string]string, []string) { + envByProvider := make(map[string]string) + missingByProvider := make(map[string][]string) + + for _, model := range models { + if envByProvider[model.ProviderName] != "" { + continue + } + envNames := liveEnvNames(model) + for _, name := range envNames { + if value := os.Getenv(name); value != "" { + envByProvider[model.ProviderName] = value + break + } + } + if envByProvider[model.ProviderName] == "" { + missingByProvider[model.ProviderName] = envNames + } + } + + var missing []string + for provider, envNames := range missingByProvider { + missing = append(missing, fmt.Sprintf("%s: set one of %s", provider, strings.Join(envNames, ", "))) + } + slices.Sort(missing) + return envByProvider, missing +} + +func liveEnvNames(model liveModel) []string { + envNames := append([]string(nil), model.Provider.Env...) + if model.ProviderName == "googleai" { + envNames = append(envNames, "GOOGLE_AI_API_KEY") + } + slices.Sort(envNames) + return slices.Compact(envNames) +} + +func newLiveModelRequest(ctx context.Context, model liveModel, apiKey string) (*http.Request, error) { + upstreamModel := stripLiveProviderPrefix(model.ID) + endpoint, body, err := liveModelEndpointAndBody(model, upstreamModel) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + setLiveAuthHeaders(req, model.ProviderName, apiKey) + return req, nil +} + +func liveModelEndpointAndBody(model liveModel, upstreamModel string) (string, []byte, error) { + if liveUseChatCompletions(model) { + body := map[string]any{ + "model": upstreamModel, + "max_tokens": 16, + "messages": []map[string]any{ + {"role": "user", "content": "hi"}, + }, + } + return liveChatCompletionsURL(model), mustJSON(body), nil + } + if liveUseCompletions(model) { + body := map[string]any{ + "model": upstreamModel, + "max_tokens": 16, + "prompt": "hi", + } + return liveCompletionsURL(model), mustJSON(body), nil + } + + switch model.API { + case "anthropic-messages": + body := map[string]any{ + "model": upstreamModel, + "max_tokens": 16, + "messages": []map[string]any{ + {"role": "user", "content": "hi"}, + }, + } + return joinLiveURLPath(model.Provider.API, "v1", "messages"), mustJSON(body), nil + case "google-generative-ai": + body := map[string]any{ + "contents": []map[string]any{ + { + "role": "user", + "parts": []map[string]any{ + {"text": "hi"}, + }, + }, + }, + "generationConfig": map[string]any{ + "maxOutputTokens": 16, + }, + } + return joinLiveURLPath(model.Provider.API, "v1beta", "models", upstreamModel+":generateContent"), mustJSON(body), nil + case "mistral-conversations", "openai-completions": + body := map[string]any{ + "model": upstreamModel, + "max_tokens": 16, + "messages": []map[string]any{ + {"role": "user", "content": "hi"}, + }, + } + return liveChatCompletionsURL(model), mustJSON(body), nil + case "openai-codex-responses", "openai-responses": + body := map[string]any{ + "model": upstreamModel, + "input": "hi", + "max_output_tokens": liveMaxOutputTokens(model), + } + if liveIsOpenAIReasoningModel(model) { + setLiveDefaultReasoningEffort(body, liveReasoningEffort(model)) + } + return joinLiveURLPath(model.Provider.API, "v1", "responses"), mustJSON(body), nil + default: + return "", nil, fmt.Errorf("unsupported API %q for model %s", model.API, model.ID) + } +} + +func liveChatCompletionsURL(model liveModel) string { + switch model.ProviderName { + case "cohere": + return joinLiveURLPath(model.Provider.API, "compatibility", "v1", "chat", "completions") + case "deepseek": + return joinLiveURLPath(model.Provider.API, "chat", "completions") + case "perplexity": + return joinLiveURLPath(model.Provider.API, "v1", "sonar") + default: + return joinLiveURLPath(model.Provider.API, "v1", "chat", "completions") + } +} + +func liveCompletionsURL(model liveModel) string { + return joinLiveURLPath(model.Provider.API, "v1", "completions") +} + +func setLiveAuthHeaders(req *http.Request, provider, apiKey string) { + switch provider { + case "anthropic": + req.Header.Set("x-api-key", apiKey) + req.Header.Set("anthropic-version", "2023-06-01") + case "googleai": + req.Header.Set("x-goog-api-key", apiKey) + default: + req.Header.Set("Authorization", "Bearer "+apiKey) + } +} + +func validateLiveModelResponseShape(model liveModel, body []byte) error { + var raw map[string]any + if err := json.Unmarshal(body, &raw); err != nil { + return err + } + if raw["error"] != nil { + return fmt.Errorf("response contains error: %v", raw["error"]) + } + + if liveUseChatCompletions(model) || liveUseCompletions(model) { + if nonEmptyArray(raw["choices"]) { + return nil + } + return fmt.Errorf("missing expected fields for %s", model.API) + } + + switch model.API { + case "anthropic-messages": + if nonEmptyString(raw["id"]) && nonEmptyArray(raw["content"]) { + return nil + } + case "google-generative-ai": + if nonEmptyArray(raw["candidates"]) { + return nil + } + case "mistral-conversations", "openai-completions": + if nonEmptyArray(raw["choices"]) || nonEmptyArray(raw["outputs"]) { + return nil + } + case "openai-codex-responses", "openai-responses": + if nonEmptyString(raw["id"]) && (nonEmptyString(raw["output_text"]) || nonEmptyArray(raw["output"])) { + return nil + } + default: + return fmt.Errorf("unsupported API %q", model.API) + } + + return fmt.Errorf("missing expected fields for %s", model.API) +} + +func liveUseChatCompletions(model liveModel) bool { + if model.ProviderName == "perplexity" { + return true + } + return model.ProviderName == "openai" && strings.Contains(model.ID, "search-preview") +} + +func liveUseCompletions(model liveModel) bool { + return model.ProviderName == "openai" && strings.HasSuffix(model.ID, "-instruct") +} + +func liveMaxOutputTokens(model liveModel) int { + if liveIsOpenAIReasoningModel(model) { + return 1024 + } + return 16 +} + +func liveIsOpenAIReasoningModel(model liveModel) bool { + return model.ProviderName == "openai" && (strings.Contains(model.ID, "codex") || strings.Contains(model.ID, "-pro")) +} + +func liveReasoningEffort(model liveModel) string { + if strings.Contains(model.ID, "-pro") { + return "high" + } + return "low" +} + +func setLiveDefaultReasoningEffort(body map[string]any, effort string) { + reasoning, ok := body["reasoning"].(map[string]any) + if !ok { + body["reasoning"] = map[string]any{"effort": effort} + return + } + if _, exists := reasoning["effort"]; !exists { + reasoning["effort"] = effort + } +} + +func liveSkipReason(model liveModel) string { + switch { + case strings.Contains(model.ID, "-tts"): + return "model is TTS/audio-only and does not support a text response shape" + case strings.Contains(model.ID, "mistral-embed"): + return "embedding model is not callable through chat/completion smoke test" + case strings.Contains(model.ID, "deep-research"): + return "deep research model requires web_search_preview, mcp, or file_search tools" + case strings.Contains(model.ID, "multi-agent"): + return "multi-agent model is not callable through chat completions" + default: + return "" + } +} + +func stripLiveProviderPrefix(model string) string { + idx := strings.Index(model, "/") + if idx < 0 { + return model + } + prefix := model[:idx] + stripped := model[idx+1:] + if strings.HasPrefix(stripped, prefix+"/") { + stripped = strings.TrimPrefix(stripped, prefix+"/") + } + return stripped +} + +func joinLiveURLPath(base string, elems ...string) string { + u, err := url.Parse(base) + if err != nil { + panic(err) + } + if len(elems) > 0 && elems[0] == "v1" && strings.HasSuffix(strings.TrimRight(u.Path, "/"), "/v1") { + elems = elems[1:] + } + all := []string{strings.TrimRight(u.Path, "/")} + all = append(all, elems...) + u.Path = pathJoin(all...) + u.RawQuery = "" + return u.String() +} + +func pathJoin(elems ...string) string { + var parts []string + for _, elem := range elems { + elem = strings.Trim(elem, "/") + if elem != "" { + parts = append(parts, elem) + } + } + return "/" + strings.Join(parts, "/") +} + +func mustJSON(v any) []byte { + data, err := json.Marshal(v) + if err != nil { + panic(err) + } + return data +} + +func nonEmptyString(v any) bool { + s, ok := v.(string) + return ok && s != "" +} + +func nonEmptyArray(v any) bool { + items, ok := v.([]any) + return ok && len(items) > 0 +} + +func truncateForLog(body []byte) string { + const max = 2048 + if len(body) <= max { + return string(body) + } + return string(body[:max]) + "..." +} + +func envOrDefault(name, fallback string) string { + if value := os.Getenv(name); value != "" { + return value + } + return fallback +} + +func envInt(name string, fallback int) int { + value := os.Getenv(name) + if value == "" { + return fallback + } + parsed, err := strconv.Atoi(value) + if err != nil { + return fallback + } + return parsed +} + +func envSet(name string) map[string]bool { + value := os.Getenv(name) + if value == "" { + return nil + } + result := make(map[string]bool) + for _, item := range strings.Split(value, ",") { + item = strings.TrimSpace(item) + if item != "" { + result[item] = true + } + } + return result +} + +func filterLiveModels(models []liveModel, allowlist map[string]bool) []liveModel { + filtered := make([]liveModel, 0, len(allowlist)) + for _, model := range models { + if allowlist[model.ID] { + filtered = append(filtered, model) + } + } + return filtered +} + +func envDuration(name string, fallback time.Duration) time.Duration { + value := os.Getenv(name) + if value == "" { + return fallback + } + parsed, err := time.ParseDuration(value) + if err == nil { + return parsed + } + seconds, err := strconv.Atoi(value) + if err != nil { + return fallback + } + return time.Duration(seconds) * time.Second +} diff --git a/providers/anthropic/parser.go b/providers/anthropic/parser.go index ec9d012..843bc4a 100644 --- a/providers/anthropic/parser.go +++ b/providers/anthropic/parser.go @@ -55,8 +55,8 @@ func (p *Parser) Parse(body io.ReadCloser) (llmproxy.BodyMetadata, []byte, error } } - if req.System != "" { - meta.Custom["system"] = req.System + if system := contentToString(req.System); system != "" { + meta.Custom["system"] = system } for k, v := range req.Custom { @@ -87,7 +87,7 @@ type Request struct { Model string `json:"model"` Messages []Message `json:"messages"` MaxTokens int `json:"max_tokens,omitempty"` - System string `json:"system,omitempty"` + System Content `json:"system,omitempty"` Custom map[string]interface{} `json:"-"` } diff --git a/providers/anthropic/parser_test.go b/providers/anthropic/parser_test.go index b5b57fd..943a1bf 100644 --- a/providers/anthropic/parser_test.go +++ b/providers/anthropic/parser_test.go @@ -45,6 +45,19 @@ func TestParser(t *testing.T) { } }) + t.Run("parses request with system prompt array", func(t *testing.T) { + body := `{"model":"anthropic/claude-sonnet-4-6","max_tokens":1024,"system":[{"type":"text","text":"You are helpful."}],"messages":[{"role":"user","content":"hello"}]}` + parser := &Parser{} + + meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if meta.Custom["system"] != "You are helpful." { + t.Errorf("expected system prompt, got %v", meta.Custom["system"]) + } + }) + t.Run("parses content as array", func(t *testing.T) { body := `{"model":"claude-3-opus-20240229","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}` parser := &Parser{} diff --git a/providers/googleai/parser.go b/providers/googleai/parser.go index ea40faa..691122a 100644 --- a/providers/googleai/parser.go +++ b/providers/googleai/parser.go @@ -88,7 +88,7 @@ func extractTextFromParts(parts []Part) string { // Request represents a Google AI generateContent request. type Request struct { - Model string `json:"-"` // Extracted from path + Model string `json:"model,omitempty"` Contents []Content `json:"contents,omitempty"` SystemInstruction *Content `json:"systemInstruction,omitempty"` GenerationConfig GenerationConfig `json:"generationConfig,omitempty"` @@ -149,8 +149,9 @@ func (r *Request) UnmarshalJSON(data []byte) error { r.Custom = make(map[string]interface{}) known := map[string]bool{ - "contents": true, "systemInstruction": true, "generationConfig": true, - "safetySettings": true, "tools": true, "toolConfig": true, + "model": true, "contents": true, "systemInstruction": true, + "generationConfig": true, "safetySettings": true, "tools": true, + "toolConfig": true, } for k, v := range raw { if !known[k] { diff --git a/providers/googleai/parser_test.go b/providers/googleai/parser_test.go index 21b8d35..e4f7d0a 100644 --- a/providers/googleai/parser_test.go +++ b/providers/googleai/parser_test.go @@ -32,6 +32,22 @@ func TestParser(t *testing.T) { } }) + t.Run("parses model from request body", func(t *testing.T) { + body := `{"model":"gemini-2.0-flash","contents":[{"role":"user","parts":[{"text":"hello"}]}]}` + parser := &Parser{} + + meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if meta.Model != "gemini-2.0-flash" { + t.Errorf("expected model gemini-2.0-flash, got %s", meta.Model) + } + if _, ok := meta.Custom["model"]; ok { + t.Error("model should not be captured as a custom field") + } + }) + t.Run("parses request with generation config", func(t *testing.T) { body := `{"contents":[{"role":"user","parts":[{"text":"hello"}]}],"generationConfig":{"maxOutputTokens":100}}` parser := &Parser{} @@ -130,6 +146,23 @@ func TestResolver(t *testing.T) { t.Errorf("expected %s, got %s", expected, u.String()) } }) + + t.Run("strips googleai provider prefix from resolved model", func(t *testing.T) { + resolver, err := NewResolver("https://generativelanguage.googleapis.com") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + meta := llmproxy.BodyMetadata{Model: "googleai/gemini-2.0-flash"} + u, err := resolver.Resolve(meta) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expected := "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent" + if u.String() != expected { + t.Errorf("expected %s, got %s", expected, u.String()) + } + }) } func TestExtractor(t *testing.T) { diff --git a/providers/googleai/resolver.go b/providers/googleai/resolver.go index 4bbc8b5..0a6b9fd 100644 --- a/providers/googleai/resolver.go +++ b/providers/googleai/resolver.go @@ -3,6 +3,7 @@ package googleai import ( "fmt" "net/url" + "strings" "github.com/agentuity/llmproxy" ) @@ -21,6 +22,9 @@ type Resolver struct { // If meta.Model is empty, defaults to "gemini-pro". func (r *Resolver) Resolve(meta llmproxy.BodyMetadata) (*url.URL, error) { model := meta.Model + if stripped, ok := strings.CutPrefix(model, "googleai/"); ok { + model = stripped + } if model == "" { model = "gemini-pro" } diff --git a/providers/openai_compatible/parser_test.go b/providers/openai_compatible/parser_test.go index 337c704..9b63032 100644 --- a/providers/openai_compatible/parser_test.go +++ b/providers/openai_compatible/parser_test.go @@ -392,6 +392,24 @@ func TestResolver_CustomBaseURL(t *testing.T) { } } +func TestResolver_CustomVersionedBaseURL(t *testing.T) { + resolver, err := NewResolver("https://api.fireworks.ai/inference/v1/") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + meta := llmproxy.BodyMetadata{Model: "accounts/fireworks/models/test"} + u, err := resolver.Resolve(meta) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := "https://api.fireworks.ai/inference/v1/chat/completions" + if u.String() != expected { + t.Errorf("URL = %q, want %q", u.String(), expected) + } +} + func TestResolver_InvalidURL(t *testing.T) { _, err := NewResolver("://invalid-url") if err == nil { diff --git a/providers/perplexity/provider_test.go b/providers/perplexity/provider_test.go index fb56772..9221456 100644 --- a/providers/perplexity/provider_test.go +++ b/providers/perplexity/provider_test.go @@ -49,6 +49,23 @@ func TestResolver_PerplexityURL(t *testing.T) { } } +func TestResolver_PerplexityURLWithVersionedBase(t *testing.T) { + resolver, err := NewResolver("https://api.perplexity.ai/v1/") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + u, err := resolver.Resolve(llmproxy.BodyMetadata{Model: "sonar"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := "https://api.perplexity.ai/v1/sonar" + if u.String() != expected { + t.Errorf("URL = %q, want %q", u.String(), expected) + } +} + func TestResolver_InvalidURL(t *testing.T) { _, err := NewResolver("://invalid-url") if err == nil { diff --git a/providers/perplexity/resolver.go b/providers/perplexity/resolver.go index 5f5409e..daab826 100644 --- a/providers/perplexity/resolver.go +++ b/providers/perplexity/resolver.go @@ -2,6 +2,7 @@ package perplexity import ( "net/url" + "strings" "github.com/agentuity/llmproxy" ) @@ -15,9 +16,17 @@ func (r *Resolver) Resolve(meta llmproxy.BodyMetadata) (*url.URL, error) { } func NewResolver(baseURL string) (*Resolver, error) { - u, err := url.Parse(baseURL) + u, err := url.Parse(normalizeBaseURL(baseURL)) if err != nil { return nil, err } return &Resolver{BaseURL: u}, nil } + +func normalizeBaseURL(raw string) string { + raw = strings.TrimRight(raw, "/") + if strings.HasSuffix(raw, "/v1") { + raw = raw[:len(raw)-3] + } + return raw +}