Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,14 @@ async def setup_coverage_workflow(
repo: str,
token: str = Header(..., alias="X-GitHub-Token"),
api_key: str = Header(..., alias="X-API-Key"),
sender_id: int = Header(0, alias="X-Sender-Id"),
sender_name: str = Header("", alias="X-Sender-Name"),
):
verify_api_key(api_key)
return await setup_handler(
owner_name=owner,
repo_name=repo,
token=token,
sender_id=sender_id,
sender_name=sender_name,
)
66 changes: 53 additions & 13 deletions schemas/supabase/generate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,58 @@
print("Error: SUPABASE_DB_PASSWORD_DEV environment variable not set")
sys.exit(1)

PSQL_ARGS = [
"psql",
"-h",
"aws-0-us-west-1.pooler.supabase.com",
"-U",
"postgres.dkrxtcbaqzrodvsagwwn",
"-d",
"postgres",
"-p",
"6543",
"-t",
]
PSQL_ENV = {**os.environ, "PGPASSWORD": db_password}

print("Generating TypedDict schemas from PostgreSQL...")

# Step 1: Query enum types and their values
enum_result = subprocess.run(
[
*PSQL_ARGS,
"-c",
"""
SELECT t.typname, e.enumlabel
FROM pg_type t
JOIN pg_enum e ON t.oid = e.enumtypid
ORDER BY t.typname, e.enumsortorder;
""",
],
env=PSQL_ENV,
capture_output=True,
text=True,
check=False,
)

# Build enum_name -> Literal type string mapping
enum_types: dict[str, str] = {}
enum_values: dict[str, list[str]] = defaultdict(list)
for line in enum_result.stdout.split("\n"):
if line.strip():
parts = [p.strip() for p in line.split("|")]
if len(parts) == 2:
enum_name, enum_label = parts
enum_values[enum_name].append(enum_label)

for enum_name, labels in enum_values.items():
literal_values = ", ".join(f'"{label}"' for label in labels)
enum_types[enum_name] = f"Literal[{literal_values}]"

# Step 2: Query table columns
result = subprocess.run(
[
"psql",
"-h",
"aws-0-us-west-1.pooler.supabase.com",
"-U",
"postgres.dkrxtcbaqzrodvsagwwn",
"-d",
"postgres",
"-p",
"6543",
"-t",
*PSQL_ARGS,
"-c",
"""
SELECT table_name, column_name, data_type, is_nullable, udt_name
Expand All @@ -36,7 +74,7 @@
ORDER BY table_name, ordinal_position;
""",
],
env={**os.environ, "PGPASSWORD": db_password},
env=PSQL_ENV,
capture_output=True,
text=True,
check=False,
Expand Down Expand Up @@ -82,7 +120,9 @@
"_jsonb": "dict[str, Any]",
}

if data_type == "ARRAY":
if data_type == "USER-DEFINED" and udt_name in enum_types:
PYTHON_TYPE = enum_types[udt_name]
elif data_type == "ARRAY":
element_type = array_element_mapping.get(udt_name, "Any")
PYTHON_TYPE = f"list[{element_type}]"
else:
Expand All @@ -101,7 +141,7 @@
output_path = Path(__file__).parent / "types.py"
with open(output_path, "w", encoding="utf-8") as f:
f.write("import datetime\n")
f.write("from typing import Any\n")
f.write("from typing import Any, Literal\n")
f.write("from typing_extensions import TypedDict, NotRequired\n")
f.write("\n\n")

Expand Down
14 changes: 7 additions & 7 deletions schemas/supabase/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from typing import Any
from typing import Any, Literal
from typing_extensions import TypedDict, NotRequired


Expand Down Expand Up @@ -181,7 +181,7 @@ class Installations(TypedDict):
installation_id: int
owner_name: str
uninstalled_at: datetime.datetime | None
owner_type: str
owner_type: Literal["User", "Organization"]
owner_id: int
created_by: str | None
uninstalled_by: str | None
Expand All @@ -191,7 +191,7 @@ class InstallationsInsert(TypedDict):
installation_id: int
owner_name: str
uninstalled_at: NotRequired[datetime.datetime | None]
owner_type: str
owner_type: Literal["User", "Organization"]
owner_id: int
created_by: NotRequired[str | None]
uninstalled_by: NotRequired[str | None]
Expand Down Expand Up @@ -262,7 +262,7 @@ class Owners(TypedDict):
created_by: str | None
owner_name: str
org_rules: str
owner_type: str
owner_type: Literal["User", "Organization"]
updated_by: str | None
updated_at: datetime.datetime
credit_balance_usd: int
Expand All @@ -278,7 +278,7 @@ class OwnersInsert(TypedDict):
created_by: NotRequired[str | None]
owner_name: str
org_rules: str
owner_type: str
owner_type: Literal["User", "Organization"]
updated_by: NotRequired[str | None]
credit_balance_usd: int
auto_reload_enabled: bool
Expand Down Expand Up @@ -495,7 +495,7 @@ class Usage(TypedDict):
created_by: str | None
total_seconds: int | None
owner_id: int
owner_type: str
owner_type: Literal["User", "Organization"]
owner_name: str
repo_id: int
repo_name: str
Expand All @@ -522,7 +522,7 @@ class UsageInsert(TypedDict):
created_by: NotRequired[str | None]
total_seconds: NotRequired[int | None]
owner_id: int
owner_type: str
owner_type: Literal["User", "Organization"]
owner_name: str
repo_id: int
repo_name: str
Expand Down
1 change: 1 addition & 0 deletions services/webhook/handle_installation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,6 @@ async def handle_installation_created(payload: InstallationPayload):
owner_name=owner_name,
repo_name=repositories[0]["name"],
token=token,
sender_id=user_id,
sender_name=sender_name,
)
1 change: 1 addition & 0 deletions services/webhook/handle_installation_repos_added.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,6 @@ async def handle_installation_repos_added(
owner_name=owner_name,
repo_name=repositories[0]["name"],
token=token,
sender_id=sender_id,
sender_name=sender_name,
)
63 changes: 42 additions & 21 deletions services/webhook/setup_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
from typing import cast

from anthropic.types import MessageParam

from constants.agent import MAX_ITERATIONS
Expand All @@ -24,7 +22,10 @@
from services.github.pulls.close_pull_request import close_pull_request
from services.github.pulls.create_pull_request import create_pull_request
from services.github.pulls.get_pull_request_files import get_pull_request_files
from services.github.repositories.is_repo_forked import is_repo_forked
from services.github.types.github_types import BaseArgs
from services.github.users.get_email_from_commits import get_email_from_commits
from services.github.users.get_user_public_email import get_user_public_info
from services.slack.slack_notify import slack_notify
from services.supabase.usage.insert_usage import insert_usage
from services.supabase.usage.update_usage import update_usage
Expand Down Expand Up @@ -54,6 +55,7 @@ async def setup_handler(
owner_name: str,
repo_name: str,
token: str,
sender_id: int,
sender_name: str,
):
set_owner_repo(owner_name, repo_name)
Expand Down Expand Up @@ -97,25 +99,44 @@ async def setup_handler(
f for f in os.listdir(efs_dir) if os.path.isfile(os.path.join(efs_dir, f))
]

# Look up sender info from GitHub
sender_info = get_user_public_info(username=sender_name, token=token)
sender_email = sender_info.email
if not sender_email:
sender_email = get_email_from_commits(
owner=owner_name, repo=repo_name, username=sender_name, token=token
)

# Create a branch for the coverage workflow PR
new_branch = generate_branch_name(trigger="setup")
base_args = cast(
BaseArgs,
{
"owner": owner_name,
"owner_id": owner_id,
"owner_type": owner_type,
"repo": repo_name,
"repo_id": repo_id,
"clone_url": clone_url,
"token": token,
"installation_id": installation_id,
"base_branch": target_branch,
"new_branch": new_branch,
"clone_dir": efs_dir,
"reviewers": [sender_name] if sender_name else [],
},
)
title = "Set up test coverage workflow"
base_args: BaseArgs = {
"owner": owner_name,
"owner_id": owner_id,
"owner_type": owner_type,
"repo": repo_name,
"repo_id": repo_id,
"clone_url": clone_url,
"token": token,
"installation_id": installation_id,
"base_branch": target_branch,
"new_branch": new_branch,
"clone_dir": efs_dir,
"is_fork": is_repo_forked(owner=owner_name, repo=repo_name, token=token),
"sender_id": sender_id,
"sender_name": sender_name,
"sender_email": sender_email,
"sender_display_name": sender_info.display_name,
"is_automation": False,
"reviewers": [sender_name] if sender_name else [],
"github_urls": [],
"other_urls": [],
"pr_number": 0, # Set after create_pull_request below
"pr_title": title,
"pr_body": SETUP_PR_BODY,
"pr_comments": [],
"pr_creator": sender_name,
}

sha = get_latest_remote_commit_sha(clone_url=clone_url, base_args=base_args)
create_remote_branch(sha=sha, base_args=base_args)
Expand All @@ -124,7 +145,7 @@ async def setup_handler(
)
pr_url, pr_number = create_pull_request(
body=SETUP_PR_BODY,
title="Set up test coverage workflow",
title=title,
base_args=base_args,
)
base_args["pr_number"] = pr_number
Expand All @@ -138,7 +159,7 @@ async def setup_handler(
repo_id=repo_id,
repo_name=repo_name,
pr_number=pr_number,
user_id=0,
user_id=sender_id,
user_name=sender_name,
installation_id=installation_id,
source="setup_handler",
Expand Down
Loading
Loading