From 642f734082cf0a5a526ded6f83b82ce95738942d Mon Sep 17 00:00:00 2001 From: samzong Date: Wed, 13 May 2026 12:11:08 +0800 Subject: [PATCH] fix(worktree): scope PR lookup to user branches Signed-off-by: samzong --- cmd/worktree.go | 2 +- internal/worktree/review.go | 254 +++++++++++++++++++++++++++---- internal/worktree/review_test.go | 229 ++++++++++++++++++++++++++-- 3 files changed, 442 insertions(+), 43 deletions(-) diff --git a/cmd/worktree.go b/cmd/worktree.go index 5b0f309..23e6859 100644 --- a/cmd/worktree.go +++ b/cmd/worktree.go @@ -554,7 +554,7 @@ func loadWorktreeReviews(wtClient *worktree.Client, worktrees []worktree.Info) w if !wtShowPR || len(worktrees) == 0 { return worktree.ReviewLookup{} } - return wtClient.ReviewStates() + return wtClient.ReviewStates(worktrees) } func printReviewWarning(w io.Writer, reviews worktree.ReviewLookup) { diff --git a/internal/worktree/review.go b/internal/worktree/review.go index fc30759..cff91f1 100644 --- a/internal/worktree/review.go +++ b/internal/worktree/review.go @@ -1,18 +1,23 @@ package worktree import ( + "crypto/sha256" "encoding/json" "errors" "fmt" "net/url" "os" "os/exec" + "path/filepath" "strings" + "time" ) const ( reviewProviderGitHub = "github" reviewProviderGitLab = "gitlab" + reviewCacheTTL = 5 * time.Minute + reviewCacheVersion = "v1" ) type ReviewInfo struct { @@ -34,6 +39,20 @@ type reviewRemote struct { provider string } +type reviewTarget struct { + Branch string + Commit string +} + +type reviewCandidate struct { + Provider string + Number int + State string + HeadBranch string + HeadCommit string + URL string +} + type missingReviewToolError struct { tool string } @@ -43,6 +62,7 @@ func (e missingReviewToolError) Error() string { } var reviewRunFunc = reviewRunDefault +var reviewCacheDirFunc = os.UserCacheDir func reviewRunDefault(repoDir string, tool string, args ...string) ([]byte, error) { if _, err := exec.LookPath(tool); err != nil { @@ -61,13 +81,18 @@ func reviewRunDefault(repoDir string, tool string, args ...string) ([]byte, erro return out, nil } -func (c *Client) ReviewStates() ReviewLookup { +func (c *Client) ReviewStates(worktrees []Info) ReviewLookup { result := ReviewLookup{Reviews: map[string]ReviewInfo{}} if err := c.ensureInit(); err != nil { result.Warning = "review lookup skipped: failed to initialize repository: " + err.Error() return result } + targets := c.reviewTargets(worktrees) + if len(targets) == 0 { + return result + } + remote, err := c.detectReviewRemote() if err != nil { result.Warning = "review lookup skipped: " + err.Error() @@ -77,9 +102,9 @@ func (c *Client) ReviewStates() ReviewLookup { var reviews map[string]ReviewInfo switch remote.provider { case reviewProviderGitHub: - reviews, err = githubReviewStates(c.repoDir, remote.url) + reviews, err = githubReviewStates(c.repoDir, remote.url, targets) case reviewProviderGitLab: - reviews, err = gitlabReviewStates(c.repoDir, remote.url) + reviews, err = gitlabReviewStates(c.repoDir, remote.url, targets) default: err = fmt.Errorf("unsupported review provider for remote %q", remote.name) } @@ -120,6 +145,35 @@ func (c *Client) detectReviewRemote() (reviewRemote, error) { return reviewRemote{}, errors.New("no usable git remote found") } +func (c *Client) reviewTargets(worktrees []Info) map[string]reviewTarget { + targets := make(map[string]reviewTarget) + mainBranch := "" + if branch, err := c.resolvedMainBranch(); err == nil { + mainBranch = branch + } + for _, wt := range worktrees { + branch := strings.TrimSpace(wt.Branch) + if branch == "" || branch == "(detached)" || branch == mainBranch { + continue + } + if !c.branchPushedToOrigin(branch) { + continue + } + if _, exists := targets[branch]; !exists { + targets[branch] = reviewTarget{Branch: branch, Commit: wt.Commit} + } + } + return targets +} + +func (c *Client) branchPushedToOrigin(branch string) bool { + if branch == "" { + return false + } + _, err := c.runner.Run("-C", c.repoDir, "show-ref", "--verify", "--quiet", "refs/remotes/origin/"+branch) + return err == nil +} + func (c *Client) remoteURL(remote string) (string, error) { result, err := c.runner.Run("-C", c.repoDir, "remote", "get-url", remote) if err != nil { @@ -188,17 +242,30 @@ type githubReviewInfo struct { Number int `json:"number"` State string `json:"state"` HeadRefName string `json:"headRefName"` + HeadRefOid string `json:"headRefOid"` URL string `json:"url"` } -func githubReviewStates(repoDir string, repoURL string) (map[string]ReviewInfo, error) { - out, err := reviewRunFunc(repoDir, - "gh", - "pr", "list", - "-R", repoURL, - "--state", "all", - "--json", "number,state,headRefName,url", - "--limit", "300", +func githubReviewStates( + repoDir string, + repoURL string, + targets map[string]reviewTarget, +) (map[string]ReviewInfo, error) { + out, err := cachedReviewOutput( + reviewProviderGitHub, + repoURL, + "me", + func() ([]byte, error) { + return reviewRunFunc(repoDir, + "gh", + "pr", "list", + "-R", repoURL, + "--author", "@me", + "--state", "all", + "--json", "number,state,headRefName,headRefOid,url", + "--limit", "1000", + ) + }, ) if err != nil { return nil, err @@ -209,37 +276,56 @@ func githubReviewStates(repoDir string, repoURL string) (map[string]ReviewInfo, return nil, err } - reviews := make(map[string]ReviewInfo, len(prs)) + candidates := make([]reviewCandidate, 0, len(prs)) for _, pr := range prs { - if _, exists := reviews[pr.HeadRefName]; exists { - continue - } - reviews[pr.HeadRefName] = ReviewInfo{ + candidates = append(candidates, reviewCandidate{ Provider: reviewProviderGitHub, Number: pr.Number, State: normalizeReviewState(pr.State), HeadBranch: pr.HeadRefName, + HeadCommit: pr.HeadRefOid, URL: pr.URL, - } + }) } - return reviews, nil + return selectReviewCandidates(candidates, targets), nil } type gitlabReviewInfo struct { IID int `json:"iid"` State string `json:"state"` SourceBranch string `json:"source_branch"` + SHA string `json:"sha"` WebURL string `json:"web_url"` } -func gitlabReviewStates(repoDir string, repoURL string) (map[string]ReviewInfo, error) { - out, err := reviewRunFunc(repoDir, - "glab", - "mr", "list", - "-R", repoURL, - "--all", - "--output", "json", - "--per-page", "100", +type gitlabUserInfo struct { + Username string `json:"username"` +} + +func gitlabReviewStates( + repoDir string, + repoURL string, + targets map[string]reviewTarget, +) (map[string]ReviewInfo, error) { + out, err := cachedReviewOutput( + reviewProviderGitLab, + repoURL, + "me", + func() ([]byte, error) { + username, err := gitlabCurrentUsername(repoDir) + if err != nil { + return nil, err + } + return reviewRunFunc(repoDir, + "glab", + "mr", "list", + "-R", repoURL, + "--all", + "--author", username, + "--output", "json", + "--per-page", "100", + ) + }, ) if err != nil { return nil, err @@ -250,20 +336,124 @@ func gitlabReviewStates(repoDir string, repoURL string) (map[string]ReviewInfo, return nil, err } - reviews := make(map[string]ReviewInfo, len(mrs)) + candidates := make([]reviewCandidate, 0, len(mrs)) for _, mr := range mrs { - if _, exists := reviews[mr.SourceBranch]; exists { - continue - } - reviews[mr.SourceBranch] = ReviewInfo{ + candidates = append(candidates, reviewCandidate{ Provider: reviewProviderGitLab, Number: mr.IID, State: normalizeReviewState(mr.State), HeadBranch: mr.SourceBranch, + HeadCommit: mr.SHA, URL: mr.WebURL, + }) + } + return selectReviewCandidates(candidates, targets), nil +} + +func gitlabCurrentUsername(repoDir string) (string, error) { + out, err := reviewRunFunc(repoDir, "glab", "api", "user") + if err != nil { + return "", err + } + var user gitlabUserInfo + if err := decodeReviewJSON(out, &user); err != nil { + return "", err + } + if user.Username == "" { + return "", errors.New("failed to determine GitLab username") + } + return user.Username, nil +} + +func selectReviewCandidates( + candidates []reviewCandidate, + targets map[string]reviewTarget, +) map[string]ReviewInfo { + reviews := make(map[string]ReviewInfo, len(targets)) + exact := make(map[string]bool, len(targets)) + for _, candidate := range candidates { + target, ok := targets[candidate.HeadBranch] + if !ok { + continue + } + if exact[candidate.HeadBranch] { + continue } + info := ReviewInfo{ + Provider: candidate.Provider, + Number: candidate.Number, + State: candidate.State, + HeadBranch: candidate.HeadBranch, + URL: candidate.URL, + } + if target.Commit != "" && candidate.HeadCommit != "" && target.Commit == candidate.HeadCommit { + reviews[candidate.HeadBranch] = info + exact[candidate.HeadBranch] = true + continue + } + if _, exists := reviews[candidate.HeadBranch]; !exists { + reviews[candidate.HeadBranch] = info + } + } + return reviews +} + +func cachedReviewOutput( + provider string, + repoURL string, + author string, + load func() ([]byte, error), +) ([]byte, error) { + if out, ok := readReviewCache(provider, repoURL, author); ok { + return out, nil + } + out, err := load() + if err != nil { + return nil, err + } + writeReviewCache(provider, repoURL, author, out) + return out, nil +} + +func readReviewCache(provider string, repoURL string, author string) ([]byte, bool) { + path, ok := reviewCachePath(provider, repoURL, author) + if !ok { + return nil, false + } + info, err := os.Stat(path) + if err != nil || time.Since(info.ModTime()) > reviewCacheTTL { + return nil, false + } + out, err := os.ReadFile(path) + if err != nil { + return nil, false + } + return out, true +} + +func writeReviewCache(provider string, repoURL string, author string, out []byte) { + path, ok := reviewCachePath(provider, repoURL, author) + if !ok { + return + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return + } + tmp := path + ".tmp" + if err := os.WriteFile(tmp, out, 0o644); err != nil { + return + } + _ = os.Rename(tmp, path) +} + +func reviewCachePath(provider string, repoURL string, author string) (string, bool) { + dir, err := reviewCacheDirFunc() + if err != nil || dir == "" { + return "", false } - return reviews, nil + key := strings.Join([]string{reviewCacheVersion, provider, repoURL, author}, "\x00") + sum := sha256.Sum256([]byte(key)) + return filepath.Join(dir, "gmc", "reviews", fmt.Sprintf("%x.json", sum)), true } func decodeReviewJSON(out []byte, target any) error { diff --git a/internal/worktree/review_test.go b/internal/worktree/review_test.go index 0020960..70d5979 100644 --- a/internal/worktree/review_test.go +++ b/internal/worktree/review_test.go @@ -32,16 +32,18 @@ func TestReviewProviderFromRemoteURL(t *testing.T) { } func TestReviewStates_GitHub(t *testing.T) { - runReviewLookupTest(t, "https://github.com/org/repo.git", "gh", `[ + useTempReviewCache(t) + worktrees := runReviewLookupTest(t, "https://github.com/org/repo.git", "gh", `[ { "number": 42, "state": "OPEN", "headRefName": "feature/github", + "headRefOid": "HEAD_COMMIT", "url": "https://github.com/org/repo/pull/42" } - ]`) + ]`, "feature/github") - result := NewClient(Options{}).ReviewStates() + result := NewClient(Options{}).ReviewStates(worktrees) if result.Warning != "" { t.Fatalf("Warning = %q, want empty", result.Warning) } @@ -51,17 +53,78 @@ func TestReviewStates_GitHub(t *testing.T) { } } +func TestReviewStates_GitHubOnlyMatchesPushedWorktreeBranches(t *testing.T) { + useTempReviewCache(t) + worktrees := runReviewLookupTest(t, "https://github.com/org/repo.git", "gh", `[ + { + "number": 42, + "state": "OPEN", + "headRefName": "not-local", + "url": "https://github.com/org/repo/pull/42" + }, + { + "number": 43, + "state": "OPEN", + "headRefName": "feature/github", + "url": "https://github.com/org/repo/pull/43" + } + ]`, "feature/github") + + result := NewClient(Options{}).ReviewStates(worktrees) + if result.Warning != "" { + t.Fatalf("Warning = %q, want empty", result.Warning) + } + if _, ok := result.Reviews["not-local"]; ok { + t.Fatalf("Reviews[not-local] = %+v, want no non-worktree branch match", result.Reviews["not-local"]) + } + review := result.Reviews["feature/github"] + if review.Number != 43 { + t.Fatalf("Reviews[feature/github].Number = %d, want 43", review.Number) + } +} + +func TestReviewStates_GitHubPrefersExactHeadCommit(t *testing.T) { + useTempReviewCache(t) + worktrees := runReviewLookupTest(t, "https://github.com/org/repo.git", "gh", `[ + { + "number": 42, + "state": "OPEN", + "headRefName": "feature/github", + "headRefOid": "old-commit", + "url": "https://github.com/org/repo/pull/42" + }, + { + "number": 43, + "state": "OPEN", + "headRefName": "feature/github", + "headRefOid": "HEAD_COMMIT", + "url": "https://github.com/org/repo/pull/43" + } + ]`, "feature/github") + + result := NewClient(Options{}).ReviewStates(worktrees) + if result.Warning != "" { + t.Fatalf("Warning = %q, want empty", result.Warning) + } + review := result.Reviews["feature/github"] + if review.Number != 43 { + t.Fatalf("Reviews[feature/github].Number = %d, want exact commit PR 43", review.Number) + } +} + func TestReviewStates_GitLab(t *testing.T) { - runReviewLookupTest(t, "https://gitlab.com/group/repo.git", "glab", `[ + useTempReviewCache(t) + worktrees := runReviewLookupTest(t, "https://gitlab.com/group/repo.git", "glab", `[ { "iid": 7, "state": "opened", "source_branch": "feature/gitlab", + "sha": "HEAD_COMMIT", "web_url": "https://gitlab.com/group/repo/-/merge_requests/7" } - ]`) + ]`, "feature/gitlab") - result := NewClient(Options{}).ReviewStates() + result := NewClient(Options{}).ReviewStates(worktrees) if result.Warning != "" { t.Fatalf("Warning = %q, want empty", result.Warning) } @@ -71,12 +134,111 @@ func TestReviewStates_GitLab(t *testing.T) { } } +func TestReviewStates_UsesFreshCache(t *testing.T) { + useTempReviewCache(t) + worktrees := runReviewLookupTest(t, "https://github.com/org/repo.git", "gh", `[ + { + "number": 42, + "state": "OPEN", + "headRefName": "feature/github", + "headRefOid": "HEAD_COMMIT", + "url": "https://github.com/org/repo/pull/42" + } + ]`, "feature/github") + + first := NewClient(Options{}).ReviewStates(worktrees) + if first.Warning != "" { + t.Fatalf("Warning = %q, want empty", first.Warning) + } + + oldRun := reviewRunFunc + reviewRunFunc = func(repoDir string, tool string, args ...string) ([]byte, error) { + t.Fatalf("review lookup should use fresh cache, got %s %v", tool, args) + return nil, nil + } + t.Cleanup(func() { reviewRunFunc = oldRun }) + + second := NewClient(Options{}).ReviewStates(worktrees) + if second.Warning != "" { + t.Fatalf("Warning = %q, want empty", second.Warning) + } + review := second.Reviews["feature/github"] + if review.Number != 42 { + t.Fatalf("Reviews[feature/github].Number = %d, want cached PR 42", review.Number) + } +} + +func TestReviewStates_GitLabUsesFreshCacheBeforeUserLookup(t *testing.T) { + useTempReviewCache(t) + worktrees := runReviewLookupTest(t, "https://gitlab.com/group/repo.git", "glab", `[ + { + "iid": 7, + "state": "opened", + "source_branch": "feature/gitlab", + "sha": "HEAD_COMMIT", + "web_url": "https://gitlab.com/group/repo/-/merge_requests/7" + } + ]`, "feature/gitlab") + + first := NewClient(Options{}).ReviewStates(worktrees) + if first.Warning != "" { + t.Fatalf("Warning = %q, want empty", first.Warning) + } + + oldRun := reviewRunFunc + reviewRunFunc = func(repoDir string, tool string, args ...string) ([]byte, error) { + t.Fatalf("GitLab review lookup should use fresh cache before user lookup, got %s %v", tool, args) + return nil, nil + } + t.Cleanup(func() { reviewRunFunc = oldRun }) + + second := NewClient(Options{}).ReviewStates(worktrees) + if second.Warning != "" { + t.Fatalf("Warning = %q, want empty", second.Warning) + } + review := second.Reviews["feature/gitlab"] + if review.Number != 7 { + t.Fatalf("Reviews[feature/gitlab].Number = %d, want cached MR 7", review.Number) + } +} + +func TestReviewStates_SkipsUnpushedBranches(t *testing.T) { + repoDir := initTestRepo(t) + runGit(t, repoDir, "remote", "add", "origin", "https://github.com/org/repo.git") + head := strings.TrimSpace(runGit(t, repoDir, "rev-parse", "HEAD")) + + oldRun := reviewRunFunc + t.Cleanup(func() { reviewRunFunc = oldRun }) + reviewRunFunc = func(repoDir string, tool string, args ...string) ([]byte, error) { + t.Fatalf("review lookup should not run for unpushed branch: %s %v", tool, args) + return nil, nil + } + + oldCwd, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = os.Chdir(oldCwd) }) + if err := os.Chdir(repoDir); err != nil { + t.Fatal(err) + } + + result := NewClient(Options{}).ReviewStates([]Info{{Branch: "feature/unpushed", Commit: head}}) + if result.Warning != "" { + t.Fatalf("Warning = %q, want empty", result.Warning) + } + if len(result.Reviews) != 0 { + t.Fatalf("Reviews = %+v, want empty", result.Reviews) + } +} + func TestReviewStates_WarnsOnMissingCLI(t *testing.T) { + useTempReviewCache(t) runReviewLookupFailureTest(t, func(repoDir string, tool string, args ...string) ([]byte, error) { return nil, missingReviewToolError{tool: tool} }) - result := NewClient(Options{}).ReviewStates() + result := NewClient(Options{}).ReviewStates([]Info{{Branch: "feature/github", Commit: ""}}) if !strings.Contains(result.Warning, "gh CLI not found") { t.Fatalf("Warning = %q, want missing gh warning", result.Warning) } @@ -86,31 +248,75 @@ func TestReviewStates_WarnsOnMissingCLI(t *testing.T) { } func TestReviewStates_WarnsOnAuthFailure(t *testing.T) { + useTempReviewCache(t) runReviewLookupFailureTest(t, func(repoDir string, tool string, args ...string) ([]byte, error) { return nil, fmt.Errorf("%s failed: authentication required", tool) }) - result := NewClient(Options{}).ReviewStates() + result := NewClient(Options{}).ReviewStates([]Info{{Branch: "feature/github", Commit: ""}}) if !strings.Contains(result.Warning, "check authentication") { t.Fatalf("Warning = %q, want authentication warning", result.Warning) } } -func runReviewLookupTest(t *testing.T, remoteURL string, wantTool string, output string) { +func useTempReviewCache(t *testing.T) { + t.Helper() + old := reviewCacheDirFunc + dir := t.TempDir() + reviewCacheDirFunc = func() (string, error) { + return dir, nil + } + t.Cleanup(func() { reviewCacheDirFunc = old }) +} + +func runReviewLookupTest( + t *testing.T, + remoteURL string, + wantTool string, + output string, + branches ...string, +) []Info { t.Helper() repoDir := initTestRepo(t) runGit(t, repoDir, "remote", "add", "origin", remoteURL) + head := strings.TrimSpace(runGit(t, repoDir, "rev-parse", "HEAD")) + worktrees := make([]Info, 0, len(branches)) + for _, branch := range branches { + runGit(t, repoDir, "update-ref", "refs/remotes/origin/"+branch, head) + worktrees = append(worktrees, Info{Branch: branch, Commit: head}) + } oldRun := reviewRunFunc t.Cleanup(func() { reviewRunFunc = oldRun }) + callCount := 0 reviewRunFunc = func(repoDir string, tool string, args ...string) ([]byte, error) { if tool != wantTool { t.Fatalf("tool = %q, want %s", tool, wantTool) } + callCount++ + if tool == "glab" && hasReviewArg(args, "api", "user") { + return []byte(`{"username":"test-user"}`), nil + } if !hasReviewArg(args, "-R", remoteURL) { t.Fatalf("args = %v, want -R %s", args, remoteURL) } - return []byte(output), nil + switch tool { + case "gh": + if !hasReviewArg(args, "--author", "@me") { + t.Fatalf("args = %v, want --author @me", args) + } + return []byte(strings.ReplaceAll(output, "HEAD_COMMIT", head)), nil + case "glab": + if !hasReviewArg(args, "--author", "test-user") { + t.Fatalf("args = %v, want --author test-user", args) + } + if callCount != 2 { + t.Fatalf("glab mr list call count = %d, want 2 after user lookup", callCount) + } + return []byte(strings.ReplaceAll(output, "HEAD_COMMIT", head)), nil + default: + return []byte(output), nil + } } oldCwd, err := os.Getwd() @@ -121,6 +327,7 @@ func runReviewLookupTest(t *testing.T, remoteURL string, wantTool string, output if err := os.Chdir(repoDir); err != nil { t.Fatal(err) } + return worktrees } func hasReviewArg(args []string, flag string, value string) bool { @@ -139,6 +346,8 @@ func runReviewLookupFailureTest( t.Helper() repoDir := initTestRepo(t) runGit(t, repoDir, "remote", "add", "origin", "https://github.com/org/repo.git") + head := strings.TrimSpace(runGit(t, repoDir, "rev-parse", "HEAD")) + runGit(t, repoDir, "update-ref", "refs/remotes/origin/feature/github", head) oldRun := reviewRunFunc t.Cleanup(func() { reviewRunFunc = oldRun })