diff --git a/desktop/app.go b/desktop/app.go index 71e6fc1b1..7bd4209b9 100644 --- a/desktop/app.go +++ b/desktop/app.go @@ -4910,7 +4910,50 @@ func modelProviderAccessAllowed(access map[string]bool, name string) bool { if len(access) == 0 { return true } - return access[strings.TrimSpace(name)] + return access[config.CanonicalDesktopOfficialProviderName(strings.TrimSpace(name))] +} + +func resolveAccessibleModelWithFallback(cfg *config.Config, ref string) (resolvedRef string, fallback bool, ok bool) { + if cfg == nil { + return "", false, false + } + access := providerAccessSet(cfg.Desktop.ProviderAccess) + tryResolve := func(candidate string, fallback bool) (string, bool, bool) { + entry, found := cfg.ResolveModel(candidate) + if !found || !entry.Configured() || !modelProviderAccessAllowed(access, entry.Name) { + return "", false, false + } + return entry.Name + "/" + entry.Model, fallback, true + } + + ref = strings.TrimSpace(ref) + if ref != "" { + if resolved, fallback, ok := tryResolve(ref, false); ok { + return resolved, fallback, true + } + } + defaultRef := strings.TrimSpace(cfg.DefaultModel) + if ref != defaultRef && defaultRef != "" { + if resolved, fallback, ok := tryResolve(defaultRef, true); ok { + return resolved, fallback, true + } + } + for i := range cfg.Providers { + p := &cfg.Providers[i] + if len(p.ModelList()) == 0 || !p.Configured() || !modelProviderAccessAllowed(access, p.Name) { + continue + } + return p.Name + "/" + p.DefaultModel(), true, true + } + return "", false, false +} + +func noAccessibleModelError(ref string) error { + ref = strings.TrimSpace(ref) + if ref == "" { + ref = "" + } + return fmt.Errorf("model %q is no longer available and no accessible fallback is configured", ref) } func controllerHasActiveRuntimeWork(ctrl *control.Controller) bool { @@ -5652,9 +5695,9 @@ func (a *App) currentProviderEntryForTab(tabID string) (*config.ProviderEntry, e if strings.TrimSpace(ref) == "" { ref = cfg.DefaultModel } - resolved, _, ok := cfg.ResolveModelWithFallback(ref) + resolved, _, ok := resolveAccessibleModelWithFallback(cfg, ref) if !ok { - return nil, fmt.Errorf("unknown model %q", ref) + return nil, noAccessibleModelError(ref) } entry, ok := cfg.ResolveModel(resolved) if !ok { @@ -5678,9 +5721,9 @@ func (a *App) resolvedModelForTab(tab *WorkspaceTab) (string, bool, error) { if ref == "" { ref = cfg.DefaultModel } - resolved, fallback, ok := cfg.ResolveModelWithFallback(ref) + resolved, fallback, ok := resolveAccessibleModelWithFallback(cfg, ref) if !ok { - return "", false, fmt.Errorf("unknown model %q", ref) + return "", false, noAccessibleModelError(ref) } return resolved, fallback, nil } diff --git a/desktop/app_test.go b/desktop/app_test.go index 800cbc23a..f8ef7795b 100644 --- a/desktop/app_test.go +++ b/desktop/app_test.go @@ -1176,6 +1176,165 @@ func TestSetTokenModeMigratesStaleOfficialDeepSeekTabModel(t *testing.T) { } } +func TestSetTokenModeFallsBackWhenTabProviderAccessWasRemoved(t *testing.T) { + isolateDesktopUserDirs(t) + t.Setenv("DEEPSEEK_API_KEY", "sk-test") + t.Setenv("PROV_A_KEY", "sk-test") + + cfg := config.Default() + cfg.DefaultModel = "prov-a/model-a" + cfg.Desktop.ProviderAccess = []string{"prov-a"} + cfg.Providers = append(cfg.Providers, config.ProviderEntry{ + Name: "prov-a", + Kind: "openai", + BaseURL: "https://a.example.com/v1", + Model: "model-a", + APIKeyEnv: "PROV_A_KEY", + }) + if err := cfg.SaveTo(config.UserConfigPath()); err != nil { + t.Fatalf("save config: %v", err) + } + + app := NewApp() + app.ctx = context.Background() + app.readyHook = func() {} + old := control.New(control.Options{Label: "old-controller"}) + app.setTestCtrl(old, "deepseek-flash/deepseek-v4-flash") + defer func() { + if c := app.activeCtrl(); c != nil { + c.Close() + } + }() + + if err := app.SetTokenMode("economy"); err != nil { + t.Fatalf("SetTokenMode(economy): %v", err) + } + tab := app.activeTab() + if tab == nil { + t.Fatal("active tab missing") + } + if tab.model != "prov-a/model-a" { + t.Fatalf("tab model = %q, want prov-a/model-a fallback", tab.model) + } +} + +func TestBuildTabControllerFallsBackWhenSavedTabProviderAccessWasRemoved(t *testing.T) { + isolateDesktopUserDirs(t) + t.Setenv("DEEPSEEK_API_KEY", "sk-test") + t.Setenv("PROV_A_KEY", "sk-test") + + cfg := config.Default() + cfg.DefaultModel = "prov-a/model-a" + cfg.Desktop.ProviderAccess = []string{"prov-a"} + cfg.Providers = append(cfg.Providers, config.ProviderEntry{ + Name: "prov-a", + Kind: "openai", + BaseURL: "https://a.example.com/v1", + Model: "model-a", + APIKeyEnv: "PROV_A_KEY", + }) + if err := cfg.SaveTo(config.UserConfigPath()); err != nil { + t.Fatalf("save config: %v", err) + } + + project := t.TempDir() + app := NewApp() + tab := app.createTabEntryWithID("project", project, "", "tab_access_removed") + tab.model = "deepseek-flash/deepseek-v4-flash" + tab.sink = &tabEventSink{tabID: tab.ID, app: app} + app.tabs = map[string]*WorkspaceTab{tab.ID: tab} + app.tabOrder = []string{tab.ID} + app.activeTabID = tab.ID + + app.buildTabController(tab) + if tab.Ctrl == nil { + t.Fatalf("tab controller was not built: %s", tab.StartupErr) + } + defer tab.Ctrl.Close() + + if tab.model != "prov-a/model-a" { + t.Fatalf("tab model = %q, want prov-a/model-a fallback", tab.model) + } + saved := loadTabsFile() + if len(saved.Tabs) != 1 || saved.Tabs[0].Model != "prov-a/model-a" { + t.Fatalf("saved tabs = %+v, want prov-a/model-a", saved.Tabs) + } +} + +func TestBuildTabControllerErrorsWhenNoAccessibleFallbackExists(t *testing.T) { + isolateDesktopUserDirs(t) + t.Setenv("PROV_A_KEY", "") + + cfg := config.Default() + cfg.DefaultModel = "prov-a/model-a" + cfg.Desktop.ProviderAccess = []string{"prov-a"} + cfg.Providers = []config.ProviderEntry{{ + Name: "prov-a", + Kind: "openai", + BaseURL: "https://a.example.com/v1", + Model: "model-a", + APIKeyEnv: "PROV_A_KEY", + }} + if err := cfg.SaveTo(config.UserConfigPath()); err != nil { + t.Fatalf("save config: %v", err) + } + + project := t.TempDir() + app := NewApp() + tab := app.createTabEntryWithID("project", project, "", "tab_no_accessible_fallback") + tab.model = "deepseek-flash/deepseek-v4-flash" + tab.sink = &tabEventSink{tabID: tab.ID, app: app} + app.tabs = map[string]*WorkspaceTab{tab.ID: tab} + app.tabOrder = []string{tab.ID} + app.activeTabID = tab.ID + + app.buildTabController(tab) + if tab.Ctrl != nil { + t.Fatalf("tab controller should not be built when no accessible fallback exists") + } + if tab.StartupErr == "" || !strings.Contains(tab.StartupErr, "no accessible fallback") { + t.Fatalf("startup error = %q, want no accessible fallback message", tab.StartupErr) + } +} + +func TestRebuildKeepsExistingControllerWhenNoAccessibleFallbackExists(t *testing.T) { + isolateDesktopUserDirs(t) + t.Setenv("PROV_A_KEY", "") + + cfg := config.Default() + cfg.DefaultModel = "prov-a/model-a" + cfg.Desktop.ProviderAccess = []string{"prov-a"} + cfg.Providers = []config.ProviderEntry{{ + Name: "prov-a", + Kind: "openai", + BaseURL: "https://a.example.com/v1", + Model: "model-a", + APIKeyEnv: "PROV_A_KEY", + }} + if err := cfg.SaveTo(config.UserConfigPath()); err != nil { + t.Fatalf("save config: %v", err) + } + + app := NewApp() + app.ctx = context.Background() + app.readyHook = func() {} + old := control.New(control.Options{Label: "old-controller"}) + app.setTestCtrl(old, "deepseek-flash/deepseek-v4-flash") + defer func() { + if c := app.activeCtrl(); c != nil { + c.Close() + } + }() + + err := app.rebuild() + if err == nil || !strings.Contains(err.Error(), "no accessible fallback") { + t.Fatalf("rebuild error = %v, want no accessible fallback", err) + } + if got := app.activeCtrl(); got != old { + t.Fatalf("active controller = %v, want existing controller preserved", got) + } +} + func TestSetTokenModeKeepsControllerWhenRebuildFails(t *testing.T) { isolateDesktopUserDirs(t) t.Setenv("DEEPSEEK_API_KEY", "") diff --git a/desktop/frontend/src/lib/bridge.ts b/desktop/frontend/src/lib/bridge.ts index 516839e98..c4626621f 100644 --- a/desktop/frontend/src/lib/bridge.ts +++ b/desktop/frontend/src/lib/bridge.ts @@ -1264,8 +1264,16 @@ function makeMockApp(): AppBindings { }, ]; const mockModelCatalog = [ - { ref: "deepseek/deepseek-v4-flash", provider: "deepseek", model: "deepseek-v4-flash" }, - { ref: "deepseek/deepseek-v4-pro", provider: "deepseek", model: "deepseek-v4-pro" }, + { + ref: "deepseek/deepseek-v4-flash", + provider: "deepseek", + model: "deepseek-v4-flash", + }, + { + ref: "deepseek/deepseek-v4-pro", + provider: "deepseek", + model: "deepseek-v4-pro", + }, ]; const defaultMockModelRef = mockModelCatalog[0].ref; const mockModelRef = (name: string): string => { diff --git a/desktop/settings_app.go b/desktop/settings_app.go index 243625bd5..1234d9e0a 100644 --- a/desktop/settings_app.go +++ b/desktop/settings_app.go @@ -237,7 +237,7 @@ func providerAccessSet(names []string) map[string]bool { for _, name := range names { name = strings.TrimSpace(name) if name != "" { - out[name] = true + out[config.CanonicalDesktopOfficialProviderName(name)] = true } } return out @@ -246,7 +246,7 @@ func providerAccessSet(names []string) map[string]bool { func addProviderAccess(c *config.Config, names ...string) { seen := providerAccessSet(c.Desktop.ProviderAccess) for _, name := range names { - name = strings.TrimSpace(name) + name = config.CanonicalDesktopOfficialProviderName(strings.TrimSpace(name)) if name == "" || seen[name] { continue } @@ -262,7 +262,7 @@ func removeProviderAccess(c *config.Config, names ...string) { } out := c.Desktop.ProviderAccess[:0] for _, name := range c.Desktop.ProviderAccess { - if !remove[name] { + if !remove[config.CanonicalDesktopOfficialProviderName(strings.TrimSpace(name))] { out = append(out, name) } } @@ -305,7 +305,7 @@ func officialProviderAddedSet(cfg *config.Config) map[string]bool { access := providerAccessSet(cfg.Desktop.ProviderAccess) for i := range cfg.Providers { p := cfg.Providers[i] - if !access[p.Name] { + if !access[config.CanonicalDesktopOfficialProviderName(strings.TrimSpace(p.Name))] { continue } if kind := officialProviderKindFromEntry(p); kind != "" { @@ -409,7 +409,7 @@ func (a *App) Settings() SettingsView { v.OfficialProviders = officialProviderViews(officialProviderAddedSet(cfg)) for i := range cfg.Providers { p := &cfg.Providers[i] - v.Providers = append(v.Providers, providerViewFromEntry(*p, isOfficialBuiltInProvider(*p), added[p.Name])) + v.Providers = append(v.Providers, providerViewFromEntry(*p, isOfficialBuiltInProvider(*p), added[config.CanonicalDesktopOfficialProviderName(strings.TrimSpace(p.Name))])) } return v } @@ -687,22 +687,24 @@ func (a *App) rebuild() error { if controllerHasActiveRuntimeWork(tab.Ctrl) { return rebuildControllerActiveWorkError("settings") } - var carried []provider.Message - prevPath := "" - if tab.Ctrl != nil { - prevPath = tab.Ctrl.SessionPath() - _ = a.snapshotTab(tab) - carried = tab.Ctrl.History() - tab.Ctrl.Close() - } model := tab.model if cfg, err := config.LoadForRoot(tab.WorkspaceRoot); err == nil { - if resolved, fallback, ok := cfg.ResolveModelWithFallback(model); ok { - if fallback && strings.TrimSpace(model) != "" { - a.noticeForTab(tab.ID, fmt.Sprintf("model %q is no longer available; switched to %s", model, resolved)) - } - model = resolved + resolved, fallback, ok := resolveAccessibleModelWithFallback(cfg, model) + if !ok { + return noAccessibleModelError(model) + } + if fallback && strings.TrimSpace(model) != "" { + a.noticeForTab(tab.ID, fmt.Sprintf("model %q is no longer available; switched to %s", model, resolved)) } + model = resolved + } + var carried []provider.Message + prevPath := "" + oldCtrl := tab.Ctrl + if oldCtrl != nil { + prevPath = oldCtrl.SessionPath() + _ = a.snapshotTab(tab) + carried = oldCtrl.History() } ctrl, err := boot.Build(a.bootContext(), boot.Options{ Model: model, RequireKey: false, @@ -713,14 +715,12 @@ func (a *App) rebuild() error { TokenMode: currentTabTokenMode(tab), }) if err != nil { - a.mu.Lock() - tab.StartupErr = err.Error() - tab.Ready = true - a.mu.Unlock() - a.emitReady(a.ctx) return err } a.bindControllerDisplayRecorder(ctrl) + if oldCtrl != nil { + oldCtrl.Close() + } a.mu.Lock() tab.Ctrl = ctrl tab.model = model @@ -1068,17 +1068,17 @@ type providerRemovalTab struct { } func providerAccessFallbackRef(c *config.Config, name string) string { - name = strings.TrimSpace(name) + name = config.CanonicalDesktopOfficialProviderName(strings.TrimSpace(name)) for _, candidate := range c.Desktop.ProviderAccess { - candidate = strings.TrimSpace(candidate) + candidate = config.CanonicalDesktopOfficialProviderName(strings.TrimSpace(candidate)) if candidate == "" || candidate == name { continue } - p, ok := c.Provider(candidate) - if !ok || len(p.ModelList()) == 0 { + p, ok := c.ResolveModel(candidate) + if !ok || !p.Configured() { continue } - return p.Name + "/" + p.DefaultModel() + return p.Name + "/" + p.Model } return "" } diff --git a/desktop/settings_app_test.go b/desktop/settings_app_test.go index f61f838d5..26b738981 100644 --- a/desktop/settings_app_test.go +++ b/desktop/settings_app_test.go @@ -149,6 +149,47 @@ func TestOfficialMimoAPITemplateIncludesVisionModels(t *testing.T) { } } +func TestSettingsTreatsLegacyExplicitProviderAccessAsAdded(t *testing.T) { + isolateDesktopUserDirs(t) + t.Setenv("DEEPSEEK_API_KEY", "sk-test") + t.Setenv("MIMO_API_KEY", "sk-test") + if err := os.MkdirAll(filepath.Dir(config.UserConfigPath()), 0o755); err != nil { + t.Fatalf("mkdir config dir: %v", err) + } + if err := os.WriteFile(config.UserConfigPath(), []byte(` +default_model = "deepseek-flash/deepseek-v4-flash" + +[desktop] +provider_access = ["deepseek-flash", "mimo-pro"] + +[[providers]] +name = "deepseek-flash" +kind = "openai" +base_url = "https://api.deepseek.com" +models = ["deepseek-v4-flash", "deepseek-v4-pro"] +default = "deepseek-v4-flash" +api_key_env = "DEEPSEEK_API_KEY" + +[[providers]] +name = "mimo-pro" +kind = "openai" +base_url = "https://token-plan-cn.xiaomimimo.com/v1" +model = "mimo-v2.5-pro" +api_key_env = "MIMO_API_KEY" +`), 0o644); err != nil { + t.Fatalf("write config: %v", err) + } + + got := NewApp().Settings() + providers := map[string]ProviderView{} + for _, p := range got.Providers { + providers[p.Name] = p + } + if !providers["deepseek"].Added || !providers["mimo-token-plan"].Added { + t.Fatalf("providers = %+v, want canonical official providers marked added", providers) + } +} + func TestSetAgentParamsPersistsStepLimitsToUserConfig(t *testing.T) { isolateDesktopUserDirs(t) diff --git a/desktop/tabs.go b/desktop/tabs.go index fa826c8a9..1b62e0ba6 100644 --- a/desktop/tabs.go +++ b/desktop/tabs.go @@ -1410,12 +1410,19 @@ func (a *App) buildTabController(tab *WorkspaceTab) { model = cfg.DefaultModel } requestedModel := model - if resolved, fallback, ok := cfg.ResolveModelWithFallback(model); ok { - if fallback && strings.TrimSpace(tab.model) != "" { - a.noticeForTab(tab.ID, fmt.Sprintf("model %q is no longer available; switched to %s", requestedModel, resolved)) - } - model = resolved + resolved, fallback, ok := resolveAccessibleModelWithFallback(cfg, model) + if !ok { + a.mu.Lock() + tab.StartupErr = noAccessibleModelError(requestedModel).Error() + tab.Ready = true + a.mu.Unlock() + a.emitReady(wailsCtx) + return + } + if fallback && strings.TrimSpace(tab.model) != "" { + a.noticeForTab(tab.ID, fmt.Sprintf("model %q is no longer available; switched to %s", requestedModel, resolved)) } + model = resolved a.mu.Lock() tab.model = model diff --git a/internal/config/config.go b/internal/config/config.go index 061cceec9..154761228 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1330,6 +1330,16 @@ func LoadForRoot(root string) (*Config, error) { return nil, err } cfg.Plugins = plugins + // Apply the same later-wins merge for [[providers]], but only across actual + // TOML sources. This preserves the long-standing rule that declaring + // [[providers]] replaces the built-in defaults, while preventing a project + // reasonix.toml from accidentally dropping the user's global custom + // providers. + if providers, ok, err := mergeTOMLProviders(tomlSources); err != nil { + return nil, err + } else if ok { + cfg.Providers = providers + } // Claude Code's .mcp.json (project root) is read last and merged into // [[plugins]], so a server configured for Claude works here unchanged. @@ -1481,6 +1491,38 @@ func mergeTOMLPlugins(paths []string) ([]PluginEntry, error) { return merged, nil } +// mergeTOMLProviders merges [[providers]] across TOML sources by name (later +// source wins). The bool reports whether any TOML source explicitly declared a +// provider, so callers can distinguish "no provider blocks anywhere" from an +// intentional merged list. +func mergeTOMLProviders(paths []string) ([]ProviderEntry, bool, error) { + var merged []ProviderEntry + index := map[string]int{} + found := false + for _, path := range paths { + if _, err := os.Stat(path); err != nil { + continue + } + var f Config + if _, err := toml.DecodeFile(path, &f); err != nil { + return nil, false, fmt.Errorf("config %s: %w", path, err) + } + if len(f.Providers) == 0 { + continue + } + found = true + for _, p := range f.Providers { + if i, ok := index[p.Name]; ok { + merged[i] = p + continue + } + index[p.Name] = len(merged) + merged = append(merged, p) + } + } + return merged, found, nil +} + // LoadForEdit returns a config to seed the `reasonix setup` wizard when reconfiguring: // the built-in defaults with the file at path (if present) decoded on top, so a // reconfigure preserves the user's existing providers and agent settings instead diff --git a/internal/config/loadedit_test.go b/internal/config/loadedit_test.go index d843d444a..0e621fe03 100644 --- a/internal/config/loadedit_test.go +++ b/internal/config/loadedit_test.go @@ -79,3 +79,73 @@ model = "m" t.Fatalf("migration should preserve ordinary config:\n%s", updated) } } + +func TestLoadForRootMergesProvidersAcrossUserAndProjectConfigs(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("XDG_CONFIG_HOME", filepath.Join(home, ".config")) + t.Setenv("USERPROFILE", home) + t.Setenv("AppData", filepath.Join(home, "AppData")) + + userPath := userConfigPath() + if err := os.MkdirAll(filepath.Dir(userPath), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(userPath, []byte(` +default_model = "corp/corp-model" + +[[providers]] +name = "corp" +kind = "openai" +base_url = "https://corp.example.com/v1" +model = "corp-model" +api_key_env = "CORP_KEY" + +[[providers]] +name = "shared" +kind = "openai" +base_url = "https://global.example.com/v1" +model = "global-model" +api_key_env = "GLOBAL_KEY" +`), 0o644); err != nil { + t.Fatal(err) + } + + project := t.TempDir() + if err := os.WriteFile(filepath.Join(project, "reasonix.toml"), []byte(` +default_model = "shared/project-model" + +[[providers]] +name = "shared" +kind = "openai" +base_url = "https://project.example.com/v1" +model = "project-model" +api_key_env = "PROJECT_KEY" + +[[providers]] +name = "project-only" +kind = "openai" +base_url = "https://project-only.example.com/v1" +model = "project-only-model" +api_key_env = "PROJECT_ONLY_KEY" +`), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := LoadForRoot(project) + if err != nil { + t.Fatalf("LoadForRoot: %v", err) + } + if len(cfg.Providers) != 3 { + t.Fatalf("providers = %+v, want 3 merged providers", cfg.Providers) + } + if p, ok := cfg.Provider("corp"); !ok || p.BaseURL != "https://corp.example.com/v1" || p.Model != "corp-model" { + t.Fatalf("corp provider = %+v, want preserved global provider", p) + } + if p, ok := cfg.Provider("shared"); !ok || p.BaseURL != "https://project.example.com/v1" || p.Model != "project-model" { + t.Fatalf("shared provider = %+v, want project override", p) + } + if p, ok := cfg.Provider("project-only"); !ok || p.BaseURL != "https://project-only.example.com/v1" { + t.Fatalf("project-only provider = %+v, want project-specific provider", p) + } +}