diff --git a/git_sync/__init__.py b/git_sync/__init__.py index ba7598d..031acf2 100644 --- a/git_sync/__init__.py +++ b/git_sync/__init__.py @@ -97,8 +97,11 @@ async def git_sync() -> None: print("Error: Not in a git repository", file=sys.stderr) sys.exit(2) - if remote_urls: - pull_request_task = create_task(fetch_pull_requests(github_token, remote_urls)) + pull_request_task = ( + create_task(fetch_pull_requests(github_token, remote_urls)) + if remote_urls + else None + ) try: branches = await get_branches_with_remote_upstreams() @@ -111,7 +114,7 @@ async def git_sync() -> None: if push_remote: await fast_forward_to_downstream(push_remote, branches) - if remote_urls: + if pull_request_task is not None: pull_requests = await pull_request_task push_remote_url = next( remote.url for remote in remotes if remote.name == push_remote diff --git a/git_sync/github.py b/git_sync/github.py index 5f1cf78..0045a07 100644 --- a/git_sync/github.py +++ b/git_sync/github.py @@ -1,9 +1,12 @@ import re +import ssl from asyncio import Semaphore, gather -from collections.abc import AsyncIterator, Callable, Iterable +from collections.abc import AsyncIterator, Callable, Iterable, Iterator from dataclasses import dataclass -from typing import TypeVar +from typing import Any, TypeVar +import aiohttp +import truststore from aiographql.client import GraphQLClient # type: ignore[import-untyped] T = TypeVar("T") @@ -101,6 +104,21 @@ def join_queries(queries: Iterable[str]) -> str: return "{" + "\n".join(f"q{i}: {query}" for i, query in enumerate(queries)) + "}" +def client_session() -> aiohttp.ClientSession: + """Configure aiohttp to trust local SSL credentials and environment variables.""" + connector = aiohttp.TCPConnector(ssl=truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)) + return aiohttp.ClientSession(trust_env=True, connector=connector) + + +def repo_urls(pr_data: dict[str, Any]) -> Iterator[str]: + head_repo = pr_data.get("headRepository") or {} + if ssh_url := head_repo.get("sshUrl"): + yield ssh_url + if http_url := head_repo.get("url"): + yield http_url + yield http_url + ".git" + + async def fetch_pull_requests_from_domain( token: str, domain: str, repos: list[Repository] ) -> AsyncIterator[PullRequest]: @@ -109,42 +127,45 @@ async def fetch_pull_requests_from_domain( if domain.count(".") == 1 else f"https://{domain}/api/graphql" ) - client = GraphQLClient( - endpoint=endpoint, headers={"Authorization": f"Bearer {token}"} - ) - # Query for PRs and commit counts - initial_queries = [ - pr_initial_query(repo.owner, repo.name) for i, repo in enumerate(repos, 1) - ] - initial_response = await client.query(join_queries(initial_queries)) - assert not initial_response.errors - - # Determine what follow-up queries to make - details_queries = [ - pr_details_query(pr_data["id"], pr_data["commits"]["totalCount"]) - for repo_data in initial_response.data.values() - for pr_data in repo_data["pullRequests"]["nodes"] - ] - - # Query for detailed PR information - details_response = await client.query(join_queries(details_queries)) - assert not details_response.errors - - # Yield response data as PullRequest objects - for pr_data in details_response.data.values(): - head_repo = pr_data.get("headRepository") or {} - repo_urls = [head_repo.get("sshUrl"), head_repo.get("url")] - hashes = tuple( - commit["commit"]["oid"] for commit in reversed(pr_data["commits"]["nodes"]) - ) - yield PullRequest( - branch_name=pr_data["headRefName"], - repo_urls=frozenset(url for url in repo_urls if url is not None), - hashes=hashes, - merged_hash=(pr_data.get("mergeCommit") or {}).get("oid"), + async with client_session() as session: + client = GraphQLClient( + endpoint=endpoint, + headers={"Authorization": f"Bearer {token}"}, + session=session, ) + # Query for PRs and commit counts + initial_queries = [ + pr_initial_query(repo.owner, repo.name) for i, repo in enumerate(repos, 1) + ] + initial_response = await client.query(join_queries(initial_queries)) + assert not initial_response.errors + + # Determine what follow-up queries to make + details_queries = [ + pr_details_query(pr_data["id"], pr_data["commits"]["totalCount"]) + for repo_data in initial_response.data.values() + for pr_data in repo_data["pullRequests"]["nodes"] + ] + + # Query for detailed PR information + details_response = await client.query(join_queries(details_queries)) + assert not details_response.errors + + # Yield response data as PullRequest objects + for pr_data in details_response.data.values(): + hashes = tuple( + commit["commit"]["oid"] + for commit in reversed(pr_data["commits"]["nodes"]) + ) + yield PullRequest( + branch_name=pr_data["headRefName"], + repo_urls=frozenset(repo_urls(pr_data)), + hashes=hashes, + merged_hash=(pr_data.get("mergeCommit") or {}).get("oid"), + ) + async def fetch_pull_requests( tokens: Callable[[str], str | None], diff --git a/pyproject.toml b/pyproject.toml index 2c46e4c..ef17650 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "git-sync" -version = "0.4.1" +version = "0.4.2" description = "Synchronize local git repo with remotes" authors = [{ name = "Alice Purcell", email = "alicederyn@gmail.com" }] requires-python = ">= 3.12" diff --git a/requirements.txt b/requirements.txt index d0268a9..4d3a478 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ aiographql-client >= 1.0.3 +truststore >= 0.10.1