diff --git a/.github/workflows/pr-summary-agent.yml b/.github/workflows/pr-summary-agent.yml deleted file mode 100644 index 88e50da0..00000000 --- a/.github/workflows/pr-summary-agent.yml +++ /dev/null @@ -1,37 +0,0 @@ -name: PR Summary Agent - -on: - pull_request: - types: [opened, synchronize, reopened] - branches: - - main -jobs: - run-pr-agent: - permissions: write-all - name: run PR Summary Agent - runs-on: ubuntu-latest - steps: - - name: Check out the repository to the runner - uses: actions/checkout@v4 - with: - fetch-depth: 2 - - - name: Install uv - uses: astral-sh/setup-uv@v6 - - - name: Run agent - run: | - uv venv --python 3.12 - uv pip install -e "../PR_code_review-agent[all,dev]" --extra-index-url https://download.pytorch.org/whl/cpu --index-strategy unsafe-first-match - git diff --merge-base HEAD^1 HEAD > PRChanges.patch - cat PRChanges.patch - uv run pr_agent/test_files/mock_pr_agent.py - -#uv run pr_agent/PR_agent.py - - env: - OPENAI_API_KEY: ${{ secrets.PRAgentOpenAIKey }} - PR_ID: ${{ github.event.pull_request.number }} - GITHUB_API_KEY: ${{ secrets.GITHUB_TOKEN }} - REPO_OWNER: supercog-ai - REPO_NAME: PR_code_review-agent diff --git a/examples/pr_agent/PRChangesTest.patch b/examples/pr_agent/PRChangesTest.patch new file mode 100644 index 00000000..beaf620e --- /dev/null +++ b/examples/pr_agent/PRChangesTest.patch @@ -0,0 +1,886 @@ +diff --git a/.github/workflows/pr-summary-agent.yml b/.github/workflows/pr-summary-agent.yml +new file mode 100644 +index 0000000..dcb5428 +--- /dev/null ++++ b/.github/workflows/pr-summary-agent.yml +@@ -0,0 +1,52 @@ ++name: PR Summary Agent ++ ++on: ++ pull_request: ++ types: [opened, synchronize, reopened] ++ branches: ++ - main ++jobs: ++ run-pr-agent: ++ permissions: write-all ++ name: run PR Summary Agent ++ runs-on: ubuntu-latest ++ steps: ++ - name: Check out the repository to the runner ++ uses: actions/checkout@v4 ++ with: ++ fetch-depth: 2 ++ ++ - name: Install uv ++ uses: astral-sh/setup-uv@v6 ++ ++ - name: Download weaviate cache ++ id: download-artifact ++ uses: dawidd6/action-download-artifact@v11 ++ with: ++ github_token: ${{secrets.GITHUB_TOKEN}} ++ workflow: pr-summary-agent.yml ++ name: weaviate ++ if_no_artifact_found: warn ++ path: /home/runner/.cache/weaviate ++ ++ - name: Run agent ++ run: | ++ uv venv --python 3.12 ++ uv pip install -e "../${{ github.event.repository.name }}[all,dev]" --extra-index-url https://download.pytorch.org/whl/cpu --index-strategy unsafe-first-match ++ uv pip install litellm[proxy] ++ git diff --merge-base HEAD^1 HEAD > PRChanges.patch ++ cat PRChanges.patch ++ uv run pr_agent/PR_agent.py ++ ++ env: ++ OPENAI_API_KEY: ${{ secrets.PRAgentOpenAIKey }} ++ PR_ID: ${{ github.event.pull_request.number }} ++ GITHUB_API_KEY: ${{ secrets.GITHUB_TOKEN }} ++ REPO: ${{ github.repository }} ++ ++ - name: Update weaviate cache ++ uses: actions/upload-artifact@v4 ++ with: ++ name: weaviate ++ path: /home/runner/.cache/weaviate ++ overwrite: true +diff --git a/pr_agent/PR_agent.py b/pr_agent/PR_agent.py +new file mode 100644 +index 0000000..5d5abd0 +--- /dev/null ++++ b/pr_agent/PR_agent.py +@@ -0,0 +1,218 @@ ++import os ++from pathlib import Path ++import re ++import json ++import requests ++from typing import Dict, List, Any, Generator, Optional, Tuple ++from pydantic import Field, BaseModel ++from dotenv import load_dotenv ++from agentic.common import Agent, AgentRunner, ThreadContext ++from agentic.events import Event, ChatOutput, TurnEnd, PromptStarted, Prompt ++from agentic.models import GPT_4O_MINI ++from litellm import token_counter ++from code_rag_agent import CodeRagAgent ++from git_grep_agent import GitGrepAgent ++from summary_agent import SummaryAgent ++from code_rag_agent import CodeSection, CodeSections ++from pydantic import BaseModel ++ ++SUMMARY_MODEL = GPT_4O_MINI ++# SUMMARY_MODEL = CLAUDE ++ ++load_dotenv() ++ ++class SearchResult(BaseModel): ++ query: str = Field( ++ description="Query used in this search." ++ ) ++ file_path: str = Field( ++ description="Path of the file this code/documentation belongs to." ++ ) ++ content: str = Field( ++ description="Content returned from search." ++ ) ++ similarity_score: float = Field( ++ desciption="Similarity score returned from vector search." ++ ) ++ included_defs: List[str] = Field( ++ default_factory=list, ++ desciption="Similarity score returned from vector search." ++ ) ++ ++class Searches(BaseModel): ++ searches: List[str] = Field( ++ description="Search queries." ++ ) ++ ++class RelevanceResult(BaseModel): ++ relevant: bool ++ ++class PRReviewAgent(Agent): ++ ++ def __init__( ++ self, ++ name: str = "PR Review Agent", ++ model: str = GPT_4O_MINI, ++ verbose: bool = False, ++ **kwargs ++ ): ++ super().__init__( ++ name=name, ++ welcome="PR Review Agent initialized. Ready to process PRs.", ++ model=model, ++ **kwargs ++ ) ++ self.git_grep_agent = GitGrepAgent() ++ self.code_rag_agent = CodeRagAgent() ++ self.verbose = verbose ++ ++ self.queryAgent = Agent( ++ name="Code Query Agent", ++ instructions= ++""" ++You are an expert in generating code search queries from a patch file to get additional context about changes to a code base. Your response must include a 'searches' field with a list of strings. Example outputs: ["Weather_Tool", "SearchQuery", "format_sections"] ++""", ++ model=GPT_4O_MINI, ++ result_model=Searches, ++ ) ++ ++ self.summaryAgent = SummaryAgent() ++ ++ def prepare_summary(self, patch_content: str, filtered_results: Dict[str,SearchResult]) -> str: ++ ++ """Prepare for summary agent""" ++ formatted_str = "" ++ formatted_str += f"\n" ++ formatted_str += f"{patch_content}\n" ++ formatted_str += f"\n\n" ++ ++ final_str = formatted_str[:] ++ ++ for result in filtered_results.values(): ++ formatted_str += f"<{result.file_path}>\n" ++ formatted_str += f"{result.content}\n" ++ formatted_str += f"\n\n" ++ ++ if token_counter(model=SUMMARY_MODEL, messages=[{"role": "user", "content": {final_str}}]) > 115000: ++ break ++ else: ++ final_str = formatted_str[:] ++ ++ return final_str ++ ++ def post_to_github(self, summary: str) -> str: ++ """Post summary as a GitHub comment""" ++ repo = os.getenv("REPO") ++ pr_id = os.getenv("PR_ID") ++ gh_token = os.getenv("GITHUB_API_KEY") ++ ++ if not all([repo, pr_id, gh_token]): ++ raise ValueError("Missing required GitHub configuration") ++ ++ url = f"https://api.github.com/repos/{repo}/issues/{pr_id}/comments" ++ headers = { ++ "Authorization": f"token {gh_token}", ++ } ++ data = {"body": summary} ++ ++ response = requests.post(url, headers=headers, json=data) ++ response.raise_for_status() ++ return response.json().get("html_url") ++ ++ def next_turn( ++ self, ++ request: str, ++ request_context: dict = None, ++ request_id: str = None, ++ continue_result: dict = {}, ++ debug = "", ++ ) -> Generator[Event, Any, None]: ++ ++ query = request.payload if isinstance(request, Prompt) else request ++ yield PromptStarted(query, {"query": query}) ++ ++ # Generate search queries ++ queries = yield from self.queryAgent.final_result( ++ request_context.get("patch_content"), ++ request_context={ ++ "thread_id": request_context.get("thread_id") ++ } ++ ) ++ ++ print("queries: ", str(queries)) ++ ++ # RAG and Git-Grep queries ++ all_results = {} ++ for query in queries.searches[:10]: ++ searchResponse = yield from self.code_rag_agent.final_result( ++ "Search codebase", ++ request_context={ ++ "query": query, ++ "thread_id": request_context.get("thread_id") ++ } ++ ) ++ ++ # Process each result ++ for file, result in searchResponse.sections.items(): ++ if file not in all_results: ++ all_results[file] = SearchResult(query=query,file_path=result.file_path,content=result.search_result,similarity_score=result.similarity_score,included_defs=result.included_defs) ++ ++ searchResponse = yield from self.git_grep_agent.final_result( ++ "Search codebase with git grep", ++ request_context={ ++ "query": query, ++ "thread_id": request_context.get("thread_id") ++ } ++ ) ++ ++ # Process each result ++ # grep_response.sections is a dict of filepaths and CodeSection objects ++ for file, result in searchResponse.sections.items(): ++ if file not in all_results: ++ all_results[file] = SearchResult( ++ query=query, ++ file_path=result.file_path, ++ content=result.search_result, ++ similarity_score=result.similarity_score, ++ included_defs=result.included_defs ++ ) ++ ++ ++ print("all: ", all_results) ++ ++ # Prepare for summary ++ formatted_str = self.prepare_summary(request_context.get("patch_content"),all_results) ++ ++ print(formatted_str) ++ ++ summary = yield from self.summaryAgent.final_result( ++ formatted_str ++ ) ++ ++ comment_url = self.post_to_github(summary) ++ ++ # Return the final result ++ yield ChatOutput( ++ self.name, ++ [{"content": f"## PR Review Complete\n\nSummary posted to: {comment_url}"}] ++ ) ++ ++ yield TurnEnd( ++ self.name, ++ [{"content": summary}] ++ ) ++ ++ ++ ++if __name__ == "__main__": ++ ++ # test ++ # Change to PRChangesTest.patch for testing ++ with open("PRChangesTest.patch", "r") as f: ++ patch_content = f.read() ++ ++ # Create an instance of the agent ++ pr_review_agent = PRReviewAgent() ++ ++ # Run the agent ++ print(pr_review_agent.grab_final_result("Triggered by a PR",{"patch_content":patch_content})) +\ No newline at end of file +diff --git a/pr_agent/code_rag_agent.py b/pr_agent/code_rag_agent.py +new file mode 100644 +index 0000000..474550e +--- /dev/null ++++ b/pr_agent/code_rag_agent.py +@@ -0,0 +1,93 @@ ++from typing import Any, Generator, List ++from agentic.common import Agent, AgentRunner, ThreadContext ++from agentic.events import Event, ChatOutput, WaitForInput, Prompt, PromptStarted, TurnEnd, ResumeWithInput ++from agentic.models import GPT_4O_MINI # model (using GPT for testing) ++from pydantic import BaseModel, Field ++from agentic.tools.rag_tool import RAGTool ++import ast ++ ++class CodeSection(BaseModel): ++ search_result: str = Field( ++ description="Part returned from search.", ++ ) ++ file_path: str = Field( ++ description="Path of the file this code belongs to." ++ ) ++ included_defs: list[str] = Field( ++ description="Classes and functions defined in this file." ++ ) ++ similarity_score: float = Field( ++ desciption="Similarity score returned from vector search." ++ ) ++ ++class CodeSections(BaseModel): ++ sections: dict[str,CodeSection] = Field( ++ description="Sections of the codebase returned from the search.", ++ ) ++ search_query: str = Field( ++ description="Query used to return this section.", ++ ) ++ ++class CodeRagAgent(Agent): ++ def __init__(self, ++ name="Code Rag Agent", ++ welcome="I am the Code Rag Agent. Please give me a search query (function name,class name, etc.) and I'll return relevant parts of the code.", ++ model: str=GPT_4O_MINI, ++ result_model = CodeSections, ++ **kwargs ++ ): ++ super().__init__( ++ name=name, ++ welcome=welcome, ++ model=model, ++ result_model=result_model, ++ **kwargs ++ ) ++ ++ self.ragTool = RAGTool( ++ default_index="codebase", ++ index_paths=["../**/*.md"], ++ recursive=True ++ ) ++ ++ ++ def next_turn( ++ self, ++ request: str|Prompt, ++ request_context: dict = {}, ++ request_id: str = None, ++ continue_result: dict = {}, ++ debug = "", ++ ) -> Generator[Event, Any, Any]: ++ ++ query = request.payload if isinstance(request, Prompt) else request ++ yield PromptStarted(query, {"query": query}) ++ ++ searchQuery = request_context.get("query") ++ ++ searchResult = self.ragTool.search_knowledge_index(query=searchQuery,limit=5) ++ ++ allSections = CodeSections(sections={},search_query=query) ++ ++ for nextResult in searchResult: ++ file_path = nextResult["source_url"] ++ if not file_path in allSections.sections: ++ #print(nextResult) ++ ++ similarity_score = nextResult["distance"] if nextResult["distance"] else 0 ++ content = nextResult["content"] ++ ++ # Only works with Python files ++ included_defs = [] ++ try: ++ with open(file_path) as file: ++ node = ast.parse(file.read()) ++ included_defs = [n.name for n in node.body if isinstance(n, ast.ClassDef) or isinstance(n, ast.FunctionDef)] ++ except: ++ included_defs = [] ++ ++ allSections.sections[file_path] = CodeSection(search_result=content,file_path=file_path,included_defs=included_defs,similarity_score=similarity_score) ++ #else: ++ #print("Skipping Duplicate: ",file_path) ++ ++ yield TurnEnd(self.name, [{"content": allSections}]) +diff --git a/pr_agent/git_grep_agent.py b/pr_agent/git_grep_agent.py +new file mode 100644 +index 0000000..3567f53 +--- /dev/null ++++ b/pr_agent/git_grep_agent.py +@@ -0,0 +1,141 @@ ++from typing import Any, Generator, List ++from agentic.common import Agent ++from agentic.events import Event, Prompt, PromptStarted, TurnEnd ++from agentic.models import GPT_4O_MINI # model (using GPT for testing) ++import subprocess ++import ast ++import keyword ++import logging ++ ++from code_rag_agent import CodeSection, CodeSections ++ ++def find_full_function(file_path: str, line_number: int) -> str: ++ """Finds the full function definition given a file path and line number. Expects a properly formatted file.""" ++ ++ SUPPORTED_EXTENSIONS = [".py"] ++ line_number -= 1 # line numbers start at 1, not 0, bad for native zero indexing ++ try: ++ with open(file_path) as file: ++ text = file.read() ++ except Exception as e: ++ return f"Error with file: {e}" ++ ++ if not any(file_path.endswith(ext) for ext in SUPPORTED_EXTENSIONS): ++ logging.warning("File is not a supported extension, returning full file") ++ return text ++ ++ if file_path.endswith(".py"): ++ lines = text.splitlines() ++ ++ if len(lines) < line_number: # move this out of the if statement later ++ logging.error("Line number is out of bounds, returning full file") ++ return text ++ ++ # this function is "good enough" -- if there is any function that is not defined by the keyword module called outside of a class or function, it will keep it ++ def is_a_zero(line: str) -> bool: ++ if not line or line[0] == " " or not keyword.iskeyword(line.split()[0]): ++ return False ++ return True ++ ++ # The idea is to find the full class definition the function is embedded in, so we need to go up and down until we find this ++ top_index = line_number ++ ++ while (top_index > 0 and not is_a_zero(lines[top_index])): ++ top_index -= 1 ++ ++ bottom_index = line_number + 1 ++ while (bottom_index < len(lines) and not is_a_zero(lines[bottom_index])): ++ bottom_index += 1 ++ ++ return "\n".join(lines[top_index:bottom_index]) ++ ++ return text ++ ++# The actual sub-agent that runs git grep and returns structured results ++class GitGrepAgent(Agent): ++ def __init__(self, ++ name="Git-Grep Agent", ++ welcome="I am the Git Grep Agent. Please give me a search query (function name,class name, etc.) and I'll return exact matches from the codebase.", ++ model: str=GPT_4O_MINI, ++ result_model = CodeSections, ++ **kwargs ++ ): ++ super().__init__( ++ name=name, ++ welcome=welcome, ++ model=model, ++ result_model=result_model, ++ **kwargs ++ ) ++ ++ ++ def run_git_grep(self, query: str) -> List[tuple[str, str]]: ++ # Runs "git grep -n " for the given query to find exact matches in the codebase ++ # parses each result line into (file_path, matched_line) both of which are strs ++ # and returns a list of (file_path, matched_line) tuples ++ try: ++ result = subprocess.run( ++ ["git", "grep", "-n", query], # make sure that query is getting passed by the Main Agent!!! ++ capture_output=True, ++ text=True, ++ check=False ++ ) ++ ++ ++ # example git grep output: "code_rag_agent.py:6:from agentic.tools.rag_tool import RAGTool" ++ ++ ++ # TODO: need to determine if the line number is neccessary returning... ++ matches = [] # list of matches from the git grep command --> will hold all (file_path, matched_line) tuples found! ++ for line in result.stdout.splitlines(): ++ if not line: ++ continue ++ parts = line.split(":", 2) # file_path, line_number, line_text ++ if len(parts) >= 3: # if the output line is in the correct format ++ file_path, line_number, matched_line = parts ++ matches.append((file_path, find_full_function(file_path, int(line_number)))) ++ return matches ++ except Exception as e: ++ print(f"Error running git grep: {e}") ++ return [] ++ ++ ++ ++ # the entry point for running one turn (input -> processing -> output) ++ def next_turn( ++ self, ++ request: str | Prompt, ++ request_context: dict = {}, ++ request_id: str = None, ++ continue_result: dict = {}, ++ debug = "", ++ ) -> Generator[Event, Any, Any]: ++ # same as for the code_rag_context ++ ++ ++ # Either use query from request_context or from direct input ++ query = request.payload if isinstance(request, Prompt) else request # extracts the query from the incoming request ++ yield PromptStarted(query, {"query": query}) # yields a PromptStarted event to signal the beginning of processing ++ ++ ++ search_query = request_context.get("query") # pulls the actual search query from the request context ++ grep_results = self.run_git_grep(search_query) # runs git grep for that specific query ++ ++ ++ # TODO: verify that sections doesn't have to be a dictionary instead (like code_rag_agent implementation) ++ allSections = CodeSections(sections={}, search_query=search_query) # creates an empty CodeSections object ++ ++ # loops over each grep match ++ for file_path, content in grep_results: ++ if not file_path in allSections.sections: ++ allSections.sections[file_path] = CodeSection( ++ search_result=content, ++ file_path=file_path, ++ included_defs=[], ++ similarity_score=1.0 # grep doesn't do semantic scoring ++ ) ++ ++ yield TurnEnd(self.name, [{"content": allSections}]) ++ ++if __name__ == "__main__": ++ print(find_full_function("git_grep_agent.py", 172)) +\ No newline at end of file +diff --git a/pr_agent/summary_agent.py b/pr_agent/summary_agent.py +new file mode 100644 +index 0000000..8401bbc +--- /dev/null ++++ b/pr_agent/summary_agent.py +@@ -0,0 +1,100 @@ ++from agentic.common import Agent ++from agentic.models import CLAUDE ++from agentic.models import GPT_4O_MINI # model (using GPT for testing) ++ ++class SummaryAgent(Agent): ++ def __init__(self, ++ name="PR Summary Agent", ++ ++ # Agent instructions ++ instructions="""You are a code review assistant. Your task is to analyze a GitHub pull request using the provided information and generate helpful, precise feedback. Your response must include specific insights only about the files and code that were changed in the pull request. ++ ++Intended Purpose ++Given a pull request's patch file and supporting repository files for context, generate a high-quality comment summarizing the changes and providing constructive feedback. Focus exclusively on the changes introduced by the pull request. ++ ++Input Format ++You will be given: ++ ++A patch file describing the code changes. ++ ++Supporting code files from the repository that provide context for understanding the codebase. ++ ++The format will be as follows: ++ ++ ++ ++ ++ ++ ++ ++ ++ ++ ++ ++ ++ ++... ++Only the contents within the tags represent the actual changes. The rest are context files and should only be used to understand the repository structure and functionality. Do not comment on context files unless they are included in the patch. ++ ++Output Format ++Respond with the following structured sections: ++ ++Key Features ++A list of important or high-level features introduced by the changes. ++ ++Summary of Changes ++A clear and concise explanation of what was changed in the pull request, written in plain language. ++ ++New Unlocks from Functionality ++Describe any new capabilities or usage scenarios unlocked by these changes. ++ ++Code Suggestions with Line Number References ++Provide specific suggestions for improving the changed code, referring to lines by number as seen in the patch. ++ ++Formatting Suggestions ++Note any formatting or stylistic improvements that should be made. ++ ++Additional Notes ++Include any other relevant insights, such as potential edge cases, compatibility issues, or tests that should be added. ++ ++Important Rules ++ ++Only refer to files/lines that are explicitly changed in the patch file. ++ ++Use the provided file contents only to gain context for understanding the changes. ++ ++Be constructive, concise, and clear in your feedback. ++ """, ++ model=GPT_4O_MINI, ++ #model=CLAUDE, # model ++ **kwargs ++ ): ++ super().__init__( ++ name=name, ++ instructions=instructions, ++ model=model, ++ **kwargs ++ ) ++ ++ ++# Main to use the agent on the test files ++if __name__ == "__main__": ++ context = "\n" ++ with open("PR_code_review-agent/pr_agent/test_files/test_patch_file.txt", "r") as file: ++ context += file.read() ++ context += "\n\n" ++ context += "\n" ++ with open("PR_code_review-agent/pr_agent/test_files/agent_runner_copy.txt", "r") as file: ++ context += file.read() ++ context += "\n\n" ++ context += "\n" ++ with open("PR_code_review-agent/pr_agent/test_files/agent_copy.txt", "r") as file: ++ context += file.read() ++ context += "\n\n" ++ context += "\n" ++ with open("PR_code_review-agent/pr_agent/test_files/weather_tool_copy.txt", "r") as file: ++ context += file.read() ++ context += "" ++ ++ agent = SummaryAgent() ++ print(agent << context) +\ No newline at end of file +diff --git a/src/agentic/tools/rag_tool.py b/src/agentic/tools/rag_tool.py +index 43d27d3..0c03848 100644 +--- a/src/agentic/tools/rag_tool.py ++++ b/src/agentic/tools/rag_tool.py +@@ -14,6 +14,8 @@ from agentic.utils.rag_helper import ( + init_embedding_model, + init_chunker, + rag_index_file, ++ rag_index_multiple_files, ++ delete_document_from_index, + ) + + from agentic.utils.summarizer import generate_document_summary +@@ -22,6 +24,7 @@ from weaviate.classes.query import Filter, HybridFusion + from weaviate.collections.classes.grpc import Sort + from weaviate.classes.config import VectorDistances + ++from rich.console import Console + + @tool_registry.register( + name="RAGTool", +@@ -38,20 +41,59 @@ class RAGTool(BaseAgenticTool): + def __init__( + self, + default_index: str = "knowledge_base", +- index_paths: list[str] = [] ++ index_paths: list[str] = [], ++ recursive: bool = False, ++ overwrite_index = True, + ): + # Construct the RAG tool. You can pass a list of files and we will ensure that + # they are added to the index on startup. Paths can include glob patterns also, + # like './docs/*.md'. ++ # Enable recursive (**.md) glob patterns with recursive = True ++ + self.default_index = default_index + self.index_paths = index_paths + if self.index_paths: + client = init_weaviate() + if default_index not in list_collections(client): + create_collection(client, default_index, VectorDistances.COSINE) ++ ++ # Keep track of files found during initialization ++ if overwrite_index: ++ indexed_documents = {} ++ + for path in index_paths: +- for file_path in [path] if path.startswith("http") else glob.glob(path): +- rag_index_file(file_path, self.default_index, client=client, ignore_errors=True) ++ if path.startswith("http"): ++ document_id = rag_index_file(path, self.default_index, client=client, ignore_errors=True) ++ ++ if overwrite_index: ++ indexed_documents[document_id] = True ++ ++ else: ++ file_paths = glob.glob(path, recursive=recursive) ++ document_ids = rag_index_multiple_files(file_paths, self.default_index, client=client, ignore_errors=True) ++ ++ if overwrite_index: ++ for document_id in document_ids: ++ indexed_documents[document_id] = True ++ ++ # Delete indexed files not found during initialization ++ if overwrite_index: ++ try: ++ console = Console() ++ collection = client.collections.get(self.default_index) ++ documents = list_documents_in_collection(collection) ++ ++ for document in documents: ++ if not document["document_id"] in indexed_documents: ++ console.print(f"[bold green]✅ Removing deleted file {document['filename']} from index") ++ delete_document_from_index(collection=collection,document_id=document["document_id"],filename=document["filename"]) ++ ++ except Exception as e: ++ print(f"Error listing documents: {str(e)}") ++ return ++ finally: ++ if client: ++ client.close() + + def get_tools(self) -> List[Callable]: + return [ +@@ -59,7 +101,7 @@ class RAGTool(BaseAgenticTool): + #self.list_indexes, + self.search_knowledge_index, + self.list_documents, +- self.review_full_document ++ self.review_full_document, + ] + + def save_content_to_knowledge_index( +diff --git a/src/agentic/utils/file_reader.py b/src/agentic/utils/file_reader.py +index 3ea5ae2..70671fb 100644 +--- a/src/agentic/utils/file_reader.py ++++ b/src/agentic/utils/file_reader.py +@@ -57,6 +57,9 @@ def read_file(file_path: str, mime_type: str|None = None) -> tuple[str, str]: + return text, mime_type + elif mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": + return pd.read_excel(file_path).to_csv(), mime_type ++ elif mime_type == "text/x-python": ++ with open(file_path,"r") as f: ++ return f.read(), mime_type + else: + return textract.process(file_path).decode('utf-8'), mime_type + except Exception as e: +diff --git a/src/agentic/utils/rag_helper.py b/src/agentic/utils/rag_helper.py +index 94dd7f5..89a1f39 100644 +--- a/src/agentic/utils/rag_helper.py ++++ b/src/agentic/utils/rag_helper.py +@@ -91,7 +91,7 @@ def prepare_document_metadata( + + # Generate document ID from filename + metadata["document_id"] = hashlib.sha256( +- metadata["filename"].encode() ++ str(Path(file_path)).encode() + ).hexdigest() + + return metadata +@@ -155,7 +155,7 @@ def rag_index_file( + client: WeaviateClient|None = None, + ignore_errors: bool = False, + distance_metric: VectorDistances = VectorDistances.COSINE, +-): ++) -> str: + """Index a file using configurable Weaviate Embedded and chunking parameters""" + + console = Console() +@@ -186,10 +186,10 @@ def rag_index_file( + + if status == "unchanged": + console.print(f"[yellow]⏩ Document '{metadata['filename']}' unchanged[/yellow]") +- return ++ return metadata["document_id"] + elif status == "duplicate": + console.print(f"[yellow]⚠️ Content already exists under different filename[/yellow]") +- return ++ return metadata["document_id"] + elif status == "changed": + console.print(f"[yellow]🔄 Updating changed document '{metadata['filename']}'[/yellow]") + collection.data.delete_many( +@@ -233,8 +233,102 @@ def rag_index_file( + finally: + if client and client_created: + client.close() +- return "indexed" ++ return metadata["document_id"] ++ ++def rag_index_multiple_files( ++ file_paths: List[str], ++ index_name: str, ++ chunk_threshold: float = 0.5, ++ chunk_delimiters: str = ". ,! ,? ,\n", ++ embedding_model: str = "BAAI/bge-small-en-v1.5", ++ client: WeaviateClient|None = None, ++ ignore_errors: bool = False, ++ distance_metric: VectorDistances = VectorDistances.COSINE, ++) -> List[str]: ++ """Index a file using configurable Weaviate Embedded and chunking parameters""" ++ ++ console = Console() ++ client_created = False ++ ++ documents_indexed = [] ++ try: ++ with Status("[bold green]Initializing Weaviate..."): ++ if client is None: ++ client = init_weaviate() ++ client_created = True ++ create_collection(client, index_name, distance_metric) ++ ++ with Status("[bold green]Initializing models..."): ++ embed_model = init_embedding_model(embedding_model) ++ chunker = init_chunker(chunk_threshold, chunk_delimiters) + ++ for file_path in file_paths: ++ with Status(f"[bold green]Processing {file_path}...", console=console): ++ text, mime_type = read_file(str(file_path)) ++ metadata = prepare_document_metadata(file_path, text, mime_type, GPT_DEFAULT_MODEL) ++ ++ console.print(f"[bold green]Indexing {file_path}...") ++ ++ collection = client.collections.get(index_name) ++ exists, status = check_document_exists( ++ collection, ++ metadata["document_id"], ++ metadata["fingerprint"] ++ ) ++ ++ if status == "unchanged": ++ console.print(f"[yellow]⏩ Document '{metadata['filename']}' unchanged[/yellow]") ++ documents_indexed.append(metadata["document_id"]) ++ continue ++ elif status == "duplicate": ++ console.print(f"[yellow]⚠️ Content already exists under different filename[/yellow]") ++ documents_indexed.append(metadata["document_id"]) ++ continue ++ elif status == "changed": ++ console.print(f"[yellow]🔄 Updating changed document '{metadata['filename']}'[/yellow]") ++ collection.data.delete_many( ++ where=Filter.by_property("document_id").equal(metadata["document_id"]) ++ ) ++ ++ with Status("[bold green]Generating document summary...", console=console): ++ metadata["summary"] = generate_document_summary( ++ text=text[:12000], ++ mime_type=mime_type, ++ model=GPT_DEFAULT_MODEL ++ ) ++ ++ chunks = chunker(text) ++ chunks_text = [chunk.text for chunk in chunks] ++ if not chunks_text: ++ if ignore_errors: ++ return client ++ raise ValueError("No text chunks generated from document") ++ ++ batch_size = 128 ++ embeddings = [] ++ with Status("[bold green]Generating embeddings..."): ++ for i in range(0, len(chunks_text), batch_size): ++ batch = chunks_text[i:i+batch_size] ++ embeddings.extend(list(embed_model.embed(batch))) ++ ++ with Status("[bold green]Indexing chunks..."), collection.batch.dynamic() as batch: ++ for i, chunk in enumerate(chunks): ++ vector = embeddings[i].tolist() ++ batch.add_object( ++ properties={ ++ **metadata, ++ "content": chunk.text, ++ "chunk_index": i, ++ }, ++ vector=vector ++ ) ++ ++ documents_indexed.append(metadata["document_id"]) ++ console.print(f"[bold green]✅ Indexed {len(chunks)} chunks in {index_name}") ++ finally: ++ if client and client_created: ++ client.close() ++ return documents_indexed + + def delete_document_from_index( + collection: Any, +diff --git a/test b/test +new file mode 100644 +index 0000000..e69de29 diff --git a/examples/pr_agent/PR_agent.py b/examples/pr_agent/PR_agent.py new file mode 100644 index 00000000..77f42f4b --- /dev/null +++ b/examples/pr_agent/PR_agent.py @@ -0,0 +1,216 @@ +import os +from pathlib import Path +import re +import json +import requests +from typing import Dict, List, Any, Generator, Optional, Tuple +from pydantic import Field, BaseModel +from dotenv import load_dotenv +from agentic.common import Agent, AgentRunner, ThreadContext +from agentic.events import Event, ChatOutput, TurnEnd, PromptStarted, Prompt +from agentic.models import GPT_4O_MINI +from litellm import token_counter +from code_rag_agent import CodeRagAgent +from git_grep_agent import GitGrepAgent +from summary_agent import SummaryAgent +from code_rag_agent import CodeSection, CodeSections +from pydantic import BaseModel + +SUMMARY_MODEL = GPT_4O_MINI +# SUMMARY_MODEL = CLAUDE + +load_dotenv() + +class SearchResult(BaseModel): + query: str = Field( + description="Query used in this search." + ) + file_path: str = Field( + description="Path of the file this code/documentation belongs to." + ) + content: str = Field( + description="Content returned from search." + ) + similarity_score: float = Field( + desciption="Similarity score returned from vector search." + ) + included_defs: List[str] = Field( + default_factory=list, + desciption="Similarity score returned from vector search." + ) + +class Searches(BaseModel): + searches: List[str] = Field( + description="Search queries." + ) + +class RelevanceResult(BaseModel): + relevant: bool + +class PRReviewAgent(Agent): + + def __init__( + self, + name: str = "PR Review Agent", + model: str = GPT_4O_MINI, + verbose: bool = False, + **kwargs + ): + super().__init__( + name=name, + welcome="PR Review Agent initialized. Ready to process PRs.", + model=model, + **kwargs + ) + self.git_grep_agent = GitGrepAgent() + self.code_rag_agent = CodeRagAgent() + self.verbose = verbose + + self.queryAgent = Agent( + name="Code Query Agent", + instructions= +""" +You are an expert in generating code search queries from a patch file to get additional context about changes to a code base. Your response must include a 'searches' field with a list of strings. Example outputs: ["Weather_Tool", "SearchQuery", "format_sections"] +""", + model=GPT_4O_MINI, + result_model=Searches, + ) + + self.summaryAgent = SummaryAgent() + + def prepare_summary(self, patch_content: str, filtered_results: Dict[str,SearchResult]) -> str: + + """Prepare for summary agent""" + formatted_str = "" + formatted_str += f"\n" + formatted_str += f"{patch_content}\n" + formatted_str += f"\n\n" + + final_str = formatted_str[:] + + for result in filtered_results.values(): + formatted_str += f"<{result.file_path}>\n" + formatted_str += f"{result.content}\n" + formatted_str += f"\n\n" + + if token_counter(model=SUMMARY_MODEL, messages=[{"role": "user", "content": {final_str}}]) > 115000: + break + else: + final_str = formatted_str[:] + + return final_str + + def post_to_github(self, summary: str) -> str: + """Post summary as a GitHub comment""" + repo = os.getenv("REPO") + pr_id = os.getenv("PR_ID") + gh_token = os.getenv("GITHUB_API_KEY") + + if not all([repo, pr_id, gh_token]): + raise ValueError("Missing required GitHub configuration") + + url = f"https://api.github.com/repos/{repo}/issues/{pr_id}/comments" + headers = { + "Authorization": f"token {gh_token}", + } + data = {"body": summary} + + response = requests.post(url, headers=headers, json=data) + response.raise_for_status() + return response.json().get("html_url") + + def next_turn( + self, + request: str, + request_context: dict = None, + request_id: str = None, + continue_result: dict = {}, + debug = "", + ) -> Generator[Event, Any, None]: + + query = request.payload if isinstance(request, Prompt) else request + yield PromptStarted(query, {"query": query}) + + # Generate search queries + queries = yield from self.queryAgent.final_result( + request_context.get("patch_content"), + request_context={ + "thread_id": request_context.get("thread_id") + } + ) + + print("queries: ", str(queries)) + + # RAG and Git-Grep queries + all_results = {} + for query in queries.searches[:10]: + searchResponse = yield from self.code_rag_agent.final_result( + "Search codebase", + request_context={ + "query": query, + "thread_id": request_context.get("thread_id") + } + ) + + # Process each result + for file, result in searchResponse.sections.items(): + if file not in all_results: + all_results[file] = SearchResult(query=query,file_path=result.file_path,content=result.search_result,similarity_score=result.similarity_score,included_defs=result.included_defs) + + searchResponse = yield from self.git_grep_agent.final_result( + "Search codebase with git grep", + request_context={ + "query": query, + "thread_id": request_context.get("thread_id") + } + ) + + # Process each result + # grep_response.sections is a dict of filepaths and CodeSection objects + for file, result in searchResponse.sections.items(): + if file not in all_results: + all_results[file] = SearchResult( + query=query, + file_path=result.file_path, + content=result.search_result, + similarity_score=result.similarity_score, + included_defs=result.included_defs + ) + + + print("all: ", all_results) + + # Prepare for summary + formatted_str = self.prepare_summary(request_context.get("patch_content"),all_results) + + print(formatted_str) + + summary = yield from self.summaryAgent.final_result( + formatted_str + ) + + comment_url = self.post_to_github(summary) + + # Return the final result + yield ChatOutput( + self.name, + [{"content": f"## PR Review Complete\n\nSummary posted to: {comment_url}"}] + ) + + yield TurnEnd( + self.name, + [{"content": summary}] + ) + + + +if __name__ == "__main__": + # Change to PRChangesTest.patch to use the test patch file + with open("PRChanges.patch", "r") as f: + patch_content = f.read() + + # Create an instance of the agent + pr_review_agent = PRReviewAgent() + + # Run the agent + print(pr_review_agent.grab_final_result("Triggered by a PR",{"patch_content":patch_content})) \ No newline at end of file diff --git a/examples/pr_agent/README.md b/examples/pr_agent/README.md new file mode 100644 index 00000000..7f530f54 --- /dev/null +++ b/examples/pr_agent/README.md @@ -0,0 +1,39 @@ +# PR summary agent + +This is a PR summary agent with GitHub integrations to automatically attatch a summary when a pull request is opened/reopened. + +## Installation and running + +This example is designed to be run within a fork of the agentic repository, as it uses features that are still in development and must clone the repository. +However, once the features are released, importing the agentic repository with the following command in the [github actions workflow](pr-summary-agent.yml) should be sufficient. +``` +uv pip install "agentic-framework[all,dev]" --extra-index-url https://download.pytorch.org/whl/cpu +``` + +Create a fork of the agentic repository. +Add [pr-summary-agent.yml](pr-summary-agent.yml) to .github/workflows + +Add `PRAgentOpenAIKey` to the github repository secrets, containing a valid OpenAI API key. + +Move the folder `pr_agent` containing all the files for the pr agent into the top-most folder, with the repository name. + +Create a branch and make changes. + +Open a pull request to your main branch. + +The action should start running immediately once the PR is opened, leaving a comment under the PR when finished. + +## How it works + +When a PR opens, the github action clones the repository and puts the differences into a file named PRChanges.patch. +`git diff --merge-base HEAD^1 HEAD > PRChanges.patch` + +The main agent is then run, and it generates search queries using this inital patch file. These queries are then put through the [RAG agent](code_rag_agent.py) and the [grep agent](git_grep_agent.py) to return more context. +The RAG agent is configured to index markdown files at startup using agentic's RAGTool. This index is cached to reduce startup times in subsequent runs. +The RAG agent finds relevant documentation and returns it to the main agent. +The grep agent finds code files (currently .py only) with exact matches, which is useful for finding implementations for functions and classes. +The grep agent finds the full function/class the match was found in, which is then returned to the main agent. + +This is then all packaged together into a final context window, with the patch file, and search results, separated by file names. +This is sent to the summary agent, which creates a summary using the branch changes and relavant code/documentation for summary generation. +The main agent then uploads the summary to the PR as a comment, using the GitHub API. \ No newline at end of file diff --git a/examples/pr_agent/code_rag_agent.py b/examples/pr_agent/code_rag_agent.py new file mode 100644 index 00000000..474550ef --- /dev/null +++ b/examples/pr_agent/code_rag_agent.py @@ -0,0 +1,93 @@ +from typing import Any, Generator, List +from agentic.common import Agent, AgentRunner, ThreadContext +from agentic.events import Event, ChatOutput, WaitForInput, Prompt, PromptStarted, TurnEnd, ResumeWithInput +from agentic.models import GPT_4O_MINI # model (using GPT for testing) +from pydantic import BaseModel, Field +from agentic.tools.rag_tool import RAGTool +import ast + +class CodeSection(BaseModel): + search_result: str = Field( + description="Part returned from search.", + ) + file_path: str = Field( + description="Path of the file this code belongs to." + ) + included_defs: list[str] = Field( + description="Classes and functions defined in this file." + ) + similarity_score: float = Field( + desciption="Similarity score returned from vector search." + ) + +class CodeSections(BaseModel): + sections: dict[str,CodeSection] = Field( + description="Sections of the codebase returned from the search.", + ) + search_query: str = Field( + description="Query used to return this section.", + ) + +class CodeRagAgent(Agent): + def __init__(self, + name="Code Rag Agent", + welcome="I am the Code Rag Agent. Please give me a search query (function name,class name, etc.) and I'll return relevant parts of the code.", + model: str=GPT_4O_MINI, + result_model = CodeSections, + **kwargs + ): + super().__init__( + name=name, + welcome=welcome, + model=model, + result_model=result_model, + **kwargs + ) + + self.ragTool = RAGTool( + default_index="codebase", + index_paths=["../**/*.md"], + recursive=True + ) + + + def next_turn( + self, + request: str|Prompt, + request_context: dict = {}, + request_id: str = None, + continue_result: dict = {}, + debug = "", + ) -> Generator[Event, Any, Any]: + + query = request.payload if isinstance(request, Prompt) else request + yield PromptStarted(query, {"query": query}) + + searchQuery = request_context.get("query") + + searchResult = self.ragTool.search_knowledge_index(query=searchQuery,limit=5) + + allSections = CodeSections(sections={},search_query=query) + + for nextResult in searchResult: + file_path = nextResult["source_url"] + if not file_path in allSections.sections: + #print(nextResult) + + similarity_score = nextResult["distance"] if nextResult["distance"] else 0 + content = nextResult["content"] + + # Only works with Python files + included_defs = [] + try: + with open(file_path) as file: + node = ast.parse(file.read()) + included_defs = [n.name for n in node.body if isinstance(n, ast.ClassDef) or isinstance(n, ast.FunctionDef)] + except: + included_defs = [] + + allSections.sections[file_path] = CodeSection(search_result=content,file_path=file_path,included_defs=included_defs,similarity_score=similarity_score) + #else: + #print("Skipping Duplicate: ",file_path) + + yield TurnEnd(self.name, [{"content": allSections}]) diff --git a/examples/pr_agent/git_grep_agent.py b/examples/pr_agent/git_grep_agent.py new file mode 100644 index 00000000..3567f539 --- /dev/null +++ b/examples/pr_agent/git_grep_agent.py @@ -0,0 +1,141 @@ +from typing import Any, Generator, List +from agentic.common import Agent +from agentic.events import Event, Prompt, PromptStarted, TurnEnd +from agentic.models import GPT_4O_MINI # model (using GPT for testing) +import subprocess +import ast +import keyword +import logging + +from code_rag_agent import CodeSection, CodeSections + +def find_full_function(file_path: str, line_number: int) -> str: + """Finds the full function definition given a file path and line number. Expects a properly formatted file.""" + + SUPPORTED_EXTENSIONS = [".py"] + line_number -= 1 # line numbers start at 1, not 0, bad for native zero indexing + try: + with open(file_path) as file: + text = file.read() + except Exception as e: + return f"Error with file: {e}" + + if not any(file_path.endswith(ext) for ext in SUPPORTED_EXTENSIONS): + logging.warning("File is not a supported extension, returning full file") + return text + + if file_path.endswith(".py"): + lines = text.splitlines() + + if len(lines) < line_number: # move this out of the if statement later + logging.error("Line number is out of bounds, returning full file") + return text + + # this function is "good enough" -- if there is any function that is not defined by the keyword module called outside of a class or function, it will keep it + def is_a_zero(line: str) -> bool: + if not line or line[0] == " " or not keyword.iskeyword(line.split()[0]): + return False + return True + + # The idea is to find the full class definition the function is embedded in, so we need to go up and down until we find this + top_index = line_number + + while (top_index > 0 and not is_a_zero(lines[top_index])): + top_index -= 1 + + bottom_index = line_number + 1 + while (bottom_index < len(lines) and not is_a_zero(lines[bottom_index])): + bottom_index += 1 + + return "\n".join(lines[top_index:bottom_index]) + + return text + +# The actual sub-agent that runs git grep and returns structured results +class GitGrepAgent(Agent): + def __init__(self, + name="Git-Grep Agent", + welcome="I am the Git Grep Agent. Please give me a search query (function name,class name, etc.) and I'll return exact matches from the codebase.", + model: str=GPT_4O_MINI, + result_model = CodeSections, + **kwargs + ): + super().__init__( + name=name, + welcome=welcome, + model=model, + result_model=result_model, + **kwargs + ) + + + def run_git_grep(self, query: str) -> List[tuple[str, str]]: + # Runs "git grep -n " for the given query to find exact matches in the codebase + # parses each result line into (file_path, matched_line) both of which are strs + # and returns a list of (file_path, matched_line) tuples + try: + result = subprocess.run( + ["git", "grep", "-n", query], # make sure that query is getting passed by the Main Agent!!! + capture_output=True, + text=True, + check=False + ) + + + # example git grep output: "code_rag_agent.py:6:from agentic.tools.rag_tool import RAGTool" + + + # TODO: need to determine if the line number is neccessary returning... + matches = [] # list of matches from the git grep command --> will hold all (file_path, matched_line) tuples found! + for line in result.stdout.splitlines(): + if not line: + continue + parts = line.split(":", 2) # file_path, line_number, line_text + if len(parts) >= 3: # if the output line is in the correct format + file_path, line_number, matched_line = parts + matches.append((file_path, find_full_function(file_path, int(line_number)))) + return matches + except Exception as e: + print(f"Error running git grep: {e}") + return [] + + + + # the entry point for running one turn (input -> processing -> output) + def next_turn( + self, + request: str | Prompt, + request_context: dict = {}, + request_id: str = None, + continue_result: dict = {}, + debug = "", + ) -> Generator[Event, Any, Any]: + # same as for the code_rag_context + + + # Either use query from request_context or from direct input + query = request.payload if isinstance(request, Prompt) else request # extracts the query from the incoming request + yield PromptStarted(query, {"query": query}) # yields a PromptStarted event to signal the beginning of processing + + + search_query = request_context.get("query") # pulls the actual search query from the request context + grep_results = self.run_git_grep(search_query) # runs git grep for that specific query + + + # TODO: verify that sections doesn't have to be a dictionary instead (like code_rag_agent implementation) + allSections = CodeSections(sections={}, search_query=search_query) # creates an empty CodeSections object + + # loops over each grep match + for file_path, content in grep_results: + if not file_path in allSections.sections: + allSections.sections[file_path] = CodeSection( + search_result=content, + file_path=file_path, + included_defs=[], + similarity_score=1.0 # grep doesn't do semantic scoring + ) + + yield TurnEnd(self.name, [{"content": allSections}]) + +if __name__ == "__main__": + print(find_full_function("git_grep_agent.py", 172)) \ No newline at end of file diff --git a/examples/pr_agent/pr-summary-agent.yml b/examples/pr_agent/pr-summary-agent.yml new file mode 100644 index 00000000..dcb54287 --- /dev/null +++ b/examples/pr_agent/pr-summary-agent.yml @@ -0,0 +1,52 @@ +name: PR Summary Agent + +on: + pull_request: + types: [opened, synchronize, reopened] + branches: + - main +jobs: + run-pr-agent: + permissions: write-all + name: run PR Summary Agent + runs-on: ubuntu-latest + steps: + - name: Check out the repository to the runner + uses: actions/checkout@v4 + with: + fetch-depth: 2 + + - name: Install uv + uses: astral-sh/setup-uv@v6 + + - name: Download weaviate cache + id: download-artifact + uses: dawidd6/action-download-artifact@v11 + with: + github_token: ${{secrets.GITHUB_TOKEN}} + workflow: pr-summary-agent.yml + name: weaviate + if_no_artifact_found: warn + path: /home/runner/.cache/weaviate + + - name: Run agent + run: | + uv venv --python 3.12 + uv pip install -e "../${{ github.event.repository.name }}[all,dev]" --extra-index-url https://download.pytorch.org/whl/cpu --index-strategy unsafe-first-match + uv pip install litellm[proxy] + git diff --merge-base HEAD^1 HEAD > PRChanges.patch + cat PRChanges.patch + uv run pr_agent/PR_agent.py + + env: + OPENAI_API_KEY: ${{ secrets.PRAgentOpenAIKey }} + PR_ID: ${{ github.event.pull_request.number }} + GITHUB_API_KEY: ${{ secrets.GITHUB_TOKEN }} + REPO: ${{ github.repository }} + + - name: Update weaviate cache + uses: actions/upload-artifact@v4 + with: + name: weaviate + path: /home/runner/.cache/weaviate + overwrite: true diff --git a/examples/pr_agent/summary_agent.py b/examples/pr_agent/summary_agent.py new file mode 100644 index 00000000..8401bbc1 --- /dev/null +++ b/examples/pr_agent/summary_agent.py @@ -0,0 +1,100 @@ +from agentic.common import Agent +from agentic.models import CLAUDE +from agentic.models import GPT_4O_MINI # model (using GPT for testing) + +class SummaryAgent(Agent): + def __init__(self, + name="PR Summary Agent", + + # Agent instructions + instructions="""You are a code review assistant. Your task is to analyze a GitHub pull request using the provided information and generate helpful, precise feedback. Your response must include specific insights only about the files and code that were changed in the pull request. + +Intended Purpose +Given a pull request's patch file and supporting repository files for context, generate a high-quality comment summarizing the changes and providing constructive feedback. Focus exclusively on the changes introduced by the pull request. + +Input Format +You will be given: + +A patch file describing the code changes. + +Supporting code files from the repository that provide context for understanding the codebase. + +The format will be as follows: + + + + + + + + + + + + + +... +Only the contents within the tags represent the actual changes. The rest are context files and should only be used to understand the repository structure and functionality. Do not comment on context files unless they are included in the patch. + +Output Format +Respond with the following structured sections: + +Key Features +A list of important or high-level features introduced by the changes. + +Summary of Changes +A clear and concise explanation of what was changed in the pull request, written in plain language. + +New Unlocks from Functionality +Describe any new capabilities or usage scenarios unlocked by these changes. + +Code Suggestions with Line Number References +Provide specific suggestions for improving the changed code, referring to lines by number as seen in the patch. + +Formatting Suggestions +Note any formatting or stylistic improvements that should be made. + +Additional Notes +Include any other relevant insights, such as potential edge cases, compatibility issues, or tests that should be added. + +Important Rules + +Only refer to files/lines that are explicitly changed in the patch file. + +Use the provided file contents only to gain context for understanding the changes. + +Be constructive, concise, and clear in your feedback. + """, + model=GPT_4O_MINI, + #model=CLAUDE, # model + **kwargs + ): + super().__init__( + name=name, + instructions=instructions, + model=model, + **kwargs + ) + + +# Main to use the agent on the test files +if __name__ == "__main__": + context = "\n" + with open("PR_code_review-agent/pr_agent/test_files/test_patch_file.txt", "r") as file: + context += file.read() + context += "\n\n" + context += "\n" + with open("PR_code_review-agent/pr_agent/test_files/agent_runner_copy.txt", "r") as file: + context += file.read() + context += "\n\n" + context += "\n" + with open("PR_code_review-agent/pr_agent/test_files/agent_copy.txt", "r") as file: + context += file.read() + context += "\n\n" + context += "\n" + with open("PR_code_review-agent/pr_agent/test_files/weather_tool_copy.txt", "r") as file: + context += file.read() + context += "" + + agent = SummaryAgent() + print(agent << context) \ No newline at end of file diff --git a/pr_agent/PR_agent.py b/pr_agent/PR_agent.py deleted file mode 100644 index 7f6e7edf..00000000 --- a/pr_agent/PR_agent.py +++ /dev/null @@ -1,43 +0,0 @@ -# PR_agent.py -# Constructs a summary for a pull request - -from agentic.common import Agent, AgentRunner -from agentic.models import CLAUDE # model - -from dotenv import load_dotenv -import openai -import requests -import os - - -load_dotenv() # This loads variables from .env into os.environ -openai.api_key = os.getenv("OPENAI_API_KEY") # api key - - -# Define the agent -agent = Agent( - name="PR Summary Agent", - - # Agent instructions - instructions=""" - You are a helpful assistant that can summarize a pull request. - - - - """, - - model=CLAUDE, # model - tools=[], - memories=[] - -) - - - - - - - -# basic main function that allows us to run our agent locally in terminal -if __name__ == "__main__": - AgentRunner(agent).repl_loop() \ No newline at end of file diff --git a/pr_agent/rag_sub_agent.py b/pr_agent/rag_sub_agent.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pr_agent/summary_agent.py b/pr_agent/summary_agent.py deleted file mode 100644 index e1898d93..00000000 --- a/pr_agent/summary_agent.py +++ /dev/null @@ -1,87 +0,0 @@ -from agentic.common import Agent -from agentic.models import CLAUDE - -agent = Agent( - name="PR Summary Agent", - - # Agent instructions - instructions="""You are a Summary Agent responsible for reviewing GitHub pull requests. Your task is to analyze the provided patch file along with relevant context files from the repository and generate a helpful and precise comment for the pull request. - -Your Tasks: -Carefully analyze the provided data and complete the following outputs: - -Key Features -Summarize the most important or high-level features affected or introduced by the changes. -Summary of Changes -Concisely describe what was changed in the codebase based on the patch file. Focus only on the actual changes — do not comment on files that were not modified. -New Unlocks from Functionality -Describe any new capabilities or user-facing functionality that the changes enable. -Code Suggestions with Line Number References -Provide suggestions for improving or correcting the code, referencing the appropriate lines (line numbers from the patch file). Be specific and constructive. -Formatting Suggestions -Point out any code style or formatting issues in the changed lines only. Do not apply formatting critiques to unchanged code. -Additional Notes -Add any relevant observations, concerns, or questions that could help the author improve the PR or that might affect merging, such as missing tests or unclear logic. -Input Format: -The following data will be passed to you, clearly delimited: - -Comment - - -Patch file - - - - - - - - -... (additional files providing context) -You must use all relevant data available to infer meaning and context behind the code changes. However, do not generate feedback on files unless they appear in the patch file. - -Output Format: -Respond with structured sections using the following headers: - -Key Features: -... - -Summary of Changes: -... - -New Unlocks from Functionality: -... - -Code Suggestions with Line Number References: -... - -Formatting Suggestions: -... - -Additional Notes: -... - -Be precise, helpful, and technically insightful. Keep your tone professional and collaborative, as your output will be seen by developers during code review. - """, - - model=CLAUDE, # model -) - -# Main to use the agent on the test files -if __name__ == "__main__": - context = "Comment: \n" - with open("PR_code_review-agent/pr_agent/test_files/test_comment.txt", "r") as file: - context += file.read() - context += "\n\nPatch file: \n" - with open("PR_code_review-agent/pr_agent/test_files/test_patch_file.txt", "r") as file: - context += file.read() - context += "\n\nrunner.py\n" - with open("PR_code_review-agent/pr_agent/test_files/agent_runner_copy.txt", "r") as file: - context += file.read() - context += "\n\nactor_agents.py\n" - with open("PR_code_review-agent/pr_agent/test_files/agent_copy.txt", "r") as file: - context += file.read() - context += "\n\nweather_tool.py\n" - with open("PR_code_review-agent/pr_agent/test_files/weather_tool_copy.txt", "r") as file: - context += file.read() - print(agent << context) \ No newline at end of file diff --git a/pr_agent/test_files/agent_copy.txt b/pr_agent/test_files/agent_copy.txt deleted file mode 100644 index 6a4846a7..00000000 --- a/pr_agent/test_files/agent_copy.txt +++ /dev/null @@ -1,1834 +0,0 @@ -import asyncio -import inspect -import json -import litellm -import os -import re -import threading -import time -import traceback -import uuid -import yaml -from pprint import pprint - -from copy import deepcopy -from dataclasses import dataclass -from datetime import timedelta -from jinja2 import Template, DebugUndefined -from litellm.types.utils import Message -from pathlib import Path -from pydantic import BaseModel, ConfigDict -from queue import Queue -from typing import Any, Callable, List, Optional, Generator, Literal, Type - -from agentic.swarm.types import ( - agent_secret_key, - AgentFunction, - ChatCompletionMessage, - ChatCompletionMessageToolCall, - Function, - Response, - Result, - ThreadContext, - tool_name, - EVENT_QUEUE_KEY, -) -from agentic.swarm.util import ( - debug_print, - debug_completion_start, - debug_completion_end, - function_to_json, - looks_like_langchain_tool, - langchain_function_to_json, - wrap_llm_function, -) - -from agentic.events import ( - Event, - Prompt, - PromptStarted, - Output, - ChatOutput, - ToolCall, - ToolResult, - SubAgentCall, - SubAgentResult, - TurnCancelledError, - StartCompletion, - FinishCompletion, - FinishAgentResult, - TurnEnd, - SetState, - AddChild, - WaitForInput, - PauseForInputResult, - ResumeWithInput, - DebugLevel, - ToolError, - StartRequestResponse, - OAuthFlow, - OAuthFlowResult, - ReasoningContent, -) -from agentic.db.models import Thread, ThreadLog -from agentic.tools.utils.registry import tool_registry -from agentic.db.db_manager import DatabaseManager -from agentic.models import get_special_model_params, mock_provider - - -__CTX_VARS_NAME__ = "thread_context" -__LEGACY_CTX_VARS_NAME__ = "run_context" - -# define a CallbackType Enum with values: "handle_turn_start", "handle_event", "handle_turn_end" -CallbackType = Literal["handle_turn_start", "handle_event", "handle_turn_end"] - -# make a Callable type that expects a Prompt and ThreadContext -CallbackFunc = Callable[[Event, ThreadContext], None] - -@dataclass -class AgentPauseContext: - orig_history_length: int - tool_partial_response: Response - # sender: Optional[Actor] = None - tool_function: Optional[Function] = None - - -litellm.drop_params = True - -# Pick the agent runtime -if os.environ.get("AGENTIC_USE_RAY"): - # Use the actual Ray implementation - print("Using Ray engine for running agents") - import ray - -else: - print("Using simple Thread engine for running agents") - from .ray_mock import ray - -_AGENT_REGISTRY = [] - -@ray.remote -class ActorBaseAgent: - name: str = "Agent" - model: str = "gpt-4o" # Default model - instructions_str: str = "You are a helpful agent." - tools: list[str] = None - functions: List[AgentFunction] = None - tool_choice: str = None - parallel_tool_calls: bool = True - paused_context: Optional[AgentPauseContext] = None - debug: DebugLevel = DebugLevel(False) - depth: int = 0 - children: dict = {} - history: list = [] - # Memories are static facts that are always injected into the context on every turn - memories: list[str] = [] - # The Actor who sent us our Prompt - max_tokens: int = None - thread_context: ThreadContext = None - api_endpoint: str = None - _prompter = None - _callbacks: dict[CallbackType, CallbackFunc] = {} - result_model: Type[BaseModel]|None = None, - # Reasoning support - reasoning_effort: str = None # Can be "low", "medium", "high" or None - - model_config = ConfigDict( - arbitrary_types_allowed=True - ) - - def __init__(self, name: str): - super().__init__() - self.name = name - self.history: list = [] - - # Always register mock provider with litellm - litellm.custom_provider_map = [ - {"provider": "mock", "custom_handler": mock_provider} - ] - - def __repr__(self): - return self.name - - def _get_llm_completion( - self, - history: List, - thread_context: ThreadContext, - model_override: str, - stream: bool, - ) -> ChatCompletionMessage: - """Call the LLM completion endpoint""" - instructions = self.get_instructions(thread_context) - messages = [{"role": "system", "content": instructions}] + history - - tools = [function_to_json(f) for f in self.functions] - # hide thread_context from model - for tool in tools: - params = tool["function"]["parameters"] - params["properties"].pop(__CTX_VARS_NAME__, None) - if __CTX_VARS_NAME__ in params["required"]: - params["required"].remove(__CTX_VARS_NAME__) - params["properties"].pop(__LEGACY_CTX_VARS_NAME__, None) - if __LEGACY_CTX_VARS_NAME__ in params["required"]: - params["required"].remove(__LEGACY_CTX_VARS_NAME__) - - # Create parameters for litellm call - completion_params = { - "model": model_override or self.model, - "messages": messages, - "temperature": 0.0, # Will be adjusted if reasoning is enabled - "tools": tools or None, - "tool_choice": self.tool_choice, - "stream": stream, - "stream_options": {"include_usage": True}, - } - if self.result_model: - completion_params["response_format"] = self.result_model - - # Add reasoning support if model supports it - model_name = model_override or self.model - - # Check reasoning support with fallback for older LiteLLM versions - supports_reasoning = True # Default to True and let LiteLLM handle parameter validation - try: - if hasattr(litellm, 'supports_reasoning'): - supports_reasoning = litellm.supports_reasoning(model=model_name) - debug_print(self.debug.debug_all(), f"Model supports reasoning: {supports_reasoning}") - else: - debug_print(self.debug.debug_all(), f"LiteLLM version does not have supports_reasoning function, assuming support") - except Exception as e: - debug_print(self.debug.debug_all(), f"Error checking reasoning support: {e}, assuming support") - supports_reasoning = True - - # Debug reasoning support - if self.debug.debug_all(): - debug_print(self.debug.debug_all(), f"Checking reasoning support for model: {model_name}") - debug_print(self.debug.debug_all(), f"Reasoning effort: {self.reasoning_effort}") - - if self.reasoning_effort and supports_reasoning: - completion_params["reasoning_effort"] = self.reasoning_effort - debug_print(self.debug.debug_all(), f"Added reasoning_effort={self.reasoning_effort} to completion params") - - # Anthropic requires temperature=1 when reasoning is enabled - if "anthropic" in model_name.lower(): - completion_params["temperature"] = 1.0 - debug_print(self.debug.debug_all(), f"Set temperature=1.0 for Anthropic reasoning model") - elif self.reasoning_effort and not supports_reasoning: - debug_print(self.debug.debug_all(), f"Model {model_name} does not support reasoning. Skipping reasoning_effort parameter.") - - # Add any special parameters needed for specific model types - completion_params.update(get_special_model_params(completion_params["model"])) - - if self.max_tokens: - completion_params["max_tokens"] = self.max_tokens - - if tools: - completion_params["parallel_tool_calls"] = self.parallel_tool_calls - - # Create simplified version of params for debug logging - debug_params = completion_params.copy() - if debug_params.get("tools"): - debug_params["tools"] = [ - f["function"]["name"] for f in debug_params["tools"] - ] - - # Debug: Show all completion parameters - if self.debug.debug_all(): - debug_print(self.debug.debug_all(), f"Final completion params: {debug_params}") - - # Get model name - model_name = model_override or self.model - - # Import token estimation utilities - from agentic.utils.token_estimation import ( - should_compress_context, - create_compressed_messages - ) - - # Check if we need to compress context - needs_compression, current_tokens, max_allowed = should_compress_context( - messages=messages, - model=model_name, - safety_factor=0.3 # Use 30% safety margin - ) - - # Debug logging for token count - if self.debug.debug_all(): - print(f"[Token Count] Model: {model_name}, Current tokens: {current_tokens}, Max: {max_allowed}") - - # Compress context if needed - if needs_compression: - # Create compressed messages - truncated_messages = create_compressed_messages( - messages=messages, - model=model_name, - current_tokens=current_tokens, - debug=self.debug.debug_all() - ) - - # Update completion params with compressed messages - completion_params["messages"] = truncated_messages - - # Update history but preserve system message - self.history = [messages[0]] + truncated_messages[2:] - - debug_completion_start(self.debug, self.model, debug_params) - - try: - return litellm.completion(**completion_params) - except litellm.exceptions.ContextWindowExceededError as e: - # Emergency fallback - print(f"Emergency fallback: {str(e)}") - - # Keep only the system message and most recent message - emergency_messages = [messages[0], messages[-1]] - completion_params["messages"] = emergency_messages - self.history = emergency_messages - - # Try one more time with minimal context - return litellm.completion(**completion_params) - except Exception as e: - traceback.print_exc() - raise RuntimeError("Error calling LLM: " + str(e)) - - def _execute_tool_calls( - self, - tool_calls: List[ChatCompletionMessageToolCall], - functions: List[AgentFunction], - thread_context: ThreadContext, - ) -> tuple[Response, list[Event]]: - """When the LLM completion includes tool calls, now invoke the tool functions. - Returns the LLM processing response, and a list of events to publish - """ - - function_map = {f.__name__: f for f in functions} - partial_response = Response(messages=[], agent=None) - - events = [] - - for tool_call in tool_calls: - name = tool_call.function.name - is_subagent_call = False # Initialize for this tool call - target_agent = "" # Initialize target agent name - - # handle missing tool case, skip to next tool - if name not in function_map: - debug_print( - self.debug.debug_tools(), f"Tool {name} not found in function map." - ) - partial_response.messages.append( - { - "role": "tool", - "tool_call_id": tool_call.id, - "tool_name": name, - "content": f"Error: Tool {name} not found.", - } - ) - continue - - try: - args = json.loads(tool_call.function.arguments) - except Exception as e: - debug_print( - self.debug.debug_tools(), - f"Error parsing tool call arguments: {e}\n" - + f"Tool call: {tool_call.function.arguments}", - ) - args = {} - - func = function_map[name] - if __CTX_VARS_NAME__ in func.__code__.co_varnames: - args[__CTX_VARS_NAME__] = thread_context - if __LEGACY_CTX_VARS_NAME__ in func.__code__.co_varnames: - args[__LEGACY_CTX_VARS_NAME__] = thread_context - - # Check if this is a subagent call (includes both call_agent and handoff_to_agent) - is_subagent_call = ( - name in ["call_agent", "handoff_to_agent"] and - 'target_agent' in args and 'message' in args - ) - - if is_subagent_call: - # Extract target agent name from arguments (prefer display name if available) - target_agent = args.get('_target_agent_display_name', args.get('target_agent', 'Unknown Agent')) - message = args.get('message', str(args)) - events.append(SubAgentCall(self.name, target_agent, message, self.depth)) - else: - events.append(ToolCall( - agent=self.name, - name=name, - arguments=args, - depth=self.depth, - tool_call_id=tool_call.id - )) - - # Call the function!! - raw_result = None - try: - if asyncio.iscoroutinefunction(function_map[name]): - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - raw_result = loop.run_until_complete(function_map[name](**args)) - elif inspect.isgeneratorfunction(function_map[name]): - # We use our generator for our call_child function. I guess we could let user's - # write generate functions as long as they yield events. Or we could catch - # strings and wrap them as events. - for child_event in function_map[name](**args): - if isinstance(child_event, TurnEnd): - raw_result = child_event.result - events.append(child_event) - elif isinstance(child_event, Result): - raw_result = child_event - else: - events.append(child_event) - if raw_result is None: - # Take last event as the function result - raw_result = events.pop() - - elif inspect.isasyncgenfunction(function_map[name]): - # Thread the async function in an event loop and yield events - async def run_async_gen(): - async for event in function_map[name](**args): - events.append(event) - asyncio.run(run_async_gen()) - # take the last yielded value as the function result - raw_result = events.pop() - - else: - raw_result = function_map[name](**args) - except Exception as e: - tb_list = traceback.format_exception(type(e), e, e.__traceback__) - # Join all lines and split them to get individual lines - full_traceback = "".join(tb_list).strip().split("\n") - # Get the last 3 lines (or all if less than 3) - if self.debug.debug_all(): - last_three = full_traceback - else: - last_three = ( - full_traceback[-3:] if len(full_traceback) >= 3 else full_traceback - ) - raw_result = f"Tool error: {name}: {last_three}" - - events.append(ToolError( - agent=self.name, - name=name, - error=raw_result, - depth=self.depth, - tool_call_id=tool_call.id - )) - - # Let tools return additional events to publish - if isinstance(raw_result, list): - for result in raw_result: - if isinstance(result, Event): - events.append(result) - raw_result = [result for result in raw_result if not isinstance(result, Event)] - if len(raw_result) == 0: - raw_result = "" - - result: Result = ( - raw_result - if isinstance(raw_result, Result) - else Result(value=str(raw_result)) - ) - - result.tool_function = Function( - name=name, - arguments=tool_call.function.arguments, - _request_id=tool_call.id, - ) - - # Functions can queue log events when they run, and we publish after - for log_event in thread_context.get_logs(): - events.append(log_event) - thread_context.reset_logs() - - # Check if this was a subagent call to emit appropriate result event - if is_subagent_call: - events.append(SubAgentResult(self.name, target_agent, result.value, self.depth)) - else: - events.append(ToolResult( - agent=self.name, - name=name, - result=result.value, - depth=self.depth, - intermediate_result=False, - tool_call_id=tool_call.id - )) - - partial_response.messages.append( - { - "role": "tool", - "tool_call_id": tool_call.id, - "name": name, - "content": result.value, - } - ) - partial_response.last_tool_result = result - # This was the simple way that Swarm did handoff - if result.agent: - partial_response.agent = result.agent - - return partial_response, events - - def handle_prompt_or_resume(self, actor_message: Prompt | ResumeWithInput): - request_id = getattr(actor_message, 'request_id', None) - if not request_id: - raise ValueError("Request ID is required") - - if isinstance(actor_message, Prompt): - self.thread_context = ( - ThreadContext( - agent_name=self.name, - agent=self, - debug_level=actor_message.debug, - api_endpoint=self.api_endpoint, - context=actor_message.request_context, - ) - if self.thread_context is None - else self.thread_context.update(actor_message.request_context) - ) - if not self.thread_context.thread_id and "thread_id" in actor_message.request_context: - self.thread_context.thread_id = actor_message.request_context["thread_id"] - - # Middleware to modify the input prompt (or change agent context) - if self._callbacks.get('handle_turn_start'): - self._callbacks['handle_turn_start'](actor_message, self.thread_context) - - self.debug = actor_message.debug - self.depth = actor_message.depth - self.history.append({"role": "user", "content": actor_message.payload}) - yield PromptStarted(self.name, {"content": actor_message.payload}, self.depth) - - elif isinstance(actor_message, ResumeWithInput): - if not self.paused_context: - self.thread_context.debug( - "Ignoring ResumeWithInput event, parent not paused: ", - actor_message, - ) - return - - init_len = self.paused_context.orig_history_length - self.thread_context.update(actor_message.request_keys.copy()) - - tool_function = self.paused_context.tool_function - if tool_function is None: - raise RuntimeError("Tool function not found on AgentResume event") - - partial_response, events = self._execute_tool_calls( - [ChatCompletionMessageToolCall( - id=(tool_function._request_id or ""), - function=tool_function, - type="function")], - self.functions, - self.thread_context - ) - yield from events - self.history.extend(partial_response.messages) - - # Main conversation loop: allow 25 tool calls in a row before stopping - init_len = len(self.history) - while len(self.history) - init_len < 50: - for event in self._yield_completion_steps(request_id): - # Wait to yield the FinishCompletion until after tool calls are executed - if not isinstance(event, FinishCompletion): - yield event - - assert isinstance(event, FinishCompletion) - response: Message = event.response - - self.history.append(response) - if not response.tool_calls: - # Now yield if no tool calls - yield event - break - - partial_response, events = self._execute_tool_calls( - response.tool_calls, - self.functions, - self.thread_context - ) - yield from events - # Now yield since the tool calls were executed - yield event - - # Yield to our publisher thread which will send queued events out to the client - time.sleep(0) - - if partial_response.last_tool_result: - if isinstance(partial_response.last_tool_result, PauseForInputResult): - self.paused_context = AgentPauseContext( - orig_history_length=init_len, - tool_partial_response=partial_response, - tool_function=partial_response.last_tool_result.tool_function - ) - yield WaitForInput(self.name, partial_response.last_tool_result.request_keys) - return - elif isinstance(partial_response.last_tool_result, OAuthFlowResult): - self.paused_context = AgentPauseContext( - orig_history_length=init_len, - tool_partial_response=partial_response, - tool_function=partial_response.last_tool_result.tool_function - ) - # Add tool result message before yielding OAuthFlow event - self.history.extend([{ - "role": "tool", - "content": "OAuth authentication required. Please complete the authorization flow.", - "tool_call_id": partial_response.last_tool_result.tool_function._request_id, - "name": partial_response.last_tool_result.tool_function.name - }]) - yield OAuthFlow( - self.name, - partial_response.last_tool_result.auth_url, - partial_response.last_tool_result.tool_name, - depth=self.depth - ) - return - - elif FinishAgentResult.matches_sentinel(partial_response.messages[-1]["content"]): - self.history.extend(partial_response.messages) - break - - self.history.extend(partial_response.messages) - - # We have already emitted history for intervening events, and I think we just look at the - # last message from TurnEnd anyway. So probably we just want to publish a single result here. - # You can see it in TurnEnd.result which just returns the "content" part of the last message. - yield TurnEnd( - self.name, - # result_model gets applied when TurnEnd is processed. We dont want to alter the the text response in history - deepcopy(self.history[init_len:]), - self.depth - ) - self.paused_context = None - - def _yield_completion_steps(self, request_id: str): - yield StartCompletion(self.name, self.depth) - - self._callback_params = {} - - def custom_callback(kwargs, completion_response, start_time, end_time): - try: - self._callback_params["cost"] = kwargs["response_cost"] - except: - pass - self._callback_params["elapsed"] = end_time - start_time - - litellm.success_callback = [custom_callback] - - try: - completion = self._get_llm_completion( - history=self.history, - thread_context=self.thread_context, - model_override=None, - stream=True, - ) - except RuntimeError as e: - yield FinishCompletion.create( - self.name, - Message(content=str(e), role="assistant"), - self.model, - 0, - 0, - timedelta(0), - self.depth - ) - return - - chunks = [] - for chunk in completion: - chunks.append(chunk) - delta = json.loads(chunk.choices[0].delta.model_dump_json()) - if delta["role"] == "assistant": - delta["sender"] = self.name - if not delta.get("tool_calls") and delta.get("content"): - yield ChatOutput(self.name, delta, self.depth) - delta.pop("role", None) - delta.pop("sender", None) - - llm_message = litellm.stream_chunk_builder(chunks, messages=self.history) - input = self.history[-1:] - output = llm_message.choices[0].message - - # Extract reasoning content if available - # Try multiple locations as different providers may place it differently - choice = llm_message.choices[0] - message = choice.message - - # Try to extract reasoning content from different locations - reasoning_content = None - - # Check choice level first - reasoning_content = getattr(choice, "reasoning_content", None) - - # If not found on choice level, check message level - if not reasoning_content: - reasoning_content = getattr(message, "reasoning_content", None) - - - # Debug: Check if we have reasoning content - if self.debug.debug_all(): - debug_print(self.debug.debug_all(), f"Checking for reasoning content...") - debug_print(self.debug.debug_all(), f"Choice object attributes: {[attr for attr in dir(choice) if not attr.startswith('_')]}") - debug_print(self.debug.debug_all(), f"Message object attributes: {[attr for attr in dir(message) if not attr.startswith('_')]}") - - # Debug the actual choice and message objects - debug_print(self.debug.debug_all(), f"Choice object dict: {choice.model_dump() if hasattr(choice, 'model_dump') else str(choice)}") - debug_print(self.debug.debug_all(), f"Message object dict: {message.model_dump() if hasattr(message, 'model_dump') else str(message)}") - - debug_print(self.debug.debug_all(), f"Reasoning content: {reasoning_content}") - - - # Also check the full llm_message structure - debug_print(self.debug.debug_all(), f"Full LLM message structure: {llm_message.model_dump() if hasattr(llm_message, 'model_dump') else str(llm_message)}") - - # Get usage directly from response - usage = getattr(llm_message, "usage", None) - if usage: - self._callback_params["input_tokens"] = usage.prompt_tokens - self._callback_params["output_tokens"] = usage.completion_tokens - else: - # Fallback to manual calculation if usage not in response - if len(input) > 0: - self._callback_params["input_tokens"] = litellm.token_counter( - self.model, messages=self.history[-1:] - ) - if output.content: - self._callback_params["output_tokens"] = litellm.token_counter( - self.model, text=output.content - ) - - debug_completion_end(self.debug, self.model, llm_message.choices[0].message) - - # Emit reasoning content FIRST as a separate event if present - if reasoning_content: - yield ReasoningContent( - self.name, - reasoning_content, - self.depth - ) - - yield FinishCompletion.create( - self.name, - llm_message.choices[0].message, - self.model, - self._callback_params.get("cost", 0), - self._callback_params.get("input_tokens"), - self._callback_params.get("output_tokens"), - self._callback_params.get("elapsed"), - self.depth, - reasoning_content=reasoning_content - ) - - def call_child( - self, - child_ref, - handoff: bool, - message, - ): - depth = self.depth if handoff else self.depth + 1 - if hasattr(child_ref.handle_prompt_or_resume, 'remote'): - remote_gen = child_ref.handle_prompt_or_resume.remote( - Prompt( - self.name, - message, - depth=depth, - debug=self.debug, - ) - ) - else: - remote_gen = child_ref.handle_prompt_or_resume( - Prompt( - self.name, - message, - depth=depth, - debug=self.debug, - request_context=self.thread_context.get_context(), - request_id=str(uuid.uuid4()) - ) - ) - - for remote_event in remote_gen: - event = ray.get(remote_event) - yield event - - if handoff: - # by definition we don't care about remembering the child result since - # the parent is gonna end anyway - yield FinishAgentResult() - - def _build_child_func(self, event: AddChild) -> Callable: - name = event.agent - llm_name = f"call_{name.lower().replace(' ', '_')}" - doc = f"Send a message to sub-agent {name}" - - return wrap_llm_function( - llm_name, doc, self.call_child, event.remote_ref, event.handoff - ) - - def add_child(self, actor_message: AddChild): - self.add_tool(actor_message) - - def add_tool(self, tool_func_or_cls): - if isinstance(tool_func_or_cls, AddChild): - tool_func_or_cls = self._build_child_func(tool_func_or_cls) - - if looks_like_langchain_tool(tool_func_or_cls): - # Langchain tools which are single functions in a whole class inheriting from BaseTool - self.functions.append(langchain_function_to_json(tool_func_or_cls)) - self.tools.append(self.functions[-1].__name__) - else: - if callable(tool_func_or_cls): - self.functions.append(tool_func_or_cls) - self.tools.append(self.functions[-1].__name__) - else: - if hasattr(tool_func_or_cls, "get_tools"): - self.functions.extend(tool_func_or_cls.get_tools()) - self.tools.append(tool_func_or_cls.__class__.__name__) - else: - print("ERROR: ", f"Tool {tool_func_or_cls} is not a callable, nor has 'get_tools' method") - - def reset_history(self): - self.history = [] - - def get_history(self): - return self.history - - def inject_secrets_into_env(self): - """Ensure the appropriate API key is set for the given model.""" - from agentic.agentic_secrets import agentic_secrets - - for key in agentic_secrets.list_secrets(): - if key not in os.environ: - value = agentic_secrets.get_secret(key) - if value: - os.environ[key] = value - - def get_instructions(self, context: ThreadContext): - # Support context var substitution in prompts - try: - prompt = Template( - self.instructions_str, - undefined=DebugUndefined - ).render( - context.get_context() - ) - if self.memories: - prompt += """ - - {% for memory in MEMORIES -%} - {{memory|trim}} - {%- endfor %} - - """ - # Fix for Jinja2 template rendering error if there is an unclosed comment, ensure all `{#` have a closing `#}` - while re.search(r"\{#(?!.*#\})", prompt, re.DOTALL): - prompt = re.sub(r"\{#(?!.*#\})", "{% raw %}{#{% endraw %}", prompt, count=1) - - return Template(prompt).render( - context.get_context() | {"MEMORIES": self.memories} - ) - except Exception as e: - print("Error in prompt template, using raw prompt without subsitutions:", e) - traceback.print_exc() - return prompt - - def set_state(self, actor_message: SetState): - self.inject_secrets_into_env() - state = actor_message.payload - remap = {"instructions": "instructions_str"} - - for key in [ - "name", - "instructions", - "model", - "max_tokens", - "memories", - "api_endpoint", - "result_model", - "reasoning_effort", - ]: - if key in state: - setattr(self, remap.get(key, key), state[key]) - - if 'history' in state: - self.history = state["history"] - - if "handle_turn_start" in state: - self._callbacks["handle_turn_start"] = state["handle_turn_start"] - - # Update our functions - if "functions" in state: - self.functions = [] - self.tools = [] - for f in state.get("functions"): - self.add_tool(f) - - return Output(self.name, f"State updated: {actor_message.payload}", self.depth) - - def set_debug_level(self, debug: DebugLevel): - self.debug = debug - print("agent set new debug level: ", debug) - - def get_callback(self, key: CallbackType) -> Optional[CallbackFunc]: - return self._callbacks.get(key) - - def set_callback(self, key: CallbackType, callback: Optional[CallbackFunc]): - if callback is None: - self._callbacks.pop(key, None) - else: - self._callbacks[key] = callback - - def list_tools(self) -> list[str]: - return self.tools - - def list_functions(self) -> list[str]: - def get_name(f): - if hasattr(f, "__name__"): - return f.__name__ - elif isinstance(f, dict): - return f["name"] - else: - return str(f) - - return [get_name(f) for f in self.functions] - - def handle_request(self, method: str, data: dict): - return f"Actor {self.name} processed {method} request with data: {data}" - - def webhook(self, thread_id: str, callback_name: str, args: dict) -> Any: - """Handle webhook callbacks by executing the specified tool function - - Args: - thread_id: ID of the agent thread this webhook is for - callback_name: Name of the tool function to call - args: Arguments to pass to the tool function - """ - # Get the thread context from the database - db_manager = DatabaseManager() - thread = db_manager.get_thread(thread_id) - if not thread: - raise ValueError(f"No thread found with ID {thread_id}") - # Recreate thread context - self.thread_context = ThreadContext( - agent=self, - agent_name=self.name, - debug_level=self.debug, - thread_id=thread_id, - api_endpoint=self.api_endpoint - ) - # Find the tool function - function_map = {f.__name__: f for f in self.functions} - if callback_name not in function_map: - raise ValueError(f"No tool function found named {callback_name}") - - # Execute the tool call - try: - # Create tool call object - tool_call = ChatCompletionMessageToolCall( - id="", - type="function", - function=Function( - name=callback_name, - arguments=json.dumps({"webhook_data":args}) - ) - ) - - # Execute the tool call - response, events = self._execute_tool_calls( - [tool_call], - self.functions, - self.thread_context - ) - return response - - except Exception as e: - raise RuntimeError(f"Error executing webhook {callback_name}: {str(e)}") - - # Add new methods to set mock configuration - def set_mock_params(self, pattern: str, response: str, tools: dict): - """Store mock parameters in the agent instance""" - # Import here to avoid circular imports - from agentic.models import mock_provider - - # Apply to the mock provider (happens in this worker process) - mock_provider.set_response(pattern, response) - mock_provider.clear_tools() - for name, tool in tools.items(): - mock_provider.register_tool(name, tool) - - # Make sure custom provider is registered in litellm - litellm.custom_provider_map = [ - {"provider": "mock", "custom_handler": mock_provider} - ] - -class HandoffAgentWrapper: - def __init__(self, agent): - self.agent = agent - - def get_agent(self): - return self.agent - - -def handoff(agent, **kwargs): - """Signal that a child agent should take over the execution context instead of being - called as a subroutine.""" - return HandoffAgentWrapper(agent) - -class ProcessRequest(BaseModel): - prompt: str - debug: Optional[str] = None - thread_id: Optional[str] = None - -class ResumeWithInputRequest(BaseModel): - continue_result: dict[str, str] - debug: Optional[str] = None - thread_id: Optional[str] = None - - -depthLocal = threading.local() -depthLocal.depth = -1 - -# The common agent proxy interface -# The core of the interface is 'start_request' and 'get_events'. Use these in -# pairs to request operation threads from the agent. -# It is deprecated to call 'next_turn' directly now. -# -# Subclasses can override next_turn to do their own orchestration logic. - -class BaseAgentProxy: - """Base agent proxy class with common functionality. Manages multiple parallel - requests, delegating each request to an instance of the agent class. - The proxy keeps a thread queue to dispatch agent events to the caller. - Subclasses will handle specific implementation details for different - execution environments (Ray, local, etc.) - """ - _agent: Any - - def __init__( - self, - name: str, - instructions: str | None = "You are a helpful assistant.", - welcome: str | None = None, - tools: list = None, - model: str | None = None, - template_path: str | Path | None = None, - max_tokens: int = None, - db_path: Optional[str | Path] = "./agent_threads.db", - memories: list[str] = [], - handle_turn_start: Callable[[Prompt, ThreadContext], None] = None, - result_model: Type[BaseModel]|None = None, - debug: DebugLevel = DebugLevel(os.environ.get("AGENTIC_DEBUG") or ""), - mock_settings: dict = None, - prompts: Optional[dict[str, str]] = None, - reasoning_effort: str = None, - ): - self.name = name - self.welcome = welcome or f"Hello, I am {name}." - self.model = model or "gpt-4o-mini" - self.prompts = prompts or {} - self.cancelled = False - self.mock_settings = mock_settings - self.reasoning_effort = reasoning_effort - - # Find template path if not provided - from agentic.utils.template import find_template_path - self.template_path = template_path or find_template_path() - - # Setup tools and other properties - self._tools = [] - if tools: - self._tools.extend(tools) - - self.max_tokens = max_tokens - self.memories = memories - self.debug = debug - self._handle_turn_start = handle_turn_start - self.request_queues: dict[str,Queue] = {} - self.result_model = result_model - self.queue_done_sentinel = "QUEUE_DONE" - - # Track active agent instances by request ID - self.agent_instances = {} - - # Process instructions - if instructions and instructions.strip(): - template = Template(instructions, undefined=DebugUndefined) - self.instructions = template.render(**self.prompt_variables) - # Allow one level of nested references - self.instructions = Template(self.instructions, undefined=DebugUndefined).render(**self.prompt_variables) - if self.instructions.strip() == "": - raise ValueError( - f"Instructions are required for {self.name}. Maybe interpolation failed from: {instructions}" - ) - else: - self.instructions = "You are a helpful assistant." - - # Check we have all the secrets - self._ensure_tool_secrets() - - # Initialize thread tracking - self.db_path = db_path - self.thread_id = None # This will be set per request - - # Ensure API key is set - self.ensure_api_key_for_model(self.model) - - # Handle mock settings - subclasses should implement this - self._handle_mock_settings(mock_settings) - - def _check_for_prompt_match(self, user_input: str) -> str: - """Check if user input matches a prompt key and return the corresponding content if it does.""" - if not self.prompts: - return user_input - - # Check if the input exactly matches a prompt key - if user_input in self.prompts: - return self.prompts[user_input] - - # Check if the input matches a prompt key when lowercase - lower_input = user_input.lower() - for key, value in self.prompts.items(): - if lower_input == key.lower(): - return value - - # No match found, return original input - return user_input - - def _handle_mock_settings(self, mock_settings): - """Handle mock settings - to be implemented by subclasses""" - pass - - def _ensure_tool_secrets(self): - """Ensure that all required secrets for tools are available""" - from .agentic_secrets import agentic_secrets - - for tool in self._tools: - if hasattr(tool, "required_secrets"): - for key, help in tool.required_secrets().items(): - value = agentic_secrets.get_secret( - agent_secret_key(self.name, key), - agentic_secrets.get_secret(key, os.environ.get(key)), - ) - if not value: - value = input(f"{tool_name(tool)} requires {help}: ") - if value: - agentic_secrets.set_secret( - agent_secret_key(self.name, key), value - ) - else: - raise ValueError( - f"Secret {key} is required for tool {tool_name(tool)}" - ) - - def cancel(self): - """Flag this agent to cancel whatever it is doing""" - self.cancelled = True - - def is_cancelled(self): - """Check if this agent has been cancelled""" - return self.cancelled - - def uncancel(self): - """Reset the cancelled flag""" - self.cancelled = False - - def add_tool(self, tool: Any): - """Add a tool to this agent""" - self._tools.append(tool) - self._update_state({"functions": self._get_funcs(self._tools)}) - - def add_child(self, child_agent): - """Add a child agent as a tool""" - self.add_tool(child_agent) - - def set_model(self, model: str): - """Set the model to use for this agent""" - self.model = model - self._update_state({"model": model}) - - def set_debug_level(self, level: DebugLevel): - """Set the debug level for this agent""" - self.debug.raise_level(level) - self._set_agent_debug_level(self.debug) - - def set_result_model(self, model: Type[BaseModel]): - """Set the result model for this agent""" - self.result_model = model - self._update_state({"result_model": model}) - - def reset_history(self): - """Reset the conversation history""" - self._reset_agent_history() - - def get_history(self): - """Get the conversation history""" - return self._get_agent_history() - - def init_thread_tracking(self, agent, thread_id: Optional[str] = None): - """Initialize thread tracking""" - pass - - def get_db_manager(self) -> DatabaseManager: - """Get the database manager for this agent""" - if self.db_path: - db_manager = DatabaseManager(self.db_path) - else: - db_manager = DatabaseManager() - return db_manager - - def get_threads(self, user_id: str|None) -> list[Thread]: - """Get all threads for this agent""" - db_manager = self.get_db_manager() - - try: - return db_manager.get_threads_by_agent(self.name, user_id=user_id) - except Exception as e: - print(f"Error getting threads: {e}") - return [] - - def get_thread_logs(self, thread_id: str) -> list[ThreadLog]: - """Get logs for a specific thread""" - db_manager = self.get_db_manager() - - try: - return db_manager.get_thread_logs(thread_id) - except Exception as e: - print(f"Error getting thread logs: {e}") - return [] - - @property - def prompt_variables(self) -> dict: - """Dictionary of variables to make available to prompt templates.""" - if self.template_path is None: - return {"name": self.name} # Return default values when no template path exists - - path = Path(self.template_path) - if path.exists(): - try: - with open(path, "r") as f: - prompts = yaml.safe_load(f) - return prompts or {"name": self.name} - except Exception as e: - print(f"Error loading prompt template: {e}") - - return {"name": self.name} - - @property - def safe_name(self) -> str: - """Renders the Agent's name, but filesystem safe.""" - return "".join(c if c.isalnum() else "_" for c in self.name).lower() - - def ensure_api_key_for_model(self, model: str): - """Ensure the appropriate API key is set for the given model.""" - from agentic.agentic_secrets import agentic_secrets - - for key in agentic_secrets.list_secrets(): - if key not in os.environ: - value = agentic_secrets.get_secret(key) - if value: - os.environ[key] = value - - def _get_funcs(self, thefuncs: list): - """Get the functions to provide to the agent implementation""" - useable = [] - for func in thefuncs: - if callable(func): - tool_registry.ensure_dependencies(func) - useable.append(func) - elif isinstance(func, HandoffAgentWrapper): - # add a child agent as a tool - useable.append( - AddChild( - func.get_agent().name, - func.get_agent()._agent, - handoff=True - ) - ) - elif isinstance(func, BaseAgentProxy): - useable.append( - AddChild( - func.name, - func._agent, - ) - ) - else: - tool_registry.ensure_dependencies(func) - useable.append(func) - - return useable - - def _update_state(self, state: dict): - """Update the agent's state""" - # To be overridden by subclasses - pass - - def _set_agent_debug_level(self, debug_level): - """Set the debug level on the agent implementation""" - # To be overridden by subclasses - pass - - def _reset_agent_history(self): - """Reset the agent's conversation history""" - # To be overridden by subclasses - pass - - def _get_agent_history(self): - """Get the agent's conversation history""" - # To be overridden by subclasses - pass - - def _reload_thread_history(self, thread_id: str): - # We load the thread history from the ThreadManager, and pass it to the agent. - # We have to keep a flag to avoid loading all of history every time that 'start_request' is - # is called. But the agent also has logic to only load its history once. - from .thread_manager import reconstruct_chat_history_from_thread_logs, validate_chat_history - - if thread_id == 'NEW': - history = [] - else: - history = validate_chat_history( - reconstruct_chat_history_from_thread_logs(self.get_thread_logs(thread_id)) - ) - update = {"history": history} - self._update_state(update) - - def _create_agent_instance(self, request_id: str): - """Create a new agent instance for a request""" - # This is implemented by the subclasses (e.g., RayAgentProxy, LocalAgentProxy) - raise NotImplementedError("Subclasses must implement _create_agent_instance") - - def _get_agent_for_request(self, request_id: str): - """Get the agent instance for a request, creating it if needed""" - # The logic here is to keep reusing the default '_agent' value created when the Proxy is - # first constructed. We only go to create a new instance if a request is started before - # the prior one finishes. - # TODO: RESOLVE THIS TO WORK WITH RESUME WITH INPUT - # if len(self.agent_instances) == 0: - self.agent_instances[request_id] = self._agent or self._create_agent_instance(request_id) - # else: - # self.agent_instances[request_id] = self._create_agent_instance(request_id) - return self.agent_instances[request_id] - - def _cleanup_agent_instance(self, request_id: str): - """Clean up an agent instance after a request is complete""" - # We remove the agent from our set, but the _agent default instance will stay around - if request_id in self.agent_instances: - del self.agent_instances[request_id] - - def start_request(self, request: str, request_context: dict = {}, - continue_result: dict = {}, thread_id: Optional[str] = None, - debug: DebugLevel = DebugLevel(DebugLevel.OFF)) -> StartRequestResponse: - """Start a new agent request""" - self.debug.raise_level(debug) - - if not hasattr(depthLocal, 'depth'): - depthLocal.depth = 0 - else: - depthLocal.depth += 1 - - for key, value in request_context.items(): - if isinstance(value, Callable): - request_context[key] = value() - - if isinstance(request, str): - request = self._check_for_prompt_match(request) - - if not thread_id and "thread_id" in request_context: - thread_id = request_context["thread_id"] - - # Create request ID if not provided in continue_result - request_id = continue_result.get("request_id") or str(uuid.uuid4()) - - agent_instance = self._get_agent_for_request(request_id) - if (self.thread_id != thread_id or not self.thread_id) and self.db_path: - if thread_id is not None: - self._reload_thread_history(thread_id) - - self.init_thread_tracking(agent_instance, thread_id or self.thread_id) - - # Initialize new request - queue = Queue() - # pass the queue down ultimate to ThreadContext by putting it in the request_context - request_context[EVENT_QUEUE_KEY] = queue - - request_obj = Prompt( - self.name, - request, - debug=self.debug, - depth=depthLocal.depth, - request_context=request_context, - request_id=request_id, - ) - - def producer(queue, request_obj, continue_result): - depthLocal.depth = request_obj.depth - for event in self._next_turn(request_obj, request_context=request_context, continue_result=continue_result, request_id=request_id): - queue.put(event) - queue.put(self.queue_done_sentinel) - # Cleanup the agent instance when done - self._cleanup_agent_instance(request_id) - - self.request_queues[request_id] = queue - - t = threading.Thread(target=producer, args=(queue, request_obj, continue_result)) - t.start() - return StartRequestResponse(request_id=request_id, thread_id=self.thread_id) - - def get_events(self, request_id: str, timeout: Optional[float] = None) -> Generator[Event, Any, Any]: - """Get events for a request""" - queue = self.request_queues[request_id] - start_time = time.time() - - while True: - try: - if timeout is not None: - remaining_time = timeout - (time.time() - start_time) - if remaining_time <= 0: - # Timeout exceeded, set sentinel and return - queue.put(self.queue_done_sentinel) - break - event = queue.get(timeout=remaining_time) - else: - event = queue.get() - - if event == self.queue_done_sentinel: - break - - yield event - # yield to subscriber thread so it can publish events to clients - time.sleep(0) - - except Exception: - # Timeout or other exception, set sentinel and break - queue.put(self.queue_done_sentinel) - break - - depthLocal.depth -= 1 - - def next_turn(self, request: str | Prompt, request_context: dict = {}, - request_id: str = None, continue_result: dict = {}, - debug: DebugLevel = DebugLevel(DebugLevel.OFF)) -> Generator[Event, Any, Any]: - """ - Default agent orchestration logic. Subclasses may override this. - If not overridden, this handles prompt/resume and returns generator from agent. - """ - # Get agent instance - agent_instance = self._get_agent_for_request(request_id) - - # Prepare the prompt or resume input - if not continue_result: - prompt = ( - request if isinstance(request, Prompt) - else Prompt( - self.name, - request, - debug=debug, - request_context=request_context, - request_id=request_id, - ) - ) - - - # Transmit depth through the Prompt - if hasattr(depthLocal, 'depth') and depthLocal.depth > prompt.depth: - prompt.depth = depthLocal.depth - - return self._get_prompt_generator(agent_instance, prompt) - - else: - resume_input = ResumeWithInput( - self.name, - continue_result, - request_id=request_id - ) - return self._get_resume_generator(agent_instance, resume_input) - - - def _next_turn(self, request: str | Prompt, request_context: dict = {}, - request_id: str = None, continue_result: dict = {}, - debug: DebugLevel = DebugLevel(DebugLevel.OFF)) -> Generator[Event, Any, Any]: - """ - Wraps `next_turn` to add thread tracking and handle_event logging. - Always used internally by the proxy to ensure consistent behavior. - """ - self.cancelled = False - self.debug.raise_level(debug) - - if not request_id: - request_id = continue_result.get("request_id") or str(uuid.uuid4()) - if isinstance(request, Prompt): - request.request_id = request_id - - self._handle_mock_settings(self.mock_settings) - - if not self.thread_id and "thread_id" in request_context: - self.thread_id = request_context["thread_id"] - - # Get agent instance for the thread - agent_instance = self._get_agent_for_request(request_id) - - # Initialize thread tracking if needed - if (not self.thread_id) and self.db_path: - self.init_thread_tracking(agent_instance, self.thread_id) - - # Add thread_id into context explicitly so child agents inherit it - request_context = {**request_context, "thread_id": self.thread_id} - - # Call the user’s or default next_turn - event_gen = self.next_turn( - request=request, - request_context=request_context, - request_id=request_id, - continue_result=continue_result, - debug=debug - ) - - # Central logging of all events - for event in self._process_generator(event_gen): - if self.cancelled: - raise TurnCancelledError() - - # Handle TurnEnd result validation - if isinstance(event, TurnEnd): - event = self._process_turn_end(event) - - yield event - - if hasattr(event, "agent") and event.agent != self.name: - # skipping event with wrong agent name - continue - - # Only now: do logging after yielding - callback = self._agent.get_callback("handle_event") if hasattr(self, "_agent") else None - if callback: - context = ThreadContext(agent=self._agent, agent_name=self.name, thread_id=self.thread_id, context=request_context) - try: - callback(event, context) - except Exception as e: - print(f"Error in handle_event callback: {e}") - - - def _get_prompt_generator(self, agent_instance, prompt): - """Get generator for a new prompt - to be implemented by subclasses""" - pass - - def _get_resume_generator(self, agent_instance, resume_input): - """Get generator for resuming with input - to be implemented by subclasses""" - pass - - def _process_generator(self, generator): - """Process generator events - to be implemented by subclasses""" - pass - - def _process_turn_end(self, event): - """Process TurnEnd event to handle result model validation""" - if isinstance(event.result, str) and self.result_model: - try: - event.set_result(self.result_model.model_validate_json(event.result)) - except Exception as e: - try: - # Hack for LLM poorly parsing Claude structured outputs - data = json.loads(event.result) - if 'values' in data: - event.set_result(self.result_model.model_validate(data['values'])) - except Exception as e: - # Create an error message event - error_event = ChatOutput.assistant_message( - self.name, - f"Error validating result: {e}", - depth=event.depth - ) - # We'll yield this error event and then the original event - return error_event - return event - - def final_result(self, request: str, request_context: dict = {}, - event_handler: Callable[[Event], None] = None) -> Any: - """Get the final result of a request""" - request_id = self.start_request( - request, - request_context=request_context, - debug=self.debug - ).request_id - turn_end = None - for event in self.get_events(request_id): - if event_handler: - event_handler(event) - yield event - if isinstance(event, TurnEnd): - turn_end = event - if turn_end: - return turn_end.result - else: - return event - - def grab_final_result(self, request: str, request_context: dict = {}) -> Any: - """Convenience method to get the final result of a request""" - try: - items = list(self.final_result(request, request_context)) - if isinstance(items[-1], TurnEnd): - return items[-1].result - else: - return items[-1] - except StopIteration as e: - return e.value - - def __lshift__(self, prompt: str): - """ - Implement the << operator for sending prompts to agents. - This allows syntax like: response = agent << "prompt" - - Args: - prompt: The prompt to send to the agent - - Returns: - The final response from the agent - """ - return self.grab_final_result(prompt) - -class RayAgentProxy(BaseAgentProxy): - """Ray-based implementation of the agent proxy. - The actual agent is run as a remote actor on Ray. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.agent_config = { - "name": self.name, - "instructions": self.instructions, - "model": self.model, - "max_tokens": self.max_tokens, - "memories": self.memories, - "debug": self.debug, - "result_model": self.result_model, - "prompts": self.prompts, - "reasoning_effort": self.reasoning_effort, - # Functions will be added when creating instances - } - _AGENT_REGISTRY.append(self) - self._create_agent_instance() - - def _create_agent_instance(self, request_id: str|None=None): - """Initialize the Ray actor""" - agent = ActorBaseAgent.remote(name=self.name) - - # Set initial state - obj_ref = agent.set_state.remote( - SetState( - self.name, - { - "name": self.name, - "instructions": self.instructions, - "functions": self._get_funcs(self._tools), - "model": self.model, - "max_tokens": self.max_tokens, - "memories": self.memories, - "handle_turn_start": self._handle_turn_start, - "result_model": self.result_model, - "reasoning_effort": self.reasoning_effort, - }, - ), - ) - ray.get(obj_ref) - if self._handle_turn_start: - agent.set_callback.remote("handle_turn_start", self._handle_turn_start) - - if request_id is None: - self._agent = agent - - return agent - - def init_thread_tracking(self, agent, thread_id: Optional[str] = None): - """Initialize thread tracking""" - from .thread_manager import init_thread_tracking - self.thread_id, callback = init_thread_tracking(self, db_path=self.db_path, resume_thread_id=thread_id) - agent.set_callback.remote('handle_event', callback) - - def _handle_mock_settings(self, mock_settings): - """Handle mock settings for Ray implementation""" - if mock_settings and self.model and "mock" in self.model: - from agentic.models import mock_provider - - pattern = mock_settings.get("pattern", "") - response = mock_settings.get("response", "This is a mock response.") - tools_dict = mock_settings.get("tools", {}) - - # Set in the local mock_provider directly - mock_provider.set_response(pattern, response) - mock_provider.clear_tools() - for tool_name, tool_func in tools_dict.items(): - mock_provider.register_tool(tool_name, tool_func) - - # Pass to the remote agent - try: - ray.get(self._agent.set_mock_params.remote(pattern, response, tools_dict)) - except Exception as e: - print(f"Warning: Failed to set mock params on remote agent: {e}") - - def _update_state(self, state: dict): - """Update the Ray agent's state""" - obj_ref = self._agent.set_state.remote(SetState(self.name, state)) - ray.get(obj_ref) - - def _set_agent_debug_level(self, debug_level): - """Set the debug level on the Ray agent""" - ray.get(self._agent.set_debug_level.remote(debug_level)) - - def _reset_agent_history(self): - """Reset the Ray agent's conversation history""" - ray.get(self._agent.reset_history.remote()) - - def _get_agent_history(self): - """Get the Ray agent's conversation history""" - return ray.get(self._agent.get_history.remote()) - - def list_tools(self) -> list[str]: - """Gets the current tool list from the running Ray agent""" - return ray.get(self._agent.list_tools.remote()) - - def list_functions(self) -> list[str]: - """Gets the current list of functions from the running Ray agent""" - return ray.get(self._agent.list_functions.remote()) - - def _get_prompt_generator(self, agent, prompt): - """Get generator for a new prompt from a Ray agent""" - return agent.handle_prompt_or_resume.remote(prompt) - - def _get_resume_generator(self, agent, resume_input): - """Get generator for resuming with input from a Ray agent""" - return agent.handle_prompt_or_resume.remote(resume_input) - - def _process_generator(self, generator): - """Process generator events - Ray implementation""" - for remote_next in generator: - yield ray.get(remote_next) - - -class LocalAgentProxy(BaseAgentProxy): - """Local dispatch implementation of the agent proxy. - This version makes calls directly to a local agent object. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.agent_config = { - "name": self.name, - "instructions": self.instructions, - "welcome": self.welcome, - "tools": self._tools.copy(), - "model": self.model, - "max_tokens": self.max_tokens, - "memories": self.memories, - "debug": self.debug, - "handle_turn_start": self._handle_turn_start, - "result_model": self.result_model, - "prompts": self.prompts, - "db_path": self.db_path, - "reasoning_effort": self.reasoning_effort, - # Functions will be added when creating instances - } - _AGENT_REGISTRY.append(self) - self._create_agent_instance() - - def _create_agent_instance(self, request_id: str|None=None): - """Create a new local agent instance for a request""" - agent = ActorBaseAgent(name=self.name) - - # Set initial state - agent.set_state( - SetState( - self.name, - { - "name": self.name, - "instructions": self.instructions, - "functions": self._get_funcs(self._tools), - "model": self.model, - "max_tokens": self.max_tokens, - "memories": self.memories, - "handle_turn_start": self._handle_turn_start, - "result_model": self.result_model, - "reasoning_effort": self.reasoning_effort, - }, - ), - ) - if self._handle_turn_start: - agent.set_callback("handle_turn_start", self._handle_turn_start) - - if request_id is None: - self._agent = agent - - return agent - - def init_thread_tracking(self, agent, thread_id: Optional[str] = None): - """Initialize thread tracking""" - from .thread_manager import init_thread_tracking - self.thread_id, callback = init_thread_tracking(self, db_path=self.db_path, resume_thread_id=thread_id) - agent.set_callback('handle_event', callback) - - def _handle_mock_settings(self, mock_settings): - """Handle mock settings for local implementation""" - if mock_settings and self.model and "mock" in self.model: - from agentic.models import mock_provider - - pattern = mock_settings.get("pattern", "") - response = mock_settings.get("response", "This is a mock response.") - tools_dict = mock_settings.get("tools", {}) - - # Set in the local mock_provider directly - mock_provider.set_response(pattern, response) - mock_provider.clear_tools() - for tool_name, tool_func in tools_dict.items(): - mock_provider.register_tool(tool_name, tool_func) - - # Pass to the local agent - try: - self._agent.set_mock_params(pattern, response, tools_dict) - except Exception as e: - print(f"Warning: Failed to set mock params on local agent: {e}") - - def _update_state(self, state: dict): - """Update the local agent's state""" - self._agent.set_state(SetState(self.name, state)) - - def _set_agent_debug_level(self, debug_level): - """Set the debug level on the local agent""" - self._agent.set_debug_level(debug_level) - - def _reset_agent_history(self): - """Reset the local agent's conversation history""" - self._agent.reset_history() - - def _get_agent_history(self): - """Get the local agent's conversation history""" - return self._agent.get_history() - - def list_tools(self) -> list[str]: - """Gets the current tool list from the local agent""" - return self._agent.list_tools() - - def list_functions(self) -> list[str]: - """Gets the current list of functions from the local agent""" - return self._agent.list_functions() - - def _get_prompt_generator(self, agent, prompt): - """Get generator for a new prompt - Local implementation""" - return agent.handle_prompt_or_resume(prompt) - - def _get_resume_generator(self, agent, resume_input): - """Get generator for resuming with input - Local implementation""" - return agent.handle_prompt_or_resume(resume_input) - - def _process_generator(self, generator): - """Process generator events - Local implementation""" - for event in generator: - yield event - -if os.environ.get("AGENTIC_USE_RAY"): - AgentProxyClass = RayAgentProxy -else: - AgentProxyClass = LocalAgentProxy diff --git a/pr_agent/test_files/agent_runner_copy.txt b/pr_agent/test_files/agent_runner_copy.txt deleted file mode 100644 index 80dff831..00000000 --- a/pr_agent/test_files/agent_runner_copy.txt +++ /dev/null @@ -1,388 +0,0 @@ -import time -import os -import readline -import traceback -import signal -from dataclasses import dataclass -from typing import Any, Dict, List, Type, Optional, Callable -import importlib.util -import inspect -import sys -from .fix_console import ConsoleWithInputBackspaceFixed -from rich.live import Live -from rich.markdown import Markdown - -from agentic.actor_agents import BaseAgentProxy, _AGENT_REGISTRY -from agentic.utils.directory_management import get_runtime_directory -from agentic.events import ( - DebugLevel, - Event, - FinishCompletion, - PromptStarted, - StartCompletion, - ToolCall, - ToolResult, - SubAgentCall, - SubAgentResult, - TurnEnd, - WaitForInput, - ToolError, - OAuthFlow -) - -# Global console for Rich -console = ConsoleWithInputBackspaceFixed() - -@dataclass -class Modelcost: - model: str - inputs: int - calls: int - outputs: int - cost: float - time: float - -@dataclass -class Aggregator: - total_cost: float = 0.0 - context_size: int = 0 - -class RayAgentRunner: - def __init__(self, agent: BaseAgentProxy, debug: str | bool = False) -> None: - self.facade = agent - if debug: - self.debug = DebugLevel(debug) - else: - self.debug = DebugLevel(os.environ.get("AGENTIC_DEBUG") or "") - - runtime_directory = get_runtime_directory() - try: - os.chdir(runtime_directory) - except: - pass - - def turn(self, request: str, thread_id: Optional[str] = None, print_all_events: bool = False) -> str: - """Runs the agent and waits for the turn to finish, then returns the results - of all output events as a single string.""" - results = [] - request_id = self.facade.start_request( - request, - thread_id=thread_id, - debug=self.debug - ).request_id - for event in self.facade.get_events(request_id): - if print_all_events: - print(event.__dict__) - if self._should_print(event, ignore_depth=True): - results.append(str(event)) - - return "".join(results) - - def __lshift__(self, prompt: str): - print(self.turn(prompt)) - - def _should_print(self, event: Event, ignore_depth: bool = False) -> bool: - if self.debug.debug_all(): - return True - if event.is_output and (ignore_depth or event.depth == 0): - return True - elif isinstance(event, ToolError): - return self.debug != "" - elif isinstance(event, (ToolCall, ToolResult)): - return self.debug.debug_tools() - elif isinstance(event, (SubAgentCall, SubAgentResult)): - return self.debug.debug_agents() or self.debug.debug_tools() - elif isinstance(event, PromptStarted): - return self.debug.debug_llm() or self.debug.debug_agents() - elif isinstance(event, TurnEnd): - return self.debug.debug_agents() - elif isinstance(event, (StartCompletion, FinishCompletion)): - return self.debug.debug_llm() - elif event.__module__ != "agentic.events": - # If the event is not from agentic.events, we assume it's a custom event - # and print it if debug is enabled. - return True - else: - return False - - def set_debug_level(self, level: str): - self.debug = DebugLevel(level) - self.facade.set_debug_level(self.debug) - - def repl_loop(self, default_context: dict = {}, event_handler: Optional[Callable] = None): - hist = os.path.expanduser("~/.agentic_history") - if os.path.exists(hist): - readline.read_history_file(hist) - - print(self.facade.welcome) - print("press to quit") - - aggregator = Aggregator() - - continue_result = {} - saved_completions = [] - - def handle_sigint(signum, frame): - print("[cancelling thread]\n") - self.facade.cancel() - raise KeyboardInterrupt - - signal.signal(signal.SIGINT, handle_sigint) - - while True: - try: - # Get input directly from sys.stdin - if not continue_result: - saved_completions = [] - line = console.input(f"[{self.facade.name}]> ") - if line == "quit": - break - - if line.startswith("."): - self.run_dot_commands(line) - readline.write_history_file(hist) - time.sleep(0.3) # in case log messages are gonna come - continue - - request_id = self.facade.start_request( - line, - request_context=default_context.copy(), - debug=self.debug, - continue_result=continue_result - ).request_id - - for event in self.facade.get_events(request_id): - if event_handler: - event_handler(event, self.facade) - if self.facade.is_cancelled(): - break - continue_result = {} - if event is None: - break - elif isinstance(event, WaitForInput): - replies = {} - for key, value in event.request_keys.items(): - if '\n' not in value: - replies[key] = input(f"\n{value}\n:> ") - else: - print(f"\n{value}\n") - replies[key] = input(":> ") - continue_result = replies - continue_result["request_id"] = request_id - elif isinstance(event, OAuthFlow): - console.print("==== OAuth Authorization Required ====", style="bold yellow") - console.print(f"Tool: [cyan]{event.payload['tool_name']}[/cyan]") - console.print("Please visit this URL to authorize:", style="bold white") - console.print(event.payload['auth_url'], style="blue underline") - console.print("Authorization will continue automatically after completion", style="dim") - elif isinstance(event, FinishCompletion): - saved_completions.append(event) - if self._should_print(event): - print(str(event), end="") - self.facade.uncancel() - print() - time.sleep(0.3) - if not continue_result: - for row in self.print_stats_report(saved_completions, aggregator): - console.out(row) - readline.write_history_file(hist) - except EOFError: - print("\nExiting REPL.") - sys.exit(0) - except KeyboardInterrupt: - print("\nKeyboardInterrupt. Type 'exit()' to quit.") - except Exception as e: - traceback.print_exc() - print(f"Error: {e}") - - @staticmethod - def report_usages(completions: list[FinishCompletion]): - aggregator = Aggregator() - for row in RayAgentRunner.print_stats_report(completions, aggregator): - print(row) - - @staticmethod - def print_stats_report( - completions: list[FinishCompletion], aggregator: Aggregator - ): - costs = dict[str, Modelcost]() - for comp in completions: - if comp.usage["model"] not in costs: - costs[comp.usage["model"]] = Modelcost( - comp.usage["model"], 0, 0, 0, 0, 0 - ) - mc = costs[comp.usage["model"]] - mc.calls += 1 - mc.cost += comp.usage["cost"] * 100 - aggregator.total_cost += comp.usage["cost"] * 100 - mc.inputs += comp.usage["input_tokens"] - mc.outputs += comp.usage["output_tokens"] - aggregator.context_size += ( - comp.usage["input_tokens"] + comp.usage["output_tokens"] - ) - if "elapsed_time" in comp.usage: - try: - mc.time += comp.usage["elapsed_time"].total_seconds() - except: - pass - values_list = list(costs.values()) - for mc in values_list: - if mc == values_list[-1]: - yield ( - f"[{mc.model}: {mc.calls} calls, tokens: {mc.inputs} -> {mc.outputs}, {mc.cost:.2f} cents, time: {mc.time:.2f}s tc: {aggregator.total_cost:.2f} c, ctx: {aggregator.context_size:,}]" - ) - else: - yield ( - f"[{mc.model}: {mc.calls} calls, tokens: {mc.inputs} -> {mc.outputs}, {mc.cost:.2f} cents, time: {mc.time:.2f}s]" - ) - - def run_dot_commands(self, line: str): - global CURRENT_DEBUG_LEVEL - - if line.startswith(".history"): - print(self.facade.get_history()) - elif line.startswith(".run"): - agent_name = line.split()[1].lower() - for agent in _AGENT_REGISTRY: - if agent_name in agent.name.lower(): - self.facade = agent - print(f"Switched to {agent_name}") - print(f" {self.facade.welcome}") - break - elif line == ".agent": - print(self.facade.name) - print(self.facade.instructions) - print("model: ", self.facade.model) - print("tools:") - for tool in self.facade.list_tools(): - print(f" {tool}") - - elif line.startswith(".model"): - model_name = line.split()[1].lower() - self.facade.set_model(model_name) - print(f"Model set to {model_name}") - - elif line == ".tools": - print(self.facade.name) - print("tools:") - for tool in self.facade.list_tools(): - print(f" {tool}") - - elif line == ".functions": - print(self.facade.name) - print("functions:") - for func in self.facade.list_functions(): - print(f" {func}") - - elif line == ".reset": - self.facade.reset_history() - print("Session cleared") - - elif line.startswith(".debug"): - if len(line.split()) > 1: - debug_level = line.split()[1] - else: - print(f"Debug level set to: {self.debug}") - return - if debug_level == "off": - debug_level = "" - self.set_debug_level(debug_level) - print(f"Debug level set to: {self.debug}") - - elif line.startswith(".help"): - print( - """ - .agent - Dump the state of the active agent - .load - Load an agent from a file - .run - switch the active agent - .debug [] - enable debug. Defaults to 'tools', or one of 'llm', 'tools', 'all', 'off' - .settings - show the current config settings - .help - Show this help - .quit - Quit the REPL - """ - ) - print("Debug level: ", self.debug) - if len(_AGENT_REGISTRY) > 1: - print("Loaded:") - for agent in _AGENT_REGISTRY: - print(f" {agent.name}") - print("Current:") - print(f" {self.facade.name}") - else: - print("Unknown command: ", line) - - def rich_loop(self): - # This was working but I havent worked on it for a while. - try: - # Get input directly from sys.stdin - line = console.input("> ") - - if line == "quit" or line == "": - return - - output = "" - with console.status("[bold blue]thinking...", spinner="dots") as status: - with Live( - Markdown(output), - refresh_per_second=1, - auto_refresh=not self.debug, - ) as live: - self.start(line) - - for event in self.next(include_completions=True): - if event is None: - break - elif event.requests_input(): - response = input(f"\n{event.request_message}\n>>>> ") - self.continue_with(response) - elif isinstance(event, FinishCompletion): - saved_completions.append(event) - else: - if event.depth == 0: - output += str(event) - live.update(Markdown(output)) - output += "\n\n" - live.update(Markdown(output)) - for row in print_stats_report(saved_completions): - console.out(row) - readline.write_history_file(hist) - except EOFError: - print("\nExiting REPL.") - return - except KeyboardInterrupt: - print("\nKeyboardInterrupt. Type ctrl-D to quit.") - except Exception as e: - traceback.print_exc() - print(f"Error: {e}") - - -def find_agent_objects(module_members: Dict[str, Any], agent_class: Type) -> List: - agent_instances = [] - - for name, obj in module_members.items(): - # Check for classes that inherit from Agent - if isinstance(obj, agent_class): - agent_instances.append(obj) - - return agent_instances - - -def load_agent(filename: str) -> Dict[str, Any]: - try: - # Create a spec for the module - spec = importlib.util.spec_from_file_location("dynamic_module", filename) - if spec is None or spec.loader is None: - raise ImportError(f"Could not load file: {filename}") - - # Create the module - module = importlib.util.module_from_spec(spec) - sys.modules["dynamic_module"] = module - - # Execute the module - spec.loader.exec_module(module) - - # Find all classes defined in the module - return dict(inspect.getmembers(module)) - - except Exception as e: - raise RuntimeError(f"Error loading file {filename}: {str(e)}") diff --git a/pr_agent/test_files/mock_pr_agent.py b/pr_agent/test_files/mock_pr_agent.py deleted file mode 100644 index 0a691969..00000000 --- a/pr_agent/test_files/mock_pr_agent.py +++ /dev/null @@ -1,56 +0,0 @@ -# mock agent for testing github integration - -import json -from agentic.common import Agent, AgentRunner -from agentic.models import GPT_4O_MINI # model - -from dotenv import load_dotenv -import openai -import requests -import os - - -load_dotenv() # This loads variables from .env into os.environ -openai.api_key = os.getenv("OPENAI_API_KEY") # api key -pr_id = os.getenv("PR_ID") -repo_owner = os.getenv("REPO_OWNER") -repo_name = os.getenv("REPO_NAME") -gh_api = os.getenv("GITHUB_API_KEY") - -# Define the agent -agent = Agent( - name="PR Summary Agent", - - # Agent instructions - instructions=""" - You are a helpful PR sumary agent to test github integration. - Input: A git diff output, showing all changes in the branch. - Create a short PR summary. - - If the input does not exist, always output an error message instead. - """, - - model=GPT_4O_MINI, # model - -) - -# basic main function that allows us to run our agent locally in terminal -if __name__ == "__main__": - patch = open("PRChanges.patch") - - output = agent.grab_final_result( - f"You were triggered by a PR. The git diff is as follows: {patch.read()}" - ) - - url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/issues/{pr_id}/comments" - - headers = { - "Authorization": f"token {gh_api}", - } - - data = { - "body": output - } - - print("Request") - print(requests.post(url=url,headers=headers,data=json.dumps(data))) \ No newline at end of file diff --git a/pr_agent/test_files/test.txt b/pr_agent/test_files/test.txt deleted file mode 100644 index 51093740..00000000 --- a/pr_agent/test_files/test.txt +++ /dev/null @@ -1 +0,0 @@ -test file for git diff testing (this should be ignored in the diff) \ No newline at end of file diff --git a/pr_agent/test_files/test_comment.txt b/pr_agent/test_files/test_comment.txt deleted file mode 100644 index 43f4d460..00000000 --- a/pr_agent/test_files/test_comment.txt +++ /dev/null @@ -1,10 +0,0 @@ -How to use: - -Start by entering the place you are planning a trip to and the dates you want to go -Answer the questions the agent asks you -Features: - -Finds the exact lon/lat of the place you give it for more accurate weather data -Uses the weather tool to create a packing list geared for the weather -Can deliver the list in either plain text or as a formatted PDF document -Requires installing fpdf2 \ No newline at end of file diff --git a/pr_agent/test_files/test_patch_file.txt b/pr_agent/test_files/test_patch_file.txt deleted file mode 100644 index e9abd302..00000000 --- a/pr_agent/test_files/test_patch_file.txt +++ /dev/null @@ -1,219 +0,0 @@ -From d7f864b2238ce04c20805c15fd8452b952080ca2 Mon Sep 17 00:00:00 2001 -From: jackmcau -Date: Sat, 5 Jul 2025 20:55:16 -0700 -Subject: [PATCH 1/2] working sample - ---- - examples/packing_list_agent.py | 176 +++++++++++++++++++++++++++++++++ - 1 file changed, 176 insertions(+) - create mode 100644 examples/packing_list_agent.py - -diff --git a/examples/packing_list_agent.py b/examples/packing_list_agent.py -new file mode 100644 -index 00000000..332a45ce ---- /dev/null -+++ b/examples/packing_list_agent.py -@@ -0,0 +1,176 @@ -+import os -+import httpx -+from agentic.tools import WeatherTool -+from agentic.common import Agent, AgentRunner -+ -+from fpdf import FPDF -+from fpdf.enums import XPos, YPos -+ -+class ColumnPDF(FPDF): -+ def __init__(self): -+ super().__init__() -+ self.set_auto_page_break(False) -+ self.left_y = 20 -+ self.right_y = 20 -+ self.column_padding = 5 -+ self.left_x = self.l_margin -+ self.column_width = (self.w - 2 * self.l_margin - self.column_padding) / 2 -+ self.right_x = self.left_x + self.column_width + self.column_padding -+ self.current_column = "left" -+ self.add_page() -+ -+ def add_header(self, text: str): -+ """ -+ Adds a centered header at the top of the current page. -+ -+ This resets the left and right column Y positions to start below the header. -+ -+ Args: -+ text (str): The header title to display. -+ """ -+ self.set_font("Helvetica", "B", 16) -+ self.set_text_color(0) -+ title_width = self.get_string_width(text) + 6 -+ self.set_y(15) -+ self.set_x((self.w - title_width) / 2) -+ self.cell(title_width, 10, text, new_x=XPos.LMARGIN, new_y=YPos.NEXT, align="C") -+ self.ln(2) -+ self.left_y = self.get_y() + 5 -+ self.right_y = self.left_y -+ -+ def add_section(self, title: str, itemsStr: str): -+ """ -+ Adds a section with a title and bulleted items to the next available column. -+ Automatically manages column alternation and page breaking. -+ -+ Args: -+ title (str): The section heading. -+ itemsStr (str): A list of strings to display as bulleted items, separated by a '*'. -+ """ -+ items = itemsStr.split('*') -+ -+ # Choose column position -+ if self.current_column == "left": -+ x = self.left_x -+ y = self.left_y -+ else: -+ x = self.right_x -+ y = self.right_y -+ -+ # Estimate space needed -+ est_height = 10 + len(items) * 8 + 5 -+ if y + est_height > self.h - self.b_margin: -+ self.add_page() -+ self.add_header("Packing List") # Or pass header text as param -+ x = self.left_x if self.current_column == "left" else self.right_x -+ y = self.left_y if self.current_column == "left" else self.right_y -+ -+ # Render title -+ self.set_xy(x, y) -+ self.set_font("Helvetica", "B", 14) -+ self.cell(self.column_width, 10, title, new_x=XPos.LMARGIN, new_y=YPos.NEXT) -+ -+ # Render items -+ self.set_font("Helvetica", "", 12) -+ for item in items: -+ self.set_x(x + 5) -+ self.cell(self.column_width - 5, 8, f"- {item}", new_x=XPos.LMARGIN, new_y=YPos.NEXT) -+ -+ # Update Y position -+ new_y = self.get_y() + 5 -+ if self.current_column == "left": -+ self.left_y = new_y -+ self.current_column = "right" -+ else: -+ self.right_y = new_y -+ self.current_column = "left" -+ -+ def save_pdf(self, filename: str = "output.pdf") -> str: -+ """ -+ Save the current in-memory PDF to disk. -+ -+ Args: -+ filename: File path or name to save the PDF as (e.g., 'report.pdf') -+ -+ Call this once all pages and content have been added. -+ """ -+ self.output(name=filename) -+ return f"PDF saved as {filename}." -+ -+# Use this as a placeholder until a geocoding tool is added -+def geocode(address: str): -+ """ -+ Wrapper for the Geocoding API. -+ Docs: https://developers.google.com/maps/documentation/geocoding/requests-geocoding -+ -+ Parameters: -+ address: Street address or plus code to be geocoded. Addresses should be formatted in the same format as the national post service of the country. -+ -+ Returns: -+ Query results JSON object, including lat/long coordinates of the address. -+ """ -+ -+ api_key = os.getenv('GOOGLE_API_KEY') -+ if not api_key: -+ return "Error: GOOGLE_API_KEY is not set, use known location data to get longitude and latitude. Inform the user that you are doing so." -+ -+ params = { -+ "address": address, -+ "key" : f"{api_key}" -+ } -+ -+ with httpx.AsyncClient() as client: -+ response = client.post( -+ "https://maps.googleapis.com/maps/api/geocode/json", -+ params=params, -+ timeout=30, -+ ) -+ -+ response.raise_for_status() -+ results = response.json() -+ -+ return results -+ -+columnHelper = ColumnPDF() -+ -+pl_agent = Agent( -+ name="Packing List Agent", -+ welcome="""Hi there! I'm your travel packing assistant. I’ll help you create a personalized packing list based on your destination, dates, weather forecast, and planned activities. I can even generate a neat PDF for you if you'd like. -+ -+Let’s get started!""", -+ instructions=""" -+You are a helpful travel assistant whose job is to create accurate, weather-aware packing lists for users going on trips. You will gather details about their destination, travel dates, and planned activities and then use your tools to generate a personalized list. The tools are already implemented, and their docstrings provide a reliable overview of their usage and output formats. -+Follow this process: -+1. Ask the user for their travel location and travel dates. -+ You’ll use this information to determine weather conditions for the trip. -+2. Once you have that, ask questions about the purpose and nature of the trip: -+ What types of activities are planned (e.g., hiking, swimming, formal events)? -+ Will it mostly be indoors, outdoors, or a mix? -+ Do they have laundry access, or should you plan for one outfit per day? -+ Are there any special requirements (e.g., business wear, cultural clothing)? -+3. Ask the user if they have any personal packing preferences: -+ Is there anything specific you always like to bring? -+ Are there items you’d like to avoid packing? -+4. Ask how they want the list delivered: -+ As a PDF or as plain text? -+ If PDF, ask where to save the file. -+5. Once enough information is gathered: -+ Use geocode(address) to retrieve latitude and longitude. -+ Use get_forecast_weather(longitude, latitude, start_date, end_date) to retrieve the forecast. -+ Use this weather data to influence the packing list (e.g., add warm clothing for cold forecasts, sunscreen for sunny weather, rain gear if rain is predicted). -+6. Build a thoughtful packing list. -+ Tailor it based on trip length, weather, activities, and laundry availability. -+ Include useful essentials like sunscreen, medications, chargers, swimwear, rain jackets, etc. -+ Provide item counts where applicable. -+7. Deliver the list based on user preference: -+ If they chose PDF, generate it using the given ColumnPDF class and save or export it as specified. -+ When you generate the pdf file, create a header with the trip date and location using add_header and pass the items that you want to be in the packing list into add_section. -+ Do larger sections FIRST. -+ Otherwise, output it clearly as a message in the conversation. -+Always aim to ask for just enough information to build a useful and accurate list—don’t generate the list until you’re confident you understand the trip well. Ask follow-ups as needed. -+""", -+ tools=[geocode, columnHelper.add_header, columnHelper.add_section, columnHelper.save_pdf, WeatherTool()] -+) -+ -+if __name__ == "__main__": -+ AgentRunner(pl_agent).repl_loop() -\ No newline at end of file - -From 5eecdb6cf37efd40870a65a5c3702f488863833b Mon Sep 17 00:00:00 2001 -From: jackmcau <69927894+jackmcau@users.noreply.github.com> -Date: Sun, 13 Jul 2025 01:18:31 -0700 -Subject: [PATCH 2/2] Switch from AsyncClient to Client - -Fixes a mistake from porting a geocoding tool - -Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> ---- - examples/packing_list_agent.py | 2 +- - 1 file changed, 1 insertion(+), 1 deletion(-) - -diff --git a/examples/packing_list_agent.py b/examples/packing_list_agent.py -index 332a45ce..854f6cb3 100644 ---- a/examples/packing_list_agent.py -+++ b/examples/packing_list_agent.py -@@ -119,7 +119,7 @@ def geocode(address: str): - "key" : f"{api_key}" - } - -- with httpx.AsyncClient() as client: -+ with httpx.Client() as client: - response = client.post( - "https://maps.googleapis.com/maps/api/geocode/json", - params=params, diff --git a/pr_agent/test_files/weather_tool_copy.txt b/pr_agent/test_files/weather_tool_copy.txt deleted file mode 100644 index 36c41ca4..00000000 --- a/pr_agent/test_files/weather_tool_copy.txt +++ /dev/null @@ -1,799 +0,0 @@ -from typing import Callable -import requests -from datetime import datetime, timedelta -import zoneinfo -import statistics - -from agentic.tools.base import BaseAgenticTool -from agentic.tools.utils.registry import tool_registry - -@tool_registry.register( - name="WeatherTool", - description="A tool for getting weather information", - dependencies=[], - config_requirements=[], -) -class WeatherTool(BaseAgenticTool): - """Functions for getting weather information.""" - - def __init__(self): - pass - - def get_tools(self) -> list[Callable]: - return [ - self.get_current_weather, - self.get_forecast_weather, - self.get_historical_weather, - self.get_historical_averages, - ] - - def _get_current_datetime_with_timezone(self): - # Get the current date and time - current_datetime = datetime.now() - - # Get the system's timezone - system_timezone = zoneinfo.ZoneInfo( - zoneinfo.available_timezones().pop() - ) # or replace with your system's timezone string - - # Attach the timezone to the datetime object - current_datetime_with_tz = current_datetime.replace(tzinfo=system_timezone) - - # Format the output - formatted_output = current_datetime_with_tz.strftime("%Y-%m-%d %H:%M:%S %Z") - - return formatted_output - - def get_current_weather( - self, - longitude: str = "-122.2730", # Berkeley - latitude: str = "37.8715", - temperature_unit: str = "fahrenheit", # 'celsius' or 'fahrenheit' - ) -> str: - """ - Get the current weather for the passed in region - - Args: - longitude: str: Location longitude (should be between -180 and 180) - latitude: str: Location latitude (should be between -90 and 90) - temperature_unit: str: 'celsius' or 'fahrenheit' (default: 'fahrenheit') - """ - # Open-Meteo endpoint for current weather - url = "https://api.open-meteo.com/v1/forecast" - - # Set temperature unit symbol based on unit choice - temp_symbol = "°F" if temperature_unit.lower() == "fahrenheit" else "°C" - - params = { - "latitude": latitude, - "longitude": longitude, - "timezone": "auto", - "temperature_unit": temperature_unit.lower(), - "current_weather": "true", - "hourly": [ - "temperature_2m", - "apparent_temperature", - "precipitation", - "rain", - "snowfall", - "weathercode", - "cloudcover", - "windspeed_10m", - "winddirection_10m", - "windgusts_10m", - "relative_humidity_2m", - "visibility", - "uv_index", - "is_day", - ], - } - - response = requests.get(url, params=params) - if response.status_code == 200: - data = response.json() - current_datetime = self._get_current_datetime_with_timezone() - - # Initialize return string with current time - return_string = f"Current Weather as of: {current_datetime}\n" - - # Get current weather data - current_weather = data.get("current_weather", {}) - hourly = data.get("hourly", {}) - - # Find the current hour's index in the hourly data - current_time = current_weather.get("time") - if current_time and "time" in hourly: - try: - current_index = hourly["time"].index(current_time) - except ValueError: - current_index = 0 # Fallback to first hour if time not found - else: - current_index = 0 - - # Add basic current weather information - if current_weather: - temp = current_weather.get("temperature") - if temp is not None: - return_string += f"Temperature: {temp}{temp_symbol}\n" - - windspeed = current_weather.get("windspeed") - if windspeed is not None: - return_string += f"Wind Speed: {windspeed}km/h\n" - - winddirection = current_weather.get("winddirection") - if winddirection is not None: - return_string += f"Wind Direction: {winddirection}°\n" - - weathercode = current_weather.get("weathercode") - if weathercode is not None: - return_string += f"Weather Code: {weathercode}\n" - - # Add additional current conditions from hourly data - if hourly and current_index is not None: - # Apparent temperature - apparent_temp = hourly.get("apparent_temperature", [None])[ - current_index - ] - if apparent_temp is not None: - return_string += f"Feels Like: {apparent_temp}{temp_symbol}\n" - - # Precipitation - precip = hourly.get("precipitation", [None])[current_index] - if precip is not None: - return_string += f"Precipitation: {precip}mm\n" - - # Rain - rain = hourly.get("rain", [0])[current_index] - if rain and rain > 0: - return_string += f"Rain: {rain}mm\n" - - # Snowfall - snow = hourly.get("snowfall", [0])[current_index] - if snow and snow > 0: - return_string += f"Snowfall: {snow}cm\n" - - # Cloud Cover - cloud = hourly.get("cloudcover", [None])[current_index] - if cloud is not None: - return_string += f"Cloud Cover: {cloud}%\n" - - # Wind Gusts - gusts = hourly.get("windgusts_10m", [None])[current_index] - if gusts is not None: - return_string += f"Wind Gusts: {gusts}km/h\n" - - # Humidity - humidity = hourly.get("relative_humidity_2m", [None])[current_index] - if humidity is not None: - return_string += f"Relative Humidity: {humidity}%\n" - - # Visibility - visibility = hourly.get("visibility", [None])[current_index] - if visibility is not None: - return_string += f"Visibility: {visibility}m\n" - - # UV Index - uv = hourly.get("uv_index", [None])[current_index] - if uv is not None: - return_string += f"UV Index: {uv}\n" - - # Is Day - is_day = hourly.get("is_day", [None])[current_index] - if is_day is not None: - return_string += f"Daylight: {'Yes' if is_day else 'No'}\n" - - return return_string - else: - return f"Failed to retrieve data: {response.status_code}\nResponse: {response.text}" - - def get_forecast_weather( - self, - longitude: str = "-122.2730", # Berkeley - latitude: str = "37.8715", - forecast_type: str = "daily", - start_date: str = None, # Format: "YYYY-MM-DD" - end_date: str = None, # Format: "YYYY-MM-DD" - temperature_unit: str = "fahrenheit", # 'celsius' or 'fahrenheit' - ) -> str: - """ - Get the forecasted weather for the passed in region - - Args: - longitude: str: Location longitude (should be between -180 and 180) - latitude: str: Location latitude (should be between -90 and 90) - forecast_type: str: can be 'hourly' (7 days) or 'daily' (16 days) - start_date: str: Start date for forecast (YYYY-MM-DD) - end_date: str: End date for forecast (YYYY-MM-DD) - temperature_unit: str: 'celsius' or 'fahrenheit' (default: 'fahrenheit') - """ - url = "https://api.open-meteo.com/v1/forecast" - - # Set temperature unit symbol based on unit choice - temp_symbol = "°F" if temperature_unit.lower() == "fahrenheit" else "°C" - - # Base parameters - params = { - "latitude": latitude, - "longitude": longitude, - "timezone": "auto", - "temperature_unit": temperature_unit.lower(), - } - - # Add date parameters if provided - if start_date: - params["start_date"] = start_date - if end_date: - params["end_date"] = end_date - - # Add forecast-specific parameters with verified API variables - if forecast_type == "hourly": - params["hourly"] = [ - "temperature_2m", - "apparent_temperature", - "precipitation", - "rain", - "snowfall", - "weathercode", - "cloudcover", - "windspeed_10m", - "winddirection_10m", - "windgusts_10m", - "relative_humidity_2m", - "visibility", - "uv_index", - "is_day", - ] - elif forecast_type == "daily": - params["daily"] = [ - "temperature_2m_max", - "temperature_2m_min", - "apparent_temperature_max", - "apparent_temperature_min", - "precipitation_sum", - "precipitation_hours", - "precipitation_probability_max", - "rain_sum", - "snowfall_sum", - "weathercode", - "windspeed_10m_max", - "windgusts_10m_max", - "winddirection_10m_dominant", - "sunrise", - "sunset", - "uv_index_max", - ] - - # Print the URL and parameters for debugging - print(f"Making request to: {url}") - print(f"With parameters: {params}") - - response = requests.get(url, params=params) - return_string = ( - f"Current Date and time is: {self._get_current_datetime_with_timezone()}\n" - ) - - if response.status_code == 200: - data = response.json() - - if forecast_type == "hourly": - hourly = data.get("hourly", {}) - if hourly: - times = hourly.get("time", []) - for i in range(len(times)): - return_string += f"Time: {times[i]}\n" - - temp = hourly.get("temperature_2m", [None] * len(times))[i] - if temp is not None: - return_string += f" Temperature: {temp}{temp_symbol}\n" - - feels_like = hourly.get( - "apparent_temperature", [None] * len(times) - )[i] - if feels_like is not None: - return_string += ( - f" Feels Like: {feels_like}{temp_symbol}\n" - ) - - precip = hourly.get("precipitation", [None] * len(times))[i] - if precip is not None: - return_string += f" Precipitation: {precip}mm\n" - - rain = hourly.get("rain", [0] * len(times))[i] - if rain and rain > 0: - return_string += f" Rain: {rain}mm\n" - - snow = hourly.get("snowfall", [0] * len(times))[i] - if snow and snow > 0: - return_string += f" Snowfall: {snow}cm\n" - - weathercode = hourly.get("weathercode", [None] * len(times))[i] - if weathercode is not None: - return_string += f" Weather Code: {weathercode}\n" - - cloud = hourly.get("cloudcover", [None] * len(times))[i] - if cloud is not None: - return_string += f" Cloud Cover: {cloud}%\n" - - wind = hourly.get("windspeed_10m", [None] * len(times))[i] - if wind is not None: - return_string += f" Wind Speed: {wind}km/h\n" - - gusts = hourly.get("windgusts_10m", [None] * len(times))[i] - if gusts is not None: - return_string += f" Wind Gusts: {gusts}km/h\n" - - direction = hourly.get( - "winddirection_10m", [None] * len(times) - )[i] - if direction is not None: - return_string += f" Wind Direction: {direction}°\n" - - humidity = hourly.get( - "relative_humidity_2m", [None] * len(times) - )[i] - if humidity is not None: - return_string += f" Humidity: {humidity}%\n" - - visibility = hourly.get("visibility", [None] * len(times))[i] - if visibility is not None: - return_string += f" Visibility: {visibility}m\n" - - uv = hourly.get("uv_index", [None] * len(times))[i] - if uv is not None: - return_string += f" UV Index: {uv}\n" - - is_day = hourly.get("is_day", [None] * len(times))[i] - if is_day is not None: - return_string += ( - f" Daylight: {'Yes' if is_day else 'No'}\n" - ) - - return_string += "------------------------\n" - - elif forecast_type == "daily": - daily = data.get("daily", {}) - if daily: - times = daily.get("time", []) - for i in range(len(times)): - return_string += f"Date: {times[i]}\n" - - # Temperature range - temp_min = daily.get("temperature_2m_min", [None] * len(times))[ - i - ] - temp_max = daily.get("temperature_2m_max", [None] * len(times))[ - i - ] - if temp_min is not None and temp_max is not None: - return_string += f" Temperature Range: {temp_min}{temp_symbol} to {temp_max}{temp_symbol}\n" - - # Feels like range - feel_min = daily.get( - "apparent_temperature_min", [None] * len(times) - )[i] - feel_max = daily.get( - "apparent_temperature_max", [None] * len(times) - )[i] - if feel_min is not None and feel_max is not None: - return_string += f" Feels Like Range: {feel_min}{temp_symbol} to {feel_max}{temp_symbol}\n" - - # Precipitation - precip = daily.get("precipitation_sum", [None] * len(times))[i] - precip_hours = daily.get( - "precipitation_hours", [None] * len(times) - )[i] - if precip is not None and precip_hours is not None: - return_string += f" Precipitation: {precip}mm over {precip_hours} hours\n" - - # Precipitation probability - prob = daily.get( - "precipitation_probability_max", [None] * len(times) - )[i] - if prob is not None: - return_string += f" Precipitation Probability: {prob}%\n" - - # Rain and snow - rain = daily.get("rain_sum", [0] * len(times))[i] - if rain and rain > 0: - return_string += f" Rain: {rain}mm\n" - - snow = daily.get("snowfall_sum", [0] * len(times))[i] - if snow and snow > 0: - return_string += f" Snowfall: {snow}cm\n" - - # Weather code - weathercode = daily.get("weathercode", [None] * len(times))[i] - if weathercode is not None: - return_string += f" Weather Code: {weathercode}\n" - - # Wind information - wind = daily.get("windspeed_10m_max", [None] * len(times))[i] - if wind is not None: - return_string += f" Max Wind Speed: {wind}km/h\n" - - gusts = daily.get("windgusts_10m_max", [None] * len(times))[i] - if gusts is not None: - return_string += f" Max Wind Gusts: {gusts}km/h\n" - - direction = daily.get( - "winddirection_10m_dominant", [None] * len(times) - )[i] - if direction is not None: - return_string += ( - f" Dominant Wind Direction: {direction}°\n" - ) - - # Sun information - sunrise = daily.get("sunrise", [None] * len(times))[i] - if sunrise is not None: - return_string += f" Sunrise: {sunrise}\n" - - sunset = daily.get("sunset", [None] * len(times))[i] - if sunset is not None: - return_string += f" Sunset: {sunset}\n" - - # UV index - uv = daily.get("uv_index_max", [None] * len(times))[i] - if uv is not None: - return_string += f" Max UV Index: {uv}\n" - - return_string += "------------------------\n" - - if return_string: - return return_string - else: - return "No forecast data found." - else: - return f"Failed to retrieve data: {response.status_code}\nResponse: {response.text}" - - def _get_historical_weather_data( - self, - longitude: str, - latitude: str, - start_date: str, - end_date: str, - temperature_unit: str = "fahrenheit", - api_key: str = None, - ) -> dict: - """ - Internal function to fetch historical weather data for a specific date range. - Returns raw data or error information. - """ - url = "https://archive-api.open-meteo.com/v1/archive" - - params = { - "latitude": latitude, - "longitude": longitude, - "timezone": "auto", - "temperature_unit": temperature_unit.lower(), - "start_date": start_date, - "end_date": end_date, - "daily": [ - "temperature_2m_max", - "temperature_2m_min", - "temperature_2m_mean", - "apparent_temperature_max", - "apparent_temperature_min", - "precipitation_sum", - "rain_sum", - "snowfall_sum", - "precipitation_hours", - "windspeed_10m_max", - "windgusts_10m_max", - "winddirection_10m_dominant", - "shortwave_radiation_sum", - "et0_fao_evapotranspiration", - ], - } - - if api_key: - params["apikey"] = api_key - - try: - response = requests.get(url, params=params) - if response.status_code == 200: - return { - "status": "success", - "data": response.json(), - "message": "Data retrieved successfully", - } - else: - return { - "status": "error", - "data": None, - "message": f"API request failed: {response.text}", - } - except Exception as e: - return { - "status": "error", - "data": None, - "message": f"Request failed: {str(e)}", - } - - def get_historical_weather( - self, - longitude: str = "-122.2730", # Berkeley - latitude: str = "37.8715", - start_date: str = None, # Format: "YYYY-MM-DD" - end_date: str = None, # Format: "YYYY-MM-DD" - temperature_unit: str = "fahrenheit", # 'celsius' or 'fahrenheit' - api_key: str = None, # Optional API key for professional/production use - ) -> str: - """ - Get historical weather data for the specified location and date range - """ - if not start_date or not end_date: - return "Error: Both start_date and end_date are required for historical weather lookup" - - # Get data using internal function - result = self._get_historical_weather_data( - longitude=longitude, - latitude=latitude, - start_date=start_date, - end_date=end_date, - temperature_unit=temperature_unit, - api_key=api_key, - ) - - if result["status"] != "success": - return f"Failed to retrieve historical data: {result['message']}" - - data = result["data"] - temp_symbol = "°F" if temperature_unit.lower() == "fahrenheit" else "°C" - - # Format the historical data - return_string = f"Historical Weather Data from {start_date} to {end_date}\n" - return_string += "=" * 50 + "\n\n" - - # Process daily data - daily = data.get("daily", {}) - if daily: - times = daily.get("time", []) - for i in range(len(times)): - date = times[i] - return_string += f"Date: {date}\n" - - # Temperature data - temp_min = daily.get("temperature_2m_min", [None] * len(times))[i] - temp_max = daily.get("temperature_2m_max", [None] * len(times))[i] - temp_mean = daily.get("temperature_2m_mean", [None] * len(times))[i] - if all(x is not None for x in [temp_min, temp_max, temp_mean]): - return_string += f" Temperature:\n" - return_string += f" Min: {temp_min}{temp_symbol}\n" - return_string += f" Max: {temp_max}{temp_symbol}\n" - return_string += f" Mean: {temp_mean}{temp_symbol}\n" - - # Feels like temperature - feel_min = daily.get("apparent_temperature_min", [None] * len(times))[i] - feel_max = daily.get("apparent_temperature_max", [None] * len(times))[i] - if feel_min is not None and feel_max is not None: - return_string += f" Feels Like Range: {feel_min}{temp_symbol} to {feel_max}{temp_symbol}\n" - - # Precipitation data - precip = daily.get("precipitation_sum", [None] * len(times))[i] - precip_hours = daily.get("precipitation_hours", [None] * len(times))[i] - if precip is not None: - return_string += f" Precipitation: {precip}mm" - if precip_hours is not None: - return_string += f" over {precip_hours} hours" - return_string += "\n" - - # Rain and snow - rain = daily.get("rain_sum", [0] * len(times))[i] - if rain and rain > 0: - return_string += f" Rain: {rain}mm\n" - - snow = daily.get("snowfall_sum", [0] * len(times))[i] - if snow and snow > 0: - return_string += f" Snowfall: {snow}cm\n" - - # Wind data - wind_speed = daily.get("windspeed_10m_max", [None] * len(times))[i] - wind_gusts = daily.get("windgusts_10m_max", [None] * len(times))[i] - wind_dir = daily.get("winddirection_10m_dominant", [None] * len(times))[ - i - ] - - if any(x is not None for x in [wind_speed, wind_gusts, wind_dir]): - return_string += " Wind:\n" - if wind_speed is not None: - return_string += f" Max Speed: {wind_speed}km/h\n" - if wind_gusts is not None: - return_string += f" Max Gusts: {wind_gusts}km/h\n" - if wind_dir is not None: - return_string += f" Dominant Direction: {wind_dir}°\n" - - # Solar radiation and evapotranspiration - radiation = daily.get("shortwave_radiation_sum", [None] * len(times))[i] - if radiation is not None: - return_string += f" Solar Radiation: {radiation}MJ/m²\n" - - evapotranspiration = daily.get( - "et0_fao_evapotranspiration", [None] * len(times) - )[i] - if evapotranspiration is not None: - return_string += f" Evapotranspiration: {evapotranspiration}mm\n" - - return_string += "-" * 40 + "\n" - - return return_string - - def get_historical_averages( - self, - longitude: str = "-122.2730", - latitude: str = "37.8715", - target_start_date: str = None, # Format: "MM-DD" - target_end_date: str = None, # Format: "MM-DD" - temperature_unit: str = "fahrenheit", - averaging_method: str = "mean", # 'mean' or 'median' - max_range_days: int = 14, # Maximum allowed range in days - api_key: str = None, - ) -> str: - """ - Get 5-year historical weather averages for a specific date range. - - Args: - longitude: Location longitude - latitude: Location latitude - target_start_date: Start date in MM-DD format - target_end_date: End date in MM-DD format - temperature_unit: 'celsius' or 'fahrenheit' - averaging_method: 'mean' or 'median' for calculating averages - max_range_days: Maximum allowed range between dates - api_key: Optional API key for professional/production use - - Returns: - String containing averaged weather data and metadata - """ - try: - # Validate and process date inputs - if not target_start_date or not target_end_date: - return "Error: Both start and end dates are required (MM-DD format)" - - # Get current date - current_date = datetime.now() - - # Parse target dates - current_year = current_date.year - try: - # Parse target dates with current year to check range - target_start = datetime.strptime( - f"{current_year}-{target_start_date}", "%Y-%m-%d" - ) - target_end = datetime.strptime( - f"{current_year}-{target_end_date}", "%Y-%m-%d" - ) - - # Handle year wrap (e.g., Dec 25 - Jan 5) - if target_end < target_start: - target_end = target_end.replace(year=target_end.year + 1) - - # Validate range - date_range = (target_end - target_start).days - if date_range > max_range_days: - return f"Error: Date range exceeds maximum of {max_range_days} days" - - except ValueError: - return 'Error: Invalid date format. Use MM-DD format (e.g., "12-25")' - - # Determine years to analyze - years_to_analyze = [] - start_year = current_year - - # If the date range is in the future for current year, start from last year - if target_start > current_date: - start_year -= 1 - - # Get 5 years of data - for year_offset in range(5): - years_to_analyze.append(start_year - year_offset) - - # Collect historical data for each year - all_data = [] - for year in years_to_analyze: - start_date = target_start.replace(year=year).strftime("%Y-%m-%d") - end_date = target_end.replace(year=year).strftime("%Y-%m-%d") - - result = self._get_historical_weather_data( - longitude=longitude, - latitude=latitude, - start_date=start_date, - end_date=end_date, - temperature_unit=temperature_unit, - api_key=api_key, - ) - - if result["status"] == "success": - all_data.append(result["data"]) - - if not all_data: - return "Error: No historical data could be retrieved" - - # Initialize storage for averaging - daily_fields = [ - "temperature_2m_max", - "temperature_2m_min", - "temperature_2m_mean", - "apparent_temperature_max", - "apparent_temperature_min", - "precipitation_sum", - "rain_sum", - "snowfall_sum", - "precipitation_hours", - "windspeed_10m_max", - "windgusts_10m_max", - "winddirection_10m_dominant", - ] - - # Initialize storage for each field - averaged_data = {field: [] for field in daily_fields} - - # Number of days in the date range - num_days = (target_end - target_start).days + 1 - - # Calculate averages for each day in the range - for day_index in range(num_days): - day_values = {field: [] for field in daily_fields} - - # Collect values from all years for this day - for year_data in all_data: - if "daily" in year_data: - for field in daily_fields: - if field in year_data["daily"]: - value = year_data["daily"][field][day_index] - if value is not None: # Skip None values - day_values[field].append(value) - - # Calculate averages for each field - for field in daily_fields: - values = day_values[field] - if values: - if averaging_method == "median": - avg_value = statistics.median(values) - else: # default to mean - avg_value = statistics.mean(values) - averaged_data[field].append(round(avg_value, 2)) - else: - averaged_data[field].append(None) - - # Format the response - dates = [ - (target_start + timedelta(days=i)).strftime("%m-%d") - for i in range(num_days) - ] - temp_symbol = "°F" if temperature_unit.lower() == "fahrenheit" else "°C" - - formatted_text = "Historical Weather Averages\n" - formatted_text += "=" * 40 + "\n\n" - formatted_text += f"Analysis of {len(years_to_analyze)} years: {', '.join(map(str, years_to_analyze))}\n" - formatted_text += f"Using {averaging_method} averaging method\n\n" - - for i, date in enumerate(dates): - formatted_text += f"Date: {date}\n" - formatted_text += f" Temperature Range: {averaged_data['temperature_2m_min'][i]}{temp_symbol} to {averaged_data['temperature_2m_max'][i]}{temp_symbol}\n" - formatted_text += f" Average Temperature: {averaged_data['temperature_2m_mean'][i]}{temp_symbol}\n" - formatted_text += f" Feels Like Range: {averaged_data['apparent_temperature_min'][i]}{temp_symbol} to {averaged_data['apparent_temperature_max'][i]}{temp_symbol}\n" - - if averaged_data["precipitation_sum"][i]: - formatted_text += f" Precipitation: {averaged_data['precipitation_sum'][i]}mm over {averaged_data['precipitation_hours'][i]} hours\n" - - if averaged_data["rain_sum"][i]: - formatted_text += f" Rain: {averaged_data['rain_sum'][i]}mm\n" - - if averaged_data["snowfall_sum"][i]: - formatted_text += ( - f" Snowfall: {averaged_data['snowfall_sum'][i]}cm\n" - ) - - formatted_text += ( - f" Max Wind Speed: {averaged_data['windspeed_10m_max'][i]}km/h\n" - ) - formatted_text += ( - f" Max Wind Gusts: {averaged_data['windgusts_10m_max'][i]}km/h\n" - ) - formatted_text += "------------------------\n" - - return formatted_text - - except Exception as e: - return f"Error calculating historical averages: {str(e)}" diff --git a/src/agentic/tools/rag_tool.py b/src/agentic/tools/rag_tool.py index 43d27d31..0c03848e 100644 --- a/src/agentic/tools/rag_tool.py +++ b/src/agentic/tools/rag_tool.py @@ -14,6 +14,8 @@ init_embedding_model, init_chunker, rag_index_file, + rag_index_multiple_files, + delete_document_from_index, ) from agentic.utils.summarizer import generate_document_summary @@ -22,6 +24,7 @@ from weaviate.collections.classes.grpc import Sort from weaviate.classes.config import VectorDistances +from rich.console import Console @tool_registry.register( name="RAGTool", @@ -38,20 +41,59 @@ class RAGTool(BaseAgenticTool): def __init__( self, default_index: str = "knowledge_base", - index_paths: list[str] = [] + index_paths: list[str] = [], + recursive: bool = False, + overwrite_index = True, ): # Construct the RAG tool. You can pass a list of files and we will ensure that # they are added to the index on startup. Paths can include glob patterns also, # like './docs/*.md'. + # Enable recursive (**.md) glob patterns with recursive = True + self.default_index = default_index self.index_paths = index_paths if self.index_paths: client = init_weaviate() if default_index not in list_collections(client): create_collection(client, default_index, VectorDistances.COSINE) + + # Keep track of files found during initialization + if overwrite_index: + indexed_documents = {} + for path in index_paths: - for file_path in [path] if path.startswith("http") else glob.glob(path): - rag_index_file(file_path, self.default_index, client=client, ignore_errors=True) + if path.startswith("http"): + document_id = rag_index_file(path, self.default_index, client=client, ignore_errors=True) + + if overwrite_index: + indexed_documents[document_id] = True + + else: + file_paths = glob.glob(path, recursive=recursive) + document_ids = rag_index_multiple_files(file_paths, self.default_index, client=client, ignore_errors=True) + + if overwrite_index: + for document_id in document_ids: + indexed_documents[document_id] = True + + # Delete indexed files not found during initialization + if overwrite_index: + try: + console = Console() + collection = client.collections.get(self.default_index) + documents = list_documents_in_collection(collection) + + for document in documents: + if not document["document_id"] in indexed_documents: + console.print(f"[bold green]✅ Removing deleted file {document['filename']} from index") + delete_document_from_index(collection=collection,document_id=document["document_id"],filename=document["filename"]) + + except Exception as e: + print(f"Error listing documents: {str(e)}") + return + finally: + if client: + client.close() def get_tools(self) -> List[Callable]: return [ @@ -59,7 +101,7 @@ def get_tools(self) -> List[Callable]: #self.list_indexes, self.search_knowledge_index, self.list_documents, - self.review_full_document + self.review_full_document, ] def save_content_to_knowledge_index( diff --git a/src/agentic/utils/file_reader.py b/src/agentic/utils/file_reader.py index 3ea5ae2f..70671fba 100644 --- a/src/agentic/utils/file_reader.py +++ b/src/agentic/utils/file_reader.py @@ -57,6 +57,9 @@ def read_file(file_path: str, mime_type: str|None = None) -> tuple[str, str]: return text, mime_type elif mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": return pd.read_excel(file_path).to_csv(), mime_type + elif mime_type == "text/x-python": + with open(file_path,"r") as f: + return f.read(), mime_type else: return textract.process(file_path).decode('utf-8'), mime_type except Exception as e: diff --git a/src/agentic/utils/rag_helper.py b/src/agentic/utils/rag_helper.py index 94dd7f52..89a1f390 100644 --- a/src/agentic/utils/rag_helper.py +++ b/src/agentic/utils/rag_helper.py @@ -91,7 +91,7 @@ def prepare_document_metadata( # Generate document ID from filename metadata["document_id"] = hashlib.sha256( - metadata["filename"].encode() + str(Path(file_path)).encode() ).hexdigest() return metadata @@ -155,7 +155,7 @@ def rag_index_file( client: WeaviateClient|None = None, ignore_errors: bool = False, distance_metric: VectorDistances = VectorDistances.COSINE, -): +) -> str: """Index a file using configurable Weaviate Embedded and chunking parameters""" console = Console() @@ -186,10 +186,10 @@ def rag_index_file( if status == "unchanged": console.print(f"[yellow]⏩ Document '{metadata['filename']}' unchanged[/yellow]") - return + return metadata["document_id"] elif status == "duplicate": console.print(f"[yellow]⚠️ Content already exists under different filename[/yellow]") - return + return metadata["document_id"] elif status == "changed": console.print(f"[yellow]🔄 Updating changed document '{metadata['filename']}'[/yellow]") collection.data.delete_many( @@ -233,8 +233,102 @@ def rag_index_file( finally: if client and client_created: client.close() - return "indexed" + return metadata["document_id"] + +def rag_index_multiple_files( + file_paths: List[str], + index_name: str, + chunk_threshold: float = 0.5, + chunk_delimiters: str = ". ,! ,? ,\n", + embedding_model: str = "BAAI/bge-small-en-v1.5", + client: WeaviateClient|None = None, + ignore_errors: bool = False, + distance_metric: VectorDistances = VectorDistances.COSINE, +) -> List[str]: + """Index a file using configurable Weaviate Embedded and chunking parameters""" + + console = Console() + client_created = False + + documents_indexed = [] + try: + with Status("[bold green]Initializing Weaviate..."): + if client is None: + client = init_weaviate() + client_created = True + create_collection(client, index_name, distance_metric) + + with Status("[bold green]Initializing models..."): + embed_model = init_embedding_model(embedding_model) + chunker = init_chunker(chunk_threshold, chunk_delimiters) + for file_path in file_paths: + with Status(f"[bold green]Processing {file_path}...", console=console): + text, mime_type = read_file(str(file_path)) + metadata = prepare_document_metadata(file_path, text, mime_type, GPT_DEFAULT_MODEL) + + console.print(f"[bold green]Indexing {file_path}...") + + collection = client.collections.get(index_name) + exists, status = check_document_exists( + collection, + metadata["document_id"], + metadata["fingerprint"] + ) + + if status == "unchanged": + console.print(f"[yellow]⏩ Document '{metadata['filename']}' unchanged[/yellow]") + documents_indexed.append(metadata["document_id"]) + continue + elif status == "duplicate": + console.print(f"[yellow]⚠️ Content already exists under different filename[/yellow]") + documents_indexed.append(metadata["document_id"]) + continue + elif status == "changed": + console.print(f"[yellow]🔄 Updating changed document '{metadata['filename']}'[/yellow]") + collection.data.delete_many( + where=Filter.by_property("document_id").equal(metadata["document_id"]) + ) + + with Status("[bold green]Generating document summary...", console=console): + metadata["summary"] = generate_document_summary( + text=text[:12000], + mime_type=mime_type, + model=GPT_DEFAULT_MODEL + ) + + chunks = chunker(text) + chunks_text = [chunk.text for chunk in chunks] + if not chunks_text: + if ignore_errors: + return client + raise ValueError("No text chunks generated from document") + + batch_size = 128 + embeddings = [] + with Status("[bold green]Generating embeddings..."): + for i in range(0, len(chunks_text), batch_size): + batch = chunks_text[i:i+batch_size] + embeddings.extend(list(embed_model.embed(batch))) + + with Status("[bold green]Indexing chunks..."), collection.batch.dynamic() as batch: + for i, chunk in enumerate(chunks): + vector = embeddings[i].tolist() + batch.add_object( + properties={ + **metadata, + "content": chunk.text, + "chunk_index": i, + }, + vector=vector + ) + + documents_indexed.append(metadata["document_id"]) + console.print(f"[bold green]✅ Indexed {len(chunks)} chunks in {index_name}") + finally: + if client and client_created: + client.close() + return documents_indexed def delete_document_from_index( collection: Any,