diff --git a/pkg/config/model_alias_test.go b/pkg/config/model_alias_test.go index 8cc7c44fc..e758b4838 100644 --- a/pkg/config/model_alias_test.go +++ b/pkg/config/model_alias_test.go @@ -16,8 +16,6 @@ func TestResolveModelAliases(t *testing.T) { mockData := &modelsdev.Database{ Providers: map[string]modelsdev.Provider{ "anthropic": { - ID: "anthropic", - Name: "Anthropic", Models: map[string]modelsdev.Model{ "claude-sonnet-4-5": {Name: "Claude Sonnet 4.5 (latest)"}, "claude-sonnet-4-5-20250929": {Name: "Claude Sonnet 4.5"}, diff --git a/pkg/modelsdev/store.go b/pkg/modelsdev/store.go index a8607a56d..e179ef9d8 100644 --- a/pkg/modelsdev/store.go +++ b/pkg/modelsdev/store.go @@ -31,12 +31,14 @@ type Store struct { cacheFile string mu sync.Mutex db *Database - etag string // ETag from last successful fetch, used for conditional requests } -// singleton holds the process-wide Store instance. It is initialised lazily -// on the first call to NewStore. All subsequent calls return the same value. -var singleton = sync.OnceValues(func() (*Store, error) { +// NewStore returns the process-wide singleton Store. +// +// The database is loaded lazily on the first call to GetDatabase and +// then cached in memory so that every caller shares one copy. +// The first call creates the cache directory if it does not exist. +var NewStore = sync.OnceValues(func() (*Store, error) { homeDir, err := os.UserHomeDir() if err != nil { return nil, fmt.Errorf("failed to get user home directory: %w", err) @@ -52,15 +54,6 @@ var singleton = sync.OnceValues(func() (*Store, error) { }, nil }) -// NewStore returns the process-wide singleton Store. -// -// The database is loaded lazily on the first call to GetDatabase and -// then cached in memory so that every caller shares one copy. -// The first call creates the cache directory if it does not exist. -func NewStore() (*Store, error) { - return singleton() -} - // NewDatabaseStore creates a Store pre-populated with the given database. // The returned store serves data entirely from memory and never fetches // from the network or touches the filesystem, making it suitable for @@ -78,18 +71,17 @@ func (s *Store) GetDatabase(ctx context.Context) (*Database, error) { return s.db, nil } - db, etag, err := loadDatabase(ctx, s.cacheFile) + db, err := loadDatabase(ctx, s.cacheFile) if err != nil { return nil, err } s.db = db - s.etag = etag return db, nil } -// GetProvider returns a specific provider by ID. -func (s *Store) GetProvider(ctx context.Context, providerID string) (*Provider, error) { +// getProvider returns a specific provider by ID. +func (s *Store) getProvider(ctx context.Context, providerID string) (*Provider, error) { db, err := s.GetDatabase(ctx) if err != nil { return nil, err @@ -112,30 +104,23 @@ func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) { providerID := parts[0] modelID := parts[1] - provider, err := s.GetProvider(ctx, providerID) + provider, err := s.getProvider(ctx, providerID) if err != nil { return nil, err } model, exists := provider.Models[modelID] - if !exists { - // For amazon-bedrock, try stripping region/inference profile prefixes - // Bedrock uses prefixes for cross-region inference profiles, - // but models.dev stores models without these prefixes. - // - // Strip known region prefixes and retry lookup. - if providerID == "amazon-bedrock" { - if before, after, ok := strings.Cut(modelID, "."); ok { - possibleRegionPrefix := before - if isBedrockRegionPrefix(possibleRegionPrefix) { - normalizedModelID := after - model, exists = provider.Models[normalizedModelID] - if exists { - return &model, nil - } - } - } + + // For amazon-bedrock, try stripping region/inference profile prefixes. + // Bedrock uses prefixes for cross-region inference profiles, + // but models.dev stores models without these prefixes. + if !exists && providerID == "amazon-bedrock" { + if prefix, after, ok := strings.Cut(modelID, "."); ok && bedrockRegionPrefixes[prefix] { + model, exists = provider.Models[after] } + } + + if !exists { return nil, fmt.Errorf("model %q not found in provider %q", modelID, providerID) } @@ -144,12 +129,11 @@ func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) { // loadDatabase loads the database from the local cache file or // falls back to fetching from the models.dev API. -// It returns the database and the ETag associated with the data. -func loadDatabase(ctx context.Context, cacheFile string) (*Database, string, error) { +func loadDatabase(ctx context.Context, cacheFile string) (*Database, error) { // Try to load from cache first cached, err := loadFromCache(cacheFile) if err == nil && time.Since(cached.LastRefresh) < refreshInterval { - return &cached.Database, cached.ETag, nil + return &cached.Database, nil } // Cache is stale or doesn't exist — try a conditional fetch with the ETag. @@ -163,9 +147,9 @@ func loadDatabase(ctx context.Context, cacheFile string) (*Database, string, err // If API fetch fails but we have cached data, use it regardless of age. if cached != nil { slog.Debug("API fetch failed, using stale cache", "error", fetchErr) - return &cached.Database, cached.ETag, nil + return &cached.Database, nil } - return nil, "", fmt.Errorf("failed to fetch from API and no cached data available: %w", fetchErr) + return nil, fmt.Errorf("failed to fetch from API and no cached data available: %w", fetchErr) } // database is nil when the server returned 304 Not Modified. @@ -175,7 +159,7 @@ func loadDatabase(ctx context.Context, cacheFile string) (*Database, string, err if saveErr := saveToCache(cacheFile, &cached.Database, cached.ETag); saveErr != nil { slog.Warn("Failed to update cache timestamp", "error", saveErr) } - return &cached.Database, cached.ETag, nil + return &cached.Database, nil } // Save the fresh data to cache. @@ -183,7 +167,7 @@ func loadDatabase(ctx context.Context, cacheFile string) (*Database, string, err slog.Warn("Failed to save to cache", "error", saveErr) } - return database, newETag, nil + return database, nil } // fetchFromAPI fetches the models.dev database. @@ -230,7 +214,6 @@ func fetchFromAPI(ctx context.Context, etag string) (*Database, string, error) { return &Database{ Providers: providers, - UpdatedAt: time.Now(), }, newETag, nil } @@ -249,11 +232,9 @@ func loadFromCache(cacheFile string) (*CachedData, error) { } func saveToCache(cacheFile string, database *Database, etag string) error { - now := time.Now() cached := CachedData{ Database: *database, - CachedAt: now, - LastRefresh: now, + LastRefresh: time.Now(), ETag: etag, } @@ -286,8 +267,7 @@ func (s *Store) ResolveModelAlias(ctx context.Context, providerID, modelName str return modelName } - // Get the provider from the database - provider, err := s.GetProvider(ctx, providerID) + provider, err := s.getProvider(ctx, providerID) if err != nil { return modelName } @@ -319,46 +299,8 @@ func (s *Store) ResolveModelAlias(ctx context.Context, providerID, modelName str // stores models without regional prefixes. AWS uses these for cross-region inference profiles. // See: https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference.html var bedrockRegionPrefixes = map[string]bool{ - "us": true, // US region inference profile - "eu": true, // EU region inference profile - "apac": true, // Asia Pacific region inference profile - "global": true, // Global inference profile (routes to any available region) -} - -// isBedrockRegionPrefix returns true if the prefix is a known Bedrock regional/inference profile prefix. -func isBedrockRegionPrefix(prefix string) bool { - return bedrockRegionPrefixes[prefix] -} - -// ModelSupportsReasoning checks if the given model ID supports reasoning/thinking. -// -// This function implements fail-open semantics: -// - If modelID is empty or not in "provider/model" format, returns true (fail-open) -// - If models.dev lookup fails for any reason, returns true (fail-open) -// - If lookup succeeds, returns the model's Reasoning field value -func ModelSupportsReasoning(ctx context.Context, modelID string) bool { - // Fail-open for empty model ID - if modelID == "" { - return true - } - - // Fail-open if not in provider/model format - if !strings.Contains(modelID, "/") { - slog.Debug("Model ID not in provider/model format, assuming reasoning supported to allow user choice", "model_id", modelID) - return true - } - - store, err := NewStore() - if err != nil { - slog.Debug("Failed to create modelsdev store, assuming reasoning supported to allow user choice", "error", err) - return true - } - - model, err := store.GetModel(ctx, modelID) - if err != nil { - slog.Debug("Failed to lookup model in models.dev, assuming reasoning supported to allow user choice", "model_id", modelID, "error", err) - return true - } - - return model.Reasoning + "us": true, + "eu": true, + "apac": true, + "global": true, } diff --git a/pkg/modelsdev/types.go b/pkg/modelsdev/types.go index a9a63a847..449d02d6b 100644 --- a/pkg/modelsdev/types.go +++ b/pkg/modelsdev/types.go @@ -5,36 +5,20 @@ import "time" // Database represents the complete models.dev database type Database struct { Providers map[string]Provider `json:"providers"` - UpdatedAt time.Time `json:"updated_at"` } // Provider represents an AI model provider type Provider struct { - ID string `json:"id"` - Name string `json:"name"` - Doc string `json:"doc,omitempty"` - API string `json:"api,omitempty"` - NPM string `json:"npm,omitempty"` - Env []string `json:"env,omitempty"` Models map[string]Model `json:"models"` } // Model represents an AI model with its specifications and capabilities type Model struct { - ID string `json:"id"` - Name string `json:"name"` - Family string `json:"family,omitempty"` - Attachment bool `json:"attachment"` - Reasoning bool `json:"reasoning"` - Temperature bool `json:"temperature"` - ToolCall bool `json:"tool_call"` - Knowledge string `json:"knowledge,omitempty"` - ReleaseDate string `json:"release_date"` - LastUpdated string `json:"last_updated"` - OpenWeights bool `json:"open_weights"` - Cost *Cost `json:"cost,omitempty"` - Limit Limit `json:"limit"` - Modalities Modalities `json:"modalities"` + Name string `json:"name"` + Family string `json:"family,omitempty"` + Cost *Cost `json:"cost,omitempty"` + Limit Limit `json:"limit"` + Modalities Modalities `json:"modalities"` } // Cost represents the pricing information for a model @@ -60,7 +44,6 @@ type Modalities struct { // CachedData represents the cached models.dev data with metadata type CachedData struct { Database Database `json:"database"` - CachedAt time.Time `json:"cached_at"` LastRefresh time.Time `json:"last_refresh"` ETag string `json:"etag,omitempty"` } diff --git a/pkg/runtime/model_switcher_test.go b/pkg/runtime/model_switcher_test.go index a4dd0038c..5f9d74e02 100644 --- a/pkg/runtime/model_switcher_test.go +++ b/pkg/runtime/model_switcher_test.go @@ -244,25 +244,20 @@ func TestBuildCatalogChoices(t *testing.T) { db := &modelsdev.Database{ Providers: map[string]modelsdev.Provider{ "openai": { - ID: "openai", - Name: "OpenAI", Models: map[string]modelsdev.Model{ "gpt-4o": { - ID: "gpt-4o", Name: "GPT-4o", Modalities: modelsdev.Modalities{ Output: []string{"text"}, }, }, "dall-e-3": { - ID: "dall-e-3", Name: "DALL-E 3", Modalities: modelsdev.Modalities{ Output: []string{"image"}, // Not a text model }, }, "text-embedding-3-large": { - ID: "text-embedding-3-large", Name: "Text Embedding 3 Large", Family: "text-embedding", Modalities: modelsdev.Modalities{ @@ -272,11 +267,8 @@ func TestBuildCatalogChoices(t *testing.T) { }, }, "anthropic": { - ID: "anthropic", - Name: "Anthropic", Models: map[string]modelsdev.Model{ "claude-sonnet-4-0": { - ID: "claude-sonnet-4-0", Name: "Claude Sonnet 4", Modalities: modelsdev.Modalities{ Output: []string{"text"}, @@ -285,11 +277,8 @@ func TestBuildCatalogChoices(t *testing.T) { }, }, "unsupported": { - ID: "unsupported", - Name: "Unsupported Provider", Models: map[string]modelsdev.Model{ "some-model": { - ID: "some-model", Name: "Some Model", Modalities: modelsdev.Modalities{ Output: []string{"text"}, @@ -348,11 +337,8 @@ func TestBuildCatalogChoicesWithDuplicates(t *testing.T) { db := &modelsdev.Database{ Providers: map[string]modelsdev.Provider{ "openai": { - ID: "openai", - Name: "OpenAI", Models: map[string]modelsdev.Model{ "gpt-4o": { - ID: "gpt-4o", Name: "GPT-4o", Modalities: modelsdev.Modalities{ Output: []string{"text"},