Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 6 additions & 23 deletions git_sync/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,6 @@ async def get_remote_branches(remote: bytes) -> list[bytes]:
return raw_bytes.splitlines()


async def is_ancestor(commit1: _ExecArg, commit2: _ExecArg) -> bool:
"""Return true if commit1 is an ancestor of commit2."""
try:
await git("merge-base", "--is-ancestor", commit1, commit2)
return True
except GitError as e:
if e.returncode == 1:
return False
raise


async def fetch_and_fast_forward_to_upstream(branches: Iterable[Branch]) -> None:
if any(b.is_current for b in branches):
await git("pull", "--all")
Expand Down Expand Up @@ -211,17 +200,11 @@ async def update_merged_prs(
if (
merged_hash
and branch_name in branch_hashes
and merged_hash != branch_hashes[branch_name]
and branch_hashes[branch_name] in pr.hashes
and push_remote_url in pr.repo_urls
):
try:
branch_is_ancestor = await is_ancestor(branch_name, pr.branch_hash)
except GitError:
pass # Probably no longer have the commit hash
else:
if branch_is_ancestor:
await update_merged_pr_branch(
branch_name=branch_name,
merged_hash=merged_hash,
allow_delete=allow_delete,
)
await update_merged_pr_branch(
branch_name=branch_name,
merged_hash=merged_hash,
allow_delete=allow_delete,
)
97 changes: 66 additions & 31 deletions git_sync/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,37 +49,58 @@ def repos_by_domain(urls: Iterable[str]) -> dict[str, list[Repository]]:
@dataclass(frozen=True)
class PullRequest:
branch_name: str
"""Name of the branch that backed the PR."""
repo_urls: frozenset[str]
branch_hash: str
"""Git and SSH URLs of the repository where the PR is located."""
hashes: tuple[str, ...]
"""All commits pushed to the PR, newest first."""
merged_hash: str | None
"""The commit hash of the PR merge commit, if it exists."""


def gql_query(owner: str, name: str) -> str:
def pr_initial_query(owner: str, name: str) -> str:
return f"""
repository(owner: "{owner}", name: "{name}" ) {{
pullRequests(orderBy: {{ field: UPDATED_AT, direction: ASC }}, last: 50) {{
nodes {{
headRefName
headRepository {{
sshUrl
url
}}
id
commits (last: 1) {{
nodes {{
commit {{
oid
}}
}}
totalCount
}}
mergeCommit {{
oid
}}
}}
}}
"""


def pr_details_query(pr_node_id: str, commit_count: int) -> str:
return f"""
node(id: "{pr_node_id}") {{
... on PullRequest {{
headRefName
headRepository {{
sshUrl
url
}}
commits (last: {commit_count}) {{
nodes {{
commit {{
oid
}}
}}
}}
mergeCommit {{
oid
}}
}}
}}
"""


def join_queries(queries: Iterable[str]) -> str:
return "{" + "\n".join(f"q{i}: {query}" for i, query in enumerate(queries)) + "}"


async def fetch_pull_requests_from_domain(
token: str, domain: str, repos: list[Repository]
) -> AsyncIterator[PullRequest]:
Expand All @@ -91,24 +112,38 @@ async def fetch_pull_requests_from_domain(
client = GraphQLClient(
endpoint=endpoint, headers={"Authorization": f"Bearer {token}"}
)
queries = [
f"repo{i}: {gql_query(repo.owner, repo.name)}"
for i, repo in enumerate(repos, 1)

# Query for PRs and commit counts
initial_queries = [
pr_initial_query(repo.owner, repo.name) for i, repo in enumerate(repos, 1)
]
query = "{" + "\n".join(queries) + "}"
response = await client.query(query)
assert not response.errors
for repo_data in response.data.values():
for pr_data in repo_data["pullRequests"]["nodes"]:
head_repo = pr_data.get("headRepository") or {}
repo_urls = [head_repo.get("sshUrl"), head_repo.get("url")]
if pr_data["commits"]["nodes"]:
yield PullRequest(
branch_name=pr_data["headRefName"],
repo_urls=frozenset(url for url in repo_urls if url is not None),
branch_hash=pr_data["commits"]["nodes"][0]["commit"]["oid"],
merged_hash=(pr_data.get("mergeCommit") or {}).get("oid"),
)
initial_response = await client.query(join_queries(initial_queries))
assert not initial_response.errors

# Determine what follow-up queries to make
details_queries = [
pr_details_query(pr_data["id"], pr_data["commits"]["totalCount"])
for repo_data in initial_response.data.values()
for pr_data in repo_data["pullRequests"]["nodes"]
]

# Query for detailed PR information
details_response = await client.query(join_queries(details_queries))
assert not details_response.errors

# Yield response data as PullRequest objects
for pr_data in details_response.data.values():
head_repo = pr_data.get("headRepository") or {}
repo_urls = [head_repo.get("sshUrl"), head_repo.get("url")]
hashes = tuple(
commit["commit"]["oid"] for commit in reversed(pr_data["commits"]["nodes"])
)
yield PullRequest(
branch_name=pr_data["headRefName"],
repo_urls=frozenset(url for url in repo_urls if url is not None),
hashes=hashes,
merged_hash=(pr_data.get("mergeCommit") or {}).get("oid"),
)


async def fetch_pull_requests(
Expand Down
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "git-sync"
version = "0.4"
version = "0.4.1"
description = "Synchronize local git repo with remotes"
authors = [{ name = "Alice Purcell", email = "alicederyn@gmail.com" }]
requires-python = ">= 3.12"
Expand All @@ -22,9 +22,6 @@ target-version = "py310"

[tool.ruff.lint]
select = ["ANN", "B", "C4", "E", "F", "I", "PGH", "PLR", "PYI", "RUF", "SIM", "UP", "W"]
ignore = [
"ANN101", # Deprecated, will be removed
]
isort.split-on-trailing-comma = false

[tool.setuptools.dynamic]
Expand Down
20 changes: 10 additions & 10 deletions tests/integration/test_fast_forward_merged_prs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def test_delete_merged_inactive_pr_branch() -> None:
pr = PullRequest(
branch_name="my_pr",
repo_urls=frozenset([REPO_URL]),
branch_hash=commit_b,
hashes=(commit_b,),
merged_hash=commit_c,
)

Expand All @@ -57,7 +57,7 @@ async def test_force_inactive_upstream_branch_to_merged_commit() -> None:
pr = PullRequest(
branch_name="my_pr",
repo_urls=frozenset([REPO_URL]),
branch_hash=commit_b,
hashes=(commit_b,),
merged_hash=commit_c,
)

Expand Down Expand Up @@ -85,7 +85,7 @@ async def test_merged_inactive_pr_branch_with_deletion_disabled() -> None:
pr = PullRequest(
branch_name="my_pr",
repo_urls=frozenset([REPO_URL]),
branch_hash=commit_b,
hashes=(commit_b,),
merged_hash=commit_c,
)

Expand Down Expand Up @@ -113,7 +113,7 @@ async def test_delete_merged_active_pr_branch() -> None:
pr = PullRequest(
branch_name="my_pr",
repo_urls=frozenset([REPO_URL]),
branch_hash=commit_b,
hashes=(commit_b,),
merged_hash=commit_c,
)

Expand Down Expand Up @@ -142,7 +142,7 @@ async def test_force_active_upstream_branch_to_merged_commit() -> None:
pr = PullRequest(
branch_name="my_pr",
repo_urls=frozenset([REPO_URL]),
branch_hash=commit_b,
hashes=(commit_b,),
merged_hash=commit_c,
)

Expand Down Expand Up @@ -170,7 +170,7 @@ async def test_merged_active_upstream_branch_with_deletion_disabled() -> None:
pr = PullRequest(
branch_name="my_pr",
repo_urls=frozenset([REPO_URL]),
branch_hash=commit_b,
hashes=(commit_b,),
merged_hash=commit_c,
)

Expand All @@ -195,7 +195,7 @@ async def test_staged_changes_not_lost() -> None:
pr = PullRequest(
branch_name="my_pr",
repo_urls=frozenset([REPO_URL]),
branch_hash=commit_b,
hashes=(commit_b,),
merged_hash=commit_c,
)

Expand All @@ -221,7 +221,7 @@ async def test_unstaged_changes_to_committed_files_not_lost() -> None:
pr = PullRequest(
branch_name="my_pr",
repo_urls=frozenset([REPO_URL]),
branch_hash=commit_b,
hashes=(commit_b,),
merged_hash=commit_c,
)

Expand Down Expand Up @@ -251,7 +251,7 @@ async def test_fastforward_when_pr_had_additional_commits() -> None:
pr = PullRequest(
branch_name="my_pr",
repo_urls=frozenset([REPO_URL]),
branch_hash=commit_c,
hashes=(commit_c, commit_b),
merged_hash=commit_d,
)

Expand All @@ -276,7 +276,7 @@ async def test_no_fastforward_when_branch_has_additional_commits() -> None:
pr = PullRequest(
branch_name="my_pr",
repo_urls=frozenset([REPO_URL]),
branch_hash=commit_b,
hashes=(commit_b,),
merged_hash=commit_d,
)

Expand Down
13 changes: 0 additions & 13 deletions tests/integration/test_is_ancestor.py

This file was deleted.