diff --git a/internal/watcher/clients.go b/internal/watcher/clients.go index 2697fa05c9..64b97d0f2b 100644 --- a/internal/watcher/clients.go +++ b/internal/watcher/clients.go @@ -17,6 +17,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff" + "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" log "github.com/sirupsen/logrus" ) @@ -75,6 +76,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string w.lastAuthHashes = make(map[string]string) w.lastAuthContents = make(map[string]*coreauth.Auth) + w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth) if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil { log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir) } else if resolvedAuthDir != "" { @@ -92,6 +94,17 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string if errParse := json.Unmarshal(data, &auth); errParse == nil { w.lastAuthContents[normalizedPath] = &auth } + ctx := &synthesizer.SynthesisContext{ + Config: cfg, + AuthDir: resolvedAuthDir, + Now: time.Now(), + IDGenerator: synthesizer.NewStableIDGenerator(), + } + if generated := synthesizer.SynthesizeAuthFile(ctx, path, data); len(generated) > 0 { + if pathAuths := authSliceToMap(generated); len(pathAuths) > 0 { + w.fileAuthsByPath[normalizedPath] = pathAuths + } + } } } return nil @@ -143,13 +156,14 @@ func (w *Watcher) addOrUpdateClient(path string) { } w.clientsMutex.Lock() - - cfg := w.config - if cfg == nil { + if w.config == nil { log.Error("config is nil, cannot add or update client") w.clientsMutex.Unlock() return } + if w.fileAuthsByPath == nil { + w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth) + } if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash { log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path)) w.clientsMutex.Unlock() @@ -177,34 +191,86 @@ func (w *Watcher) addOrUpdateClient(path string) { } w.lastAuthContents[normalized] = &newAuth - w.clientsMutex.Unlock() // Unlock before the callback - - w.refreshAuthState(false) + oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized])) + for id, a := range w.fileAuthsByPath[normalized] { + oldByID[id] = a + } - if w.reloadCallback != nil { - log.Debugf("triggering server update callback after add/update") - w.triggerServerUpdate(cfg) + // Build synthesized auth entries for this single file only. + sctx := &synthesizer.SynthesisContext{ + Config: w.config, + AuthDir: w.effectiveAuthDir(), + Now: time.Now(), + IDGenerator: synthesizer.NewStableIDGenerator(), + } + generated := synthesizer.SynthesizeAuthFile(sctx, path, data) + newByID := authSliceToMap(generated) + if len(newByID) > 0 { + w.fileAuthsByPath[normalized] = newByID + } else { + delete(w.fileAuthsByPath, normalized) } + updates := w.computePerPathUpdatesLocked(oldByID, newByID) + w.clientsMutex.Unlock() + w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path) + w.dispatchAuthUpdates(updates) } func (w *Watcher) removeClient(path string) { normalized := w.normalizeAuthPath(path) w.clientsMutex.Lock() - - cfg := w.config + oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized])) + for id, a := range w.fileAuthsByPath[normalized] { + oldByID[id] = a + } delete(w.lastAuthHashes, normalized) delete(w.lastAuthContents, normalized) + delete(w.fileAuthsByPath, normalized) - w.clientsMutex.Unlock() // Release the lock before the callback + updates := w.computePerPathUpdatesLocked(oldByID, map[string]*coreauth.Auth{}) + w.clientsMutex.Unlock() - w.refreshAuthState(false) + w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path) + w.dispatchAuthUpdates(updates) +} - if w.reloadCallback != nil { - log.Debugf("triggering server update callback after removal") - w.triggerServerUpdate(cfg) +func (w *Watcher) computePerPathUpdatesLocked(oldByID, newByID map[string]*coreauth.Auth) []AuthUpdate { + if w.currentAuths == nil { + w.currentAuths = make(map[string]*coreauth.Auth) + } + updates := make([]AuthUpdate, 0, len(oldByID)+len(newByID)) + for id, newAuth := range newByID { + existing, ok := w.currentAuths[id] + if !ok { + w.currentAuths[id] = newAuth.Clone() + updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: newAuth.Clone()}) + continue + } + if !authEqual(existing, newAuth) { + w.currentAuths[id] = newAuth.Clone() + updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: newAuth.Clone()}) + } } - w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path) + for id := range oldByID { + if _, stillExists := newByID[id]; stillExists { + continue + } + delete(w.currentAuths, id) + updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id}) + } + return updates +} + +func authSliceToMap(auths []*coreauth.Auth) map[string]*coreauth.Auth { + byID := make(map[string]*coreauth.Auth, len(auths)) + for _, a := range auths { + if a == nil || strings.TrimSpace(a.ID) == "" { + continue + } + byID[a.ID] = a + } + return byID } func (w *Watcher) loadFileClients(cfg *config.Config) int { @@ -304,78 +370,3 @@ func (w *Watcher) persistAuthAsync(message string, paths ...string) { }() } -func (w *Watcher) stopServerUpdateTimer() { - w.serverUpdateMu.Lock() - defer w.serverUpdateMu.Unlock() - if w.serverUpdateTimer != nil { - w.serverUpdateTimer.Stop() - w.serverUpdateTimer = nil - } - w.serverUpdatePend = false -} - -func (w *Watcher) triggerServerUpdate(cfg *config.Config) { - if w == nil || w.reloadCallback == nil || cfg == nil { - return - } - if w.stopped.Load() { - return - } - - now := time.Now() - - w.serverUpdateMu.Lock() - if w.serverUpdateLast.IsZero() || now.Sub(w.serverUpdateLast) >= serverUpdateDebounce { - w.serverUpdateLast = now - if w.serverUpdateTimer != nil { - w.serverUpdateTimer.Stop() - w.serverUpdateTimer = nil - } - w.serverUpdatePend = false - w.serverUpdateMu.Unlock() - w.reloadCallback(cfg) - return - } - - if w.serverUpdatePend { - w.serverUpdateMu.Unlock() - return - } - - delay := serverUpdateDebounce - now.Sub(w.serverUpdateLast) - if delay < 10*time.Millisecond { - delay = 10 * time.Millisecond - } - w.serverUpdatePend = true - if w.serverUpdateTimer != nil { - w.serverUpdateTimer.Stop() - w.serverUpdateTimer = nil - } - var timer *time.Timer - timer = time.AfterFunc(delay, func() { - if w.stopped.Load() { - return - } - w.clientsMutex.RLock() - latestCfg := w.config - w.clientsMutex.RUnlock() - - w.serverUpdateMu.Lock() - if w.serverUpdateTimer != timer || !w.serverUpdatePend { - w.serverUpdateMu.Unlock() - return - } - w.serverUpdateTimer = nil - w.serverUpdatePend = false - if latestCfg == nil || w.reloadCallback == nil || w.stopped.Load() { - w.serverUpdateMu.Unlock() - return - } - - w.serverUpdateLast = time.Now() - w.serverUpdateMu.Unlock() - w.reloadCallback(latestCfg) - }) - w.serverUpdateTimer = timer - w.serverUpdateMu.Unlock() -} diff --git a/internal/watcher/dispatcher.go b/internal/watcher/dispatcher.go index ff3c5b632c..ffc5c06654 100644 --- a/internal/watcher/dispatcher.go +++ b/internal/watcher/dispatcher.go @@ -14,6 +14,8 @@ import ( coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" ) +var snapshotCoreAuthsFunc = snapshotCoreAuths + func (w *Watcher) setAuthUpdateQueue(queue chan<- AuthUpdate) { w.clientsMutex.Lock() defer w.clientsMutex.Unlock() @@ -76,7 +78,11 @@ func (w *Watcher) dispatchRuntimeAuthUpdate(update AuthUpdate) bool { } func (w *Watcher) refreshAuthState(force bool) { - auths := w.SnapshotCoreAuths() + w.clientsMutex.RLock() + cfg := w.config + authDir := w.effectiveAuthDir() + w.clientsMutex.RUnlock() + auths := snapshotCoreAuthsFunc(cfg, authDir) w.clientsMutex.Lock() if len(w.runtimeAuths) > 0 { for _, a := range w.runtimeAuths { @@ -271,3 +277,4 @@ func snapshotCoreAuths(cfg *config.Config, authDir string) []*coreauth.Auth { return out } + diff --git a/internal/watcher/events.go b/internal/watcher/events.go index 250cf75cb4..f90d623432 100644 --- a/internal/watcher/events.go +++ b/internal/watcher/events.go @@ -33,11 +33,12 @@ func (w *Watcher) start(ctx context.Context) error { } log.Debugf("watching config file: %s", w.configPath) - if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil { - log.Errorf("failed to watch auth directory %s: %v", w.authDir, errAddAuthDir) + authDir := w.effectiveAuthDir() + if errAddAuthDir := w.watcher.Add(authDir); errAddAuthDir != nil { + log.Errorf("failed to watch auth directory %s: %v", authDir, errAddAuthDir) return errAddAuthDir } - log.Debugf("watching auth directory: %s", w.authDir) + log.Debugf("watching auth directory: %s", authDir) go w.processEvents(ctx) @@ -69,10 +70,9 @@ func (w *Watcher) handleEvent(event fsnotify.Event) { configOps := fsnotify.Write | fsnotify.Create | fsnotify.Rename normalizedName := w.normalizeAuthPath(event.Name) normalizedConfigPath := w.normalizeAuthPath(w.configPath) - normalizedAuthDir := w.normalizeAuthPath(w.authDir) isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0 authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename - isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0 + isAuthJSON := strings.HasSuffix(normalizedName, ".json") && pathBelongsToDir(event.Name, w.effectiveAuthDir()) && event.Op&authOps != 0 if !isConfigEvent && !isAuthJSON { // Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise. return @@ -192,3 +192,6 @@ func (w *Watcher) shouldDebounceRemove(normalizedPath string, now time.Time) boo w.clientsMutex.Unlock() return false } + + + diff --git a/internal/watcher/mirrored_auth_dir_test.go b/internal/watcher/mirrored_auth_dir_test.go new file mode 100644 index 0000000000..88ff1c9045 --- /dev/null +++ b/internal/watcher/mirrored_auth_dir_test.go @@ -0,0 +1,119 @@ +package watcher + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/fsnotify/fsnotify" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +func TestStartUsesMirroredAuthDirForWatching(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + mirroredAuthDir := filepath.Join(tmpDir, "mirrored-auth") + if err := os.MkdirAll(mirroredAuthDir, 0o755); err != nil { + t.Fatalf("failed to create mirrored auth dir: %v", err) + } + if err := os.WriteFile(configPath, []byte("auth_dir: "+mirroredAuthDir+"\n"), 0o644); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + w, err := NewWatcher(configPath, filepath.Join(tmpDir, "missing-auth"), nil) + if err != nil { + t.Fatalf("failed to create watcher: %v", err) + } + defer w.Stop() + w.mirroredAuthDir = mirroredAuthDir + w.SetConfig(&config.Config{AuthDir: mirroredAuthDir}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := w.Start(ctx); err != nil { + t.Fatalf("expected Start to watch mirrored auth dir, got error: %v", err) + } +} + +func TestHandleEventUsesMirroredAuthDir(t *testing.T) { + tmpDir := t.TempDir() + originalAuthDir := filepath.Join(tmpDir, "auth") + mirroredAuthDir := filepath.Join(tmpDir, "mirror") + if err := os.MkdirAll(mirroredAuthDir, 0o755); err != nil { + t.Fatalf("failed to create mirrored auth dir: %v", err) + } + authFile := filepath.Join(mirroredAuthDir, "demo.json") + if err := os.WriteFile(authFile, []byte(`{"type":"demo","email":"demo@example.com"}`), 0o644); err != nil { + t.Fatalf("failed to write auth file: %v", err) + } + + w := &Watcher{ + authDir: originalAuthDir, + mirroredAuthDir: mirroredAuthDir, + lastAuthHashes: make(map[string]string), + fileAuthsByPath: make(map[string]map[string]*coreauth.Auth), + } + w.SetConfig(&config.Config{AuthDir: mirroredAuthDir}) + + w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Write}) + + normalized := w.normalizeAuthPath(authFile) + w.clientsMutex.RLock() + _, ok := w.lastAuthHashes[normalized] + w.clientsMutex.RUnlock() + if !ok { + t.Fatal("expected mirrored auth file event to update watcher state") + } +} + +func TestRefreshAuthStateUsesMirroredAuthDir(t *testing.T) { + oldSnapshot := snapshotCoreAuthsFunc + defer func() { snapshotCoreAuthsFunc = oldSnapshot }() + + mirroredAuthDir := t.TempDir() + calledDir := "" + snapshotCoreAuthsFunc = func(cfg *config.Config, authDir string) []*coreauth.Auth { + calledDir = authDir + return nil + } + + w := &Watcher{authDir: filepath.Join(t.TempDir(), "auth")} + w.mirroredAuthDir = mirroredAuthDir + w.SetConfig(&config.Config{AuthDir: mirroredAuthDir}) + + w.refreshAuthState(false) + + if calledDir != mirroredAuthDir { + t.Fatalf("expected refreshAuthState to use mirrored auth dir %s, got %s", mirroredAuthDir, calledDir) + } +} + +func TestSnapshotCoreAuthsUsesMirroredAuthDir(t *testing.T) { + originalAuthDir := filepath.Join(t.TempDir(), "original-auth") + if err := os.MkdirAll(originalAuthDir, 0o755); err != nil { + t.Fatalf("failed to create original auth dir: %v", err) + } + mirroredAuthDir := t.TempDir() + if err := os.WriteFile(filepath.Join(mirroredAuthDir, "demo.json"), []byte(`{"type":"demo","email":"mirror@example.com"}`), 0o644); err != nil { + t.Fatalf("failed to write mirrored auth file: %v", err) + } + + w := &Watcher{ + authDir: originalAuthDir, + mirroredAuthDir: mirroredAuthDir, + } + w.SetConfig(&config.Config{AuthDir: originalAuthDir}) + + auths := w.SnapshotCoreAuths() + if len(auths) != 1 { + t.Fatalf("expected mirrored auth dir to provide 1 auth, got %d", len(auths)) + } + if auths[0] == nil || auths[0].Provider != "demo" { + t.Fatalf("unexpected auths from mirrored dir: %+v", auths) + } +} + + diff --git a/internal/watcher/path_scope.go b/internal/watcher/path_scope.go new file mode 100644 index 0000000000..dd4a58e6fd --- /dev/null +++ b/internal/watcher/path_scope.go @@ -0,0 +1,57 @@ +package watcher + +import ( + "path/filepath" + "runtime" + "strings" +) + +// pathBelongsToDir reports whether path resolves to dir or one of its descendants. +func pathBelongsToDir(path, dir string) bool { + normalizedPath, okPath := normalizeAbsolutePath(path) + if !okPath { + return false + } + normalizedDir, okDir := normalizeAbsolutePath(dir) + if !okDir { + return false + } + + relPath, errRel := filepath.Rel(normalizedDir, normalizedPath) + if errRel != nil { + return false + } + relPath = filepath.Clean(relPath) + if relPath == "." { + return true + } + + parentPrefix := ".." + string(filepath.Separator) + return relPath != ".." && !strings.HasPrefix(relPath, parentPrefix) +} + +func normalizeAbsolutePath(path string) (string, bool) { + trimmed := strings.TrimSpace(path) + if trimmed == "" { + return "", false + } + if runtime.GOOS == "windows" { + trimmed = strings.TrimPrefix(trimmed, `\\?\`) + } + + normalizedPath := trimmed + if resolvedPath, errEval := filepath.EvalSymlinks(trimmed); errEval == nil { + normalizedPath = resolvedPath + } + + absolutePath, errAbs := filepath.Abs(normalizedPath) + if errAbs != nil { + return "", false + } + cleaned := filepath.Clean(absolutePath) + if runtime.GOOS == "windows" { + cleaned = strings.TrimPrefix(cleaned, `\\?\`) + cleaned = strings.ToLower(cleaned) + } + return cleaned, true +} diff --git a/internal/watcher/synthesizer/file.go b/internal/watcher/synthesizer/file.go index ea96118b5e..02a0cefac8 100644 --- a/internal/watcher/synthesizer/file.go +++ b/internal/watcher/synthesizer/file.go @@ -36,9 +36,6 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e return out, nil } - now := ctx.Now - cfg := ctx.Config - for _, e := range entries { if e.IsDir() { continue @@ -52,97 +49,118 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e if errRead != nil || len(data) == 0 { continue } - var metadata map[string]any - if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil { - continue - } - t, _ := metadata["type"].(string) - if t == "" { + auths := synthesizeFileAuths(ctx, full, data) + if len(auths) == 0 { continue } - provider := strings.ToLower(t) - if provider == "gemini" { - provider = "gemini-cli" - } - label := provider - if email, _ := metadata["email"].(string); email != "" { - label = email - } - // Use relative path under authDir as ID to stay consistent with the file-based token store - id := full - if rel, errRel := filepath.Rel(ctx.AuthDir, full); errRel == nil && rel != "" { + out = append(out, auths...) + } + return out, nil +} + +// SynthesizeAuthFile generates Auth entries for one auth JSON file payload. +// It shares exactly the same mapping behavior as FileSynthesizer.Synthesize. +func SynthesizeAuthFile(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth { + return synthesizeFileAuths(ctx, fullPath, data) +} + +func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth { + if ctx == nil || len(data) == 0 { + return nil + } + now := ctx.Now + cfg := ctx.Config + var metadata map[string]any + if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil { + return nil + } + t, _ := metadata["type"].(string) + if t == "" { + return nil + } + provider := strings.ToLower(t) + if provider == "gemini" { + provider = "gemini-cli" + } + label := provider + if email, _ := metadata["email"].(string); email != "" { + label = email + } + // Use relative path under authDir as ID to stay consistent with the file-based token store. + id := fullPath + if strings.TrimSpace(ctx.AuthDir) != "" { + if rel, errRel := filepath.Rel(ctx.AuthDir, fullPath); errRel == nil && rel != "" { id = rel } - // On Windows, normalize ID casing to avoid duplicate auth entries caused by case-insensitive paths. - if runtime.GOOS == "windows" { - id = strings.ToLower(id) - } + } + if runtime.GOOS == "windows" { + id = strings.ToLower(id) + } - proxyURL := "" - if p, ok := metadata["proxy_url"].(string); ok { - proxyURL = p - } + proxyURL := "" + if p, ok := metadata["proxy_url"].(string); ok { + proxyURL = p + } - prefix := "" - if rawPrefix, ok := metadata["prefix"].(string); ok { - trimmed := strings.TrimSpace(rawPrefix) - trimmed = strings.Trim(trimmed, "/") - if trimmed != "" && !strings.Contains(trimmed, "/") { - prefix = trimmed - } + prefix := "" + if rawPrefix, ok := metadata["prefix"].(string); ok { + trimmed := strings.TrimSpace(rawPrefix) + trimmed = strings.Trim(trimmed, "/") + if trimmed != "" && !strings.Contains(trimmed, "/") { + prefix = trimmed } + } - disabled, _ := metadata["disabled"].(bool) - status := coreauth.StatusActive - if disabled { - status = coreauth.StatusDisabled - } + disabled, _ := metadata["disabled"].(bool) + status := coreauth.StatusActive + if disabled { + status = coreauth.StatusDisabled + } - // Read per-account excluded models from the OAuth JSON file - perAccountExcluded := extractExcludedModelsFromMetadata(metadata) + // Read per-account excluded models from the OAuth JSON file. + perAccountExcluded := extractExcludedModelsFromMetadata(metadata) - a := &coreauth.Auth{ - ID: id, - Provider: provider, - Label: label, - Prefix: prefix, - Status: status, - Disabled: disabled, - Attributes: map[string]string{ - "source": full, - "path": full, - }, - ProxyURL: proxyURL, - Metadata: metadata, - CreatedAt: now, - UpdatedAt: now, - } - // Read priority from auth file - if rawPriority, ok := metadata["priority"]; ok { - switch v := rawPriority.(type) { - case float64: - a.Attributes["priority"] = strconv.Itoa(int(v)) - case string: - priority := strings.TrimSpace(v) - if _, errAtoi := strconv.Atoi(priority); errAtoi == nil { - a.Attributes["priority"] = priority - } + a := &coreauth.Auth{ + ID: id, + Provider: provider, + Label: label, + Prefix: prefix, + Status: status, + Disabled: disabled, + Attributes: map[string]string{ + "source": fullPath, + "path": fullPath, + }, + ProxyURL: proxyURL, + Metadata: metadata, + CreatedAt: now, + UpdatedAt: now, + } + // Read priority from auth file. + if rawPriority, ok := metadata["priority"]; ok { + switch v := rawPriority.(type) { + case float64: + a.Attributes["priority"] = strconv.Itoa(int(v)) + case string: + priority := strings.TrimSpace(v) + if _, errAtoi := strconv.Atoi(priority); errAtoi == nil { + a.Attributes["priority"] = priority } } - ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth") - if provider == "gemini-cli" { - if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 { - for _, v := range virtuals { - ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth") - } - out = append(out, a) - out = append(out, virtuals...) - continue + } + ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth") + if provider == "gemini-cli" { + if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 { + for _, v := range virtuals { + ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth") } + out := make([]*coreauth.Auth, 0, 1+len(virtuals)) + out = append(out, a) + out = append(out, virtuals...) + return out } - out = append(out, a) } - return out, nil + return []*coreauth.Auth{a} } // SynthesizeGeminiVirtualAuths creates virtual Auth entries for multi-project Gemini credentials. diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 76e2dee5ac..ec9a40f287 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -6,7 +6,6 @@ import ( "context" "strings" "sync" - "sync/atomic" "time" "github.com/fsnotify/fsnotify" @@ -36,15 +35,11 @@ type Watcher struct { clientsMutex sync.RWMutex configReloadMu sync.Mutex configReloadTimer *time.Timer - serverUpdateMu sync.Mutex - serverUpdateTimer *time.Timer - serverUpdateLast time.Time - serverUpdatePend bool - stopped atomic.Bool reloadCallback func(*config.Config) watcher *fsnotify.Watcher lastAuthHashes map[string]string lastAuthContents map[string]*coreauth.Auth + fileAuthsByPath map[string]map[string]*coreauth.Auth lastRemoveTimes map[string]time.Time lastConfigHash string authQueue chan<- AuthUpdate @@ -82,7 +77,6 @@ const ( replaceCheckDelay = 50 * time.Millisecond configReloadDebounce = 150 * time.Millisecond authRemoveDebounceWindow = 1 * time.Second - serverUpdateDebounce = 1 * time.Second ) // NewWatcher creates a new file watcher instance @@ -92,11 +86,12 @@ func NewWatcher(configPath, authDir string, reloadCallback func(*config.Config)) return nil, errNewWatcher } w := &Watcher{ - configPath: configPath, - authDir: authDir, - reloadCallback: reloadCallback, - watcher: watcher, - lastAuthHashes: make(map[string]string), + configPath: configPath, + authDir: authDir, + reloadCallback: reloadCallback, + watcher: watcher, + lastAuthHashes: make(map[string]string), + fileAuthsByPath: make(map[string]map[string]*coreauth.Auth), } w.dispatchCond = sync.NewCond(&w.dispatchMu) if store := sdkAuth.GetTokenStore(); store != nil { @@ -106,6 +101,7 @@ func NewWatcher(configPath, authDir string, reloadCallback func(*config.Config)) } if provider, ok := store.(authDirProvider); ok { if fixed := strings.TrimSpace(provider.AuthDir()); fixed != "" { + w.authDir = fixed w.mirroredAuthDir = fixed log.Debugf("mirrored auth directory locked to %s", fixed) } @@ -121,10 +117,8 @@ func (w *Watcher) Start(ctx context.Context) error { // Stop stops the file watcher func (w *Watcher) Stop() error { - w.stopped.Store(true) w.stopDispatch() w.stopConfigReloadTimer() - w.stopServerUpdateTimer() return w.watcher.Close() } @@ -153,5 +147,15 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { w.clientsMutex.RLock() cfg := w.config w.clientsMutex.RUnlock() - return snapshotCoreAuths(cfg, w.authDir) + return snapshotCoreAuths(cfg, w.effectiveAuthDir()) +} + +func (w *Watcher) effectiveAuthDir() string { + if w == nil { + return "" + } + if fixed := strings.TrimSpace(w.mirroredAuthDir); fixed != "" { + return fixed + } + return w.authDir } diff --git a/internal/watcher/watcher_test.go b/internal/watcher/watcher_test.go index 0f9cd019ae..486ed095b0 100644 --- a/internal/watcher/watcher_test.go +++ b/internal/watcher/watcher_test.go @@ -406,8 +406,8 @@ func TestAddOrUpdateClientTriggersReloadAndHash(t *testing.T) { w.addOrUpdateClient(authFile) - if got := atomic.LoadInt32(&reloads); got != 1 { - t.Fatalf("expected reload callback once, got %d", got) + if got := atomic.LoadInt32(&reloads); got != 0 { + t.Fatalf("expected no reload callback for auth update, got %d", got) } // Use normalizeAuthPath to match how addOrUpdateClient stores the key normalized := w.normalizeAuthPath(authFile) @@ -436,48 +436,110 @@ func TestRemoveClientRemovesHash(t *testing.T) { if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok { t.Fatal("expected hash to be removed after deletion") } - if got := atomic.LoadInt32(&reloads); got != 1 { - t.Fatalf("expected reload callback once, got %d", got) + if got := atomic.LoadInt32(&reloads); got != 0 { + t.Fatalf("expected no reload callback for auth removal, got %d", got) } } -func TestTriggerServerUpdateCancelsPendingTimerOnImmediate(t *testing.T) { +func TestAuthFileEventsDoNotInvokeSnapshotCoreAuths(t *testing.T) { tmpDir := t.TempDir() - cfg := &config.Config{AuthDir: tmpDir} + authFile := filepath.Join(tmpDir, "sample.json") + if err := os.WriteFile(authFile, []byte(`{"type":"codex","email":"u@example.com"}`), 0o644); err != nil { + t.Fatalf("failed to create auth file: %v", err) + } + + origSnapshot := snapshotCoreAuthsFunc + var snapshotCalls int32 + snapshotCoreAuthsFunc = func(cfg *config.Config, authDir string) []*coreauth.Auth { + atomic.AddInt32(&snapshotCalls, 1) + return origSnapshot(cfg, authDir) + } + defer func() { snapshotCoreAuthsFunc = origSnapshot }() - var reloads int32 w := &Watcher{ - reloadCallback: func(*config.Config) { - atomic.AddInt32(&reloads, 1) - }, + authDir: tmpDir, + lastAuthHashes: make(map[string]string), + lastAuthContents: make(map[string]*coreauth.Auth), + fileAuthsByPath: make(map[string]map[string]*coreauth.Auth), } - w.SetConfig(cfg) + w.SetConfig(&config.Config{AuthDir: tmpDir}) - w.serverUpdateMu.Lock() - w.serverUpdateLast = time.Now().Add(-(serverUpdateDebounce - 100*time.Millisecond)) - w.serverUpdateMu.Unlock() - w.triggerServerUpdate(cfg) + w.addOrUpdateClient(authFile) + w.removeClient(authFile) - if got := atomic.LoadInt32(&reloads); got != 0 { - t.Fatalf("expected no immediate reload, got %d", got) + if got := atomic.LoadInt32(&snapshotCalls); got != 0 { + t.Fatalf("expected auth file events to avoid full snapshot, got %d calls", got) } +} - w.serverUpdateMu.Lock() - if !w.serverUpdatePend || w.serverUpdateTimer == nil { - w.serverUpdateMu.Unlock() - t.Fatal("expected a pending server update timer") - } - w.serverUpdateLast = time.Now().Add(-(serverUpdateDebounce + 10*time.Millisecond)) - w.serverUpdateMu.Unlock() +func TestAuthSliceToMap(t *testing.T) { + t.Parallel() - w.triggerServerUpdate(cfg) - if got := atomic.LoadInt32(&reloads); got != 1 { - t.Fatalf("expected immediate reload once, got %d", got) + valid1 := &coreauth.Auth{ID: "a"} + valid2 := &coreauth.Auth{ID: "b"} + dupOld := &coreauth.Auth{ID: "dup", Label: "old"} + dupNew := &coreauth.Auth{ID: "dup", Label: "new"} + empty := &coreauth.Auth{ID: " "} + + tests := []struct { + name string + in []*coreauth.Auth + want map[string]*coreauth.Auth + }{ + { + name: "nil input", + in: nil, + want: map[string]*coreauth.Auth{}, + }, + { + name: "empty input", + in: []*coreauth.Auth{}, + want: map[string]*coreauth.Auth{}, + }, + { + name: "filters invalid auths", + in: []*coreauth.Auth{nil, empty}, + want: map[string]*coreauth.Auth{}, + }, + { + name: "keeps valid auths", + in: []*coreauth.Auth{valid1, nil, valid2}, + want: map[string]*coreauth.Auth{"a": valid1, "b": valid2}, + }, + { + name: "last duplicate wins", + in: []*coreauth.Auth{dupOld, dupNew}, + want: map[string]*coreauth.Auth{"dup": dupNew}, + }, } - time.Sleep(250 * time.Millisecond) - if got := atomic.LoadInt32(&reloads); got != 1 { - t.Fatalf("expected pending timer to be cancelled, got %d reloads", got) + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := authSliceToMap(tc.in) + if len(tc.want) == 0 { + if got == nil { + t.Fatal("expected empty map, got nil") + } + if len(got) != 0 { + t.Fatalf("expected empty map, got %#v", got) + } + return + } + if len(got) != len(tc.want) { + t.Fatalf("unexpected map length: got %d, want %d", len(got), len(tc.want)) + } + for id, wantAuth := range tc.want { + gotAuth, ok := got[id] + if !ok { + t.Fatalf("missing id %q in result map", id) + } + if !authEqual(gotAuth, wantAuth) { + t.Fatalf("unexpected auth for id %q: got %#v, want %#v", id, gotAuth, wantAuth) + } + } + }) } } @@ -695,8 +757,8 @@ func TestHandleEventRemovesAuthFile(t *testing.T) { w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove}) - if atomic.LoadInt32(&reloads) != 1 { - t.Fatalf("expected reload callback once, got %d", reloads) + if atomic.LoadInt32(&reloads) != 0 { + t.Fatalf("expected no reload callback for auth removal, got %d", reloads) } if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok { t.Fatal("expected hash entry to be removed") @@ -893,8 +955,8 @@ func TestHandleEventAuthWriteTriggersUpdate(t *testing.T) { w.SetConfig(&config.Config{AuthDir: authDir}) w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Write}) - if atomic.LoadInt32(&reloads) != 1 { - t.Fatalf("expected auth write to trigger reload callback, got %d", reloads) + if atomic.LoadInt32(&reloads) != 0 { + t.Fatalf("expected auth write to avoid global reload callback, got %d", reloads) } } @@ -990,8 +1052,8 @@ func TestHandleEventAtomicReplaceChangedTriggersUpdate(t *testing.T) { w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(oldSum[:]) w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename}) - if atomic.LoadInt32(&reloads) != 1 { - t.Fatalf("expected changed atomic replace to trigger update, got %d", reloads) + if atomic.LoadInt32(&reloads) != 0 { + t.Fatalf("expected changed atomic replace to avoid global reload, got %d", reloads) } } @@ -1045,8 +1107,8 @@ func TestHandleEventRemoveKnownFileDeletes(t *testing.T) { w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash" w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove}) - if atomic.LoadInt32(&reloads) != 1 { - t.Fatalf("expected known remove to trigger reload, got %d", reloads) + if atomic.LoadInt32(&reloads) != 0 { + t.Fatalf("expected known remove to avoid global reload, got %d", reloads) } if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok { t.Fatal("expected known auth hash to be deleted") @@ -1422,6 +1484,9 @@ func TestNewWatcherDetectsPersisterAndAuthDir(t *testing.T) { if w.mirroredAuthDir != tmp { t.Fatalf("expected mirroredAuthDir %s, got %s", tmp, w.mirroredAuthDir) } + if w.authDir != tmp { + t.Fatalf("expected runtime authDir to switch to mirrored path %s, got %s", tmp, w.authDir) + } } func TestPersistConfigAndAuthAsyncInvokePersister(t *testing.T) { @@ -1589,3 +1654,5 @@ func TestScheduleProcessEventsStopsOnContextDone(t *testing.T) { func hexString(data []byte) string { return strings.ToLower(fmt.Sprintf("%x", data)) } + +