Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,23 @@ def test_function_accumulates_tokens():

**Remember**: Tests should provide confidence that the code works, not just that it compiles.

## CRITICAL: NEVER Dismiss Test Failures When Fixing PRs

**NEVER say "this test is not part of this PR" or "our PR's test passed fine" or "this is a pre-existing issue on the base branch".**

This is the WORST possible response. CI failed. The PR is blocked. Fix it.

- Do NOT distinguish between "our test" and "other tests" - ALL failing tests are your problem
- Do NOT blame the base branch, other developers, or previous PRs
- Do NOT report which tests "passed fine" as if that's relevant when CI is red
- When reporting root cause and fix, focus on WHAT failed and HOW you fixed it - not on whose fault it is or which PR introduced it

**BAD (will make user angry):**
- "The failing test ApplicationTest::test_getByApplicationNo is not part of this PR - it's a pre-existing test in the base branch. Our PR's test AnnotationElectricalOutletSpotModelTest passed fine."

**GOOD:**
- "ApplicationTest::test_getByApplicationNo failed because line 351 was missing `application_no` in the factory call. Fixed by adding the missing parameter."

## Proactive Code Fixes

When refactoring or replacing old systems, always be PROACTIVE and think comprehensively. Don't wait for the user to point out every piece of old code that needs to fix.
Expand Down
35 changes: 28 additions & 7 deletions services/github/branches/get_required_status_checks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Standard imports
from dataclasses import dataclass
from typing import cast

# Third party imports
Expand All @@ -12,7 +13,14 @@
from utils.logging.logging_config import logger


@handle_exceptions(default_return_value=(201, None), raise_on_error=False)
@dataclass
class StatusChecksResult:
status_code: int = 201
checks: list[str] | None = None
strict: bool = True


@handle_exceptions(default_return_value=StatusChecksResult(), raise_on_error=False)
def get_required_status_checks(owner: str, repo: str, branch: str, token: str):
"""https://docs.github.com/en/rest/branches/branch-protection#get-branch-protection"""
url = f"{GITHUB_API_URL}/repos/{owner}/{repo}/branches/{branch}/protection"
Expand All @@ -21,16 +29,26 @@ def get_required_status_checks(owner: str, repo: str, branch: str, token: str):

# NOTE: 403 happens when GitHub App lacks "Administration: Read" permission
if response.status_code == 403:
strict = True
logger.warning(
"No permission to read branch protection for %s/%s/%s", owner, repo, branch
"No permission to read branch protection for %s/%s/%s, assuming strict=%s",
owner,
repo,
branch,
strict,
)
return 403, None
return StatusChecksResult(status_code=403, checks=None, strict=strict)

if response.status_code == 404:
strict = False
logger.warning(
"No branch protection configured for %s/%s/%s", owner, repo, branch
"No branch protection configured for %s/%s/%s, assuming strict=%s",
owner,
repo,
branch,
strict,
)
return 404, []
return StatusChecksResult(status_code=404, checks=[], strict=strict)

response.raise_for_status()
protection = cast(BranchProtection, response.json())
Expand All @@ -43,11 +61,14 @@ def get_required_status_checks(owner: str, repo: str, branch: str, token: str):
repo,
branch,
)
return 200, []
return StatusChecksResult(status_code=200, checks=[], strict=False)

strict = required_status_checks.get("strict", False)
contexts = set(required_status_checks.get("contexts", []))
checks = {
check.get("context") for check in required_status_checks.get("checks", [])
}

return 200, list(contexts | checks)
return StatusChecksResult(
status_code=200, checks=list(contexts | checks), strict=strict
)
58 changes: 33 additions & 25 deletions services/github/branches/test_get_required_status_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_get_required_status_checks_success(
mock_get.return_value = mock_response
mock_headers.return_value = {"Authorization": "Bearer test_token"}

status_code, result = get_required_status_checks(
result = get_required_status_checks(
owner=test_owner, repo=test_repo, branch="main", token=test_token
)

Expand All @@ -48,14 +48,15 @@ def test_get_required_status_checks_success(
headers={"Authorization": "Bearer test_token"},
timeout=120,
)
assert status_code == 200
assert result is not None
assert set(result) == {
assert result.status_code == 200
assert result.checks is not None
assert set(result.checks) == {
"ci/circleci: test",
"Codecov",
"CircleCI Checks",
"Aikido Security",
}
assert result.strict is True


def test_get_required_status_checks_403_no_permission(
Expand All @@ -71,12 +72,13 @@ def test_get_required_status_checks_403_no_permission(
mock_get.return_value = mock_response
mock_headers.return_value = {"Authorization": "Bearer test_token"}

status_code, result = get_required_status_checks(
result = get_required_status_checks(
owner=test_owner, repo=test_repo, branch="main", token=test_token
)

assert status_code == 403
assert result is None
assert result.status_code == 403
assert result.checks is None
assert result.strict is True


def test_get_required_status_checks_404_no_protection(
Expand All @@ -92,12 +94,13 @@ def test_get_required_status_checks_404_no_protection(
mock_get.return_value = mock_response
mock_headers.return_value = {"Authorization": "Bearer test_token"}

status_code, result = get_required_status_checks(
result = get_required_status_checks(
owner=test_owner, repo=test_repo, branch="main", token=test_token
)

assert status_code == 404
assert not result
assert result.status_code == 404
assert not result.checks
assert result.strict is False


def test_get_required_status_checks_no_required_checks(
Expand All @@ -114,12 +117,13 @@ def test_get_required_status_checks_no_required_checks(
mock_get.return_value = mock_response
mock_headers.return_value = {"Authorization": "Bearer test_token"}

status_code, result = get_required_status_checks(
result = get_required_status_checks(
owner=test_owner, repo=test_repo, branch="main", token=test_token
)

assert status_code == 200
assert not result
assert result.status_code == 200
assert not result.checks
assert result.strict is False


def test_get_required_status_checks_only_contexts(test_owner, test_repo, test_token):
Expand All @@ -140,12 +144,13 @@ def test_get_required_status_checks_only_contexts(test_owner, test_repo, test_to
mock_get.return_value = mock_response
mock_headers.return_value = {"Authorization": "Bearer test_token"}

status_code, result = get_required_status_checks(
result = get_required_status_checks(
owner=test_owner, repo=test_repo, branch="main", token=test_token
)

assert status_code == 200
assert result == ["ci/circleci: test"]
assert result.status_code == 200
assert result.checks == ["ci/circleci: test"]
assert result.strict is True


def test_get_required_status_checks_only_checks(test_owner, test_repo, test_token):
Expand All @@ -166,12 +171,13 @@ def test_get_required_status_checks_only_checks(test_owner, test_repo, test_toke
mock_get.return_value = mock_response
mock_headers.return_value = {"Authorization": "Bearer test_token"}

status_code, result = get_required_status_checks(
result = get_required_status_checks(
owner=test_owner, repo=test_repo, branch="main", token=test_token
)

assert status_code == 200
assert result == ["CircleCI Checks"]
assert result.status_code == 200
assert result.checks == ["CircleCI Checks"]
assert result.strict is False


def test_get_required_status_checks_http_error_500(test_owner, test_repo, test_token):
Expand All @@ -188,12 +194,13 @@ def test_get_required_status_checks_http_error_500(test_owner, test_repo, test_t
mock_get.return_value = mock_response
mock_headers.return_value = {"Authorization": "Bearer test_token"}

status_code, result = get_required_status_checks(
result = get_required_status_checks(
owner=test_owner, repo=test_repo, branch="main", token=test_token
)

assert status_code == 201
assert result is None
assert result.status_code == 201
assert result.checks is None
assert result.strict is True


def test_get_required_status_checks_network_error(test_owner, test_repo, test_token):
Expand All @@ -205,9 +212,10 @@ def test_get_required_status_checks_network_error(test_owner, test_repo, test_to
mock_get.side_effect = requests.exceptions.ConnectionError("Network error")
mock_headers.return_value = {"Authorization": "Bearer test_token"}

status_code, result = get_required_status_checks(
result = get_required_status_checks(
owner=test_owner, repo=test_repo, branch="main", token=test_token
)

assert status_code == 201
assert result is None
assert result.status_code == 201
assert result.checks is None
assert result.strict is True
8 changes: 8 additions & 0 deletions services/github/types/webhook/push.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from typing import TypedDict

from services.github.types.installation import Installation
from services.github.types.repository import Repository
from services.github.types.sender import Sender


class PushCommit(TypedDict):
added: list[str]
modified: list[str]
removed: list[str]


class PushWebhookPayload(TypedDict):
ref: str
commits: list[PushCommit]
repository: Repository
sender: Sender
installation: Installation
23 changes: 23 additions & 0 deletions services/webhook/push_handler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from services.github.branches.get_required_status_checks import (
get_required_status_checks,
)
from services.github.pulls.get_open_pull_requests import get_open_pull_requests
from services.github.pulls.update_pull_request_branch import update_pull_request_branch
from services.github.token.get_installation_token import get_installation_access_token
from services.github.types.webhook.push import PushWebhookPayload
from services.supabase.repositories.get_repository import get_repository
from utils.error.handle_exceptions import handle_exceptions
from utils.files.is_test_file import is_test_file
from utils.logging.logging_config import logger, set_trigger


Expand Down Expand Up @@ -39,6 +43,25 @@ def handle_push(payload: PushWebhookPayload):

token = get_installation_access_token(installation_id=installation_id)

# Check if this is a test-only push and the repo doesn't require up-to-date branches
commits = payload.get("commits", [])
if commits:
all_files: set[str] = set()
for commit in commits:
all_files.update(commit.get("added", []))
all_files.update(commit.get("modified", []))
all_files.update(commit.get("removed", []))
if all_files and all(is_test_file(f) for f in all_files):
protection = get_required_status_checks(
owner=owner_name, repo=repo_name, branch=target_branch, token=token
)
if not protection.strict:
logger.info(
"Skipping PR branch updates: test-only push to %s and strict=False",
target_branch,
)
return None

open_prs = get_open_pull_requests(owner=owner_name, repo=repo_name, token=token)

if not open_prs:
Expand Down
12 changes: 6 additions & 6 deletions services/webhook/successful_check_suite_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,27 +119,27 @@ def handle_successful_check_suite(payload: CheckSuiteCompletedPayload):
logger.error(msg)
raise RuntimeError(msg)

status_code, required_checks = get_required_status_checks(
protection = get_required_status_checks(
owner=owner_name, repo=repo_name, branch=base_branch, token=token
)

if required_checks:
logger.info("Using required status checks: %s", required_checks)
if protection.checks:
logger.info("Using required status checks: %s", protection.checks)
for suite in all_suites:
app_name = suite["app"]["name"]
status = suite["status"]
if app_name in required_checks and status != "completed":
if app_name in protection.checks and status != "completed":
logger.info(
"Required check '%s' not completed: status=%s", app_name, status
)
return
logger.info("All required checks completed")
else:
if required_checks is None:
if protection.checks is None:
logger.info(
"Could not read branch protection (status=%s), "
"waiting for all check suites to complete",
status_code,
protection.status_code,
)
else:
logger.info(
Expand Down
Loading
Loading