diff --git a/.github/workflows/pr-summary-agent.yml b/.github/workflows/pr-summary-agent.yml index d3a23e93..dcb54287 100644 --- a/.github/workflows/pr-summary-agent.yml +++ b/.github/workflows/pr-summary-agent.yml @@ -32,7 +32,8 @@ jobs: - name: Run agent run: | uv venv --python 3.12 - uv pip install -e "../agentic[all,dev]" --extra-index-url https://download.pytorch.org/whl/cpu --index-strategy unsafe-first-match + uv pip install -e "../${{ github.event.repository.name }}[all,dev]" --extra-index-url https://download.pytorch.org/whl/cpu --index-strategy unsafe-first-match + uv pip install litellm[proxy] git diff --merge-base HEAD^1 HEAD > PRChanges.patch cat PRChanges.patch uv run pr_agent/PR_agent.py @@ -41,7 +42,7 @@ jobs: OPENAI_API_KEY: ${{ secrets.PRAgentOpenAIKey }} PR_ID: ${{ github.event.pull_request.number }} GITHUB_API_KEY: ${{ secrets.GITHUB_TOKEN }} - REPO: ${{ github.GITHUB_REPOSITORY }} + REPO: ${{ github.repository }} - name: Update weaviate cache uses: actions/upload-artifact@v4 diff --git a/pr_agent/PR_agent.py b/pr_agent/PR_agent.py index 449d3a32..e71e843f 100644 --- a/pr_agent/PR_agent.py +++ b/pr_agent/PR_agent.py @@ -76,16 +76,9 @@ def __init__( result_model=Searches, ) - self.relevanceAgent = Agent( - name="Code Relevange Agent", - instructions="""You are an expert in determining if a snippet of code or documentation is directly relevant to understand the changes listed under . Your response must include a 'relevant' field boolean.""", - model=GPT_4O_MINI, - result_model=RelevanceResult, - ) - self.summaryAgent = SummaryAgent() - def prepare_summary(self, patch_content: str, filtered_results: List[SearchResult]) -> str: + def prepare_summary(self, patch_content: str, filtered_results: Dict[str,SearchResult]) -> str: """Prepare for summary agent""" formatted_str = "" @@ -95,12 +88,12 @@ def prepare_summary(self, patch_content: str, filtered_results: List[SearchResul final_str = formatted_str[:] - for result in filtered_results: + for result in filtered_results.values(): formatted_str += f"<{result.file_path}>\n" formatted_str += f"{result.content}\n" formatted_str += f"\n\n" - if token_counter(model=SUMMARY_MODEL, messages=[{"role": "user", "content": {final_str}}]) < 115000: + if token_counter(model=SUMMARY_MODEL, messages=[{"role": "user", "content": {final_str}}]) > 115000: break else: final_str = formatted_str[:] @@ -109,7 +102,7 @@ def prepare_summary(self, patch_content: str, filtered_results: List[SearchResul def post_to_github(self, summary: str) -> str: """Post summary as a GitHub comment""" - repo = os.getenv("repo") + repo = os.getenv("REPO") pr_id = os.getenv("PR_ID") gh_token = os.getenv("GITHUB_API_KEY") @@ -188,27 +181,28 @@ def next_turn( print("all: ", all_results) # Filter rag search results using LLM-based relevance checking - filtered_results = [] - for result in all_results.values(): + #filtered_results = [] + #for result in all_results.values(): - try: - relevance_check = yield from self.relevanceAgent.grab_final_result( - f"\n{request_context.get('patch_content')}\n\n\n{result.content}{result.query}" - ) - - if relevance_check.relevant: - filtered_results.append(result) - except Exception as e: + # try: + # relevance_check = yield from self.relevanceAgent.grab_final_result( + # "True" + # ) + # print(relevance_check) + #f"\n{request_context.get('patch_content')}\n\n\n{result.content}{result.query}" + #if relevance_check.relevant: + # filtered_results.append(result) + # except Exception as e: # LLM error - print(e) + # print(e) - for result in all_results.values(): - filtered_results.append(result) + #for result in all_results.values(): + # filtered_results.append(result) - print("filtered: ", str(filtered_results)) + #print("filtered: ", str(filtered_results)) # Prepare for summary - formatted_str = self.prepare_summary(request_context.get("patch_content"),filtered_results) + formatted_str = self.prepare_summary(request_context.get("patch_content"),all_results) print(formatted_str) @@ -233,6 +227,7 @@ def next_turn( if __name__ == "__main__": + # test # Change to PRChangesTest.patch for testing with open("PRChangesTest.patch", "r") as f: patch_content = f.read() diff --git a/pr_agent/code_rag_agent.py b/pr_agent/code_rag_agent.py index 2f134c5a..474550ef 100644 --- a/pr_agent/code_rag_agent.py +++ b/pr_agent/code_rag_agent.py @@ -46,7 +46,7 @@ def __init__(self, self.ragTool = RAGTool( default_index="codebase", - index_paths=[], + index_paths=["../**/*.md"], recursive=True )