Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 103 additions & 16 deletions openviking/retrieve/hierarchical_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import math
import time
from datetime import datetime
from pathlib import PurePath
from typing import Any, Dict, List, Optional, Tuple

from openviking.core.retrieval_targets import default_target_directories
Expand All @@ -21,7 +22,7 @@
from openviking.retrieve.memory_lifecycle import hotness_score
from openviking.retrieve.retrieval_stats import get_stats_collector
from openviking.server.identity import RequestContext
from openviking.storage import VikingDBManager, VikingDBManagerProxy
from openviking.storage import VikingDBManager, VikingDBManagerProxy, get_viking_fs
from openviking.telemetry import get_current_telemetry
from openviking.utils.time_utils import parse_iso_datetime
from openviking_cli.retrieve.types import (
Expand Down Expand Up @@ -49,7 +50,44 @@ class HierarchicalRetriever:
DIRECTORY_DOMINANCE_RATIO = 1.2 # Directory score must exceed max child score
GLOBAL_SEARCH_TOPK = 10 # Global retrieval count (more candidates = better rerank precision)
MAX_PARALLEL_CHILD_SEARCHES = 4 # Limit per-request fan-out against remote vector stores
FILE_RERANK_FALLBACK_MAX_CHARS = 4000
LEVEL_URI_SUFFIX = {0: ".abstract.md", 1: ".overview.md"}
_FILE_RERANK_FALLBACK_SUFFIXES = {
".txt",
".md",
".csv",
".json",
".xml",
".py",
".js",
".ts",
".java",
".cpp",
".c",
".h",
".go",
".rs",
".lua",
".rb",
".php",
".sh",
".bash",
".zsh",
".fish",
".sql",
".kt",
".swift",
".scala",
".r",
".m",
".pl",
".toml",
".yaml",
".yml",
".ini",
".cfg",
".conf",
}

def __init__(
self,
Expand Down Expand Up @@ -176,10 +214,11 @@ async def retrieve(
)

# Step 3: Merge starting points
starting_points = self._merge_starting_points(
starting_points = await self._merge_starting_points(
query.query,
root_uris,
global_results,
ctx=ctx,
mode=mode,
)

Expand All @@ -189,9 +228,10 @@ async def retrieve(
else:
initial_candidates = [r for r in global_results if r.get("level", 2) == 2]

initial_candidates = self._prepare_initial_candidates(
initial_candidates = await self._prepare_initial_candidates(
query.query,
initial_candidates,
ctx=ctx,
mode=mode,
)

Expand All @@ -212,6 +252,7 @@ async def retrieve(
scope_dsl=scope_dsl,
initial_candidates=initial_candidates,
level=level,
ctx=ctx,
)

# Step 6: Convert results
Expand Down Expand Up @@ -273,33 +314,77 @@ def _rerank_scores(
if not self._rerank_client or not documents:
return fallback_scores

rerank_documents: List[str] = []
rerank_indices: List[int] = []
for index, document in enumerate(documents):
if not str(document).strip():
continue
rerank_documents.append(document)
rerank_indices.append(index)

if not rerank_documents:
return fallback_scores

try:
scores = self._rerank_client.rerank_batch(query, documents)
scores = self._rerank_client.rerank_batch(query, rerank_documents)
except Exception as e:
logger.warning(
"[HierarchicalRetriever] Rerank failed, fallback to vector scores: %s", e
)
return fallback_scores

if not scores or len(scores) != len(documents):
if not scores or len(scores) != len(rerank_documents):
logger.warning(
"[HierarchicalRetriever] Invalid rerank result, fallback to vector scores"
)
return fallback_scores

normalized_scores: List[float] = []
for score, fallback in zip(scores, fallback_scores, strict=True):
normalized_scores = list(fallback_scores)
for index, score in zip(rerank_indices, scores, strict=True):
if isinstance(score, (int, float)):
normalized_scores.append(float(score))
else:
normalized_scores.append(fallback)
normalized_scores[index] = float(score)
return normalized_scores

def _merge_starting_points(
async def _build_rerank_document(
self,
result: Dict[str, Any],
ctx: RequestContext,
) -> str:
abstract = str(result.get("abstract") or "").strip()
if abstract:
return abstract

if result.get("level") != 2:
return ""

uri = str(result.get("uri") or "").strip()
if not uri:
return ""

if PurePath(uri).suffix.lower() not in self._FILE_RERANK_FALLBACK_SUFFIXES:
return ""

try:
viking_fs = getattr(ctx, "viking_fs", None) or get_viking_fs()
content = await viking_fs.read_file(uri, ctx=ctx)
except Exception:
return ""

return str(content or "")[: self.FILE_RERANK_FALLBACK_MAX_CHARS]

async def _build_rerank_documents(
self,
results: List[Dict[str, Any]],
ctx: RequestContext,
) -> List[str]:
return await asyncio.gather(*(self._build_rerank_document(result, ctx) for result in results))

async def _merge_starting_points(
self,
query: str,
root_uris: List[str],
global_results: List[Dict[str, Any]],
ctx: RequestContext,
mode: str = "thinking",
) -> List[Tuple[str, float]]:
"""Merge starting points.
Expand All @@ -316,7 +401,7 @@ def _merge_starting_points(
s if math.isfinite(s) else 0.0 for s in (r.get("_score", 0.0) for r in global_results)
]
if self._rerank_client and mode == RetrieverMode.THINKING:
docs = [str(r.get("abstract", "")) for r in global_results]
docs = await self._build_rerank_documents(global_results, ctx)
query_scores = self._rerank_scores(query, docs, default_scores)
for i, r in enumerate(global_results):
# 只添加非 level 2 的项目到起始点
Expand All @@ -338,10 +423,11 @@ def _merge_starting_points(

return points

def _prepare_initial_candidates(
async def _prepare_initial_candidates(
self,
query: str,
global_results: List[Dict[str, Any]],
ctx: RequestContext,
mode: str = RetrieverMode.THINKING,
) -> List[Dict[str, Any]]:
"""Preserve rerank scores for global hits added to the result pool."""
Expand All @@ -354,7 +440,7 @@ def _prepare_initial_candidates(
for s in (r.get("_score", 0.0) for r in initial_candidates)
]
if self._rerank_client and mode == RetrieverMode.THINKING:
docs = [str(r.get("abstract", "")) for r in initial_candidates]
docs = await self._build_rerank_documents(initial_candidates, ctx)
query_scores = self._rerank_scores(query, docs, default_scores)
else:
query_scores = default_scores
Expand All @@ -380,6 +466,7 @@ async def _recursive_search(
scope_dsl: Optional[Dict[str, Any]] = None,
initial_candidates: Optional[List[Dict[str, Any]]] = None,
level: Optional[List[int]] = None,
ctx: Optional[RequestContext] = None,
) -> List[Dict[str, Any]]:
"""
Recursive search with directory priority return and score propagation.
Expand Down Expand Up @@ -477,8 +564,8 @@ async def search_children(current_uri: str) -> List[Dict[str, Any]]:
query_scores = [
s if math.isfinite(s) else 0.0 for s in (r.get("_score", 0.0) for r in results)
]
if self._rerank_client and mode == RetrieverMode.THINKING:
documents = [str(r.get("abstract", "")) for r in results]
if self._rerank_client and mode == RetrieverMode.THINKING and ctx is not None:
documents = await self._build_rerank_documents(results, ctx)
query_scores = self._rerank_scores(query, documents, query_scores)

for r, score in zip(results, query_scores, strict=True):
Expand Down
Loading