forked from rjmacarthy/gpt-code-reviewer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhelpers.py
More file actions
96 lines (74 loc) · 3.08 KB
/
helpers.py
File metadata and controls
96 lines (74 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import yaml
import os
import tiktoken
import requests
from rich.console import Console
from rich.markdown import Markdown
config = yaml.safe_load(open("config.yaml", "r", encoding="utf-8"))
repositories = config["repositories"]
MODEL_ENGINE = config["model_engine"]
MAX_LENGTH = 4000
user = config["user"]
console = Console()
encoding = tiktoken.encoding_for_model(MODEL_ENGINE)
github_repository_base_url = "https://api.github.com/repos"
GITHUB_TOKEN = os.getenv("GITHUB_TOKEN")
def count_tokens(string: str) -> int:
num_tokens = len(encoding.encode(string))
return num_tokens
def get_diff(diff: str, num_template_tokens: int) -> str:
if count_tokens(diff) > MAX_LENGTH - num_template_tokens:
encoded_diff = encoding.encode(diff)
truncated_diff = encoding.decode(
encoded_diff[: MAX_LENGTH - num_template_tokens * 2]
)
return truncated_diff
return diff
def print_options(repository: str, pull_request: str):
console.print(
Markdown(
f"""You have chosen to review {repository} pull request {pull_request}
enter `r` to review the code, `q` to quit, `h` for help and `n`
to review a different pull request."""
)
)
def get_repo_and_pr() -> tuple:
while True:
console.print("Select a repository:")
for index, repo in enumerate(repositories):
console.print(f"{index + 1}. {repo}")
try:
selection = int(input("Enter the number of the repository: "))
if 1 <= selection <= len(repositories):
repository = repositories[selection - 1]
break
except ValueError:
pass
console.print(
f"Invalid input. Please enter a number between 1 and {len(repositories)}"
)
pull_request = input("Enter the number of the pull request: ").strip()
return repository, pull_request
def add_message(messages, message: str, role, pr: str, repository):
messages.append({"role": role, "content": message})
if not os.path.exists("./transcripts"):
os.makedirs("./transcripts")
with open(f"./transcripts/{pr}-{repository}.md", "a") as f:
f.write(role + "\n" + message + "\n")
def fetch_commits(repository: str, num_commits: int) -> list:
url = f"{github_repository_base_url}/{repository}/commits"
headers = {"Authorization": f"Bearer {GITHUB_TOKEN}"}
params = {"per_page": num_commits}
response = requests.get(url, headers=headers, params=params, timeout=10)
if response.status_code == 200:
return response.json() # Ensure this returns a list of dictionaries
else:
console.print(f"Failed to fetch commits: {response.status_code}")
return []
def fetch_repository_data(
repository: str, pull_request: str, accept="application/vnd.github.v3.diff"
) -> requests.Response:
url = f"{github_repository_base_url}/{user}/{repository}/pulls/{pull_request}"
headers = {"Accept": accept, "Authorization": f"Bearer {GITHUB_TOKEN}"}
response = requests.get(url, headers=headers, timeout=10)
return response