diff --git a/openviking/retrieve/hierarchical_retriever.py b/openviking/retrieve/hierarchical_retriever.py index 62782bf36c..f5937dd0ea 100644 --- a/openviking/retrieve/hierarchical_retriever.py +++ b/openviking/retrieve/hierarchical_retriever.py @@ -13,6 +13,7 @@ import math import time from datetime import datetime +from pathlib import PurePath from typing import Any, Dict, List, Optional, Tuple from openviking.core.retrieval_targets import default_target_directories @@ -21,7 +22,7 @@ from openviking.retrieve.memory_lifecycle import hotness_score from openviking.retrieve.retrieval_stats import get_stats_collector from openviking.server.identity import RequestContext -from openviking.storage import VikingDBManager, VikingDBManagerProxy +from openviking.storage import VikingDBManager, VikingDBManagerProxy, get_viking_fs from openviking.telemetry import get_current_telemetry from openviking.utils.time_utils import parse_iso_datetime from openviking_cli.retrieve.types import ( @@ -49,7 +50,44 @@ class HierarchicalRetriever: DIRECTORY_DOMINANCE_RATIO = 1.2 # Directory score must exceed max child score GLOBAL_SEARCH_TOPK = 10 # Global retrieval count (more candidates = better rerank precision) MAX_PARALLEL_CHILD_SEARCHES = 4 # Limit per-request fan-out against remote vector stores + FILE_RERANK_FALLBACK_MAX_CHARS = 4000 LEVEL_URI_SUFFIX = {0: ".abstract.md", 1: ".overview.md"} + _FILE_RERANK_FALLBACK_SUFFIXES = { + ".txt", + ".md", + ".csv", + ".json", + ".xml", + ".py", + ".js", + ".ts", + ".java", + ".cpp", + ".c", + ".h", + ".go", + ".rs", + ".lua", + ".rb", + ".php", + ".sh", + ".bash", + ".zsh", + ".fish", + ".sql", + ".kt", + ".swift", + ".scala", + ".r", + ".m", + ".pl", + ".toml", + ".yaml", + ".yml", + ".ini", + ".cfg", + ".conf", + } def __init__( self, @@ -176,10 +214,11 @@ async def retrieve( ) # Step 3: Merge starting points - starting_points = self._merge_starting_points( + starting_points = await self._merge_starting_points( query.query, root_uris, global_results, + ctx=ctx, mode=mode, ) @@ -189,9 +228,10 @@ async def retrieve( else: initial_candidates = [r for r in global_results if r.get("level", 2) == 2] - initial_candidates = self._prepare_initial_candidates( + initial_candidates = await self._prepare_initial_candidates( query.query, initial_candidates, + ctx=ctx, mode=mode, ) @@ -212,6 +252,7 @@ async def retrieve( scope_dsl=scope_dsl, initial_candidates=initial_candidates, level=level, + ctx=ctx, ) # Step 6: Convert results @@ -273,33 +314,77 @@ def _rerank_scores( if not self._rerank_client or not documents: return fallback_scores + rerank_documents: List[str] = [] + rerank_indices: List[int] = [] + for index, document in enumerate(documents): + if not str(document).strip(): + continue + rerank_documents.append(document) + rerank_indices.append(index) + + if not rerank_documents: + return fallback_scores + try: - scores = self._rerank_client.rerank_batch(query, documents) + scores = self._rerank_client.rerank_batch(query, rerank_documents) except Exception as e: logger.warning( "[HierarchicalRetriever] Rerank failed, fallback to vector scores: %s", e ) return fallback_scores - if not scores or len(scores) != len(documents): + if not scores or len(scores) != len(rerank_documents): logger.warning( "[HierarchicalRetriever] Invalid rerank result, fallback to vector scores" ) return fallback_scores - normalized_scores: List[float] = [] - for score, fallback in zip(scores, fallback_scores, strict=True): + normalized_scores = list(fallback_scores) + for index, score in zip(rerank_indices, scores, strict=True): if isinstance(score, (int, float)): - normalized_scores.append(float(score)) - else: - normalized_scores.append(fallback) + normalized_scores[index] = float(score) return normalized_scores - def _merge_starting_points( + async def _build_rerank_document( + self, + result: Dict[str, Any], + ctx: RequestContext, + ) -> str: + abstract = str(result.get("abstract") or "").strip() + if abstract: + return abstract + + if result.get("level") != 2: + return "" + + uri = str(result.get("uri") or "").strip() + if not uri: + return "" + + if PurePath(uri).suffix.lower() not in self._FILE_RERANK_FALLBACK_SUFFIXES: + return "" + + try: + viking_fs = getattr(ctx, "viking_fs", None) or get_viking_fs() + content = await viking_fs.read_file(uri, ctx=ctx) + except Exception: + return "" + + return str(content or "")[: self.FILE_RERANK_FALLBACK_MAX_CHARS] + + async def _build_rerank_documents( + self, + results: List[Dict[str, Any]], + ctx: RequestContext, + ) -> List[str]: + return await asyncio.gather(*(self._build_rerank_document(result, ctx) for result in results)) + + async def _merge_starting_points( self, query: str, root_uris: List[str], global_results: List[Dict[str, Any]], + ctx: RequestContext, mode: str = "thinking", ) -> List[Tuple[str, float]]: """Merge starting points. @@ -316,7 +401,7 @@ def _merge_starting_points( s if math.isfinite(s) else 0.0 for s in (r.get("_score", 0.0) for r in global_results) ] if self._rerank_client and mode == RetrieverMode.THINKING: - docs = [str(r.get("abstract", "")) for r in global_results] + docs = await self._build_rerank_documents(global_results, ctx) query_scores = self._rerank_scores(query, docs, default_scores) for i, r in enumerate(global_results): # 只添加非 level 2 的项目到起始点 @@ -338,10 +423,11 @@ def _merge_starting_points( return points - def _prepare_initial_candidates( + async def _prepare_initial_candidates( self, query: str, global_results: List[Dict[str, Any]], + ctx: RequestContext, mode: str = RetrieverMode.THINKING, ) -> List[Dict[str, Any]]: """Preserve rerank scores for global hits added to the result pool.""" @@ -354,7 +440,7 @@ def _prepare_initial_candidates( for s in (r.get("_score", 0.0) for r in initial_candidates) ] if self._rerank_client and mode == RetrieverMode.THINKING: - docs = [str(r.get("abstract", "")) for r in initial_candidates] + docs = await self._build_rerank_documents(initial_candidates, ctx) query_scores = self._rerank_scores(query, docs, default_scores) else: query_scores = default_scores @@ -380,6 +466,7 @@ async def _recursive_search( scope_dsl: Optional[Dict[str, Any]] = None, initial_candidates: Optional[List[Dict[str, Any]]] = None, level: Optional[List[int]] = None, + ctx: Optional[RequestContext] = None, ) -> List[Dict[str, Any]]: """ Recursive search with directory priority return and score propagation. @@ -477,8 +564,8 @@ async def search_children(current_uri: str) -> List[Dict[str, Any]]: query_scores = [ s if math.isfinite(s) else 0.0 for s in (r.get("_score", 0.0) for r in results) ] - if self._rerank_client and mode == RetrieverMode.THINKING: - documents = [str(r.get("abstract", "")) for r in results] + if self._rerank_client and mode == RetrieverMode.THINKING and ctx is not None: + documents = await self._build_rerank_documents(results, ctx) query_scores = self._rerank_scores(query, documents, query_scores) for r, score in zip(results, query_scores, strict=True): diff --git a/tests/retrieve/test_hierarchical_retriever_rerank.py b/tests/retrieve/test_hierarchical_retriever_rerank.py index 1c5ab89de8..5c57615ff3 100644 --- a/tests/retrieve/test_hierarchical_retriever_rerank.py +++ b/tests/retrieve/test_hierarchical_retriever_rerank.py @@ -188,6 +188,73 @@ async def search_children_in_tenant( return [] +class EmptyAbstractLevelTwoGlobalStorage(DummyStorage): + async def search_global_roots_in_tenant( + self, + ctx, + query_vector=None, + sparse_query_vector=None, + context_type=None, + target_directories=None, + extra_filter=None, + limit: int = 10, + ): + self.global_search_calls.append( + { + "ctx": ctx, + "query_vector": query_vector, + "sparse_query_vector": sparse_query_vector, + "context_type": context_type, + "target_directories": target_directories, + "extra_filter": extra_filter, + "limit": limit, + } + ) + return [ + { + "uri": "viking://resources/file-a.py", + "abstract": "", + "_score": 0.2, + "level": 2, + "context_type": "resource", + "category": "doc", + }, + { + "uri": "viking://resources/file-b.py", + "abstract": "child B", + "_score": 0.8, + "level": 2, + "context_type": "resource", + "category": "doc", + }, + ] + + async def search_children_in_tenant( + self, + ctx, + parent_uri: str, + query_vector=None, + sparse_query_vector=None, + context_type=None, + target_directories=None, + extra_filter=None, + limit: int = 10, + ): + self.child_search_calls.append( + { + "ctx": ctx, + "parent_uri": parent_uri, + "query_vector": query_vector, + "sparse_query_vector": sparse_query_vector, + "context_type": context_type, + "target_directories": target_directories, + "extra_filter": extra_filter, + "limit": limit, + } + ) + return [] + + class DirectChildProxy: async def search_children_in_tenant( self, @@ -231,6 +298,16 @@ def rerank_batch(self, query: str, documents: list[str]): return list(self.scores[start:end]) +class FakeVikingFS: + def __init__(self, files): + self.files = files + self.calls = [] + + async def read_file(self, uri: str, ctx=None): + self.calls.append((uri, ctx)) + return self.files[uri] + + def _ctx() -> RequestContext: return RequestContext(user=UserIdentifier("acc1", "user1"), role=Role.USER) @@ -260,7 +337,8 @@ def test_retriever_initializes_rerank_client(monkeypatch): assert retriever._rerank_client is fake_client -def test_merge_starting_points_prefers_rerank_scores_in_thinking_mode(monkeypatch): +@pytest.mark.asyncio +async def test_merge_starting_points_prefers_rerank_scores_in_thinking_mode(monkeypatch): fake_client = FakeRerankClient([0.95, 0.05]) monkeypatch.setattr( "openviking.retrieve.hierarchical_retriever.RerankClient.from_config", @@ -273,7 +351,7 @@ def test_merge_starting_points_prefers_rerank_scores_in_thinking_mode(monkeypatc rerank_config=_config(), ) - starting_points = retriever._merge_starting_points( + starting_points = await retriever._merge_starting_points( "hello", ["viking://resources"], [ @@ -290,6 +368,7 @@ def test_merge_starting_points_prefers_rerank_scores_in_thinking_mode(monkeypatc "level": 1, }, ], + ctx=_ctx(), mode=RetrieverMode.THINKING, ) @@ -347,6 +426,68 @@ async def test_retrieve_reranks_level_two_initial_candidates_in_thinking_mode(mo assert fake_client.calls == [("hello", ["child A", "child B"])] +@pytest.mark.asyncio +async def test_retrieve_uses_file_content_fallback_for_empty_level_two_abstract(monkeypatch): + class RejectEmptyRerankClient(FakeRerankClient): + def rerank_batch(self, query: str, documents: list[str]): + self.calls.append((query, list(documents))) + if any(not document.strip() for document in documents): + raise ValueError("empty rerank document") + start = self._cursor + end = start + len(documents) + self._cursor = end + return list(self.scores[start:end]) + + fake_client = RejectEmptyRerankClient([0.95, 0.11]) + fake_fs = FakeVikingFS({"viking://resources/file-a.py": "fallback file contents"}) + monkeypatch.setattr( + "openviking.retrieve.hierarchical_retriever.RerankClient.from_config", + lambda config: fake_client, + ) + + retriever = HierarchicalRetriever( + storage=EmptyAbstractLevelTwoGlobalStorage(), + embedder=DummyEmbedder(), + rerank_config=_config(), + ) + ctx = _ctx() + ctx.viking_fs = fake_fs + + result = await retriever.retrieve(_query(), ctx=ctx, limit=2, mode=RetrieverMode.THINKING) + + assert [matched_ctx.uri for matched_ctx in result.matched_contexts] == [ + "viking://resources/file-a.py", + "viking://resources/file-b.py", + ] + assert fake_client.calls == [ + ("hello", ["fallback file contents", "child B"]), + ] + assert fake_fs.calls == [("viking://resources/file-a.py", ctx)] + + +def test_rerank_scores_filters_empty_documents_before_rerank(monkeypatch): + fake_client = FakeRerankClient([0.9]) + monkeypatch.setattr( + "openviking.retrieve.hierarchical_retriever.RerankClient.from_config", + lambda config: fake_client, + ) + + retriever = HierarchicalRetriever( + storage=DummyStorage(), + embedder=DummyEmbedder(), + rerank_config=_config(), + ) + + scores = retriever._rerank_scores( + "hello", + ["", " ", "document body"], + [0.1, 0.2, 0.3], + ) + + assert scores == [0.1, 0.2, 0.9] + assert fake_client.calls == [("hello", ["document body"])] + + @pytest.mark.asyncio async def test_retrieve_falls_back_to_vector_scores_when_rerank_returns_none(monkeypatch): class NoneRerankClient(FakeRerankClient):