diff --git a/.github/workflows/pr-summary-agent.yml b/.github/workflows/pr-summary-agent.yml
index 1fe312bf..ac8cb038 100644
--- a/.github/workflows/pr-summary-agent.yml
+++ b/.github/workflows/pr-summary-agent.yml
@@ -18,7 +18,10 @@ jobs:
- name: Install uv
uses: astral-sh/setup-uv@v6
-
+
+ - name: Download a single artifact
+ uses: actions/download-artifact@v4
+
- name: Run agent
run: |
uv venv --python 3.12
@@ -33,3 +36,9 @@ jobs:
GITHUB_API_KEY: ${{ secrets.GITHUB_TOKEN }}
REPO_OWNER: supercog-ai
REPO_NAME: PR_code_review-agent
+
+ - name: Update weaviate cache
+ uses: actions/upload-artifact@v4
+ with:
+ name: weaviate
+ path: home/.cache/weaviate
diff --git a/PRChangesTest.patch b/PRChangesTest.patch
new file mode 100644
index 00000000..80fe4bfc
--- /dev/null
+++ b/PRChangesTest.patch
@@ -0,0 +1,348 @@
+diff --git a/.github/workflows/pr-summary-agent.yml b/.github/workflows/pr-summary-agent.yml
+index 88e50da..1fe312b 100644
+--- a/.github/workflows/pr-summary-agent.yml
++++ b/.github/workflows/pr-summary-agent.yml
+@@ -25,10 +25,8 @@ jobs:
+ uv pip install -e "../PR_code_review-agent[all,dev]" --extra-index-url https://download.pytorch.org/whl/cpu --index-strategy unsafe-first-match
+ git diff --merge-base HEAD^1 HEAD > PRChanges.patch
+ cat PRChanges.patch
+- uv run pr_agent/test_files/mock_pr_agent.py
+-
+-#uv run pr_agent/PR_agent.py
+-
++ uv run pr_agent/PR_agent.py
++
+ env:
+ OPENAI_API_KEY: ${{ secrets.PRAgentOpenAIKey }}
+ PR_ID: ${{ github.event.pull_request.number }}
+diff --git a/pr_agent/PR_agent.py b/pr_agent/PR_agent.py
+index 46b8f1f..c99d863 100644
+--- a/pr_agent/PR_agent.py
++++ b/pr_agent/PR_agent.py
+@@ -29,14 +29,6 @@ class SearchResult(BaseModel):
+ similarity_score: float = Field(
+ desciption="Similarity score returned from vector search."
+ )
+- is_relevant: bool = Field(
+- default = True,
+- description="Boolean describing if the search result is relevant to the query."
+- )
+- relevance_reason: str = Field(
+- default = "",
+- description="Boolean describing if the search result is relevant to the query."
+- )
+ included_defs: List[str] = Field(
+ default_factory=list,
+ desciption="Similarity score returned from vector search."
+@@ -49,7 +41,6 @@ class Searches(BaseModel):
+
+ class RelevanceResult(BaseModel):
+ relevant: bool
+- reason: str
+
+ class PRReviewAgent(Agent):
+
+@@ -81,7 +72,7 @@ You are an expert in generating NON-NATURAL LANGUAGE CODE search queries from a
+
+ self.relevanceAgent = Agent(
+ name="Code Relevange Agent",
+- instructions="""You are an expert in determining if a snippet of code or documentation is needed to determine the purpose of a code change from the patch file. Your response must include a 'relevant' field boolean and a 'reason' field with a brief explanation.""",
++ instructions="""You are an expert in determining if a snippet of code or documentation is directly relevant to a query. Your response must include a 'relevant' field boolean.""",
+ model=GPT_4O_MINI,
+ result_model=RelevanceResult,
+ )
+@@ -142,10 +133,10 @@ You are an expert in generating NON-NATURAL LANGUAGE CODE search queries from a
+ }
+ )
+
+- print("quer"+str(queries))
++ print("queries: "+str(queries))
+
+- all_results = []
+-
++ # RAG queries
++ all_results = {}
+ for query in queries.searches[:10]:
+ searchResponse = yield from self.code_rag_agent.final_result(
+ f"Search codebase",
+@@ -156,35 +147,34 @@ You are an expert in generating NON-NATURAL LANGUAGE CODE search queries from a
+ )
+
+ # Process each result
+- for result in searchResponse.sections:
+- all_results.append(SearchResult(query=query,file_path=result.file_path,content=result.search_result,similarity_score=result.similarity_score,included_defs=result.included_defs))
++ for key, result in searchResponse.sections.items():
++ if not key in all_results:
++ all_results[key] = SearchResult(query=query,file_path=result.file_path,content=result.search_result,similarity_score=result.similarity_score,included_defs=result.included_defs)
+
+- print("fil"+str(all_results))
++ print("all: "+str(all_results))
+
+ # Filter search results using LLM-based relevance checking
+ filtered_results = []
+-
+- for result in all_results:
+- if result.similarity_score < 0.5:
+- continue
+-
+- relevance_check = yield from self.relevanceAgent.final_result(
+- f"\n{request_context.get("patch_content")}\n\n\n{result.content}{result.query}"
+- )
+-
+- print(relevance_check)
+-
+- result.is_relevant = relevance_check.relevant
+- result.relevance_reason = relevance_check.reason
++ for result in all_results.values():
+
+- if result.is_relevant:
+- filtered_results.append(result)
++ try:
++ relevance_check = yield from self.relevanceAgent.final_result(
++ f"\n{request_context.get("patch_content")}\n\n\n{result.content}{result.query}"
++ )
++
++ if relevance_check.relevant:
++ filtered_results.append(result)
++ except Exception as e:
++ # LLM error
++ print(e)
+
+- print(str(filtered_results))
++ print("filtered: ",str(filtered_results))
+
+ # Prepare for summary
+ formatted_str = self.prepare_summary(request_context.get("patch_content"),filtered_results)
+
++ print(formatted_str)
++
+ summary = yield from self.summaryAgent.final_result(
+ formatted_str
+ )
+diff --git a/pr_agent/code_rag_agent.py b/pr_agent/code_rag_agent.py
+index b7f36df..32d0a34 100644
+--- a/pr_agent/code_rag_agent.py
++++ b/pr_agent/code_rag_agent.py
+@@ -21,7 +21,7 @@ class CodeSection(BaseModel):
+ )
+
+ class CodeSections(BaseModel):
+- sections: List[CodeSection] = Field(
++ sections: dict[str,CodeSection] = Field(
+ description="Sections of the codebase returned from the search.",
+ )
+ search_query: str = Field(
+@@ -46,7 +46,8 @@ class CodeRagAgent(Agent):
+
+ self.ragTool = RAGTool(
+ default_index="codebase",
+- index_paths=["../*.md","../*.py"],
++ index_paths=["../**/*.py","../**/*.md"],
++ recursive=True
+ )
+
+
+@@ -65,24 +66,28 @@ class CodeRagAgent(Agent):
+ searchQuery = request_context.get("query")
+
+ searchResult = self.ragTool.search_knowledge_index(query=searchQuery,limit=5)
+-
+- allSections = CodeSections(sections=[],search_query=query)
++
++ allSections = CodeSections(sections={},search_query=query)
+
+ for nextResult in searchResult:
+- print(nextResult)
+ file_path = nextResult["source_url"]
+- similarity_score = nextResult["distance"] if nextResult["distance"] else 0
+- content = nextResult["content"]
++ if not file_path in allSections.sections:
++ #print(nextResult)
++
++ similarity_score = nextResult["distance"] if nextResult["distance"] else 0
++ content = nextResult["content"]
+
+- # Only works with Python files
+- included_defs = []
+- try:
+- with open(file_path) as file:
+- node = ast.parse(file.read())
+- included_defs = [n.name for n in node.body if isinstance(n, ast.ClassDef) or isinstance(n, ast.FunctionDef)]
+- except:
++ # Only works with Python files
+ included_defs = []
++ try:
++ with open(file_path) as file:
++ node = ast.parse(file.read())
++ included_defs = [n.name for n in node.body if isinstance(n, ast.ClassDef) or isinstance(n, ast.FunctionDef)]
++ except:
++ included_defs = []
+
+- allSections.sections.append(CodeSection(search_result=content,file_path=file_path,included_defs=included_defs,similarity_score=similarity_score))
++ allSections.sections[file_path] = CodeSection(search_result=content,file_path=file_path,included_defs=included_defs,similarity_score=similarity_score)
++ #else:
++ #print("Skipping Duplicate: ",file_path)
+
+ yield TurnEnd(self.name, [{"content": allSections}])
+diff --git a/src/agentic/events.py b/src/agentic/events.py
+index 01aa68e..da12979 100644
+--- a/src/agentic/events.py
++++ b/src/agentic/events.py
+@@ -656,7 +656,6 @@ class TurnEnd(Event):
+ def result(self):
+ """Safe result access with fallback"""
+ try:
+- print(self.agent,self.messages)
+ return self.messages[-1]["content"] if self.messages else "No response generated"
+ except (IndexError, KeyError):
+ return "Error: Malformed response"
+diff --git a/src/agentic/tools/rag_tool.py b/src/agentic/tools/rag_tool.py
+index e3e7280..9891e7c 100644
+--- a/src/agentic/tools/rag_tool.py
++++ b/src/agentic/tools/rag_tool.py
+@@ -14,6 +14,7 @@ from agentic.utils.rag_helper import (
+ init_embedding_model,
+ init_chunker,
+ rag_index_file,
++ rag_index_multiple_files,
+ )
+
+ from agentic.utils.summarizer import generate_document_summary
+@@ -44,6 +45,8 @@ class RAGTool(BaseAgenticTool):
+ # Construct the RAG tool. You can pass a list of files and we will ensure that
+ # they are added to the index on startup. Paths can include glob patterns also,
+ # like './docs/*.md'.
++ # Enable recursive (**.md) glob patterns with recursive = True
++
+ self.default_index = default_index
+ self.index_paths = index_paths
+ if self.index_paths:
+@@ -51,8 +54,11 @@ class RAGTool(BaseAgenticTool):
+ if default_index not in list_collections(client):
+ create_collection(client, default_index, VectorDistances.COSINE)
+ for path in index_paths:
+- for file_path in [path] if path.startswith("http") else glob.glob(path, recursive=recursive):
+- rag_index_file(file_path, self.default_index, client=client, ignore_errors=True)
++ if path.startswith("http"):
++ rag_index_file(path, self.default_index, client=client, ignore_errors=True)
++ else:
++ file_paths = glob.glob(path, recursive=recursive)
++ rag_index_multiple_files(file_paths, self.default_index, client=client, ignore_errors=True)
+
+ def get_tools(self) -> List[Callable]:
+ return [
+diff --git a/src/agentic/utils/file_reader.py b/src/agentic/utils/file_reader.py
+index 3ea5ae2..70671fb 100644
+--- a/src/agentic/utils/file_reader.py
++++ b/src/agentic/utils/file_reader.py
+@@ -57,6 +57,9 @@ def read_file(file_path: str, mime_type: str|None = None) -> tuple[str, str]:
+ return text, mime_type
+ elif mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
+ return pd.read_excel(file_path).to_csv(), mime_type
++ elif mime_type == "text/x-python":
++ with open(file_path,"r") as f:
++ return f.read(), mime_type
+ else:
+ return textract.process(file_path).decode('utf-8'), mime_type
+ except Exception as e:
+diff --git a/src/agentic/utils/rag_helper.py b/src/agentic/utils/rag_helper.py
+index 94dd7f5..1d585f2 100644
+--- a/src/agentic/utils/rag_helper.py
++++ b/src/agentic/utils/rag_helper.py
+@@ -234,7 +234,96 @@ def rag_index_file(
+ if client and client_created:
+ client.close()
+ return "indexed"
++
++def rag_index_multiple_files(
++ file_paths: List[str],
++ index_name: str,
++ chunk_threshold: float = 0.5,
++ chunk_delimiters: str = ". ,! ,? ,\n",
++ embedding_model: str = "BAAI/bge-small-en-v1.5",
++ client: WeaviateClient|None = None,
++ ignore_errors: bool = False,
++ distance_metric: VectorDistances = VectorDistances.COSINE,
++):
++ """Index a file using configurable Weaviate Embedded and chunking parameters"""
++
++ console = Console()
++ client_created = False
++ try:
++ with Status("[bold green]Initializing Weaviate..."):
++ if client is None:
++ client = init_weaviate()
++ client_created = True
++ create_collection(client, index_name, distance_metric)
++
++ with Status("[bold green]Initializing models..."):
++ embed_model = init_embedding_model(embedding_model)
++ chunker = init_chunker(chunk_threshold, chunk_delimiters)
+
++ for file_path in file_paths:
++ with Status(f"[bold green]Processing {file_path}...", console=console):
++ text, mime_type = read_file(str(file_path))
++ metadata = prepare_document_metadata(file_path, text, mime_type, GPT_DEFAULT_MODEL)
++
++ console.print(f"[bold green]Indexing {file_path}...")
++
++ collection = client.collections.get(index_name)
++ exists, status = check_document_exists(
++ collection,
++ metadata["document_id"],
++ metadata["fingerprint"]
++ )
++
++ if status == "unchanged":
++ console.print(f"[yellow]⏩ Document '{metadata['filename']}' unchanged[/yellow]")
++ continue
++ elif status == "duplicate":
++ console.print(f"[yellow]⚠️ Content already exists under different filename[/yellow]")
++ continue
++ elif status == "changed":
++ console.print(f"[yellow]🔄 Updating changed document '{metadata['filename']}'[/yellow]")
++ collection.data.delete_many(
++ where=Filter.by_property("document_id").equal(metadata["document_id"])
++ )
++
++ with Status("[bold green]Generating document summary...", console=console):
++ metadata["summary"] = generate_document_summary(
++ text=text[:12000],
++ mime_type=mime_type,
++ model=GPT_DEFAULT_MODEL
++ )
++
++ chunks = chunker(text)
++ chunks_text = [chunk.text for chunk in chunks]
++ if not chunks_text:
++ if ignore_errors:
++ return client
++ raise ValueError("No text chunks generated from document")
++
++ batch_size = 128
++ embeddings = []
++ with Status("[bold green]Generating embeddings..."):
++ for i in range(0, len(chunks_text), batch_size):
++ batch = chunks_text[i:i+batch_size]
++ embeddings.extend(list(embed_model.embed(batch)))
++
++ with Status("[bold green]Indexing chunks..."), collection.batch.dynamic() as batch:
++ for i, chunk in enumerate(chunks):
++ vector = embeddings[i].tolist()
++ batch.add_object(
++ properties={
++ **metadata,
++ "content": chunk.text,
++ "chunk_index": i,
++ },
++ vector=vector
++ )
++
++ console.print(f"[bold green]✅ Indexed {len(chunks)} chunks in {index_name}")
++ finally:
++ if client and client_created:
++ client.close()
++ return "indexed"
+
+ def delete_document_from_index(
+ collection: Any,
diff --git a/pr_agent/PR_agent.py b/pr_agent/PR_agent.py
index 7cf3c4e2..3867aeff 100644
--- a/pr_agent/PR_agent.py
+++ b/pr_agent/PR_agent.py
@@ -8,6 +8,9 @@
from git_grep_agent import GitGrepAgent
from summary_agent import SummaryAgent
from pydantic import BaseModel
+from typing import Dict, List, Any, Generator, Optional, Tuple
+from agentic.common import Agent, AgentRunner, ThreadContext
+from agentic.events import Event, ChatOutput, TurnEnd, PromptStarted, Prompt
load_dotenv()
@@ -129,28 +132,61 @@ def post_to_github(self, summary: str) -> str:
response.raise_for_status()
return response.json().get("html_url")
-
- def generate(self, patch_content: str) -> str:
+ def next_turn(
+ self,
+ request: str,
+ request_context: dict = None,
+ request_id: str = None,
+ continue_result: dict = {},
+ debug = "",
+ ) -> Generator[Event, Any, None]:
+
+ query = request.payload if isinstance(request, Prompt) else request
+ yield PromptStarted(query, {"query": query})
+
# Generate search queries
- queries = self.queryAgent << patch_content
+ queries = yield from self.queryAgent.final_result(
+ request_context.get("patch_content"),
+ request_context={
+ "thread_id": request_context.get("thread_id")
+ }
+ )
+
+ print("queries: "+str(queries))
- # Git-Grep queries
+ # RAG and Git-Grep queries
all_results = {}
for query in queries.searches[:10]:
searchResponse = self.git_grep_agent.get_search(query)
- if len(searchResponse.sections) > 0:
- # Process each result
- # grep_response.sections is a list of CodeSection objects
- for result in searchResponse.sections:
- if result.file_path not in all_results:
- all_results[result.file_path] = SearchResult(
+ # Process each result
+ for file, result in searchResponse.sections.items():
+ if not file in all_results:
+ all_results[file] = SearchResult(query=query,file_path=result.file_path,content=result.search_result,similarity_score=result.similarity_score,included_defs=result.included_defs)
+
+ searchResponse = yield from self.git_grep_agent.final_result(
+ f"Search codebase with git grep",
+ request_context={
+ "query": query,
+ "thread_id": request_context.get("thread_id")
+ }
+ )
+
+ # Process each result
+ # grep_response.sections is a list of CodeSection objects
+ for file, result in searchResponse.sections.items():
+ if not file in all_results:
+ all_results[file] = SearchResult(
query=query,
file_path=result.file_path,
content=result.search_result,
+ similarity_score=result.similarity_score,
included_defs=result.included_defs
- )
-
+ )
+
+
+ print("all: "+str(all_results))
+
# Filter search results using LLM-based relevance checking
filtered_results = []
for result in all_results.values():
@@ -176,8 +212,9 @@ def generate(self, patch_content: str) -> str:
pr_review_agent = PRReviewAgent()
if __name__ == "__main__":
+ # Change to PRChangesTest.patch for testing
with open("PRChanges.patch", "r") as f:
patch_content = f.read()
# Run the agent
- print(pr_review_agent.generate(patch_content))
\ No newline at end of file
+ print(pr_review_agent.final_result(patch_content))
\ No newline at end of file
diff --git a/pr_agent/code_rag_agent.py b/pr_agent/code_rag_agent.py
index 32d0a346..474550ef 100644
--- a/pr_agent/code_rag_agent.py
+++ b/pr_agent/code_rag_agent.py
@@ -46,7 +46,7 @@ def __init__(self,
self.ragTool = RAGTool(
default_index="codebase",
- index_paths=["../**/*.py","../**/*.md"],
+ index_paths=["../**/*.md"],
recursive=True
)
diff --git a/pr_agent/code_rag_agent.txt b/pr_agent/code_rag_agent.txt
deleted file mode 100644
index b7f36df7..00000000
--- a/pr_agent/code_rag_agent.txt
+++ /dev/null
@@ -1,88 +0,0 @@
-from typing import Any, Generator, List
-from agentic.common import Agent, AgentRunner, ThreadContext
-from agentic.events import Event, ChatOutput, WaitForInput, Prompt, PromptStarted, TurnEnd, ResumeWithInput
-from agentic.models import GPT_4O_MINI # model (using GPT for testing)
-from pydantic import BaseModel, Field
-from agentic.tools.rag_tool import RAGTool
-import ast
-
-class CodeSection(BaseModel):
- search_result: str = Field(
- description="Part returned from search.",
- )
- file_path: str = Field(
- description="Path of the file this code belongs to."
- )
- included_defs: list[str] = Field(
- description="Classes and functions defined in this file."
- )
- similarity_score: float = Field(
- desciption="Similarity score returned from vector search."
- )
-
-class CodeSections(BaseModel):
- sections: List[CodeSection] = Field(
- description="Sections of the codebase returned from the search.",
- )
- search_query: str = Field(
- description="Query used to return this section.",
- )
-
-class CodeRagAgent(Agent):
- def __init__(self,
- name="Code Rag Agent",
- welcome="I am the Code Rag Agent. Please give me a search query (function name,class name, etc.) and I'll return relevant parts of the code.",
- model: str=GPT_4O_MINI,
- result_model = CodeSections,
- **kwargs
- ):
- super().__init__(
- name=name,
- welcome=welcome,
- model=model,
- result_model=result_model,
- **kwargs
- )
-
- self.ragTool = RAGTool(
- default_index="codebase",
- index_paths=["../*.md","../*.py"],
- )
-
-
- def next_turn(
- self,
- request: str|Prompt,
- request_context: dict = {},
- request_id: str = None,
- continue_result: dict = {},
- debug = "",
- ) -> Generator[Event, Any, Any]:
-
- query = request.payload if isinstance(request, Prompt) else request
- yield PromptStarted(query, {"query": query})
-
- searchQuery = request_context.get("query")
-
- searchResult = self.ragTool.search_knowledge_index(query=searchQuery,limit=5)
-
- allSections = CodeSections(sections=[],search_query=query)
-
- for nextResult in searchResult:
- print(nextResult)
- file_path = nextResult["source_url"]
- similarity_score = nextResult["distance"] if nextResult["distance"] else 0
- content = nextResult["content"]
-
- # Only works with Python files
- included_defs = []
- try:
- with open(file_path) as file:
- node = ast.parse(file.read())
- included_defs = [n.name for n in node.body if isinstance(n, ast.ClassDef) or isinstance(n, ast.FunctionDef)]
- except:
- included_defs = []
-
- allSections.sections.append(CodeSection(search_result=content,file_path=file_path,included_defs=included_defs,similarity_score=similarity_score))
-
- yield TurnEnd(self.name, [{"content": allSections}])
diff --git a/pr_agent/git_grep_agent.py b/pr_agent/git_grep_agent.py
index 906f2a4e..817bdd6e 100644
--- a/pr_agent/git_grep_agent.py
+++ b/pr_agent/git_grep_agent.py
@@ -3,33 +3,7 @@
import subprocess
import ast
-
-# Defines structured data containers for the serach/query results
-# each CodeSection object will represent one match from a git grep search
-class CodeSection(BaseModel):
- search_result: str = Field(
- description="Matching line returned from git grep.",
- )
- file_path: str = Field(
- description="Path of the file containing the match ."
- )
- included_defs: list[str] = Field(
- description="Classes and functions defined in this file."
- )
-
-
-
-# Represents the collection of matches for one serach query
-class CodeSections(BaseModel):
- # list of CodeSection objects
- sections: List[CodeSection] = Field(
- description="Sections of the codebase returned from the git grep search.",
- )
- # This is the query used for git grep
- search_query: str = Field(
- description="Query used to return this section.",
- )
-
+from code_rag_agent import CodeSection, CodeSections
# The actual sub-agent that runs git grep and returns structured results
class GitGrepAgent():
@@ -71,11 +45,11 @@ def get_search(self, search_query: str) -> CodeSections:
# TODO: verify that sections doesn't have to be a dictionary instead (like code_rag_agent implementation)
- allSections = CodeSections(sections=[], search_query=search_query) # creates an empty CodeSections object
+ allSections = CodeSections(sections={}, search_query=search_query) # creates an empty CodeSections object
# loops over each grep match
for file_path, matched_line in grep_results:
- if file_path not in allSections.sections:
+ if not file_path in allSections.sections:
included_defs = []
try:
if file_path.endswith(".py"): # if a python file, parse the AST, and collect all function/class names
@@ -85,16 +59,17 @@ def get_search(self, search_query: str) -> CodeSections:
n.name for n in node.body
if isinstance (n, ast.ClassDef) or isinstance(n, ast.FunctionDef)
]
+ else:
+ continue # ONLY search for .py files
except:
included_defs = []
- # Only add if this file_path hasn’t already been added
- if not any(sec.file_path == file_path for sec in allSections.sections):
- allSections.sections.append(CodeSection(
- search_result=matched_line,
- file_path=file_path,
- included_defs=included_defs
- ))
+ allSections.sections[file_path] = CodeSection(
+ search_result=matched_line,
+ file_path=file_path,
+ included_defs=included_defs,
+ similarity_score=1.0 # grep doesn't do semantic scoring
+ )
return allSections
diff --git a/pr_agent/rag_sub_agent.py b/pr_agent/rag_sub_agent.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/src/agentic/tools/rag_tool.py b/src/agentic/tools/rag_tool.py
index 9891e7c9..b42bea28 100644
--- a/src/agentic/tools/rag_tool.py
+++ b/src/agentic/tools/rag_tool.py
@@ -15,6 +15,7 @@
init_chunker,
rag_index_file,
rag_index_multiple_files,
+ delete_document_from_index,
)
from agentic.utils.summarizer import generate_document_summary
@@ -23,6 +24,7 @@
from weaviate.collections.classes.grpc import Sort
from weaviate.classes.config import VectorDistances
+from rich.console import Console
@tool_registry.register(
name="RAGTool",
@@ -41,6 +43,7 @@ def __init__(
default_index: str = "knowledge_base",
index_paths: list[str] = [],
recursive: bool = False,
+ overwrite_index = True,
):
# Construct the RAG tool. You can pass a list of files and we will ensure that
# they are added to the index on startup. Paths can include glob patterns also,
@@ -53,12 +56,44 @@ def __init__(
client = init_weaviate()
if default_index not in list_collections(client):
create_collection(client, default_index, VectorDistances.COSINE)
+
+ # Keep track of files found during initialization
+ if overwrite_index:
+ indexed_documents = {}
+
for path in index_paths:
if path.startswith("http"):
- rag_index_file(path, self.default_index, client=client, ignore_errors=True)
+ document_id = rag_index_file(path, self.default_index, client=client, ignore_errors=True)
+
+ if overwrite_index:
+ indexed_documents[document_id] = True
+
else:
file_paths = glob.glob(path, recursive=recursive)
- rag_index_multiple_files(file_paths, self.default_index, client=client, ignore_errors=True)
+ document_ids = rag_index_multiple_files(file_paths, self.default_index, client=client, ignore_errors=True)
+
+ if overwrite_index:
+ for document_id in document_ids:
+ indexed_documents[document_id] = True
+
+ # Delete indexed files not found during initialization
+ if overwrite_index:
+ try:
+ console = Console()
+ collection = client.collections.get(self.default_index)
+ documents = list_documents_in_collection(collection)
+
+ for document in documents:
+ if not document["document_id"] in indexed_documents:
+ console.print(f"[bold green]✅ Removing deleted file {document["filename"]} from index")
+ delete_document_from_index(collection=collection,document_id=document["document_id"],filename=document["filename"])
+
+ except Exception as e:
+ print(f"Error listing documents: {str(e)}")
+ return
+ finally:
+ if client:
+ client.close()
def get_tools(self) -> List[Callable]:
return [
@@ -66,7 +101,7 @@ def get_tools(self) -> List[Callable]:
#self.list_indexes,
self.search_knowledge_index,
self.list_documents,
- self.review_full_document
+ self.review_full_document,
]
def save_content_to_knowledge_index(
diff --git a/src/agentic/utils/rag_helper.py b/src/agentic/utils/rag_helper.py
index 1d585f2d..89a1f390 100644
--- a/src/agentic/utils/rag_helper.py
+++ b/src/agentic/utils/rag_helper.py
@@ -91,7 +91,7 @@ def prepare_document_metadata(
# Generate document ID from filename
metadata["document_id"] = hashlib.sha256(
- metadata["filename"].encode()
+ str(Path(file_path)).encode()
).hexdigest()
return metadata
@@ -155,7 +155,7 @@ def rag_index_file(
client: WeaviateClient|None = None,
ignore_errors: bool = False,
distance_metric: VectorDistances = VectorDistances.COSINE,
-):
+) -> str:
"""Index a file using configurable Weaviate Embedded and chunking parameters"""
console = Console()
@@ -186,10 +186,10 @@ def rag_index_file(
if status == "unchanged":
console.print(f"[yellow]⏩ Document '{metadata['filename']}' unchanged[/yellow]")
- return
+ return metadata["document_id"]
elif status == "duplicate":
console.print(f"[yellow]⚠️ Content already exists under different filename[/yellow]")
- return
+ return metadata["document_id"]
elif status == "changed":
console.print(f"[yellow]🔄 Updating changed document '{metadata['filename']}'[/yellow]")
collection.data.delete_many(
@@ -233,7 +233,7 @@ def rag_index_file(
finally:
if client and client_created:
client.close()
- return "indexed"
+ return metadata["document_id"]
def rag_index_multiple_files(
file_paths: List[str],
@@ -244,11 +244,13 @@ def rag_index_multiple_files(
client: WeaviateClient|None = None,
ignore_errors: bool = False,
distance_metric: VectorDistances = VectorDistances.COSINE,
-):
+) -> List[str]:
"""Index a file using configurable Weaviate Embedded and chunking parameters"""
console = Console()
client_created = False
+
+ documents_indexed = []
try:
with Status("[bold green]Initializing Weaviate..."):
if client is None:
@@ -276,9 +278,11 @@ def rag_index_multiple_files(
if status == "unchanged":
console.print(f"[yellow]⏩ Document '{metadata['filename']}' unchanged[/yellow]")
+ documents_indexed.append(metadata["document_id"])
continue
elif status == "duplicate":
console.print(f"[yellow]⚠️ Content already exists under different filename[/yellow]")
+ documents_indexed.append(metadata["document_id"])
continue
elif status == "changed":
console.print(f"[yellow]🔄 Updating changed document '{metadata['filename']}'[/yellow]")
@@ -318,12 +322,13 @@ def rag_index_multiple_files(
},
vector=vector
)
-
- console.print(f"[bold green]✅ Indexed {len(chunks)} chunks in {index_name}")
+
+ documents_indexed.append(metadata["document_id"])
+ console.print(f"[bold green]✅ Indexed {len(chunks)} chunks in {index_name}")
finally:
if client and client_created:
client.close()
- return "indexed"
+ return documents_indexed
def delete_document_from_index(
collection: Any,