diff --git a/pkg/cli/mcp_server.go b/pkg/cli/mcp_server.go index 19371793ec..ead520a888 100644 --- a/pkg/cli/mcp_server.go +++ b/pkg/cli/mcp_server.go @@ -38,44 +38,21 @@ func mcpErrorData(v any) json.RawMessage { return data } -// actorPermissionCache stores cached actor permission lookups with TTL -type actorPermissionCache struct { - permission string - timestamp time.Time -} - -// repositoryCache stores cached repository information with TTL -type repositoryCache struct { - repository string - timestamp time.Time -} - -var ( - permissionCache = make(map[string]*actorPermissionCache) - permissionCacheTTL = 1 * time.Hour - repoCache *repositoryCache - repoCacheTTL = 1 * time.Hour -) - // getRepository retrieves the current repository name (owner/repo format). // Results are cached for 1 hour to avoid repeated queries. // Checks GITHUB_REPOSITORY environment variable first, then falls back to gh repo view. func getRepository() (string, error) { // Check cache first - if repoCache != nil && time.Since(repoCache.timestamp) < repoCacheTTL { - mcpLog.Printf("Using cached repository: %s (age: %v)", repoCache.repository, time.Since(repoCache.timestamp)) - return repoCache.repository, nil + if repo, ok := mcpCache.GetRepo(); ok { + mcpLog.Printf("Using cached repository: %s", repo) + return repo, nil } // Try GITHUB_REPOSITORY environment variable first repo := os.Getenv("GITHUB_REPOSITORY") if repo != "" { mcpLog.Printf("Got repository from GITHUB_REPOSITORY: %s", repo) - // Cache the result - repoCache = &repositoryCache{ - repository: repo, - timestamp: time.Now(), - } + mcpCache.SetRepo(repo) return repo, nil } @@ -94,11 +71,7 @@ func getRepository() (string, error) { } mcpLog.Printf("Got repository from gh repo view: %s", repo) - // Cache the result - repoCache = &repositoryCache{ - repository: repo, - timestamp: time.Now(), - } + mcpCache.SetRepo(repo) return repo, nil } @@ -114,15 +87,9 @@ func queryActorRole(ctx context.Context, actor string, repo string) (string, err } // Check cache first - cacheKey := fmt.Sprintf("%s:%s", actor, repo) - if cached, ok := permissionCache[cacheKey]; ok { - if time.Since(cached.timestamp) < permissionCacheTTL { - mcpLog.Printf("Using cached permission for %s in %s: %s (age: %v)", actor, repo, cached.permission, time.Since(cached.timestamp)) - return cached.permission, nil - } - // Cache expired, remove it - delete(permissionCache, cacheKey) - mcpLog.Printf("Permission cache expired for %s in %s", actor, repo) + if perm, ok := mcpCache.GetPermission(actor, repo); ok { + mcpLog.Printf("Using cached permission for %s in %s: %s", actor, repo, perm) + return perm, nil } // Query GitHub API for user's permission level @@ -142,11 +109,7 @@ func queryActorRole(ctx context.Context, actor string, repo string) (string, err return "", fmt.Errorf("no permission found for actor %s in repository %s", actor, repo) } - // Cache the result - permissionCache[cacheKey] = &actorPermissionCache{ - permission: permission, - timestamp: time.Now(), - } + mcpCache.SetPermission(actor, repo, permission) mcpLog.Printf("Cached permission for %s in %s: %s", actor, repo, permission) return permission, nil diff --git a/pkg/cli/mcp_server_cache.go b/pkg/cli/mcp_server_cache.go new file mode 100644 index 0000000000..561a6bcfb2 --- /dev/null +++ b/pkg/cli/mcp_server_cache.go @@ -0,0 +1,87 @@ +package cli + +import ( + "sync" + "time" +) + +// mcpCacheStore provides thread-safe caching for actor permissions and repository lookups. +// All exported methods are safe for concurrent use. +type mcpCacheStore struct { + mu sync.RWMutex + permissions map[string]*permissionEntry + permissionTTL time.Duration + repo *repoEntry + repoTTL time.Duration +} + +type permissionEntry struct { + permission string + timestamp time.Time +} + +type repoEntry struct { + repository string + timestamp time.Time +} + +func newMCPCacheStore() *mcpCacheStore { + return &mcpCacheStore{ + permissions: make(map[string]*permissionEntry), + permissionTTL: 1 * time.Hour, + repoTTL: 1 * time.Hour, + } +} + +// GetPermission returns the cached permission for the given actor and repo, or ("", false) on cache miss. +func (c *mcpCacheStore) GetPermission(actor, repo string) (string, bool) { + cacheKey := actor + ":" + repo + c.mu.RLock() + entry, ok := c.permissions[cacheKey] + if ok && time.Since(entry.timestamp) < c.permissionTTL { + perm := entry.permission + c.mu.RUnlock() + return perm, true + } + c.mu.RUnlock() + if ok { + // Expired — remove it + c.mu.Lock() + delete(c.permissions, cacheKey) + c.mu.Unlock() + } + return "", false +} + +// SetPermission stores a permission in the cache. +func (c *mcpCacheStore) SetPermission(actor, repo, permission string) { + cacheKey := actor + ":" + repo + c.mu.Lock() + c.permissions[cacheKey] = &permissionEntry{ + permission: permission, + timestamp: time.Now(), + } + c.mu.Unlock() +} + +// GetRepo returns the cached repository name, or ("", false) on cache miss. +func (c *mcpCacheStore) GetRepo() (string, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + if c.repo != nil && time.Since(c.repo.timestamp) < c.repoTTL { + return c.repo.repository, true + } + return "", false +} + +// SetRepo stores a repository name in the cache. +func (c *mcpCacheStore) SetRepo(repository string) { + c.mu.Lock() + c.repo = &repoEntry{ + repository: repository, + timestamp: time.Now(), + } + c.mu.Unlock() +} + +var mcpCache = newMCPCacheStore() diff --git a/pkg/cli/mcp_server_cache_test.go b/pkg/cli/mcp_server_cache_test.go new file mode 100644 index 0000000000..b944b8fb47 --- /dev/null +++ b/pkg/cli/mcp_server_cache_test.go @@ -0,0 +1,106 @@ +//go:build !integration + +package cli + +import ( + "fmt" + "sync" + "testing" + "time" +) + +func TestMCPCacheStore_ConcurrentPermissionAccess(t *testing.T) { + cache := newMCPCacheStore() + cache.permissionTTL = 50 * time.Millisecond + + // Pre-populate + for i := 0; i < 5; i++ { + cache.SetPermission(fmt.Sprintf("actor%d", i), "owner/repo", "write") + } + + const numGoroutines = 20 + const numIterations = 100 + + var wg sync.WaitGroup + + for g := 0; g < numGoroutines; g++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < numIterations; i++ { + actor := fmt.Sprintf("actor%d", i%10) + cache.GetPermission(actor, "owner/repo") + cache.SetPermission(actor, "owner/repo", "write") + } + }() + } + + wg.Wait() +} + +func TestMCPCacheStore_ConcurrentRepoAccess(t *testing.T) { + cache := newMCPCacheStore() + cache.repoTTL = 50 * time.Millisecond + + const numGoroutines = 20 + const numIterations = 100 + + var wg sync.WaitGroup + + for g := 0; g < numGoroutines; g++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for i := 0; i < numIterations; i++ { + cache.GetRepo() + cache.SetRepo(fmt.Sprintf("owner/repo-%d", id)) + } + }(g) + } + + wg.Wait() +} + +func TestMCPCacheStore_PermissionExpiry(t *testing.T) { + cache := newMCPCacheStore() + cache.permissionTTL = 10 * time.Millisecond + + cache.SetPermission("actor", "owner/repo", "admin") + + // Should hit cache + perm, ok := cache.GetPermission("actor", "owner/repo") + if !ok || perm != "admin" { + t.Errorf("GetPermission() = (%q, %v), want (\"admin\", true)", perm, ok) + } + + // Wait for expiry + time.Sleep(20 * time.Millisecond) + + // Should miss cache + _, ok = cache.GetPermission("actor", "owner/repo") + if ok { + t.Error("GetPermission() should return false after TTL expiry") + } +} + +func TestMCPCacheStore_RepoExpiry(t *testing.T) { + cache := newMCPCacheStore() + cache.repoTTL = 10 * time.Millisecond + + cache.SetRepo("owner/repo") + + // Should hit cache + repo, ok := cache.GetRepo() + if !ok || repo != "owner/repo" { + t.Errorf("GetRepo() = (%q, %v), want (\"owner/repo\", true)", repo, ok) + } + + // Wait for expiry + time.Sleep(20 * time.Millisecond) + + // Should miss cache + _, ok = cache.GetRepo() + if ok { + t.Error("GetRepo() should return false after TTL expiry") + } +}