From 36021ac3882f174e2ed5b7fd59c19c5b16df0a27 Mon Sep 17 00:00:00 2001 From: Greg Hogue Date: Thu, 22 Jan 2026 16:36:16 -0500 Subject: [PATCH 1/2] rewrite HybridRetriever class --- src/retrievers/csv_chroma.py | 210 +++++++++++++++++++++++++++++------ 1 file changed, 177 insertions(+), 33 deletions(-) diff --git a/src/retrievers/csv_chroma.py b/src/retrievers/csv_chroma.py index 691b884..7ab322e 100644 --- a/src/retrievers/csv_chroma.py +++ b/src/retrievers/csv_chroma.py @@ -1,20 +1,60 @@ from pathlib import Path +from typing import Annotated, Any, TypedDict import chromadb.config from langchain.chains.query_constructor.schema import AttributeInfo -from langchain.retrievers import EnsembleRetriever +from langchain.retrievers import EnsembleRetriever, MultiQueryRetriever from langchain.retrievers.merger_retriever import MergerRetriever from langchain.retrievers.self_query.base import SelfQueryRetriever from langchain_chroma.vectorstores import Chroma from langchain_community.document_loaders.csv_loader import CSVLoader from langchain_community.retrievers import BM25Retriever +from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.retrievers import BaseRetriever +from langchain_core.prompts.prompt import PromptTemplate from nltk.tokenize import word_tokenize +from pydantic import AfterValidator, Field +from pydantic.json_schema import SkipJsonSchema chroma_settings = chromadb.config.Settings(anonymized_telemetry=False) +multi_query_prompt = PromptTemplate( + input_variables=["question"], + template="""You are a biomedical question expansion engine for information retrieval over the Reactome biological pathway database. + +Given a single user question, generate **exactly 4** alternate standalone questions. These should be: + +- Semantically related to the original question. +- Lexically diverse to improve retrieval via vector search and RAG-fusion. +- Biologically enriched with inferred or associated details. + +Your goal is to improve recall of relevant documents by expanding the original query using: +- Synonymous gene/protein names (e.g., EGFR, ErbB1, HER1) +- Pathway or process-level context (e.g., signal transduction, apoptosis) +- Known diseases, phenotypes, or biological functions +- Cellular localization (e.g., nucleus, cytoplasm, membrane) +- Upstream/downstream molecular interactions + +Rules: +- Each question must be **fully standalone** (no "this"/"it"). +- Do not change the core intent—preserve the user's informational goal. +- Use appropriate biological terminology and Reactome-relevant concepts. +- Vary the **phrasing**, **focus**, or **biological angle** of each question. +- If the input is ambiguous, infer a biologically meaningful interpretation. + +Output: +Return only the 4 alternative questions separated by newlines. +Do not include any explanations or metadata. + +Original Question: {question}""", +) + + +ExcludedField = SkipJsonSchema[ + Annotated[Any, Field(default=None, exclude=True), AfterValidator(lambda x: None)] +] + def list_chroma_subdirectories(directory: Path) -> list[str]: subdirectories = list( @@ -31,40 +71,144 @@ def create_bm25_chroma_ensemble_retriever( descriptions_info: dict[str, str], field_info: dict[str, list[AttributeInfo]], ) -> MergerRetriever: - retriever_list: list[BaseRetriever] = [] - for subdirectory in list_chroma_subdirectories(embeddings_directory): - # set up BM25 retriever - csv_file_name = subdirectory + ".csv" - reactome_csvs_dir: Path = embeddings_directory / "csv_files" - loader = CSVLoader(file_path=reactome_csvs_dir / csv_file_name) - data = loader.load() - bm25_retriever = BM25Retriever.from_documents( - data, - preprocess_func=lambda text: word_tokenize( - text.casefold(), language="english" - ), - ) - bm25_retriever.k = 10 + return HybridRetriever.from_subdirectory( + llm, + embedding, + embeddings_directory, + descriptions_info=descriptions_info, + field_info=field_info, + include_original=True, + ) - # set up vectorstore SelfQuery retriever - vectordb = Chroma( - persist_directory=str(embeddings_directory / subdirectory), - embedding_function=embedding, - client_settings=chroma_settings, - ) - selfq_retriever = SelfQueryRetriever.from_llm( - llm=llm, - vectorstore=vectordb, - document_contents=descriptions_info[subdirectory], - metadata_field_info=field_info[subdirectory], - search_kwargs={"k": 10}, +class RetrieverDict(TypedDict): + bm25: BM25Retriever + vector: SelfQueryRetriever + + +class HybridRetriever(MultiQueryRetriever, EnsembleRetriever): + retriever: ExcludedField = None + retrievers: ExcludedField = None + _retrievers: dict[str, RetrieverDict] + + @classmethod + def from_subdirectory( + cls, + llm: BaseChatModel, + embedding: Embeddings, + embeddings_directory: Path, + *, + descriptions_info: dict[str, str], + field_info: dict[str, list[AttributeInfo]], + include_original=False, + ): + _retrievers: dict[str, RetrieverDict] = {} + for subdirectory in list_chroma_subdirectories(embeddings_directory): + # set up BM25 retriever + csv_file_name = subdirectory + ".csv" + reactome_csvs_dir: Path = embeddings_directory / "csv_files" + loader = CSVLoader(file_path=reactome_csvs_dir / csv_file_name) + data = loader.load() + bm25_retriever = BM25Retriever.from_documents( + data, + preprocess_func=lambda text: word_tokenize( + text.casefold(), language="english" + ), + ) + bm25_retriever.k = 10 + + # set up vectorstore SelfQuery retriever + vectordb = Chroma( + persist_directory=str(embeddings_directory / subdirectory), + embedding_function=embedding, + client_settings=chroma_settings, + ) + + selfq_retriever = SelfQueryRetriever.from_llm( + llm=llm, + vectorstore=vectordb, + document_contents=descriptions_info[subdirectory], + metadata_field_info=field_info[subdirectory], + search_kwargs={"k": 10}, + ) + + _retrievers[subdirectory] = { + "bm25": bm25_retriever, + "vector": selfq_retriever, + } + llm_chain = ( + super() + .from_llm(bm25_retriever, llm, multi_query_prompt, None, include_original) + .llm_chain ) - rrf_retriever = EnsembleRetriever( - retrievers=[bm25_retriever, selfq_retriever], weights=[0.2, 0.8] + return cls( + _retrievers=_retrievers, + llm_chain=llm_chain, + include_original=include_original, + weights=[0.2] * 5, ) - retriever_list.append(rrf_retriever) - reactome_retriever = MergerRetriever(retrievers=retriever_list) + def retrieve_documents(self, queries: list[str], run_manager) -> list[Document]: + subdirectory_docs: list[Document] = [] + for subdirectory, retrievers in self._retrievers.items(): + bm25_retriever = retrievers["bm25"] + vector_retriever = retrievers["vector"] + doc_lists: list[list[Document]] = [] + for i, query in enumerate(queries): + bm25_docs = bm25_retriever.invoke( + query, + config={ + "callbacks": run_manager.get_child( + tag=f"{subdirectory}-bm25-{i}" + ) + }, + ) + vector_docs = vector_retriever.invoke( + query, + config={ + "callbacks": run_manager.get_child( + tag=f"{subdirectory}-vector-{i}" + ) + }, + ) + doc_lists.append(bm25_docs + vector_docs) + subdirectory_docs.extend(self.weighted_reciprocal_rank(doc_lists)) + return subdirectory_docs - return reactome_retriever + async def aretrieve_documents( + self, queries: list[Document], run_manager + ) -> list[Document]: + subdirectory_results = {} + for subdirectory, retrievers in self._retrievers.items(): + bm25_retriever = retrievers["bm25"] + vector_retriever = retrievers["vector"] + subdirectory_results[subdirectory] = [] + for i, query in enumerate(queries): + bm25_results = bm25_retriever.ainvoke( + query, + config={ + "callbacks": run_manager.get_child( + tag=f"{subdirectory}-bm25-{i}" + ) + }, + ) + vector_results = vector_retriever.ainvoke( + query, + config={ + "callbacks": run_manager.get_child( + tag=f"{subdirectory}-vector-{i}" + ) + }, + ) + subdirectory_results[subdirectory].append( + (bm25_results, vector_results) + ) + subdirectory_docs: list[Document] = [] + for subdir_results in subdirectory_results.values(): + doc_lists: list[list[Document]] = [] + for bm25_results, vector_results in subdir_results: + bm25_docs = await bm25_results + vector_docs = await vector_results + doc_lists.append(bm25_docs + vector_docs) + subdirectory_docs.extend(self.weighted_reciprocal_rank(doc_lists)) + return subdirectory_docs From e80edded1569980c5a4817c8def590fb5f2d0c5d Mon Sep 17 00:00:00 2001 From: Greg Hogue Date: Thu, 22 Jan 2026 16:51:32 -0500 Subject: [PATCH 2/2] fix types --- src/retrievers/csv_chroma.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/retrievers/csv_chroma.py b/src/retrievers/csv_chroma.py index 7ab322e..8f59411 100644 --- a/src/retrievers/csv_chroma.py +++ b/src/retrievers/csv_chroma.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Annotated, Any, TypedDict +from typing import Annotated, Any, Coroutine, TypedDict import chromadb.config from langchain.chains.query_constructor.schema import AttributeInfo @@ -176,9 +176,17 @@ def retrieve_documents(self, queries: list[str], run_manager) -> list[Document]: return subdirectory_docs async def aretrieve_documents( - self, queries: list[Document], run_manager + self, queries: list[str], run_manager ) -> list[Document]: - subdirectory_results = {} + subdirectory_results: dict[ + str, + list[ + tuple[ + Coroutine[Any, Any, list[Document]], + Coroutine[Any, Any, list[Document]], + ] + ], + ] = {} for subdirectory, retrievers in self._retrievers.items(): bm25_retriever = retrievers["bm25"] vector_retriever = retrievers["vector"]