diff --git a/provider/github.go b/provider/github.go index bf00bb2..ef1801d 100644 --- a/provider/github.go +++ b/provider/github.go @@ -317,10 +317,21 @@ func (g *githubProvider) ListBranches(ctx context.Context, owner, repo string) ( } func (g *githubProvider) CreateBranch(ctx context.Context, owner, repo, branch, ref string) (*PlatformBranch, error) { + sha := ref + if !isCommitSHA(ref) { + commits, err := g.ListCommits(ctx, owner, repo, ListCommitsOptions{Branch: ref, PerPage: 1}) + if err != nil { + return nil, fmt.Errorf("failed to resolve ref %q to commit SHA: %w", ref, err) + } + if len(commits) == 0 { + return nil, fmt.Errorf("no commits found on ref %q", ref) + } + sha = commits[0].SHA + } _, _, err := g.client.Git.CreateRef(ctx, owner, repo, &github.Reference{ Ref: github.String("refs/heads/" + branch), Object: &github.GitObject{ - SHA: github.String(ref), + SHA: github.String(sha), }, }) if err != nil { diff --git a/provider/github_test.go b/provider/github_test.go index 3cb831d..2add9d7 100644 --- a/provider/github_test.go +++ b/provider/github_test.go @@ -560,6 +560,12 @@ func TestGitHub_DeleteNote(t *testing.T) { func TestGitHub_CreateBranch(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") + if strings.Contains(r.URL.Path, "/commits") { + json.NewEncoder(w).Encode([]*github.RepositoryCommit{ + {SHA: github.String("abc123def456abc123def456abc123def456abc1")}, + }) + return + } json.NewEncoder(w).Encode(&github.Reference{Ref: github.String("refs/heads/new-branch"), Object: &github.GitObject{SHA: github.String("abc")}}) })) defer srv.Close() @@ -575,6 +581,24 @@ func TestGitHub_CreateBranch(t *testing.T) { } } +func TestGitHub_CreateBranch_WithSHA(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(&github.Reference{Ref: github.String("refs/heads/new-branch"), Object: &github.GitObject{SHA: github.String("abc")}}) + })) + defer srv.Close() + client, _ := github.NewEnterpriseClient(srv.URL+"/api/v3", "", srv.Client()) + p := &githubProvider{client: client, baseURL: srv.URL} + + b, err := p.CreateBranch(context.Background(), "owner", "repo", "new-branch", "abc123def456abc123def456abc123def456abc1") + if err != nil { + t.Fatal(err) + } + if b.Name != "new-branch" { + t.Errorf("expected new-branch, got %s", b.Name) + } +} + func TestGitHub_GetCRDiff(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") diff --git a/provider/util.go b/provider/util.go index 1d84023..8629a9b 100644 --- a/provider/util.go +++ b/provider/util.go @@ -1,10 +1,19 @@ package provider import ( + "encoding/hex" "net/http" "strconv" ) +func isCommitSHA(s string) bool { + if len(s) != 40 { + return false + } + _, err := hex.DecodeString(s) + return err == nil +} + func parseTotalCount(headers http.Header, fallback int) int { for _, key := range []string{"X-Total-Count", "X-Total"} { if v := headers.Get(key); v != "" {