Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions git_sync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ async def git_sync() -> None:
print("Error: Not in a git repository", file=sys.stderr)
sys.exit(2)

if remote_urls:
pull_request_task = create_task(fetch_pull_requests(github_token, remote_urls))
pull_request_task = (
create_task(fetch_pull_requests(github_token, remote_urls))
if remote_urls
else None
)

try:
branches = await get_branches_with_remote_upstreams()
Expand All @@ -111,7 +114,7 @@ async def git_sync() -> None:
if push_remote:
await fast_forward_to_downstream(push_remote, branches)

if remote_urls:
if pull_request_task is not None:
pull_requests = await pull_request_task
push_remote_url = next(
remote.url for remote in remotes if remote.name == push_remote
Expand Down
91 changes: 56 additions & 35 deletions git_sync/github.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import re
import ssl
from asyncio import Semaphore, gather
from collections.abc import AsyncIterator, Callable, Iterable
from collections.abc import AsyncIterator, Callable, Iterable, Iterator
from dataclasses import dataclass
from typing import TypeVar
from typing import Any, TypeVar

import aiohttp
import truststore
from aiographql.client import GraphQLClient # type: ignore[import-untyped]

T = TypeVar("T")
Expand Down Expand Up @@ -101,6 +104,21 @@ def join_queries(queries: Iterable[str]) -> str:
return "{" + "\n".join(f"q{i}: {query}" for i, query in enumerate(queries)) + "}"


def client_session() -> aiohttp.ClientSession:
"""Configure aiohttp to trust local SSL credentials and environment variables."""
connector = aiohttp.TCPConnector(ssl=truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT))
return aiohttp.ClientSession(trust_env=True, connector=connector)


def repo_urls(pr_data: dict[str, Any]) -> Iterator[str]:
head_repo = pr_data.get("headRepository") or {}
if ssh_url := head_repo.get("sshUrl"):
yield ssh_url
if http_url := head_repo.get("url"):
yield http_url
yield http_url + ".git"


async def fetch_pull_requests_from_domain(
token: str, domain: str, repos: list[Repository]
) -> AsyncIterator[PullRequest]:
Expand All @@ -109,42 +127,45 @@ async def fetch_pull_requests_from_domain(
if domain.count(".") == 1
else f"https://{domain}/api/graphql"
)
client = GraphQLClient(
endpoint=endpoint, headers={"Authorization": f"Bearer {token}"}
)

# Query for PRs and commit counts
initial_queries = [
pr_initial_query(repo.owner, repo.name) for i, repo in enumerate(repos, 1)
]
initial_response = await client.query(join_queries(initial_queries))
assert not initial_response.errors

# Determine what follow-up queries to make
details_queries = [
pr_details_query(pr_data["id"], pr_data["commits"]["totalCount"])
for repo_data in initial_response.data.values()
for pr_data in repo_data["pullRequests"]["nodes"]
]

# Query for detailed PR information
details_response = await client.query(join_queries(details_queries))
assert not details_response.errors

# Yield response data as PullRequest objects
for pr_data in details_response.data.values():
head_repo = pr_data.get("headRepository") or {}
repo_urls = [head_repo.get("sshUrl"), head_repo.get("url")]
hashes = tuple(
commit["commit"]["oid"] for commit in reversed(pr_data["commits"]["nodes"])
)
yield PullRequest(
branch_name=pr_data["headRefName"],
repo_urls=frozenset(url for url in repo_urls if url is not None),
hashes=hashes,
merged_hash=(pr_data.get("mergeCommit") or {}).get("oid"),
async with client_session() as session:
client = GraphQLClient(
endpoint=endpoint,
headers={"Authorization": f"Bearer {token}"},
session=session,
)

# Query for PRs and commit counts
initial_queries = [
pr_initial_query(repo.owner, repo.name) for i, repo in enumerate(repos, 1)
]
initial_response = await client.query(join_queries(initial_queries))
assert not initial_response.errors

# Determine what follow-up queries to make
details_queries = [
pr_details_query(pr_data["id"], pr_data["commits"]["totalCount"])
for repo_data in initial_response.data.values()
for pr_data in repo_data["pullRequests"]["nodes"]
]

# Query for detailed PR information
details_response = await client.query(join_queries(details_queries))
assert not details_response.errors

# Yield response data as PullRequest objects
for pr_data in details_response.data.values():
hashes = tuple(
commit["commit"]["oid"]
for commit in reversed(pr_data["commits"]["nodes"])
)
yield PullRequest(
branch_name=pr_data["headRefName"],
repo_urls=frozenset(repo_urls(pr_data)),
hashes=hashes,
merged_hash=(pr_data.get("mergeCommit") or {}).get("oid"),
)


async def fetch_pull_requests(
tokens: Callable[[str], str | None],
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "git-sync"
version = "0.4.1"
version = "0.4.2"
description = "Synchronize local git repo with remotes"
authors = [{ name = "Alice Purcell", email = "alicederyn@gmail.com" }]
requires-python = ">= 3.12"
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
aiographql-client >= 1.0.3
truststore >= 0.10.1