diff --git a/openviking/async_client.py b/openviking/async_client.py index f351a6cc..1ba053d2 100644 --- a/openviking/async_client.py +++ b/openviking/async_client.py @@ -58,11 +58,13 @@ def __init__( self.user = UserIdentifier.the_default_user() self._initialized = False - self._singleton_initialized = True + # Mark initialized only after LocalClient is successfully constructed. + self._singleton_initialized = False self._client: BaseClient = LocalClient( path=path, ) + self._singleton_initialized = True # ============= Lifecycle methods ============= @@ -78,7 +80,9 @@ async def _ensure_initialized(self): async def close(self) -> None: """Close OpenViking and release resources.""" - await self._client.close() + client = getattr(self, "_client", None) + if client is not None: + await client.close() self._initialized = False self._singleton_initialized = False @@ -88,8 +92,6 @@ async def reset(cls) -> None: with cls._lock: if cls._instance is not None: await cls._instance.close() - cls._instance._initialized = False - cls._instance._singleton_initialized = False cls._instance = None # ============= Session methods ============= diff --git a/openviking/core/directories.py b/openviking/core/directories.py index 2033a29e..080cf7ca 100644 --- a/openviking/core/directories.py +++ b/openviking/core/directories.py @@ -228,13 +228,9 @@ async def _ensure_directory( logger.debug(f"[VikingFS] Directory {uri} already exists") # 2. Ensure record exists in vector storage - from openviking_cli.utils.config import get_openviking_config - - config = get_openviking_config() - - existing = await self.vikingdb.filter( - collection=config.storage.vectordb.name, - filter={"op": "must", "field": "uri", "conds": [uri]}, + existing = await self.vikingdb.get_context_by_uri( + account_id=ctx.account_id, + uri=uri, limit=1, ) if not existing: diff --git a/openviking/eval/ragas/playback.py b/openviking/eval/ragas/playback.py index 6ec9e10c..3a9b75e7 100644 --- a/openviking/eval/ragas/playback.py +++ b/openviking/eval/ragas/playback.py @@ -415,29 +415,87 @@ async def _play_vikingdb_operation(self, record: IORecord) -> PlaybackResult: kwargs = request.get("kwargs", {}) if operation == "insert": - await self._vector_store.insert(*args, **kwargs) + if args: + payload = args[-1] + else: + payload = kwargs.get("data", request.get("data", {})) + await self._vector_store.upsert(payload) elif operation == "update": - await self._vector_store.update(*args, **kwargs) + if len(args) >= 3: + record_id = args[-2] + payload = args[-1] + elif len(args) == 2: + record_id = args[0] + payload = args[1] + else: + record_id = kwargs.get("id", request.get("id")) + payload = kwargs.get("data", request.get("data", {})) + existing = await self._vector_store.get([record_id]) + if existing: + merged = {**existing[0], **payload, "id": record_id} + await self._vector_store.upsert(merged) elif operation == "upsert": - await self._vector_store.upsert(*args, **kwargs) + if args: + payload = args[-1] + else: + payload = kwargs.get("data", request.get("data", {})) + await self._vector_store.upsert(payload) elif operation == "delete": - await self._vector_store.delete(*args, **kwargs) + if args: + ids = args[-1] + else: + ids = kwargs.get("ids", request.get("ids", [])) + await self._vector_store.delete(ids) elif operation == "get": - await self._vector_store.get(*args, **kwargs) + if args: + ids = args[-1] + else: + ids = kwargs.get("ids", request.get("ids", [])) + await self._vector_store.get(ids) elif operation == "exists": - await self._vector_store.exists(*args, **kwargs) + if len(args) >= 2: + record_id = args[-1] + elif len(args) == 1: + record_id = args[0] + else: + record_id = kwargs.get("id", request.get("id")) + await self._vector_store.exists(record_id) elif operation == "search": - await self._vector_store.search(*args, **kwargs) + if len(args) >= 4: + query_vector = args[1] + limit = args[2] + where = args[3] + elif args: + query_vector = args[0] + limit = kwargs.get("top_k", kwargs.get("limit", 10)) + where = kwargs.get("filter") + else: + query_vector = kwargs.get("vector", kwargs.get("query_vector")) + limit = kwargs.get("top_k", kwargs.get("limit", request.get("top_k", 10))) + where = kwargs.get("filter", request.get("filter")) + await self._vector_store.search( + query_vector=query_vector, filter=where, limit=limit + ) elif operation == "filter": - await self._vector_store.filter(*args, **kwargs) + if len(args) >= 4: + where = args[1] + limit = args[2] + offset = args[3] + elif args: + where = args[0] + limit = kwargs.get("limit", 100) + offset = kwargs.get("offset", 0) + else: + where = kwargs.get("filter", request.get("filter", {})) + limit = kwargs.get("limit", request.get("limit", 100)) + offset = kwargs.get("offset", request.get("offset", 0)) + await self._vector_store.filter(filter=where, limit=limit, offset=offset) elif operation == "create_collection": await self._vector_store.create_collection(*args, **kwargs) elif operation == "drop_collection": - await self._vector_store.drop_collection(*args, **kwargs) + await self._vector_store.drop_collection() elif operation == "collection_exists": - await self._vector_store.collection_exists(*args, **kwargs) - elif operation == "list_collections": - await self._vector_store.list_collections(*args, **kwargs) + await self._vector_store.collection_exists() else: raise ValueError(f"Unknown VikingDB operation: {operation}") diff --git a/openviking/eval/recorder/wrapper.py b/openviking/eval/recorder/wrapper.py index e8fb2686..685bfb3b 100644 --- a/openviking/eval/recorder/wrapper.py +++ b/openviking/eval/recorder/wrapper.py @@ -105,12 +105,35 @@ def __getattr__(self, name: str) -> Any: if not callable(original_attr) or name.startswith("_"): return original_attr # viking_fs文件操作 - if name not in ("ls", "mkdir", "stat", "rm", "mv", "read", "write", "grep", "glob", "tree", - "abstract", "overview", "relations", "link", "unlink", - "write_file", "read_file", "read_file_bytes", "write_file_bytes", "append_file", "move_file", - "delete_temp", "write_context", "get_relations", "get_relations_with_content", - "find", "search", - ): + if name not in ( + "ls", + "mkdir", + "stat", + "rm", + "mv", + "read", + "write", + "grep", + "glob", + "tree", + "abstract", + "overview", + "relations", + "link", + "unlink", + "write_file", + "read_file", + "read_file_bytes", + "write_file_bytes", + "append_file", + "move_file", + "delete_temp", + "write_context", + "get_relations", + "get_relations_with_content", + "find", + "search", + ): return original_attr async def wrapped_async(*args, **kwargs): @@ -179,6 +202,7 @@ def wrapped_sync(*args, **kwargs): raise import inspect + if inspect.iscoroutinefunction(original_attr) or name.startswith("_"): return wrapped_async @@ -201,6 +225,7 @@ def _build_request(self, name: str, args: tuple, kwargs: dict) -> Dict[str, Any] param_names = [] try: import inspect + original_attr = getattr(self._fs, name, None) if original_attr and callable(original_attr): sig = inspect.signature(original_attr) @@ -223,7 +248,7 @@ def _build_request(self, name: str, args: tuple, kwargs: dict) -> Dict[str, Any] class RecordingVikingDB: """ - Wrapper for VikingDBInterface that records all operations. + Wrapper for vector store instances that records all operations. Usage: from openviking.eval.recorder import init_recorder @@ -239,7 +264,7 @@ def __init__(self, viking_db: Any, recorder: Optional[IORecorder] = None): Initialize wrapper. Args: - viking_db: VikingDBInterface instance to wrap + viking_db: Vector store instance to wrap recorder: IORecorder instance (uses global if None) """ self._db = viking_db @@ -269,7 +294,7 @@ async def insert(self, collection: str, data: Dict[str, Any]) -> str: request = {"collection": collection, "data": data} start_time = time.time() try: - result = await self._db.insert(collection, data) + result = await self._db.upsert(data) latency_ms = (time.time() - start_time) * 1000 self._record("insert", request, result, latency_ms) return result @@ -283,7 +308,12 @@ async def update(self, collection: str, id: str, data: Dict[str, Any]) -> bool: request = {"collection": collection, "id": id, "data": data} start_time = time.time() try: - result = await self._db.update(collection, id, data) + existing = await self._db.get([id]) + if not existing: + result = False + else: + payload = {**existing[0], **data, "id": id} + result = bool(await self._db.upsert(payload)) latency_ms = (time.time() - start_time) * 1000 self._record("update", request, result, latency_ms) return result @@ -297,7 +327,7 @@ async def upsert(self, collection: str, data: Dict[str, Any]) -> str: request = {"collection": collection, "data": data} start_time = time.time() try: - result = await self._db.upsert(collection, data) + result = await self._db.upsert(data) latency_ms = (time.time() - start_time) * 1000 self._record("upsert", request, result, latency_ms) return result @@ -311,7 +341,7 @@ async def delete(self, collection: str, ids: List[str]) -> int: request = {"collection": collection, "ids": ids} start_time = time.time() try: - result = await self._db.delete(collection, ids) + result = await self._db.delete(ids) latency_ms = (time.time() - start_time) * 1000 self._record("delete", request, result, latency_ms) return result @@ -325,7 +355,7 @@ async def get(self, collection: str, ids: List[str]) -> List[Dict[str, Any]]: request = {"collection": collection, "ids": ids} start_time = time.time() try: - result = await self._db.get(collection, ids) + result = await self._db.get(ids) latency_ms = (time.time() - start_time) * 1000 self._record("get", request, result, latency_ms) return result @@ -339,7 +369,7 @@ async def exists(self, collection: str, id: str) -> bool: request = {"collection": collection, "id": id} start_time = time.time() try: - result = await self._db.exists(collection, id) + result = await self._db.exists(id) latency_ms = (time.time() - start_time) * 1000 self._record("exists", request, result, latency_ms) return result @@ -359,7 +389,11 @@ async def search( request = {"collection": collection, "vector": vector, "top_k": top_k, "filter": filter} start_time = time.time() try: - result = await self._db.search(collection, vector, top_k, filter) + result = await self._db.search( + query_vector=vector, + filter=filter, + limit=top_k, + ) latency_ms = (time.time() - start_time) * 1000 self._record("search", request, result, latency_ms) return result @@ -379,7 +413,11 @@ async def filter( request = {"collection": collection, "filter": filter, "limit": limit, "offset": offset} start_time = time.time() try: - result = await self._db.filter(collection, filter, limit, offset) + result = await self._db.filter( + filter=filter, + limit=limit, + offset=offset, + ) latency_ms = (time.time() - start_time) * 1000 self._record("filter", request, result, latency_ms) return result @@ -402,12 +440,12 @@ async def create_collection(self, name: str, schema: Dict[str, Any]) -> bool: self._record("create_collection", request, None, latency_ms, False, str(e)) raise - async def drop_collection(self, name: str) -> bool: + async def drop_collection(self) -> bool: """Drop collection with recording.""" - request = {"name": name} + request = {} start_time = time.time() try: - result = await self._db.drop_collection(name) + result = await self._db.drop_collection() latency_ms = (time.time() - start_time) * 1000 self._record("drop_collection", request, result, latency_ms) return result @@ -416,12 +454,12 @@ async def drop_collection(self, name: str) -> bool: self._record("drop_collection", request, None, latency_ms, False, str(e)) raise - async def collection_exists(self, name: str) -> bool: + async def collection_exists(self) -> bool: """Check collection exists with recording.""" - request = {"name": name} + request = {} start_time = time.time() try: - result = await self._db.collection_exists(name) + result = await self._db.collection_exists() latency_ms = (time.time() - start_time) * 1000 self._record("collection_exists", request, result, latency_ms) return result @@ -430,21 +468,6 @@ async def collection_exists(self, name: str) -> bool: self._record("collection_exists", request, None, latency_ms, False, str(e)) raise - async def list_collections(self) -> List[str]: - """List collections with recording.""" - request = {} - start_time = time.time() - try: - result = await self._db.list_collections() - latency_ms = (time.time() - start_time) * 1000 - self._record("list_collections", request, result, latency_ms) - return result - except Exception as e: - latency_ms = (time.time() - start_time) * 1000 - self._record("list_collections", request, None, latency_ms, False, str(e)) - raise - def __getattr__(self, name: str) -> Any: """Pass through any other attributes to the wrapped db.""" return getattr(self._db, name) - diff --git a/openviking/retrieve/hierarchical_retriever.py b/openviking/retrieve/hierarchical_retriever.py index 7f92e1df..d342c3d9 100644 --- a/openviking/retrieve/hierarchical_retriever.py +++ b/openviking/retrieve/hierarchical_retriever.py @@ -14,7 +14,7 @@ from openviking.models.embedder.base import EmbedResult from openviking.retrieve.memory_lifecycle import hotness_score from openviking.server.identity import RequestContext, Role -from openviking.storage import VikingDBInterface +from openviking.storage import VikingVectorIndexBackend from openviking.storage.viking_fs import get_viking_fs from openviking_cli.retrieve.types import ( ContextType, @@ -46,18 +46,18 @@ class HierarchicalRetriever: def __init__( self, - storage: VikingDBInterface, + storage: VikingVectorIndexBackend, embedder: Optional[Any], rerank_config: Optional[RerankConfig] = None, ): """Initialize hierarchical retriever with rerank_config. Args: - storage: VikingDBInterface instance + storage: VikingVectorIndexBackend instance embedder: Embedder instance (supports dense/sparse/hybrid) rerank_config: Rerank configuration (optional, will fallback to vector search only) """ - self.storage = storage + self.vector_store = storage self.embedder = embedder self.rerank_config = rerank_config @@ -85,7 +85,7 @@ async def retrieve( mode: RetrieverMode = RetrieverMode.THINKING, score_threshold: Optional[float] = None, score_gte: bool = False, - metadata_filter: Optional[Dict[str, Any]] = None, + scope_dsl: Optional[Dict[str, Any]] = None, ) -> QueryResult: """ Execute hierarchical retrieval. @@ -95,44 +95,19 @@ async def retrieve( score_threshold: Custom score threshold (overrides config) score_gte: True uses >=, False uses > grep_patterns: Keyword match pattern list - metadata_filter: Additional metadata filter conditions + scope_dsl: Additional scope constraints passed from public find/search filter """ # Use custom threshold or default threshold effective_threshold = score_threshold if score_threshold is not None else self.threshold - collection = "context" - target_dirs = [d for d in (query.target_directories or []) if d] - # Create context_type filter (skip when context_type is None = search all types) - filters_to_merge = [] - if query.context_type is not None: - type_filter = { - "op": "must", - "field": "context_type", - "conds": [query.context_type.value], - } - filters_to_merge.append(type_filter) - tenant_filter = self._build_tenant_filter(ctx, context_type=query.context_type) - if tenant_filter: - filters_to_merge.append(tenant_filter) - if target_dirs: - target_filter = { - "op": "or", - "conds": [ - {"op": "must", "field": "uri", "conds": [target_dir]} - for target_dir in target_dirs - ], - } - filters_to_merge.append(target_filter) - if metadata_filter: - filters_to_merge.append(metadata_filter) - - final_metadata_filter = {"op": "and", "conds": filters_to_merge} - - if not await self.storage.collection_exists(collection): - logger.warning(f"[RecursiveSearch] Collection {collection} does not exist") + if not await self.vector_store.collection_exists_bound(): + logger.warning( + "[RecursiveSearch] Collection %s does not exist", + self.vector_store.collection_name, + ) return QueryResult( query=query, matched_contexts=[], @@ -155,11 +130,13 @@ async def retrieve( # Step 2: Global vector search to supplement starting points global_results = await self._global_vector_search( - collection=collection, + ctx=ctx, query_vector=query_vector, sparse_query_vector=sparse_query_vector, + context_type=query.context_type.value if query.context_type else None, + target_dirs=target_dirs, + scope_dsl=scope_dsl, limit=self.GLOBAL_SEARCH_TOPK, - filter=final_metadata_filter, ) # Step 3: Merge starting points @@ -168,7 +145,7 @@ async def retrieve( # Step 4: Recursive search candidates = await self._recursive_search( query=query.query, - collection=collection, + ctx=ctx, query_vector=query_vector, sparse_query_vector=sparse_query_vector, starting_points=starting_points, @@ -176,7 +153,9 @@ async def retrieve( mode=mode, threshold=effective_threshold, score_gte=score_gte, - metadata_filter=final_metadata_filter, + context_type=query.context_type.value if query.context_type else None, + target_dirs=target_dirs, + scope_dsl=scope_dsl, ) # Step 6: Convert results @@ -188,56 +167,24 @@ async def retrieve( searched_directories=root_uris, ) - def _build_tenant_filter( - self, ctx: RequestContext, context_type: Optional[ContextType] = None - ) -> Optional[Dict[str, Any]]: - """Build tenant visibility filter by role. - - Args: - ctx: Request context with role and user info. - context_type: When RESOURCE, allow owner_space="" so shared - resources are visible to USER role. - """ - if ctx.role == Role.ROOT: - return None - - owner_spaces = [ctx.user.user_space_name(), ctx.user.agent_space_name()] - if context_type == ContextType.RESOURCE: - owner_spaces.append("") - return { - "op": "and", - "conds": [ - {"op": "must", "field": "account_id", "conds": [ctx.account_id]}, - { - "op": "must", - "field": "owner_space", - "conds": owner_spaces, - }, - ], - } - async def _global_vector_search( self, - collection: str, + ctx: RequestContext, query_vector: Optional[List[float]], sparse_query_vector: Optional[Dict[str, float]], + context_type: Optional[str], + target_dirs: List[str], + scope_dsl: Optional[Dict[str, Any]], limit: int, - filter: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """Global vector search to locate initial directories.""" - if not query_vector: - return [] - sparse_query_vector = sparse_query_vector or {} - - global_filter = { - "op": "and", - "conds": [filter, {"op": "must", "field": "level", "conds": [0, 1]}], - } - results = await self.storage.search( - collection=collection, + results = await self.vector_store.search_global_roots_in_tenant( + ctx=ctx, query_vector=query_vector, sparse_query_vector=sparse_query_vector, - filter=global_filter, + context_type=context_type, + target_directories=target_dirs, + extra_filter=scope_dsl, limit=limit, ) return results @@ -283,7 +230,7 @@ def _merge_starting_points( async def _recursive_search( self, query: str, - collection: str, + ctx: RequestContext, query_vector: Optional[List[float]], sparse_query_vector: Optional[Dict[str, float]], starting_points: List[Tuple[str, float]], @@ -291,7 +238,9 @@ async def _recursive_search( mode: str, threshold: Optional[float] = None, score_gte: bool = False, - metadata_filter: Optional[Dict[str, Any]] = None, + context_type: Optional[str] = None, + target_dirs: Optional[List[str]] = None, + scope_dsl: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """ Recursive search with directory priority return and score propagation. @@ -300,7 +249,7 @@ async def _recursive_search( threshold: Score threshold score_gte: True uses >=, False uses > grep_patterns: Keyword match patterns - metadata_filter: Additional metadata filter conditions + scope_dsl: Additional scope constraints from public find/search filter """ # Use passed threshold or default threshold effective_threshold = threshold if threshold is not None else self.threshold @@ -311,12 +260,6 @@ def passes_threshold(score: float) -> bool: return score >= effective_threshold return score > effective_threshold - def merge_filter(base_filter: Dict, extra_filter: Optional[Dict]) -> Dict: - """Merge filter conditions.""" - if not extra_filter: - return base_filter - return {"op": "and", "conds": [base_filter, extra_filter]} - sparse_query_vector = sparse_query_vector or None collected: List[Dict[str, Any]] = [] # Collected results (directories and leaves) @@ -341,13 +284,14 @@ def merge_filter(base_filter: Dict, extra_filter: Optional[Dict]) -> Dict: pre_filter_limit = max(limit * 2, 20) - results = await self.storage.search( - collection=collection, + results = await self.vector_store.search_children_in_tenant( + ctx=ctx, + parent_uri=current_uri, query_vector=query_vector, sparse_query_vector=sparse_query_vector, # Pass sparse vector - filter=merge_filter( - {"op": "must", "field": "parent_uri", "conds": [current_uri]}, metadata_filter - ), + context_type=context_type, + target_directories=target_dirs, + extra_filter=scope_dsl, limit=pre_filter_limit, ) diff --git a/openviking/server/routers/admin.py b/openviking/server/routers/admin.py index 681dbe5c..b768fe69 100644 --- a/openviking/server/routers/admin.py +++ b/openviking/server/routers/admin.py @@ -120,12 +120,7 @@ async def delete_account( try: storage = viking_fs._get_vector_store() if storage: - account_filter = { - "op": "must", - "field": "account_id", - "conds": [account_id], - } - deleted = await storage.batch_delete("context", account_filter) + deleted = await storage.delete_account_data(account_id) logger.info(f"VectorDB cascade delete for account {account_id}: {deleted} records") except Exception as e: logger.warning(f"VectorDB cleanup for account {account_id}: {e}") diff --git a/openviking/session/compressor.py b/openviking/session/compressor.py index 4007ecc0..5f1b6c8f 100644 --- a/openviking/session/compressor.py +++ b/openviking/session/compressor.py @@ -113,7 +113,7 @@ async def _delete_existing_memory( try: # rm() already syncs vector deletion in most cases; keep this as a safe fallback. - await self.vikingdb.remove_by_uri("context", memory.uri) + await self.vikingdb.delete_uris(ctx, [memory.uri]) except Exception as e: logger.warning(f"Failed to remove vector record for {memory.uri}: {e}") return True diff --git a/openviking/session/memory_deduplicator.py b/openviking/session/memory_deduplicator.py index 67d8ea85..cb038c19 100644 --- a/openviking/session/memory_deduplicator.py +++ b/openviking/session/memory_deduplicator.py @@ -83,7 +83,7 @@ def __init__( ): """Initialize deduplicator.""" self.vikingdb = vikingdb - self.embedder = vikingdb.get_embedder() + self.embedder = self.vikingdb.get_embedder() async def deduplicate( self, @@ -127,42 +127,33 @@ async def _find_similar_memories( embed_result: EmbedResult = self.embedder.embed(query_text) query_vector = embed_result.dense_vector - # Determine collection and filter based on category - collection = "context" - category_uri_prefix = self._category_uri_prefix(candidate.category.value, candidate.user) - # Build filter by memory scope + uri prefix (schema does not have category field yet). - filter_conds = [ - {"field": "context_type", "op": "must", "conds": ["memory"]}, - {"field": "level", "op": "must", "conds": [2]}, - ] owner = candidate.user - if hasattr(owner, "account_id"): - filter_conds.append({"field": "account_id", "op": "must", "conds": [owner.account_id]}) + account_id = owner.account_id if hasattr(owner, "account_id") else "default" + owner_space = None if owner and hasattr(owner, "user_space_name"): owner_space = ( owner.agent_space_name() if candidate.category.value in {"cases", "patterns"} else owner.user_space_name() ) - filter_conds.append({"field": "owner_space", "op": "must", "conds": [owner_space]}) - if category_uri_prefix: - filter_conds.append({"field": "uri", "op": "must", "conds": [category_uri_prefix]}) - dedup_filter = {"op": "and", "conds": filter_conds} logger.debug( - "Dedup prefilter candidate category=%s filter=%s", + "Dedup prefilter candidate category=%s account=%s owner_space=%s uri_prefix=%s", candidate.category.value, - dedup_filter, + account_id, + owner_space, + category_uri_prefix, ) try: # Search with memory-scope filter. - results = await self.vikingdb.search( - collection=collection, + results = await self.vikingdb.search_similar_memories( + account_id=account_id, + owner_space=owner_space, + category_uri_prefix=category_uri_prefix, query_vector=query_vector, limit=5, - filter=dedup_filter, ) # Filter by similarity threshold diff --git a/openviking/session/session.py b/openviking/session/session.py index d46c5276..136823e4 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -5,7 +5,6 @@ Session as Context: Sessions integrated into L0/L1/L2 system. """ -import hashlib import json import re from dataclasses import dataclass, field @@ -300,32 +299,12 @@ def _update_active_counts(self) -> int: if not self._vikingdb_manager: return 0 - updated = 0 - storage = self._vikingdb_manager - - for usage in self._usage_records: - try: - # Compute the record ID from the URI directly. - # collection_schemas.py assigns id = md5(uri) for every context - # record, so storage.get() gives us a precise single-record lookup - # without the subtree-matching side-effect of fetch_by_uri() on - # path-type fields. - record_id = hashlib.md5(usage.uri.encode("utf-8")).hexdigest() - records = run_async(storage.get(collection="context", ids=[record_id])) - if not records: - logger.debug(f"Record not found for URI: {usage.uri}") - continue - current_count = records[0].get("active_count") or 0 - run_async( - storage.update( - collection="context", - id=record_id, - data={"active_count": current_count + 1}, - ) - ) - updated += 1 - except Exception as e: - logger.debug(f"Could not update active_count for {usage.uri}: {e}") + uris = [usage.uri for usage in self._usage_records if usage.uri] + try: + updated = run_async(self._vikingdb_manager.increment_active_count(self.ctx, uris)) + except Exception as e: + logger.debug(f"Could not update active_count for usage URIs: {e}") + updated = 0 if updated > 0: logger.info(f"Updated active_count for {updated} contexts/skills") diff --git a/openviking/storage/__init__.py b/openviking/storage/__init__.py index 63620722..ac4cd8d5 100644 --- a/openviking/storage/__init__.py +++ b/openviking/storage/__init__.py @@ -2,24 +2,21 @@ # SPDX-License-Identifier: Apache-2.0 """Storage layer interfaces and implementations.""" -from openviking.storage.observers import BaseObserver, QueueObserver -from openviking.storage.queuefs import QueueManager, get_queue_manager, init_queue_manager -from openviking.storage.viking_fs import VikingFS, get_viking_fs, init_viking_fs -from openviking.storage.viking_vector_index_backend import VikingVectorIndexBackend -from openviking.storage.vikingdb_interface import ( +from openviking.storage.errors import ( CollectionNotFoundError, ConnectionError, DuplicateKeyError, RecordNotFoundError, SchemaError, StorageException, - VikingDBInterface, ) +from openviking.storage.observers import BaseObserver, QueueObserver +from openviking.storage.queuefs import QueueManager, get_queue_manager, init_queue_manager +from openviking.storage.viking_fs import VikingFS, get_viking_fs, init_viking_fs +from openviking.storage.viking_vector_index_backend import VikingVectorIndexBackend from openviking.storage.vikingdb_manager import VikingDBManager __all__ = [ - # Interface - "VikingDBInterface", # Exceptions "StorageException", "CollectionNotFoundError", diff --git a/openviking/storage/collection_schemas.py b/openviking/storage/collection_schemas.py index 56b59cd9..0cff8725 100644 --- a/openviking/storage/collection_schemas.py +++ b/openviking/storage/collection_schemas.py @@ -13,9 +13,10 @@ from typing import Any, Dict, Optional from openviking.models.embedder.base import EmbedResult +from openviking.storage.errors import CollectionNotFoundError from openviking.storage.queuefs.embedding_msg import EmbeddingMsg from openviking.storage.queuefs.named_queue import DequeueHandlerBase -from openviking.storage.vikingdb_interface import CollectionNotFoundError, VikingDBInterface +from openviking.storage.viking_vector_index_backend import VikingVectorIndexBackend from openviking_cli.utils import get_logger from openviking_cli.utils.config.open_viking_config import OpenVikingConfig @@ -126,11 +127,11 @@ class TextEmbeddingHandler(DequeueHandlerBase): Supports both dense and sparse embeddings based on configuration. """ - def __init__(self, vikingdb: VikingDBInterface): + def __init__(self, vikingdb: VikingVectorIndexBackend): """Initialize the text embedding handler. Args: - vikingdb: VikingDBInterface instance for writing to vector database + vikingdb: VikingVectorIndexBackend instance for writing to vector database """ from openviking_cli.utils.config import get_openviking_config @@ -207,7 +208,7 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, id_seed = f"{account_id}:{owner_space}:{uri}" inserted_data["id"] = hashlib.md5(id_seed.encode("utf-8")).hexdigest() - record_id = await self._vikingdb.insert(self._collection_name, inserted_data) + record_id = await self._vikingdb.upsert(inserted_data) if record_id: logger.debug( f"Successfully wrote embedding to database: {record_id} abstract {inserted_data['abstract']} vector {inserted_data['vector'][:5]}" diff --git a/openviking/storage/errors.py b/openviking/storage/errors.py new file mode 100644 index 00000000..bc3e36be --- /dev/null +++ b/openviking/storage/errors.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Storage-layer exceptions.""" + + +class VikingDBException(Exception): + """Base exception for vector-store operations.""" + + +class StorageException(VikingDBException): + """Legacy alias for VikingDBException for backward compatibility.""" + + +class CollectionNotFoundError(StorageException): + """Raised when a collection does not exist.""" + + +class RecordNotFoundError(StorageException): + """Raised when a record does not exist.""" + + +class DuplicateKeyError(StorageException): + """Raised when trying to insert a duplicate key.""" + + +class ConnectionError(StorageException): + """Raised when storage connection fails.""" + + +class SchemaError(StorageException): + """Raised when schema validation fails.""" diff --git a/openviking/storage/expr.py b/openviking/storage/expr.py new file mode 100644 index 00000000..e19519bd --- /dev/null +++ b/openviking/storage/expr.py @@ -0,0 +1,61 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Filter expression AST for vector store queries.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Dict, List, Union + + +@dataclass(frozen=True) +class And: + conds: List["FilterExpr"] + + +@dataclass(frozen=True) +class Or: + conds: List["FilterExpr"] + + +@dataclass(frozen=True) +class Eq: + field: str + value: Any + + +@dataclass(frozen=True) +class In: + field: str + values: List[Any] + + +@dataclass(frozen=True) +class Range: + field: str + gte: Any | None = None + gt: Any | None = None + lte: Any | None = None + lt: Any | None = None + + +@dataclass(frozen=True) +class Contains: + field: str + substring: str + + +@dataclass(frozen=True) +class TimeRange: + field: str + start: datetime | str | None = None + end: datetime | str | None = None + + +@dataclass(frozen=True) +class RawDSL: + payload: Dict[str, Any] + + +FilterExpr = Union[And, Or, Eq, In, Range, Contains, TimeRange, RawDSL] diff --git a/openviking/storage/observers/vikingdb_observer.py b/openviking/storage/observers/vikingdb_observer.py index dc8f6e35..9e6eeab7 100644 --- a/openviking/storage/observers/vikingdb_observer.py +++ b/openviking/storage/observers/vikingdb_observer.py @@ -30,12 +30,10 @@ async def get_status_table_async(self) -> str: if not self._vikingdb_manager: return "VikingDB manager not initialized." - collection_names = await self._vikingdb_manager.list_collections() - - if not collection_names: + if not await self._vikingdb_manager.collection_exists(): return "No collections found." - statuses = await self._get_collection_statuses(collection_names) + statuses = await self._get_collection_statuses([self._vikingdb_manager.collection_name]) return self._format_status_as_table(statuses) def get_status_table(self) -> str: @@ -49,19 +47,12 @@ async def _get_collection_statuses(self, collection_names: list) -> Dict[str, Di for name in collection_names: try: - if not self._vikingdb_manager.project.has_collection(name): - continue - - collection = self._vikingdb_manager.project.get_collection(name) - if not collection: + if not await self._vikingdb_manager.collection_exists(): continue - index_count = len(collection.list_indexes()) - - count_result = collection.aggregate_data( - index_name=self._vikingdb_manager.DEFAULT_INDEX_NAME, op="count" - ) - vector_count = count_result.agg.get("_total", 0) + # Current OpenViking flow uses one managed default index per collection. + index_count = 1 + vector_count = await self._vikingdb_manager.count() statuses[name] = { "index_count": index_count, diff --git a/openviking/storage/queuefs/queue_manager.py b/openviking/storage/queuefs/queue_manager.py index dc9aeb57..95e9aeb2 100644 --- a/openviking/storage/queuefs/queue_manager.py +++ b/openviking/storage/queuefs/queue_manager.py @@ -107,7 +107,7 @@ def start(self) -> None: logger.info("[QueueManager] Started") - def setup_standard_queues(self, vikingdb_interface: Any) -> None: + def setup_standard_queues(self, vector_store: Any) -> None: """ Setup standard queues (Embedding and Semantic) with their handlers. @@ -116,14 +116,14 @@ def setup_standard_queues(self, vikingdb_interface: Any) -> None: queue manager is started. Args: - vikingdb_interface: VikingDBInterface instance for handlers to write results. + vector_store: Vector store instance for handlers to write results. """ # Import handlers here to avoid circular dependencies from openviking.storage.collection_schemas import TextEmbeddingHandler from openviking.storage.queuefs import SemanticProcessor # Embedding Queue - embedding_handler = TextEmbeddingHandler(vikingdb_interface) + embedding_handler = TextEmbeddingHandler(vector_store) self.get_queue( self.EMBEDDING, dequeue_handler=embedding_handler, diff --git a/openviking/storage/vectordb/collection/vikingdb_collection.py b/openviking/storage/vectordb/collection/vikingdb_collection.py index 9bfa74cd..b9995a1f 100644 --- a/openviking/storage/vectordb/collection/vikingdb_collection.py +++ b/openviking/storage/vectordb/collection/vikingdb_collection.py @@ -371,17 +371,17 @@ def aggregate_data( filters: Optional[Dict[str, Any]] = None, cond: Optional[Dict[str, Any]] = None, ) -> AggregateResult: - path = "/api/vikingdb/data/aggregate" + path = "/api/vikingdb/data/agg" data = { "project": self.project_name, "collection_name": self.collection_name, "index_name": index_name, - "agg": { - "op": op, - "field": field, - }, + "op": op, + "field": field, "filter": filters, } + if cond is not None: + data["cond"] = cond resp_data = self._data_post(path, data) return self._parse_aggregate_result(resp_data, op, field) diff --git a/openviking/storage/vectordb/collection/volcengine_collection.py b/openviking/storage/vectordb/collection/volcengine_collection.py index 8564c2c0..4acb9445 100644 --- a/openviking/storage/vectordb/collection/volcengine_collection.py +++ b/openviking/storage/vectordb/collection/volcengine_collection.py @@ -35,14 +35,13 @@ def get_or_create_volcengine_collection(config: Dict[str, Any], meta_data: Dict[ ak = config.get("AK") sk = config.get("SK") region = config.get("Region") - host = config.get("Host", "") collection_name = meta_data.get("CollectionName") if not collection_name: raise ValueError("CollectionName is required in config") # Initialize Console client for creating Collection - client = ClientForConsoleApi(ak, sk, region, host) + client = ClientForConsoleApi(ak, sk, region) # Try to create Collection try: @@ -64,10 +63,10 @@ def get_or_create_volcengine_collection(config: Dict[str, Any], meta_data: Dict[ raise e logger.info(f"Collection {collection_name} created successfully") - return VolcengineCollection(ak, sk, region, host, meta_data) + return VolcengineCollection(ak, sk, region, meta_data=meta_data) # Return VolcengineCollection instance - return VolcengineCollection(ak=ak, sk=sk, region=region, host=host, meta_data=meta_data) + return VolcengineCollection(ak=ak, sk=sk, region=region, meta_data=meta_data) class VolcengineCollection(ICollection): @@ -76,7 +75,7 @@ def __init__( ak: str, sk: str, region: str, - host: str, + host: Optional[str] = None, meta_data: Optional[Dict[str, Any]] = None, ): self.console_client = ClientForConsoleApi(ak, sk, region, host) @@ -580,16 +579,16 @@ def aggregate_data( filters: Optional[Dict[str, Any]] = None, cond: Optional[Dict[str, Any]] = None, ) -> AggregateResult: - path = "/api/vikingdb/data/aggregate" + path = "/api/vikingdb/data/agg" data = { "project": self.project_name, "collection_name": self.collection_name, "index_name": index_name, - "agg": { - "op": op, - "field": field, - }, + "op": op, + "field": field, "filter": filters, } + if cond is not None: + data["cond"] = cond resp_data = self._data_post(path, data) return self._parse_aggregate_result(resp_data, op, field) diff --git a/openviking/storage/vectordb_adapters/README.md b/openviking/storage/vectordb_adapters/README.md new file mode 100644 index 00000000..3cb8e065 --- /dev/null +++ b/openviking/storage/vectordb_adapters/README.md @@ -0,0 +1,211 @@ +# VectorDB Adapter 接入指南(新增第三方后端) + +本指南说明如何在 `openviking/storage/vectordb_adapters` 下新增一个第三方向量库后端,并接入 OpenViking 现有检索链路。 + +--- + +## 1. 目标与范围 + +### 目标 +- 以最小改动新增一个向量库后端。 +- 保持上层业务接口不变(`find/search` 等无需改调用方式)。 +- 将后端差异封装在 Adapter 层,不泄漏到业务层。 + +### 非目标 +- 不改上层语义检索策略(租户、目录层级、召回策略)。 +- 不增加新的对外 API 协议。 + +--- + +## 2. 架构位置与职责 + +当前分层职责如下: + +1. **上层语义层(OpenViking 业务)** + 面向语义接口,不关心后端协议差异。 + +2. **通用向量存储层(Store/Backend)** + 提供统一查询、写入、删除、计数能力。 + +3. **Adapter 层(本目录)** + 负责把统一能力映射到具体后端实现(local/http/volcengine/vikingdb/thirdparty)。 + +新增后端时,主要只改第 3 层。 + +--- + +## 3. 接入前提 + +在开始前,请确认: + +- 你已拿到第三方后端的: + - 集合管理 API(查/建/删集合) + - 数据 API(upsert/get/delete/search/aggregate) +- 你已明确该后端的: + - 认证方式(AK/SK、token、header) + - 过滤语法能力(是否支持 must/range/and/or) + - 索引参数约束(dense/sparse、距离度量、索引类型) + +--- + +## 4. 接入步骤 + +## Step 1:新增 Adapter 文件 + +在目录下新增文件,例如: + +- `openviking/storage/vectordb_adapters/thirdparty_adapter.py` + +定义类: + +- `ThirdPartyCollectionAdapter(CollectionAdapter)` + +基类位于: + +- `openviking/storage/vectordb_adapters/base.py` + +--- + +## Step 2:实现最小必需方法 + +你需要实现以下方法: + +1. `from_config(cls, config)` + - 从 `VectorDBBackendConfig` 读取后端配置并构造 adapter。 + - collection 名建议使用 `config.name or "context"`。 + +2. `_load_existing_collection_if_needed(self)` + - 懒加载已存在 collection handle。 + - 若不存在,保持 `_collection is None`。 + +3. `_create_backend_collection(self, meta)` + - 按传入 schema 创建 collection 并返回 handle。 + +--- + +## Step 3:按后端能力补充可选 Hook + +如后端有差异,可重写: + +- `_sanitize_scalar_index_fields(...)` +- `_build_default_index_meta(...)` + +目的:把后端特性差异收敛在 adapter 内。 + +--- + +## Step 4:注册到 Factory + +编辑: + +- `openviking/storage/vectordb_adapters/factory.py` + +在 `_ADAPTER_REGISTRY` 增加映射,例如: + +```python +"thirdparty": ThirdPartyCollectionAdapter +``` + +这样 `create_collection_adapter(config)` 会自动路由到你的实现。 + +--- + +## Step 5:补充配置模型 + +确保配置中可声明新 backend(如 `backend: thirdparty`)及其专属字段(endpoint/auth/region 等)。 + +原则: +- `create_collection` 时使用配置中的 name 绑定 collection。 +- 后续操作默认绑定,不需要每次传 collection_name。 + +--- + +## 5. Filter 与查询兼容规则 + +- Adapter 需要兼容统一过滤表达。 +- 上层传入的过滤表达会经由统一编译流程进入后端查询。 +- 若第三方语法不同,请在 adapter 内做映射,不改上层调用协议。 + +关键原则: +- **后端 DSL 不上浮到业务层**。 +- **业务层不依赖第三方私有查询语法**。 + +--- + +## 6. 最小代码骨架(示例) + +```python +from __future__ import annotations +from typing import Any, Dict + +from openviking.storage.vectordb_adapters.base import CollectionAdapter + +class ThirdPartyCollectionAdapter(CollectionAdapter): + def __init__(self, *, endpoint: str, token: str, collection_name: str): + super().__init__(collection_name=collection_name) + self.mode = "thirdparty" + self._endpoint = endpoint + self._token = token + + @classmethod + def from_config(cls, config: Any): + if not config.thirdparty or not config.thirdparty.endpoint: + raise ValueError("ThirdParty backend requires endpoint") + return cls( + endpoint=config.thirdparty.endpoint, + token=config.thirdparty.token, + collection_name=config.name or "context", + ) + + def _load_existing_collection_if_needed(self) -> None: + if self._collection is not None: + return + # TODO: 查询远端 collection 是否存在,存在则初始化 handle + # self._collection = ... + pass + + def _create_backend_collection(self, meta: Dict[str, Any]): + # TODO: 调后端 create collection,并返回 collection handle + # return ... + raise NotImplementedError +``` + +--- + +## 7. 测试要求(必须) + +至少覆盖以下场景: + +1. backend 工厂路由正确(能创建到新 adapter)。 +2. collection 生命周期可用(exists/create/drop)。 +3. 基础数据链路可用(upsert/get/delete/query)。 +4. count/aggregate 行为正确。 +5. filter 条件可正确生效(含组合条件)。 + +--- + +## 8. 常见问题与排查 + +### Q1:启动时报 backend 不支持 +- 检查 factory 是否注册。 +- 检查配置里的 backend 字符串是否与 registry key 一致。 + +### Q2:集合创建成功但查询为空 +- 检查 collection 绑定名是否一致。 +- 检查索引是否创建成功。 +- 检查 filter 映射是否把条件误转成空条件。 + +### Q3:count 与 query 条数不一致 +- 检查 aggregate API 的字段命名与返回结构解析。 +- 检查 count 使用的 filter 与 query 使用的 filter 是否一致。 + +--- + +## 9. 验收标准 + +当满足以下条件,即可视为接入完成: + +- `backend=thirdparty` 可正常初始化。 +- create 后可完成 upsert/get/query/delete/count 全流程。 +- 不改上层业务调用方式即可参与 `find/search` 检索链路。 +- 后端差异全部封装在 adapter 层。 diff --git a/openviking/storage/vectordb_adapters/__init__.py b/openviking/storage/vectordb_adapters/__init__.py new file mode 100644 index 00000000..8446b64a --- /dev/null +++ b/openviking/storage/vectordb_adapters/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""VectorDB backend collection adapter package.""" + +from .base import CollectionAdapter +from .factory import create_collection_adapter +from .http_adapter import HttpCollectionAdapter +from .local_adapter import LocalCollectionAdapter +from .vikingdb_private_adapter import VikingDBPrivateCollectionAdapter +from .volcengine_adapter import VolcengineCollectionAdapter + +__all__ = [ + "CollectionAdapter", + "LocalCollectionAdapter", + "HttpCollectionAdapter", + "VolcengineCollectionAdapter", + "VikingDBPrivateCollectionAdapter", + "create_collection_adapter", +] diff --git a/openviking/storage/vectordb_adapters/base.py b/openviking/storage/vectordb_adapters/base.py new file mode 100644 index 00000000..fc74bb79 --- /dev/null +++ b/openviking/storage/vectordb_adapters/base.py @@ -0,0 +1,425 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Base adapter primitives for backend-specific vector collection operations.""" + +from __future__ import annotations + +import uuid +from abc import ABC, abstractmethod +from typing import Any, Dict, Iterable, Optional +from urllib.parse import urlparse + +from openviking.storage.errors import CollectionNotFoundError +from openviking.storage.expr import ( + And, + Contains, + Eq, + FilterExpr, + In, + Or, + Range, + RawDSL, + TimeRange, +) +from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.result import FetchDataInCollectionResult +from openviking_cli.utils import get_logger + +logger = get_logger(__name__) + + +def _parse_url(url: str) -> tuple[str, int]: + normalized = url + if not normalized.startswith(("http://", "https://")): + normalized = f"http://{normalized}" + parsed = urlparse(normalized) + host = parsed.hostname or "127.0.0.1" + port = parsed.port or 5000 + return host, port + + +def _normalize_collection_names(raw_collections: Iterable[Any]) -> list[str]: + names: list[str] = [] + for item in raw_collections: + if isinstance(item, str): + names.append(item) + elif isinstance(item, dict): + name = item.get("CollectionName") or item.get("collection_name") or item.get("name") + if isinstance(name, str): + names.append(name) + return names + + +class CollectionAdapter(ABC): + """Backend-specific adapter for single-collection operations. + + Public API methods are kept without prefix (create/query/upsert/delete/count...). + Internal extension hooks for subclasses use leading underscore. + """ + + mode: str + + def __init__(self, collection_name: str): + self._collection_name = collection_name + self._collection: Optional[Collection] = None + + @property + def collection_name(self) -> str: + return self._collection_name + + @classmethod + @abstractmethod + def from_config(cls, config: Any) -> "CollectionAdapter": + """Create an adapter instance from VectorDB backend config.""" + + @abstractmethod + def _load_existing_collection_if_needed(self) -> None: + """Load existing bound collection handle when possible.""" + + @abstractmethod + def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: + """Create backend collection handle for bound collection.""" + + def collection_exists(self) -> bool: + self._load_existing_collection_if_needed() + return self._collection is not None + + def get_collection(self) -> Collection: + self._load_existing_collection_if_needed() + if self._collection is None: + raise CollectionNotFoundError(f"Collection {self._collection_name} does not exist") + return self._collection + + def create_collection( + self, + name: str, + schema: Dict[str, Any], + *, + distance: str, + sparse_weight: float, + index_name: str, + ) -> bool: + if self.collection_exists(): + return False + + self._collection_name = name + collection_meta = dict(schema) + scalar_index_fields = collection_meta.pop("ScalarIndex", []) + if "CollectionName" not in collection_meta: + collection_meta["CollectionName"] = name + + self._collection = self._create_backend_collection(collection_meta) + + scalar_index_fields = self._sanitize_scalar_index_fields( + scalar_index_fields=scalar_index_fields, + fields_meta=collection_meta.get("Fields", []), + ) + index_meta = self._build_default_index_meta( + index_name=index_name, + distance=distance, + use_sparse=sparse_weight > 0.0, + sparse_weight=sparse_weight, + scalar_index_fields=scalar_index_fields, + ) + self._collection.create_index(index_name, index_meta) + return True + + def drop_collection(self) -> bool: + if not self.collection_exists(): + return False + + coll = self.get_collection() + + # Drop indexes first so index lifecycle remains internal to adapter. + try: + for index_name in coll.list_indexes() or []: + try: + coll.drop_index(index_name) + except Exception as e: + logger.warning("Failed to drop index %s: %s", index_name, e) + except Exception as e: + logger.warning("Failed to list indexes before dropping collection: %s", e) + + try: + coll.drop() + except NotImplementedError: + logger.warning("Collection drop is not supported by backend mode=%s", self.mode) + return False + finally: + self._collection = None + + return True + + def close(self) -> None: + if self._collection is not None: + self._collection.close() + self._collection = None + + def get_collection_info(self) -> Optional[Dict[str, Any]]: + if not self.collection_exists(): + return None + return self.get_collection().get_meta_data() + + def _sanitize_scalar_index_fields( + self, + scalar_index_fields: list[str], + fields_meta: list[dict[str, Any]], + ) -> list[str]: + return scalar_index_fields + + def _build_default_index_meta( + self, + *, + index_name: str, + distance: str, + use_sparse: bool, + sparse_weight: float, + scalar_index_fields: list[str], + ) -> Dict[str, Any]: + index_type = "flat_hybrid" if use_sparse else "flat" + index_meta: Dict[str, Any] = { + "IndexName": index_name, + "VectorIndex": { + "IndexType": index_type, + "Distance": distance, + "Quant": "int8", + }, + "ScalarIndex": scalar_index_fields, + } + if use_sparse: + index_meta["VectorIndex"]["EnableSparse"] = True + index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight + return index_meta + + def _normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: + return record + + def _compile_filter(self, expr: FilterExpr | Dict[str, Any] | None) -> Dict[str, Any]: + if expr is None: + return {} + if isinstance(expr, dict): + return expr + if isinstance(expr, RawDSL): + return expr.payload + if isinstance(expr, And): + conds = [self._compile_filter(c) for c in expr.conds if c is not None] + conds = [c for c in conds if c] + if not conds: + return {} + if len(conds) == 1: + return conds[0] + return {"op": "and", "conds": conds} + if isinstance(expr, Or): + conds = [self._compile_filter(c) for c in expr.conds if c is not None] + conds = [c for c in conds if c] + if not conds: + return {} + if len(conds) == 1: + return conds[0] + return {"op": "or", "conds": conds} + if isinstance(expr, Eq): + return {"op": "must", "field": expr.field, "conds": [expr.value]} + if isinstance(expr, In): + return {"op": "must", "field": expr.field, "conds": list(expr.values)} + if isinstance(expr, Range): + payload: Dict[str, Any] = {"op": "range", "field": expr.field} + if expr.gte is not None: + payload["gte"] = expr.gte + if expr.gt is not None: + payload["gt"] = expr.gt + if expr.lte is not None: + payload["lte"] = expr.lte + if expr.lt is not None: + payload["lt"] = expr.lt + return payload + if isinstance(expr, Contains): + return { + "op": "contains", + "field": expr.field, + "substring": expr.substring, + } + if isinstance(expr, TimeRange): + payload: Dict[str, Any] = {"op": "range", "field": expr.field} + if expr.start is not None: + payload["gte"] = expr.start + if expr.end is not None: + payload["lt"] = expr.end + return payload + raise TypeError(f"Unsupported filter expr type: {type(expr)!r}") + + # Backward-compatible aliases: keep old non-underscore names callable. + def sanitize_scalar_index_fields( + self, + scalar_index_fields: list[str], + fields_meta: list[dict[str, Any]], + ) -> list[str]: + return self._sanitize_scalar_index_fields( + scalar_index_fields=scalar_index_fields, + fields_meta=fields_meta, + ) + + def build_default_index_meta( + self, + *, + index_name: str, + distance: str, + use_sparse: bool, + sparse_weight: float, + scalar_index_fields: list[str], + ) -> Dict[str, Any]: + return self._build_default_index_meta( + index_name=index_name, + distance=distance, + use_sparse=use_sparse, + sparse_weight=sparse_weight, + scalar_index_fields=scalar_index_fields, + ) + + def normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: + return self._normalize_record_for_read(record) + + def compile_filter(self, expr: FilterExpr | Dict[str, Any] | None) -> Dict[str, Any]: + return self._compile_filter(expr) + + def upsert(self, data: Dict[str, Any] | list[Dict[str, Any]]) -> list[str]: + coll = self.get_collection() + records = [data] if isinstance(data, dict) else data + normalized: list[Dict[str, Any]] = [] + ids: list[str] = [] + for item in records: + record = dict(item) + record_id = record.get("id") or str(uuid.uuid4()) + record["id"] = record_id + ids.append(record_id) + normalized.append(record) + coll.upsert_data(normalized) + return ids + + def get(self, ids: list[str]) -> list[Dict[str, Any]]: + coll = self.get_collection() + result = coll.fetch_data(ids) + + records: list[Dict[str, Any]] = [] + if isinstance(result, FetchDataInCollectionResult): + for item in result.items: + record = dict(item.fields) if item.fields else {} + record["id"] = item.id + records.append(self._normalize_record_for_read(record)) + return records + + if isinstance(result, dict) and "fetch" in result: + for item in result.get("fetch", []): + record = dict(item.get("fields", {})) if item.get("fields") else {} + record_id = item.get("id") + if record_id: + record["id"] = record_id + records.append(self._normalize_record_for_read(record)) + return records + + def query( + self, + *, + query_vector: Optional[list[float]] = None, + sparse_query_vector: Optional[Dict[str, float]] = None, + filter: Optional[Dict[str, Any] | FilterExpr] = None, + limit: int = 10, + offset: int = 0, + output_fields: Optional[list[str]] = None, + with_vector: bool = False, + order_by: Optional[str] = None, + order_desc: bool = False, + ) -> list[Dict[str, Any]]: + coll = self.get_collection() + vectordb_filter = self._compile_filter(filter) + + if query_vector or sparse_query_vector: + result = coll.search_by_vector( + index_name="default", + dense_vector=query_vector, + sparse_vector=sparse_query_vector, + limit=limit, + offset=offset, + filters=vectordb_filter, + output_fields=output_fields, + ) + elif order_by: + result = coll.search_by_scalar( + index_name="default", + field=order_by, + order="desc" if order_desc else "asc", + limit=limit, + offset=offset, + filters=vectordb_filter, + output_fields=output_fields, + ) + else: + result = coll.search_by_random( + index_name="default", + limit=limit, + offset=offset, + filters=vectordb_filter, + output_fields=output_fields, + ) + + records: list[Dict[str, Any]] = [] + for item in result.data: + record = dict(item.fields) if item.fields else {} + record["id"] = item.id + record["_score"] = item.score if item.score is not None else 0.0 + record = self._normalize_record_for_read(record) + if not with_vector: + record.pop("vector", None) + record.pop("sparse_vector", None) + records.append(record) + return records + + def delete( + self, + *, + ids: Optional[list[str]] = None, + filter: Optional[Dict[str, Any] | FilterExpr] = None, + limit: int = 100000, + ) -> int: + coll = self.get_collection() + delete_ids = list(ids or []) + if not delete_ids and filter is not None: + matched = self.query(filter=filter, limit=limit, with_vector=True) + delete_ids = [record["id"] for record in matched if record.get("id")] + + if not delete_ids: + return 0 + + coll.delete_data(delete_ids) + return len(delete_ids) + + @staticmethod + def _coerce_int(value: Any) -> Optional[int]: + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float) and value.is_integer(): + return int(value) + if isinstance(value, str): + stripped = value.strip() + if stripped.isdigit(): + return int(stripped) + return None + + def count(self, filter: Optional[Dict[str, Any] | FilterExpr] = None) -> int: + coll = self.get_collection() + result = coll.aggregate_data( + index_name="default", + op="count", + filters=self._compile_filter(filter), + ) + if "_total" in result.agg: + parsed_total = self._coerce_int(result.agg.get("_total")) + if parsed_total is not None: + return parsed_total + + return 0 + + def clear(self) -> bool: + self.get_collection().delete_all_data() + return True diff --git a/openviking/storage/vectordb_adapters/factory.py b/openviking/storage/vectordb_adapters/factory.py new file mode 100644 index 00000000..21f15279 --- /dev/null +++ b/openviking/storage/vectordb_adapters/factory.py @@ -0,0 +1,29 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Adapter registry and factory entrypoints.""" + +from __future__ import annotations + +from .base import CollectionAdapter +from .http_adapter import HttpCollectionAdapter +from .local_adapter import LocalCollectionAdapter +from .vikingdb_private_adapter import VikingDBPrivateCollectionAdapter +from .volcengine_adapter import VolcengineCollectionAdapter + +_ADAPTER_REGISTRY: dict[str, type[CollectionAdapter]] = { + "local": LocalCollectionAdapter, + "http": HttpCollectionAdapter, + "volcengine": VolcengineCollectionAdapter, + "vikingdb": VikingDBPrivateCollectionAdapter, +} + + +def create_collection_adapter(config) -> CollectionAdapter: + """Unified factory entrypoint for backend-specific collection adapters.""" + adapter_cls = _ADAPTER_REGISTRY.get(config.backend) + if adapter_cls is None: + raise ValueError( + f"Vector backend {config.backend} is not supported. " + f"Available backends: {sorted(_ADAPTER_REGISTRY)}" + ) + return adapter_cls.from_config(config) diff --git a/openviking/storage/vectordb_adapters/http_adapter.py b/openviking/storage/vectordb_adapters/http_adapter.py new file mode 100644 index 00000000..d33b7243 --- /dev/null +++ b/openviking/storage/vectordb_adapters/http_adapter.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""HTTP backend collection adapter.""" + +from __future__ import annotations + +from typing import Any, Dict + +from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.http_collection import ( + HttpCollection, + get_or_create_http_collection, + list_vikingdb_collections, +) + +from .base import CollectionAdapter, _normalize_collection_names, _parse_url + + +class HttpCollectionAdapter(CollectionAdapter): + """Adapter for remote HTTP vectordb project.""" + + def __init__(self, host: str, port: int, project_name: str, collection_name: str): + super().__init__(collection_name=collection_name) + self.mode = "http" + self._host = host + self._port = port + self._project_name = project_name + + @classmethod + def from_config(cls, config: Any): + if not config.url: + raise ValueError("HTTP backend requires a valid URL") + host, port = _parse_url(config.url) + return cls( + host=host, + port=port, + project_name=config.project_name or "default", + collection_name=config.name or "context", + ) + + def _meta(self) -> Dict[str, Any]: + return { + "ProjectName": self._project_name, + "CollectionName": self._collection_name, + } + + def _remote_has_collection(self) -> bool: + raw = list_vikingdb_collections( + host=self._host, + port=self._port, + project_name=self._project_name, + ) + return self._collection_name in _normalize_collection_names(raw) + + def _load_existing_collection_if_needed(self) -> None: + if self._collection is not None: + return + if not self._remote_has_collection(): + return + self._collection = Collection( + HttpCollection( + ip=self._host, + port=self._port, + meta_data=self._meta(), + ) + ) + + def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: + payload = dict(meta) + payload.update(self._meta()) + return get_or_create_http_collection( + host=self._host, + port=self._port, + meta_data=payload, + ) diff --git a/openviking/storage/vectordb_adapters/local_adapter.py b/openviking/storage/vectordb_adapters/local_adapter.py new file mode 100644 index 00000000..29489e44 --- /dev/null +++ b/openviking/storage/vectordb_adapters/local_adapter.py @@ -0,0 +1,53 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Local backend collection adapter.""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any, Dict + +from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.local_collection import get_or_create_local_collection + +from .base import CollectionAdapter + + +class LocalCollectionAdapter(CollectionAdapter): + """Adapter for local embedded vectordb backend.""" + + DEFAULT_LOCAL_PROJECT_NAME = "vectordb" + + def __init__(self, collection_name: str, project_path: str): + super().__init__(collection_name=collection_name) + self.mode = "local" + self._project_path = project_path + + @classmethod + def from_config(cls, config: Any): + project_path = ( + str(Path(config.path) / cls.DEFAULT_LOCAL_PROJECT_NAME) if config.path else "" + ) + return cls(collection_name=config.name or "context", project_path=project_path) + + def _collection_path(self) -> str: + if not self._project_path: + return "" + return str(Path(self._project_path) / self._collection_name) + + def _load_existing_collection_if_needed(self) -> None: + if self._collection is not None: + return + collection_path = self._collection_path() + if not collection_path: + return + meta_path = os.path.join(collection_path, "collection_meta.json") + if os.path.exists(meta_path): + self._collection = get_or_create_local_collection(path=collection_path) + + def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: + collection_path = self._collection_path() + if collection_path: + os.makedirs(collection_path, exist_ok=True) + return get_or_create_local_collection(meta_data=meta, path=collection_path) diff --git a/openviking/storage/vectordb_adapters/vikingdb_private_adapter.py b/openviking/storage/vectordb_adapters/vikingdb_private_adapter.py new file mode 100644 index 00000000..fe15ad74 --- /dev/null +++ b/openviking/storage/vectordb_adapters/vikingdb_private_adapter.py @@ -0,0 +1,121 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Private VikingDB backend collection adapter.""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.vikingdb_clients import VIKINGDB_APIS, VikingDBClient +from openviking.storage.vectordb.collection.vikingdb_collection import VikingDBCollection + +from .base import CollectionAdapter + + +class VikingDBPrivateCollectionAdapter(CollectionAdapter): + """Adapter for private VikingDB deployment.""" + + def __init__( + self, + *, + host: str, + headers: Optional[dict[str, str]], + project_name: str, + collection_name: str, + ): + super().__init__(collection_name=collection_name) + self.mode = "vikingdb" + self._host = host + self._headers = headers + self._project_name = project_name + + @classmethod + def from_config(cls, config: Any): + if not config.vikingdb or not config.vikingdb.host: + raise ValueError("VikingDB backend requires a valid host") + return cls( + host=config.vikingdb.host, + headers=config.vikingdb.headers, + project_name=config.project_name or "default", + collection_name=config.name or "context", + ) + + def _client(self) -> VikingDBClient: + return VikingDBClient(self._host, self._headers) + + def _fetch_collection_meta(self) -> Optional[Dict[str, Any]]: + path, method = VIKINGDB_APIS["GetVikingdbCollection"] + req = { + "ProjectName": self._project_name, + "CollectionName": self._collection_name, + } + response = self._client().do_req(method, path=path, req_body=req) + if response.status_code != 200: + return None + result = response.json() + meta = result.get("Result", {}) + return meta or None + + def _load_existing_collection_if_needed(self) -> None: + if self._collection is not None: + return + meta = self._fetch_collection_meta() + if meta is None: + return + self._collection = Collection( + VikingDBCollection( + host=self._host, + headers=self._headers, + meta_data=meta, + ) + ) + + def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: + self._load_existing_collection_if_needed() + if self._collection is None: + raise NotImplementedError("private vikingdb collection should be pre-created") + return self._collection + + def _sanitize_scalar_index_fields( + self, + scalar_index_fields: list[str], + fields_meta: list[dict[str, Any]], + ) -> list[str]: + date_time_fields = { + field.get("FieldName") for field in fields_meta if field.get("FieldType") == "date_time" + } + return [field for field in scalar_index_fields if field not in date_time_fields] + + def _build_default_index_meta( + self, + *, + index_name: str, + distance: str, + use_sparse: bool, + sparse_weight: float, + scalar_index_fields: list[str], + ) -> Dict[str, Any]: + index_type = "hnsw_hybrid" if use_sparse else "hnsw" + index_meta: Dict[str, Any] = { + "IndexName": index_name, + "VectorIndex": { + "IndexType": index_type, + "Distance": distance, + "Quant": "int8", + }, + "ScalarIndex": scalar_index_fields, + } + if use_sparse: + index_meta["VectorIndex"]["EnableSparse"] = True + index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight + return index_meta + + def _normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: + for key in ("uri", "parent_uri"): + value = record.get(key) + if isinstance(value, str) and not value.startswith("viking://"): + stripped = value.strip("/") + if stripped: + record[key] = f"viking://{stripped}" + return record diff --git a/openviking/storage/vectordb_adapters/volcengine_adapter.py b/openviking/storage/vectordb_adapters/volcengine_adapter.py new file mode 100644 index 00000000..16d5ab08 --- /dev/null +++ b/openviking/storage/vectordb_adapters/volcengine_adapter.py @@ -0,0 +1,132 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Volcengine backend collection adapter.""" + +from __future__ import annotations + +from typing import Any, Dict + +from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.volcengine_collection import ( + VolcengineCollection, + get_or_create_volcengine_collection, +) + +from .base import CollectionAdapter + + +class VolcengineCollectionAdapter(CollectionAdapter): + """Adapter for Volcengine-hosted VikingDB.""" + + def __init__( + self, + *, + ak: str, + sk: str, + region: str, + project_name: str, + collection_name: str, + ): + super().__init__(collection_name=collection_name) + self.mode = "volcengine" + self._ak = ak + self._sk = sk + self._region = region + self._project_name = project_name + + @classmethod + def from_config(cls, config: Any): + if not ( + config.volcengine + and config.volcengine.ak + and config.volcengine.sk + and config.volcengine.region + ): + raise ValueError("Volcengine backend requires AK, SK, and Region configuration") + return cls( + ak=config.volcengine.ak, + sk=config.volcengine.sk, + region=config.volcengine.region, + project_name=config.project_name or "default", + collection_name=config.name or "context", + ) + + def _meta(self) -> Dict[str, Any]: + return { + "ProjectName": self._project_name, + "CollectionName": self._collection_name, + } + + def _config(self) -> Dict[str, Any]: + return { + "AK": self._ak, + "SK": self._sk, + "Region": self._region, + } + + def _new_collection_handle(self) -> VolcengineCollection: + return VolcengineCollection( + ak=self._ak, + sk=self._sk, + region=self._region, + meta_data=self._meta(), + ) + + def _load_existing_collection_if_needed(self) -> None: + if self._collection is not None: + return + candidate = self._new_collection_handle() + meta = candidate.get_meta_data() or {} + if meta and meta.get("CollectionName"): + self._collection = candidate + + def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: + payload = dict(meta) + payload.update(self._meta()) + return get_or_create_volcengine_collection( + config=self._config(), + meta_data=payload, + ) + + def _sanitize_scalar_index_fields( + self, + scalar_index_fields: list[str], + fields_meta: list[dict[str, Any]], + ) -> list[str]: + date_time_fields = { + field.get("FieldName") for field in fields_meta if field.get("FieldType") == "date_time" + } + return [field for field in scalar_index_fields if field not in date_time_fields] + + def _build_default_index_meta( + self, + *, + index_name: str, + distance: str, + use_sparse: bool, + sparse_weight: float, + scalar_index_fields: list[str], + ) -> Dict[str, Any]: + index_type = "hnsw_hybrid" if use_sparse else "hnsw" + index_meta: Dict[str, Any] = { + "IndexName": index_name, + "VectorIndex": { + "IndexType": index_type, + "Distance": distance, + "Quant": "int8", + }, + "ScalarIndex": scalar_index_fields, + } + if use_sparse: + index_meta["VectorIndex"]["EnableSparse"] = True + index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight + return index_meta + + def _normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: + for key in ("uri", "parent_uri"): + value = record.get(key) + if isinstance(value, str) and not value.startswith("viking://"): + stripped = value.strip("/") + if stripped: + record[key] = f"viking://{stripped}" + return record diff --git a/openviking/storage/viking_fs.py b/openviking/storage/viking_fs.py index deef27db..ae09efe6 100644 --- a/openviking/storage/viking_fs.py +++ b/openviking/storage/viking_fs.py @@ -25,13 +25,13 @@ from pyagfs.exceptions import AGFSHTTPError from openviking.server.identity import RequestContext, Role -from openviking.storage.vikingdb_interface import VikingDBInterface from openviking.utils.time_utils import format_simplified, get_current_timestamp, parse_iso_datetime from openviking_cli.session.user_id import UserIdentifier from openviking_cli.utils.logger import get_logger from openviking_cli.utils.uri import VikingURI if TYPE_CHECKING: + from openviking.storage.viking_vector_index_backend import VikingVectorIndexBackend from openviking_cli.utils.config import RerankConfig logger = get_logger(__name__) @@ -71,7 +71,8 @@ def init_viking_fs( agfs: Any, query_embedder: Optional[Any] = None, rerank_config: Optional["RerankConfig"] = None, - vector_store: Optional["VikingDBInterface"] = None, + vector_store: Optional["VikingVectorIndexBackend"] = None, + timeout: int = 10, enable_recorder: bool = False, ) -> "VikingFS": """Initialize VikingFS singleton. @@ -162,7 +163,8 @@ def __init__( agfs: Any, query_embedder: Optional[Any] = None, rerank_config: Optional["RerankConfig"] = None, - vector_store: Optional["VikingDBInterface"] = None, + vector_store: Optional["VikingVectorIndexBackend"] = None, + timeout: int = 10, ): self.agfs = agfs self.query_embedder = query_embedder @@ -604,7 +606,7 @@ async def find( ctx=self._ctx_or_default(ctx), limit=limit, score_threshold=score_threshold, - metadata_filter=filter, + scope_dsl=filter, ) # Convert QueryResult to FindResult @@ -721,7 +723,7 @@ async def _execute(tq: TypedQuery): ctx=self._ctx_or_default(ctx), limit=limit, score_threshold=score_threshold, - metadata_filter=filter, + scope_dsl=filter, ) query_results = await asyncio.gather(*[_execute(tq) for tq in typed_queries]) @@ -1045,40 +1047,19 @@ async def _delete_from_vector_store( ) -> None: """Delete records with specified URIs from vector store. - Uses storage.remove_by_uri method, which implements recursive deletion of child nodes. + Uses tenant-safe URI deletion semantics from vector store. """ - storage = self._get_vector_store() - if not storage: + vector_store = self._get_vector_store() + if not vector_store: return real_ctx = self._ctx_or_default(ctx) - for uri in uris: - try: - filter_conds: List[Dict[str, Any]] = [ - {"op": "must", "field": "account_id", "conds": [real_ctx.account_id]}, - { - "op": "or", - "conds": [ - {"op": "must", "field": "uri", "conds": [uri]}, - {"op": "must", "field": "uri", "conds": [f"{uri}/"]}, - ], - }, - ] - if real_ctx.role == Role.USER and uri.startswith( - ("viking://user/", "viking://agent/") - ): - owner_space = ( - real_ctx.user.user_space_name() - if uri.startswith("viking://user/") - else real_ctx.user.agent_space_name() - ) - filter_conds.append( - {"op": "must", "field": "owner_space", "conds": [owner_space]} - ) - await storage.batch_delete("context", {"op": "and", "conds": filter_conds}) + try: + await vector_store.delete_uris(real_ctx, uris) + for uri in uris: logger.info(f"[VikingFS] Deleted from vector store: {uri}") - except Exception as e: - logger.warning(f"[VikingFS] Failed to delete {uri} from vector store: {e}") + except Exception as e: + logger.warning(f"[VikingFS] Failed to delete from vector store: {e}") async def _update_vector_store_uris( self, @@ -1091,8 +1072,8 @@ async def _update_vector_store_uris( Preserves vector data, only updates uri and parent_uri fields, no need to regenerate embeddings. """ - storage = self._get_vector_store() - if not storage: + vector_store = self._get_vector_store() + if not vector_store: return old_base_uri = self._path_to_uri(old_base, ctx=ctx) @@ -1100,19 +1081,9 @@ async def _update_vector_store_uris( for uri in uris: try: - records = await storage.filter( - collection="context", - filter={ - "op": "and", - "conds": [ - {"op": "must", "field": "uri", "conds": [uri]}, - { - "op": "must", - "field": "account_id", - "conds": [self._ctx_or_default(ctx).account_id], - }, - ], - }, + records = await vector_store.get_context_by_uri( + account_id=self._ctx_or_default(ctx).account_id, + uri=uri, limit=1, ) @@ -1120,7 +1091,6 @@ async def _update_vector_store_uris( continue record = records[0] - record_id = record["id"] new_uri = uri.replace(old_base_uri, new_base_uri, 1) @@ -1129,19 +1099,17 @@ async def _update_vector_store_uris( old_parent_uri.replace(old_base_uri, new_base_uri, 1) if old_parent_uri else "" ) - await storage.update( - "context", - record_id, - { - "uri": new_uri, - "parent_uri": new_parent_uri, - }, + await vector_store.update_uri_mapping( + ctx=self._ctx_or_default(ctx), + uri=uri, + new_uri=new_uri, + new_parent_uri=new_parent_uri, ) logger.info(f"[VikingFS] Updated URI: {uri} -> {new_uri}") except Exception as e: logger.warning(f"[VikingFS] Failed to update {uri} in vector store: {e}") - def _get_vector_store(self) -> Optional["VikingDBInterface"]: + def _get_vector_store(self) -> Optional["VikingVectorIndexBackend"]: """Get vector store instance.""" return self.vector_store diff --git a/openviking/storage/viking_vector_index_backend.py b/openviking/storage/viking_vector_index_backend.py index 7c579759..7307960a 100644 --- a/openviking/storage/viking_vector_index_backend.py +++ b/openviking/storage/viking_vector_index_backend.py @@ -1,920 +1,582 @@ # Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. # SPDX-License-Identifier: Apache-2.0 -""" -VikingDB storage backend for OpenViking. +"""VikingDB storage backend for OpenViking.""" -Implements the VikingDBInterface using the custom vectordb implementation. -Supports both in-memory and local persistent storage modes. -""" +from __future__ import annotations import uuid -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional +from openviking.server.identity import RequestContext, Role +from openviking.storage.expr import And, Eq, FilterExpr, In, Or, RawDSL from openviking.storage.vectordb.collection.collection import Collection -from openviking.storage.vectordb.collection.result import FetchDataInCollectionResult from openviking.storage.vectordb.utils.logging_init import init_cpp_logging -from openviking.storage.vikingdb_interface import CollectionNotFoundError, VikingDBInterface +from openviking.storage.vectordb_adapters import CollectionAdapter, create_collection_adapter from openviking_cli.utils import get_logger from openviking_cli.utils.config.vectordb_config import VectorDBBackendConfig logger = get_logger(__name__) -class VikingVectorIndexBackend(VikingDBInterface): - """ - VikingDB storage backend implementation. +class VikingVectorIndexBackend: + """Single-collection vector backend with adapter-based backend specialization.""" - Features: - - Vector similarity search with BruteForce indexing - - Scalar filtering with support for multiple operators - - Support for local persistent storage, HTTP service, and Volcengine VikingDB - - Auto-managed indexes per collection - - VikingDBManager is derived by VikingVectorIndexBackend. - """ - - # Default project and index names DEFAULT_INDEX_NAME = "default" - DEFAULT_LOCAL_PROJECT_NAME = "vectordb" + ALLOWED_CONTEXT_TYPES = {"resource", "skill", "memory"} - def __init__( - self, - config: Optional[VectorDBBackendConfig], - ): - """ - Initialize VikingDB backend. - - Args: - config: Configuration object for VectorDB backend. - - Examples: - # 1. Local persistent storage - config = VectorDBBackendConfig( - backend="local", - path="./data/vectordb" - ) - backend = VikingVectorIndexBackend(config=config) + def __init__(self, config: Optional[VectorDBBackendConfig]): + if config is None: + raise ValueError("VectorDB backend config is required") - # 2. Remote HTTP service - config = VectorDBBackendConfig( - backend="http", - url="http://localhost:5000" - ) - backend = VikingVectorIndexBackend(config=config) - - # 3. Volcengine VikingDB - from openviking_cli.utils.config.storage_config import VolcengineConfig - config = VectorDBBackendConfig( - backend="volcengine", - volcengine=VolcengineConfig( - ak="your-ak", - sk="your-sk", - region="cn-beijing" - ) - ) - backend = VikingVectorIndexBackend(config=config) - """ init_cpp_logging() self.vector_dim = config.dimension self.distance_metric = config.distance_metric self.sparse_weight = config.sparse_weight + self._collection_name = config.name or "context" - if config.backend == "volcengine": - if not ( - config.volcengine - and config.volcengine.ak - and config.volcengine.sk - and config.volcengine.region - ): - raise ValueError("Volcengine backend requires AK, SK, and Region configuration") - - # Volcengine VikingDB mode - self._mode = config.backend - # Convert lowercase keys to uppercase for consistency with volcengine_collection - volc_config = { - "AK": config.volcengine.ak, - "SK": config.volcengine.sk, - "Region": config.volcengine.region, - } + self._adapter: CollectionAdapter = create_collection_adapter(config) + self._mode = self._adapter.mode - from openviking.storage.vectordb.project.volcengine_project import ( - get_or_create_volcengine_project, - ) + logger.info( + "VikingDB backend initialized via adapter %s (mode=%s)", + type(self._adapter).__name__, + self._mode, + ) - self.project = get_or_create_volcengine_project( - project_name=config.project_name, config=volc_config - ) - logger.info( - f"VectorDB backend initialized in Volcengine mode: region={volc_config['Region']}" - ) - elif config.backend == "vikingdb": - if not config.vikingdb.host: - raise ValueError("VikingDB backend requires a valid host") - # VikingDB private deployment mode - self._mode = config.backend - viking_config = { - "Host": config.vikingdb.host, - "Headers": config.vikingdb.headers, - } + self._collection_config: Dict[str, Any] = {} + self._meta_data_cache: Dict[str, Any] = {} - from openviking.storage.vectordb.project.vikingdb_project import ( - get_or_create_vikingdb_project, - ) + @property + def collection_name(self) -> str: + return self._collection_name - self.project = get_or_create_vikingdb_project( - project_name=config.project_name, config=viking_config - ) - logger.info(f"VikingDB backend initialized in private mode: {config.vikingdb.host}") - elif config.backend == "http": - if not config.url: - raise ValueError("HTTP backend requires a valid URL") - # Remote mode: parse URL and create HTTP project - self._mode = config.backend - self.host, self.port = self._parse_url(config.url) - - from openviking.storage.vectordb.project.http_project import get_or_create_http_project - - self.project = get_or_create_http_project( - host=self.host, port=self.port, project_name=config.project_name - ) - logger.info(f"VikingDB backend initialized in remote mode: {config.url}") - elif config.backend == "local": - # Local persistent mode - self._mode = config.backend - from openviking.storage.vectordb.project.local_project import ( - get_or_create_local_project, - ) + def _get_collection(self) -> Collection: + return self._adapter.get_collection() - project_path = ( - Path(config.path) / self.DEFAULT_LOCAL_PROJECT_NAME if config.path else "" - ) - self.project = get_or_create_local_project(path=str(project_path)) - logger.info(f"VikingDB backend initialized with local storage: {project_path}") - else: - raise ValueError(f"Unsupported VikingDB backend type: {config.type}") - - self._collection_configs: Dict[str, Dict[str, Any]] = {} - # Cache meta_data at collection level to avoid repeated remote calls - self._meta_data_cache: Dict[str, Dict[str, Any]] = {} - - def _parse_url(self, url: str) -> Tuple[str, int]: - """ - Parse VikingVectorIndex service URL to extract host and port. - - Args: - url: Service URL (e.g., "http://localhost:5000" or "localhost:5000") - - Returns: - Tuple of (host, port) - """ - from urllib.parse import urlparse - - # Add scheme if not present - if not url.startswith(("http://", "https://")): - url = f"http://{url}" - - parsed = urlparse(url) - host = parsed.hostname or "127.0.0.1" - port = parsed.port or 5000 - - return host, port - - def _get_collection(self, name: str) -> Collection: - """Get collection object or raise error if not found.""" - if not self.project.has_collection(name): - raise CollectionNotFoundError(f"Collection '{name}' does not exist") - return self.project.get_collection(name) - - def _get_meta_data(self, collection_name: str, coll: Collection) -> Dict[str, Any]: - """Get meta_data with collection-level caching to avoid repeated remote calls.""" - if collection_name not in self._meta_data_cache: - self._meta_data_cache[collection_name] = coll.get_meta_data() - return self._meta_data_cache[collection_name] - - def _update_meta_data_cache(self, collection_name: str, coll: Collection): - """Update the cached meta_data after modifications.""" - meta_data = coll.get_meta_data() - self._meta_data_cache[collection_name] = meta_data + def _get_meta_data(self, coll: Collection) -> Dict[str, Any]: + if not self._meta_data_cache: + self._meta_data_cache = coll.get_meta_data() or {} + return self._meta_data_cache - @staticmethod - def _restore_uri_fields(record: Dict[str, Any]) -> Dict[str, Any]: - """Restore viking:// prefix on uri/parent_uri fields read from VikingDB. - - The volcengine backend sanitizes URIs to /path/ format on write; - this reverses that transformation so the rest of the system sees - the canonical viking:// scheme. Idempotent for values that - already carry the prefix (local/http backends). - """ - for key in ("uri", "parent_uri"): - val = record.get(key) - if isinstance(val, str) and not val.startswith("viking://"): - restored = val.strip("/") - if restored: - record[key] = f"viking://{restored}" - return record + def _refresh_meta_data(self, coll: Collection) -> None: + self._meta_data_cache = coll.get_meta_data() or {} + + def _filter_known_fields(self, data: Dict[str, Any]) -> Dict[str, Any]: + try: + coll = self._get_collection() + fields = self._get_meta_data(coll).get("Fields", []) + allowed = {item.get("FieldName") for item in fields} + return {k: v for k, v in data.items() if k in allowed and v is not None} + except Exception: + return data # ========================================================================= - # Collection/Table Management + # Collection Management (single collection) # ========================================================================= async def create_collection(self, name: str, schema: Dict[str, Any]) -> bool: - """ - Create a new collection. - - Args: - name: Collection name - schema: VikingVectorIndex collection metadata in the format: - { - "CollectionName": "name", - "Description": "description", - "Fields": [ - {"FieldName": "id", "FieldType": "string", "IsPrimaryKey": True}, - {"FieldName": "vector", "FieldType": "vector", "Dim": 128}, - ... - ] - } - - Returns: - True if created successfully, False if already exists - """ try: - if self.project.has_collection(name): - logger.debug(f"Collection '{name}' already exists") - return False - - collection_meta = schema.copy() - - scalar_index_fields = [] - if "ScalarIndex" in collection_meta: - scalar_index_fields = collection_meta.pop("ScalarIndex") - - # Ensure CollectionName is set - if "CollectionName" not in collection_meta: - collection_meta["CollectionName"] = name + collection_meta = dict(schema) - # Extract distance metric and vector_dim for config tracking - distance = self.distance_metric + # Track vector dim from schema for info. vector_dim = self.vector_dim for field in collection_meta.get("Fields", []): if field.get("FieldType") == "vector": vector_dim = field.get("Dim", self.vector_dim) break - logger.info(f"Creating collection mode={self._mode} with meta: {collection_meta}") - - # Create collection using vectordb project - collection = self.project.create_collection(name, collection_meta) - - # Filter date_time fields for volcengine and vikingdb backends - if self._mode in ["volcengine", "vikingdb"]: - date_time_fields = { - field.get("FieldName") - for field in collection_meta.get("Fields", []) - if field.get("FieldType") == "date_time" - } - scalar_index_fields = [ - field for field in scalar_index_fields if field not in date_time_fields - ] - - # Create default index for the collection - use_sparse = self.sparse_weight > 0.0 - index_type = "flat_hybrid" if use_sparse else "flat" - if self._mode in ["volcengine", "vikingdb"]: - index_type = "hnsw_hybrid" if use_sparse else "hnsw" - - index_meta = { - "IndexName": self.DEFAULT_INDEX_NAME, - "VectorIndex": { - "IndexType": index_type, - "Distance": distance, - "Quant": "int8", - }, - "ScalarIndex": scalar_index_fields, - } - if use_sparse: - index_meta["VectorIndex"]["EnableSparse"] = True - index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = self.sparse_weight - - logger.info(f"Creating index with meta: {index_meta}") - collection.create_index(self.DEFAULT_INDEX_NAME, index_meta) - - # Update cached meta_data after creating index - self._update_meta_data_cache(name, collection) + created = self._adapter.create_collection( + name=name, + schema=collection_meta, + distance=self.distance_metric, + sparse_weight=self.sparse_weight, + index_name=self.DEFAULT_INDEX_NAME, + ) + if not created: + return False - # Store collection config - self._collection_configs[name] = { + self._collection_name = name + self._collection_config = { "vector_dim": vector_dim, - "distance": distance, + "distance": self.distance_metric, "schema": schema, } - - logger.info(f"Created VikingDB collection: {name} (dim={vector_dim})") + self._refresh_meta_data(self._get_collection()) + logger.info("Created VikingDB collection: %s (dim=%s)", name, vector_dim) return True - except Exception as e: - logger.error(f"Error creating collection '{name}': {e}") - import traceback - - traceback.print_exc() + logger.error("Error creating collection %s: %s", name, e) return False - async def drop_collection(self, name: str) -> bool: - """Drop a collection.""" + async def drop_collection(self) -> bool: try: - if not self.project.has_collection(name): - logger.warning(f"Collection '{name}' does not exist") - return False - - self.project.drop_collection(name) - self._collection_configs.pop(name, None) - # Clear cached meta_data when dropping collection - self._meta_data_cache.pop(name, None) - - logger.info(f"Dropped collection: {name}") - return True + dropped = self._adapter.drop_collection() + if dropped: + self._collection_config = {} + self._meta_data_cache = {} + return dropped except Exception as e: - logger.error(f"Error dropping collection '{name}': {e}") + logger.error("Error dropping collection %s: %s", self._collection_name, e) return False - async def collection_exists(self, name: str) -> bool: - """Check if a collection exists.""" - return self.project.has_collection(name) + async def collection_exists(self) -> bool: + return self._adapter.collection_exists() - async def list_collections(self) -> List[str]: - """List all collection names.""" - return self.project.list_collections() - - async def get_collection_info(self, name: str) -> Optional[Dict[str, Any]]: - """Get collection metadata and statistics.""" - try: - if not self.project.has_collection(name): - return None - - config = self._collection_configs.get(name, {}) - - return { - "name": name, - "vector_dim": config.get("vector_dim", self.vector_dim), - "count": 0, # vectordb doesn't easily expose count - "status": "active", - } - except Exception as e: - logger.error(f"Error getting collection info for '{name}': {e}") + async def get_collection_info(self) -> Optional[Dict[str, Any]]: + if not await self.collection_exists(): return None + config = self._collection_config + return { + "name": self._collection_name, + "vector_dim": config.get("vector_dim", self.vector_dim), + "count": await self.count(), + "status": "active", + } + + async def collection_exists_bound(self) -> bool: + return await self.collection_exists() # ========================================================================= - # CRUD Operations - Single Record + # Data Operations # ========================================================================= - async def insert(self, collection: str, data: Dict[str, Any]) -> str: - """Insert a single record.""" - coll = self._get_collection(collection) - - # Ensure ID exists - record_id = data.get("id") - if not record_id: - record_id = str(uuid.uuid4()) - data = {**data, "id": record_id} - - # Validate context_type for context collection - if collection == "context": - context_type = data.get("context_type") - if context_type not in ["resource", "skill", "memory"]: - logger.warning( - f"Invalid context_type: {context_type}. " - f"Must be one of ['resource', 'skill', 'memory'], Ignore" - ) - return "" - - fields = self._get_meta_data(collection, coll).get("Fields", []) - fields_dict = {item["FieldName"]: item for item in fields} - new_data = {} - for k in data: - if k in fields_dict and data[k] is not None: - new_data[k] = data[k] - - try: - coll.upsert_data([new_data]) - return record_id - except Exception as e: - logger.error(f"Error inserting record: {e}") - raise - - async def update(self, collection: str, id: str, data: Dict[str, Any]) -> bool: - """Update a record by ID.""" - coll = self._get_collection(collection) - - try: - # Fetch existing record - existing = await self.get(collection, [id]) - if not existing: - return False - - # Merge data with existing record - updated_data = {**existing[0], **data} - updated_data["id"] = id - - # Upsert the updated record - coll.upsert_data([updated_data]) - return True - except Exception as e: - logger.error(f"Error updating record '{id}': {e}") - return False - - async def upsert(self, collection: str, data: Dict[str, Any]) -> str: - """Insert or update a record.""" - coll = self._get_collection(collection) - - record_id = data.get("id") - if not record_id: - record_id = str(uuid.uuid4()) - data = {**data, "id": record_id} - - try: - coll.upsert_data([data]) - return record_id - except Exception as e: - logger.error(f"Error upserting record: {e}") - raise + async def upsert(self, data: Dict[str, Any]) -> str: + payload = dict(data) + context_type = payload.get("context_type") + if context_type and context_type not in self.ALLOWED_CONTEXT_TYPES: + logger.warning( + "Invalid context_type: %s. Must be one of %s", + context_type, + sorted(self.ALLOWED_CONTEXT_TYPES), + ) + return "" - async def delete(self, collection: str, ids: List[str]) -> int: - """Delete records by IDs.""" - coll = self._get_collection(collection) + if not payload.get("id"): + payload["id"] = str(uuid.uuid4()) - try: - coll.delete_data(ids) - return len(ids) - except Exception as e: - logger.error(f"Error deleting records: {e}") - return 0 - - async def get(self, collection: str, ids: List[str]) -> List[Dict[str, Any]]: - """Get records by IDs.""" - coll = self._get_collection(collection) + payload = self._filter_known_fields(payload) + ids = self._adapter.upsert(payload) + return ids[0] if ids else "" + async def get(self, ids: List[str]) -> List[Dict[str, Any]]: try: - result = coll.fetch_data(ids) - - if isinstance(result, FetchDataInCollectionResult): - records = [] - for item in result.items: - record = dict(item.fields) if item.fields else {} - record["id"] = item.id - self._restore_uri_fields(record) - records.append(record) - return records - elif isinstance(result, dict): - records = [] - if "fetch" in result: - for item in result.get("fetch", []): - record = dict(item.get("fields", {})) if item.get("fields") else {} - record["id"] = item.get("id") - if record["id"]: - self._restore_uri_fields(record) - records.append(record) - return records - else: - logger.warning(f"Unexpected return type from fetch_data: {type(result)}") - return [] + return self._adapter.get(ids) except Exception as e: - logger.error(f"Error getting records: {e}") + logger.error("Error getting records: %s", e) return [] - async def fetch_by_uri(self, collection: str, uri: str) -> Optional[Dict[str, Any]]: - """Fetch a record by URI.""" - coll = self._get_collection(collection) + async def delete(self, ids: List[str]) -> int: try: - result = coll.search_by_random( - index_name=self.DEFAULT_INDEX_NAME, - limit=10, - filters={"op": "must", "field": "uri", "conds": [uri]}, - ) - records = [] - for item in result.data: - record = dict(item.fields) if item.fields else {} - record["id"] = item.id - self._restore_uri_fields(record) - records.append(record) - if len(records) > 1: - raise ValueError(f"Duplicate records found for URI: {uri}") - if len(records) == 0: - raise ValueError(f"Record not found for URI: {uri}") - return records[0] + return self._adapter.delete(ids=ids) except Exception as e: - logger.error(f"Error fetching record by URI '{uri}': {e}") - return None + logger.error("Error deleting records: %s", e) + return 0 - async def exists(self, collection: str, id: str) -> bool: - """Check if a record exists.""" + async def exists(self, id: str) -> bool: try: - results = await self.get(collection, [id]) - return len(results) > 0 + return len(await self.get([id])) > 0 except Exception: return False - # ========================================================================= - # CRUD Operations - Batch - # ========================================================================= - - async def batch_insert(self, collection: str, data: List[Dict[str, Any]]) -> List[str]: - """Batch insert multiple records.""" - coll = self._get_collection(collection) - - # Ensure all records have IDs - ids = [] - records_with_ids = [] - for record in data: - if "id" not in record: - record_id = str(uuid.uuid4()) - records_with_ids.append({**record, "id": record_id}) - ids.append(record_id) - else: - records_with_ids.append(record) - ids.append(record["id"]) - + async def fetch_by_uri(self, uri: str) -> Optional[Dict[str, Any]]: try: - coll.upsert_data(records_with_ids) - return ids + records = await self.query( + filter={"op": "must", "field": "uri", "conds": [uri]}, + limit=2, + ) + if len(records) == 1: + return records[0] + return None except Exception as e: - logger.error(f"Error batch inserting records: {e}") - raise - - async def batch_upsert(self, collection: str, data: List[Dict[str, Any]]) -> List[str]: - """Batch insert or update multiple records.""" - coll = self._get_collection(collection) - - ids = [] - records_with_ids = [] - for record in data: - if "id" not in record: - record_id = str(uuid.uuid4()) - records_with_ids.append({**record, "id": record_id}) - ids.append(record_id) - else: - records_with_ids.append(record) - ids.append(record["id"]) + logger.error("Error fetching record by URI %s: %s", uri, e) + return None + async def query( + self, + query_vector: Optional[List[float]] = None, + sparse_query_vector: Optional[Dict[str, float]] = None, + filter: Optional[Dict[str, Any] | FilterExpr] = None, + limit: int = 10, + offset: int = 0, + output_fields: Optional[List[str]] = None, + with_vector: bool = False, + order_by: Optional[str] = None, + order_desc: bool = False, + ) -> List[Dict[str, Any]]: try: - coll.upsert_data(records_with_ids) - return ids + return self._adapter.query( + query_vector=query_vector, + sparse_query_vector=sparse_query_vector, + filter=filter, + limit=limit, + offset=offset, + output_fields=output_fields, + with_vector=with_vector, + order_by=order_by, + order_desc=order_desc, + ) except Exception as e: - logger.error(f"Error batch upserting records: {e}") - raise - - async def batch_delete(self, collection: str, filter: Dict[str, Any]) -> int: - """Delete records matching filter conditions.""" - try: - # First, find matching records - matching_records = await self.filter(collection, filter, limit=10000) + logger.error("Error querying collection %s: %s", self._collection_name, e) + return [] - if not matching_records: - return 0 + async def search( + self, + query_vector: Optional[List[float]] = None, + sparse_query_vector: Optional[Dict[str, float]] = None, + filter: Optional[Dict[str, Any] | FilterExpr] = None, + limit: int = 10, + offset: int = 0, + output_fields: Optional[List[str]] = None, + with_vector: bool = False, + ) -> List[Dict[str, Any]]: + # Backward-compatible alias for internal call sites. + return await self.query( + query_vector=query_vector, + sparse_query_vector=sparse_query_vector, + filter=filter, + limit=limit, + offset=offset, + output_fields=output_fields, + with_vector=with_vector, + ) - # Extract IDs and delete - ids = [record["id"] for record in matching_records if "id" in record] - return await self.delete(collection, ids) - except Exception as e: - logger.error(f"Error batch deleting records: {e}") - return 0 + async def filter( + self, + filter: Dict[str, Any] | FilterExpr, + limit: int = 10, + offset: int = 0, + output_fields: Optional[List[str]] = None, + order_by: Optional[str] = None, + order_desc: bool = False, + ) -> List[Dict[str, Any]]: + return await self.query( + filter=filter, + limit=limit, + offset=offset, + output_fields=output_fields, + order_by=order_by, + order_desc=order_desc, + ) - async def remove_by_uri(self, collection: str, uri: str) -> int: - """Remove resource(s) by URI.""" + async def remove_by_uri(self, uri: str) -> int: try: target_records = await self.filter( - collection=collection, - filter={"op": "must", "field": "uri", "conds": [uri]}, + {"op": "must", "field": "uri", "conds": [uri]}, limit=10, ) - if not target_records: return 0 total_deleted = 0 - - # If any record indicates this URI is a directory node, remove descendants first. if any(r.get("level") in [0, 1] for r in target_records): - descendant_count = await self._remove_descendants(collection, uri) - total_deleted += descendant_count + total_deleted += await self._remove_descendants(parent_uri=uri) ids = [r.get("id") for r in target_records if r.get("id")] if ids: - total_deleted += await self.delete(collection, ids) - - logger.info(f"Removed {total_deleted} record(s) for URI: {uri}") + total_deleted += await self.delete(ids) return total_deleted - except Exception as e: - logger.error(f"Error removing URI '{uri}': {e}") + logger.error("Error removing URI %s: %s", uri, e) return 0 - async def _remove_descendants(self, collection: str, parent_uri: str) -> int: - """Recursively remove all descendants of a parent URI.""" + async def _remove_descendants(self, parent_uri: str) -> int: total_deleted = 0 - - # Find direct children children = await self.filter( - collection=collection, - filter={"op": "must", "field": "parent_uri", "conds": [parent_uri]}, - limit=10000, + {"op": "must", "field": "parent_uri", "conds": [parent_uri]}, + limit=100000, ) - for child in children: child_uri = child.get("uri") level = child.get("level", 2) - - # Recursively delete if child is also an intermediate directory if level in [0, 1] and child_uri: - descendant_count = await self._remove_descendants(collection, child_uri) - total_deleted += descendant_count - - # Delete the child - if "id" in child: - await self.delete(collection, [child["id"]]) + total_deleted += await self._remove_descendants(parent_uri=child_uri) + child_id = child.get("id") + if child_id: + await self.delete([child_id]) total_deleted += 1 - return total_deleted # ========================================================================= - # Search Operations + # Semantic Context Operations (Tenant-Aware) # ========================================================================= - async def search( + async def search_in_tenant( self, - collection: str, - query_vector: Optional[List[float]] = None, + ctx: RequestContext, + query_vector: Optional[List[float]], sparse_query_vector: Optional[Dict[str, float]] = None, - filter: Optional[Dict[str, Any]] = None, + context_type: Optional[str] = None, + target_directories: Optional[List[str]] = None, + extra_filter: Optional[FilterExpr | Dict[str, Any]] = None, limit: int = 10, offset: int = 0, - output_fields: Optional[List[str]] = None, - with_vector: bool = False, ) -> List[Dict[str, Any]]: - """Hybrid search: vector similarity (dense/sparse/hybrid) + scalar filtering. - - Args: - collection: Collection name, by default it should be "context" - query_vector: Dense query vector (optional) - sparse_query_vector: Sparse query vector as {term: weight} dict (optional) - filter: Scalar filter conditions - limit: Maximum number of results - offset: Offset for pagination - output_fields: Fields to return - with_vector: Whether to include vector field in results - - Returns: - List of matching records with scores - """ - coll = self._get_collection(collection) - - try: - # Filter is already in vectordb DSL format - vectordb_filter = filter if filter else {} - - if query_vector or sparse_query_vector: - # Vector search (dense, sparse, or hybrid) with optional filtering - result = coll.search_by_vector( - index_name=self.DEFAULT_INDEX_NAME, - dense_vector=query_vector, - sparse_vector=sparse_query_vector, - limit=limit, - offset=offset, - filters=vectordb_filter, - output_fields=output_fields, - ) - - # Convert results - records = [] - for item in result.data: - record = dict(item.fields) if item.fields else {} - record["id"] = item.id - record["_score"] = item.score if item.score is not None else 0.0 - self._restore_uri_fields(record) - - if not with_vector: - if "vector" in record: - record.pop("vector") - if "sparse_vector" in record: - record.pop("sparse_vector") - - records.append(record) - - return records - else: - # Pure filtering without vector search - return await self.filter(collection, filter or {}, limit, offset, output_fields) - - except Exception as e: - logger.error(f"Error searching collection '{collection}': {e}") - import traceback + scope_filter = self._build_scope_filter( + ctx=ctx, + context_type=context_type, + target_directories=target_directories, + extra_filter=extra_filter, + ) + return await self.search( + query_vector=query_vector, + sparse_query_vector=sparse_query_vector, + filter=scope_filter, + limit=limit, + offset=offset, + ) - traceback.print_exc() + async def search_global_roots_in_tenant( + self, + ctx: RequestContext, + query_vector: Optional[List[float]], + sparse_query_vector: Optional[Dict[str, float]] = None, + context_type: Optional[str] = None, + target_directories: Optional[List[str]] = None, + extra_filter: Optional[FilterExpr | Dict[str, Any]] = None, + limit: int = 10, + ) -> List[Dict[str, Any]]: + if not query_vector: return [] - async def filter( + merged_filter = self._merge_filters( + self._build_scope_filter( + ctx=ctx, + context_type=context_type, + target_directories=target_directories, + extra_filter=extra_filter, + ), + In("level", [0, 1]), + ) + return await self.search( + query_vector=query_vector, + sparse_query_vector=sparse_query_vector, + filter=merged_filter, + limit=limit, + ) + + async def search_children_in_tenant( self, - collection: str, - filter: Dict[str, Any], + ctx: RequestContext, + parent_uri: str, + query_vector: Optional[List[float]], + sparse_query_vector: Optional[Dict[str, float]] = None, + context_type: Optional[str] = None, + target_directories: Optional[List[str]] = None, + extra_filter: Optional[FilterExpr | Dict[str, Any]] = None, limit: int = 10, - offset: int = 0, - output_fields: Optional[List[str]] = None, - order_by: Optional[str] = None, - order_desc: bool = False, ) -> List[Dict[str, Any]]: - """Pure scalar filtering without vector search.""" - coll = self._get_collection(collection) + merged_filter = self._merge_filters( + Eq("parent_uri", parent_uri), + self._build_scope_filter( + ctx=ctx, + context_type=context_type, + target_directories=target_directories, + extra_filter=extra_filter, + ), + ) + return await self.search( + query_vector=query_vector, + sparse_query_vector=sparse_query_vector, + filter=merged_filter, + limit=limit, + ) + + async def search_similar_memories( + self, + account_id: str, + owner_space: Optional[str], + category_uri_prefix: str, + query_vector: List[float], + limit: int = 5, + ) -> List[Dict[str, Any]]: + conds: List[FilterExpr] = [ + Eq("context_type", "memory"), + Eq("level", 2), + Eq("account_id", account_id), + ] + if owner_space: + conds.append(Eq("owner_space", owner_space)) + if category_uri_prefix: + conds.append(In("uri", [category_uri_prefix])) + + return await self.search( + query_vector=query_vector, + filter=And(conds), + limit=limit, + ) - try: - # Filter is already in vectordb DSL format - vectordb_filter = filter if filter else {} - - if order_by: - # Use search_by_scalar for sorting - result = coll.search_by_scalar( - index_name=self.DEFAULT_INDEX_NAME, - field=order_by, - order="desc" if order_desc else "asc", - limit=limit, - offset=offset, - filters=vectordb_filter, - output_fields=output_fields, + async def get_context_by_uri( + self, + account_id: str, + uri: str, + owner_space: Optional[str] = None, + limit: int = 1, + ) -> List[Dict[str, Any]]: + conds: List[FilterExpr] = [Eq("uri", uri), Eq("account_id", account_id)] + if owner_space: + conds.append(Eq("owner_space", owner_space)) + return await self.filter(filter=And(conds), limit=limit) + + async def delete_account_data(self, account_id: str) -> int: + return self._adapter.delete(filter=Eq("account_id", account_id)) + + async def delete_uris(self, ctx: RequestContext, uris: List[str]) -> None: + for uri in uris: + conds: List[FilterExpr] = [ + Eq("account_id", ctx.account_id), + Or([Eq("uri", uri), In("uri", [f"{uri}/"])]), + ] + if ctx.role == Role.USER and uri.startswith(("viking://user/", "viking://agent/")): + owner_space = ( + ctx.user.user_space_name() + if uri.startswith("viking://user/") + else ctx.user.agent_space_name() ) + conds.append(Eq("owner_space", owner_space)) + self._adapter.delete(filter=And(conds)) + + async def update_uri_mapping( + self, + ctx: RequestContext, + uri: str, + new_uri: str, + new_parent_uri: str, + ) -> bool: + records = await self.filter( + filter=And([Eq("uri", uri), Eq("account_id", ctx.account_id)]), + limit=1, + ) + if not records or "id" not in records[0]: + return False + updated = {**records[0], "uri": new_uri, "parent_uri": new_parent_uri} + return bool(await self.upsert(updated)) + + async def increment_active_count(self, ctx: RequestContext, uris: List[str]) -> int: + updated = 0 + for uri in uris: + records = await self.get_context_by_uri(account_id=ctx.account_id, uri=uri, limit=1) + if not records: + continue + record = records[0] + current = int(record.get("active_count", 0) or 0) + record["active_count"] = current + 1 + if await self.upsert(record): + updated += 1 + return updated + + def _build_scope_filter( + self, + ctx: RequestContext, + context_type: Optional[str], + target_directories: Optional[List[str]], + extra_filter: Optional[FilterExpr | Dict[str, Any]], + ) -> Optional[FilterExpr]: + filters: List[FilterExpr] = [] + if context_type: + filters.append(Eq("context_type", context_type)) + + tenant_filter = self._tenant_filter(ctx, context_type=context_type) + if tenant_filter: + filters.append(tenant_filter) + + if target_directories: + uri_conds = [In("uri", [target_dir]) for target_dir in target_directories if target_dir] + if uri_conds: + filters.append(Or(uri_conds)) + + if extra_filter: + if isinstance(extra_filter, dict): + filters.append(RawDSL(extra_filter)) else: - # Use search_by_random for pure filtering - result = coll.search_by_random( - index_name=self.DEFAULT_INDEX_NAME, - limit=limit, - offset=offset, - filters=vectordb_filter, - output_fields=output_fields, - ) + filters.append(extra_filter) - # Convert results - records = [] - for item in result.data: - record = dict(item.fields) if item.fields else {} - record["id"] = item.id - self._restore_uri_fields(record) - records.append(record) + return self._merge_filters(*filters) - return records + @staticmethod + def _tenant_filter( + ctx: RequestContext, context_type: Optional[str] = None + ) -> Optional[FilterExpr]: + if ctx.role == Role.ROOT: + return None - except Exception as e: - logger.error(f"Error filtering collection '{collection}': {e}") - import traceback + owner_spaces = [ctx.user.user_space_name(), ctx.user.agent_space_name()] + if context_type == "resource": + owner_spaces.append("") + return And([Eq("account_id", ctx.account_id), In("owner_space", owner_spaces)]) - traceback.print_exc() - return [] + @staticmethod + def _merge_filters(*filters: Optional[FilterExpr]) -> Optional[FilterExpr]: + non_empty = [f for f in filters if f] + if not non_empty: + return None + if len(non_empty) == 1: + return non_empty[0] + return And(non_empty) async def scroll( self, - collection: str, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Dict[str, Any] | FilterExpr] = None, limit: int = 100, cursor: Optional[str] = None, output_fields: Optional[List[str]] = None, - ) -> Tuple[List[Dict[str, Any]], Optional[str]]: - """Scroll through large result sets efficiently.""" - # vectordb doesn't natively support scroll, so we simulate it + ) -> tuple[List[Dict[str, Any]], Optional[str]]: offset = int(cursor) if cursor else 0 - records = await self.filter( - collection=collection, filter=filter or {}, limit=limit, offset=offset, output_fields=output_fields, ) - - # Return next cursor if we got a full batch next_cursor = str(offset + limit) if len(records) == limit else None - return records, next_cursor - # ========================================================================= - # Aggregation Operations - # ========================================================================= - - async def count(self, collection: str, filter: Optional[Dict[str, Any]] = None) -> int: - """Count records matching filter.""" + async def count(self, filter: Optional[Dict[str, Any] | FilterExpr] = None) -> int: try: - coll = self._get_collection(collection) - result = coll.aggregate_data( - index_name=self.DEFAULT_INDEX_NAME, op="count", filters=filter - ) - return result.agg.get("_total", 0) + return self._adapter.count(filter=filter) except Exception as e: - logger.error(f"Error counting records: {e}") + logger.error("Error counting records: %s", e) return 0 - # ========================================================================= - # Index Operations - # ========================================================================= - - async def create_index( - self, - collection: str, - field: str, - index_type: str, - **kwargs, - ) -> bool: - """Create an index on a field.""" + async def clear(self) -> bool: try: - # vectordb manages indexes at collection level - # Indexes are already created with the collection - logger.info(f"Index creation requested for field '{field}' (managed by vectordb)") - return True + return self._adapter.clear() except Exception as e: - logger.error(f"Error creating index on '{field}': {e}") + logger.error("Error clearing collection: %s", e) return False - async def drop_index(self, collection: str, field: str) -> bool: - """Drop an index on a field.""" - try: - # vectordb manages indexes internally - logger.info(f"Index drop requested for field '{field}' (managed by vectordb)") - return True - except Exception as e: - logger.error(f"Error dropping index on '{field}': {e}") - return False - - # ========================================================================= - # Lifecycle Operations - # ========================================================================= - - async def clear(self, collection: str) -> bool: - """Clear all data in a collection.""" - coll = self._get_collection(collection) - - try: - coll.delete_all_data() - logger.info(f"Cleared all data in collection: {collection}") - return True - except Exception as e: - logger.error(f"Error clearing collection: {e}") - return False - - async def optimize(self, collection: str) -> bool: - """Optimize collection for better performance.""" - try: - # vectordb handles optimization internally via index rebuilding - logger.info(f"Optimization requested for collection: {collection}") - return True - except Exception as e: - logger.error(f"Error optimizing collection: {e}") - return False + async def optimize(self) -> bool: + logger.info("Optimization requested for collection: %s", self._collection_name) + return True async def close(self) -> None: - """Close storage connection and release resources.""" try: - if self.project: - self.project.close() - - self._collection_configs.clear() + self._adapter.close() + self._collection_config = {} + self._meta_data_cache = {} logger.info("VikingDB backend closed") except Exception as e: - logger.error(f"Error closing VikingDB backend: {e}") - - # ========================================================================= - # Health & Status - # ========================================================================= + logger.error("Error closing VikingDB backend: %s", e) async def health_check(self) -> bool: - """Check if storage backend is healthy and accessible.""" try: - # Simple check: verify we can access the project - self.project.list_collections() + await self.collection_exists() return True except Exception: return False async def get_stats(self) -> Dict[str, Any]: - """Get storage statistics.""" try: - collections = self.project.list_collections() - - # Count total records across all collections using aggregate_data - total_records = 0 - for collection_name in collections: - try: - coll = self._get_collection(collection_name) - result = coll.aggregate_data( - index_name=self.DEFAULT_INDEX_NAME, op="count", filters=None - ) - total_records += result.agg.get("_total", 0) - except Exception as e: - logger.warning(f"Error counting records in collection '{collection_name}': {e}") - continue - + exists = await self.collection_exists() + total_records = await self.count() if exists else 0 return { - "collections": len(collections), + "collections": 1 if exists else 0, "total_records": total_records, "backend": "vikingdb", "mode": self._mode, } except Exception as e: - logger.error(f"Error getting stats: {e}") + logger.error("Error getting stats: %s", e) return { "collections": 0, "total_records": 0, @@ -924,5 +586,4 @@ async def get_stats(self) -> Dict[str, Any]: @property def mode(self) -> str: - """Return the current storage mode.""" return self._mode diff --git a/openviking/storage/vikingdb_interface.py b/openviking/storage/vikingdb_interface.py deleted file mode 100644 index f11e8116..00000000 --- a/openviking/storage/vikingdb_interface.py +++ /dev/null @@ -1,589 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -""" -Storage interface for OpenViking. - -Defines the abstract storage interface inspired by vector database designs -(Milvus/Qdrant). All storage backends must implement this interface. -""" - -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple - - -class VikingDBInterface(ABC): - """ - Abstract vector indexing interface for OpenViking. - - This interface defines all vector indexing capabilities required by OpenViking. - New vector indexing backends should implement this interface to ensure compatibility. - - Capabilities: - - Collection management - - CRUD operations (single and batch) - - Vector similarity search - - Scalar filtering - - Index management - - Lifecycle management - """ - - # ========================================================================= - # Collection Management - # ========================================================================= - - @abstractmethod - async def create_collection(self, name: str, schema: Dict[str, Any]) -> bool: - """ - Create a new collection.` - - Args: - name: Collection name (e.g., "memory", "resource", "skill") - schema: Schema definition including: - - vector_dim: int - Vector dimension (default: 2048) - - distance: str - Distance metric ("cosine", "euclid", "dot") - - fields: List[dict] - Field definitions with name, type, indexed - - Returns: - True if created successfully, False if already exists - - Example: - schema = { - "vector_dim": 2048, - "distance": "cosine", - "fields": [ - {"name": "uri", "type": "string", "indexed": True}, - {"name": "abstract", "type": "text"}, - {"name": "active_count", "type": "integer"}, - ] - } - """ - pass - - @abstractmethod - async def drop_collection(self, name: str) -> bool: - """ - Drop a collection/table. - - Args: - name: Collection name - - Returns: - True if dropped successfully, False otherwise - """ - pass - - @abstractmethod - async def collection_exists(self, name: str) -> bool: - """ - Check if a collection exists. - - Args: - name: Collection name - - Returns: - True if exists, False otherwise - """ - pass - - @abstractmethod - async def list_collections(self) -> List[str]: - """ - List all collection names. - - Returns: - List of collection names - """ - pass - - @abstractmethod - async def get_collection_info(self, name: str) -> Optional[Dict[str, Any]]: - """ - Get collection metadata and statistics. - - Args: - name: Collection name - - Returns: - Dictionary with collection info: - - name: str - - vector_dim: int - - count: int - - status: str - Returns None if collection doesn't exist - """ - pass - - # ========================================================================= - # CRUD Operations - Single Record - # ========================================================================= - - @abstractmethod - async def insert(self, collection: str, data: Dict[str, Any]) -> str: - """ - Insert a single record. - - Args: - collection: Collection name - data: Record data. Must include: - - id: str (optional, auto-generated if not provided) - - vector: List[float] (optional) - - Other payload fields - - Returns: - ID of the inserted record - """ - pass - - @abstractmethod - async def update(self, collection: str, id: str, data: Dict[str, Any]) -> bool: - """ - Update a record by ID. - - Args: - collection: Collection name - id: Record ID - data: Fields to update (can include vector) - - Returns: - True if updated successfully, False if not found - """ - pass - - @abstractmethod - async def upsert(self, collection: str, data: Dict[str, Any]) -> str: - """ - Insert or update a record. - - If record with same ID exists, update it. Otherwise insert new record. - - Args: - collection: Collection name - data: Record data with id field - - Returns: - ID of the upserted record - """ - pass - - @abstractmethod - async def delete(self, collection: str, ids: List[str]) -> int: - """ - Delete records by IDs. - - Args: - collection: Collection name - ids: List of record IDs to delete - - Returns: - Number of records deleted - """ - pass - - @abstractmethod - async def get(self, collection: str, ids: List[str]) -> List[Dict[str, Any]]: - """ - Get records by IDs. - - Args: - collection: Collection name - ids: List of record IDs - - Returns: - List of records (may be fewer than requested if some IDs not found) - """ - pass - - @abstractmethod - async def exists(self, collection: str, id: str) -> bool: - """ - Check if a record exists. - - Args: - collection: Collection name - id: Record ID - - Returns: - True if exists, False otherwise - """ - pass - - # ========================================================================= - # CRUD Operations - Batch - # ========================================================================= - - @abstractmethod - async def batch_insert(self, collection: str, data: List[Dict[str, Any]]) -> List[str]: - """ - Batch insert multiple records. - - Args: - collection: Collection name - data: List of records - - Returns: - List of IDs of inserted records - """ - pass - - @abstractmethod - async def batch_upsert(self, collection: str, data: List[Dict[str, Any]]) -> List[str]: - """ - Batch insert or update multiple records. - - Args: - collection: Collection name - data: List of records with id fields - - Returns: - List of IDs of upserted records - """ - pass - - @abstractmethod - async def batch_delete(self, collection: str, filter: Dict[str, Any]) -> int: - """ - Delete records matching filter conditions. - - Args: - collection: Collection name - filter: Filter conditions - - Returns: - Number of records deleted - """ - pass - - @abstractmethod - async def remove_by_uri( - self, - collection: str, - uri: str, - ) -> int: - """ - Remove resource(s) by URI. - - If the URI points to a directory, removes all descendants first, - then removes the directory itself. - - Args: - collection: Collection name - uri: URI to remove (e.g., "viking://resources/references/doc_name") - - Returns: - Number of records removed - - Example: - # Remove a single context - await storage.remove_by_uri("resource", "viking://resources/ref/doc/section1") - - # Remove entire document tree (directory + all children) - await storage.remove_by_uri("resource", "viking://resources/ref/doc_name") - """ - pass - - # ========================================================================= - # Search Operations - # ========================================================================= - - @abstractmethod - async def search( - self, - collection: str, - query_vector: Optional[List[float]] = None, - sparse_query_vector: Optional[Dict[str, float]] = None, - filter: Optional[Dict[str, Any]] = None, - limit: int = 10, - offset: int = 0, - output_fields: Optional[List[str]] = None, - with_vector: bool = False, - ) -> List[Dict[str, Any]]: - """ - Hybrid search: vector similarity + scalar filtering + sparse vector matching. - - Args: - collection: Collection name - query_vector: Dense query vector for similarity search (optional) - sparse_query_vector: Sparse query vector for term matching (optional, Dict[str, float]) - filter: Scalar filter conditions (optional) - limit: Maximum number of results - offset: Offset for pagination - output_fields: Fields to return (None for all) - with_vector: Include vector in results - - Returns: - List of matching records. If query_vector provided, includes _score field. - - Notes: - - If both query_vector and sparse_query_vector are provided, performs hybrid search - - If only query_vector is provided, performs dense vector search - - If only sparse_query_vector is provided, performs sparse search - - If neither is provided, performs filter-only search - - Filter format (VikingVectorIndex DSL): - { - "op": "and" | "or", - "conds": [ - {"op": "must", "field": "name", "conds": [value]}, - {"op": "range", "field": "age", "gte": 18, "lt": 65}, - {"op": "must", "field": "uri", "conds": ["viking://"]}, - {"op": "contains", "field": "desc", "substring": "hello"} - ] - } - - Example: - # Dense search - results = await storage.search( - collection="context", - query_vector=embedding, - filter={ - "op": "and", - "conds": [ - {"op": "must", "field": "uri", "conds": ["viking://user"]}, - {"op": "range", "field": "active_count", "gte": 1} - ] - }, - limit=10 - ) - """ - pass - - @abstractmethod - async def filter( - self, - collection: str, - filter: Dict[str, Any], - limit: int = 10, - offset: int = 0, - output_fields: Optional[List[str]] = None, - order_by: Optional[str] = None, - order_desc: bool = False, - ) -> List[Dict[str, Any]]: - """ - Pure scalar filtering without vector search. - - Args: - collection: Collection name - filter: Filter conditions - limit: Maximum number of results - offset: Offset for pagination - output_fields: Fields to return - order_by: Field to sort by (optional) - order_desc: Sort descending if True - - Returns: - List of matching records - """ - pass - - @abstractmethod - async def scroll( - self, - collection: str, - filter: Optional[Dict[str, Any]] = None, - limit: int = 100, - cursor: Optional[str] = None, - output_fields: Optional[List[str]] = None, - ) -> Tuple[List[Dict[str, Any]], Optional[str]]: - """ - Scroll through large result sets efficiently. - - Args: - collection: Collection name - filter: Optional filter conditions - limit: Batch size - cursor: Cursor from previous scroll (None for first batch) - output_fields: Fields to return - - Returns: - Tuple of (records, next_cursor). next_cursor is None when exhausted. - - Example: - cursor = None - while True: - records, cursor = await storage.scroll( - "memory", limit=100, cursor=cursor - ) - process(records) - if cursor is None: - break - """ - pass - - # ========================================================================= - # Aggregation Operations - # ========================================================================= - - @abstractmethod - async def count(self, collection: str, filter: Optional[Dict[str, Any]] = None) -> int: - """ - Count records matching filter. - - Args: - collection: Collection name - filter: Optional filter conditions - - Returns: - Number of matching records - """ - pass - - # ========================================================================= - # Index Operations - # ========================================================================= - - @abstractmethod - async def create_index( - self, - collection: str, - field: str, - index_type: str, - **kwargs, - ) -> bool: - """ - Create an index on a field. - - Args: - collection: Collection name - field: Field name to index - index_type: Index type: - - "keyword": Exact match index - - "text": Full-text search index - - "integer": Numeric range index - - "float": Numeric range index - - "bool": Boolean index - **kwargs: Additional index parameters - - Returns: - True if created successfully - """ - pass - - @abstractmethod - async def drop_index(self, collection: str, field: str) -> bool: - """ - Drop an index on a field. - - Args: - collection: Collection name - field: Field name - - Returns: - True if dropped successfully - """ - pass - - # ========================================================================= - # Lifecycle Operations - # ========================================================================= - - @abstractmethod - async def clear(self, collection: str) -> bool: - """ - Clear all data in a collection (keep schema). - - Args: - collection: Collection name - - Returns: - True if cleared successfully - """ - pass - - @abstractmethod - async def optimize(self, collection: str) -> bool: - """ - Optimize collection for better performance. - - Triggers index optimization, compaction, etc. - - Args: - collection: Collection name - - Returns: - True if optimization completed - """ - pass - - @abstractmethod - async def close(self) -> None: - """ - Close storage connection and release resources. - - Should be called when storage is no longer needed. - """ - pass - - # ========================================================================= - # Health & Status - # ========================================================================= - - @abstractmethod - async def health_check(self) -> bool: - """ - Check if storage backend is healthy and accessible. - - Returns: - True if healthy, False otherwise - """ - pass - - @abstractmethod - async def get_stats(self) -> Dict[str, Any]: - """ - Get storage statistics. - - Returns: - Dictionary with stats: - - collections: int - Number of collections - - total_records: int - Total record count - - storage_size: int - Storage size in bytes (if available) - - backend: str - Backend type identifier - """ - pass - - -# ============================================================================= -# Exceptions -# ============================================================================= - - -class VikingDBException(Exception): - """Base exception for VikingDB operations.""" - - pass - - -class StorageException(VikingDBException): - """Legacy alias for VikingDBException for backward compatibility.""" - - pass - - -class CollectionNotFoundError(StorageException): - """Raised when a collection does not exist.""" - - pass - - -class RecordNotFoundError(StorageException): - """Raised when a record does not exist.""" - - pass - - -class DuplicateKeyError(StorageException): - """Raised when trying to insert a duplicate key.""" - - pass - - -class ConnectionError(StorageException): - """Raised when storage connection fails.""" - - pass - - -class SchemaError(StorageException): - """Raised when schema validation fails.""" - - pass diff --git a/openviking_cli/utils/config/vectordb_config.py b/openviking_cli/utils/config/vectordb_config.py index 4052e0c4..0a78877b 100644 --- a/openviking_cli/utils/config/vectordb_config.py +++ b/openviking_cli/utils/config/vectordb_config.py @@ -4,8 +4,11 @@ from pydantic import BaseModel, Field, model_validator +from openviking_cli.utils.logger import get_logger + COLLECTION_NAME = "context" DEFAULT_PROJECT_NAME = "default" +logger = get_logger(__name__) class VolcengineConfig(BaseModel): @@ -16,7 +19,13 @@ class VolcengineConfig(BaseModel): region: Optional[str] = Field( default=None, description="Volcengine region (e.g., 'cn-beijing')" ) - host: Optional[str] = Field(default=None, description="Volcengine VikingDB host (optional)") + host: Optional[str] = Field( + default=None, + description=( + "[Deprecated] Ignored in volcengine mode. " + "Hosts are derived from `region` to route console/data APIs correctly." + ), + ) model_config = {"extra": "forbid"} @@ -112,6 +121,12 @@ def validate_config(self): raise ValueError("VectorDB volcengine backend requires 'ak' and 'sk' to be set") if not self.volcengine.region: raise ValueError("VectorDB volcengine backend requires 'region' to be set") + if self.volcengine.host: + logger.warning( + "VectorDB volcengine backend: 'volcengine.host' is deprecated and ignored. " + "Using region-based console/data hosts for region='%s'.", + self.volcengine.region, + ) elif self.backend == "vikingdb": if not self.vikingdb or not self.vikingdb.host: diff --git a/tests/retrieve/test_hierarchical_retriever_target_dirs.py b/tests/retrieve/test_hierarchical_retriever_target_dirs.py index a3f53d3d..019328bb 100644 --- a/tests/retrieve/test_hierarchical_retriever_target_dirs.py +++ b/tests/retrieve/test_hierarchical_retriever_target_dirs.py @@ -6,55 +6,77 @@ import pytest from openviking.retrieve.hierarchical_retriever import HierarchicalRetriever +from openviking.server.identity import RequestContext, Role from openviking_cli.retrieve.types import ContextType, TypedQuery +from openviking_cli.session.user_id import UserIdentifier class DummyStorage: """Minimal storage stub to capture search filters.""" def __init__(self) -> None: - self.search_calls = [] + self.collection_name = "context" + self.global_search_calls = [] + self.child_search_calls = [] - async def collection_exists(self, _name: str) -> bool: + async def collection_exists_bound(self) -> bool: return True - async def search( + async def search_global_roots_in_tenant( self, - collection: str, + ctx, query_vector=None, sparse_query_vector=None, - filter=None, + context_type=None, + target_directories=None, + extra_filter=None, limit: int = 10, - offset: int = 0, - output_fields=None, - with_vector: bool = False, ): - self.search_calls.append( + self.global_search_calls.append( { - "collection": collection, - "filter": filter, + "ctx": ctx, + "query_vector": query_vector, + "sparse_query_vector": sparse_query_vector, + "context_type": context_type, + "target_directories": target_directories, + "extra_filter": extra_filter, "limit": limit, - "offset": offset, } ) return [] - -def _contains_prefix_filter(obj, prefix: str) -> bool: - if isinstance(obj, dict): - if obj.get("op") == "prefix" and obj.get("field") == "uri" and obj.get("prefix") == prefix: - return True - return any(_contains_prefix_filter(v, prefix) for v in obj.values()) - if isinstance(obj, list): - return any(_contains_prefix_filter(v, prefix) for v in obj) - return False + async def search_children_in_tenant( + self, + ctx, + parent_uri: str, + query_vector=None, + sparse_query_vector=None, + context_type=None, + target_directories=None, + extra_filter=None, + limit: int = 10, + ): + self.child_search_calls.append( + { + "ctx": ctx, + "parent_uri": parent_uri, + "query_vector": query_vector, + "sparse_query_vector": sparse_query_vector, + "context_type": context_type, + "target_directories": target_directories, + "extra_filter": extra_filter, + "limit": limit, + } + ) + return [] @pytest.mark.asyncio -async def test_retrieve_honors_target_directories_prefix_filter(): +async def test_retrieve_honors_target_directories_scope_filter(): target_uri = "viking://resources/foo" storage = DummyStorage() retriever = HierarchicalRetriever(storage=storage, embedder=None, rerank_config=None) + ctx = RequestContext(user=UserIdentifier("acc1", "user1", "agent1"), role=Role.USER) query = TypedQuery( query="test", @@ -63,8 +85,11 @@ async def test_retrieve_honors_target_directories_prefix_filter(): target_directories=[target_uri], ) - result = await retriever.retrieve(query, limit=3) + result = await retriever.retrieve(query, ctx=ctx, limit=3) assert result.searched_directories == [target_uri] - assert storage.search_calls - assert _contains_prefix_filter(storage.search_calls[0]["filter"], target_uri) + assert storage.global_search_calls + assert storage.global_search_calls[0]["target_directories"] == [target_uri] + assert storage.child_search_calls + assert storage.child_search_calls[0]["target_directories"] == [target_uri] + assert storage.child_search_calls[0]["parent_uri"] == target_uri diff --git a/tests/session/test_memory_dedup_actions.py b/tests/session/test_memory_dedup_actions.py new file mode 100644 index 00000000..8b6e886e --- /dev/null +++ b/tests/session/test_memory_dedup_actions.py @@ -0,0 +1,558 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openviking.core.context import Context +from openviking.message import Message +from openviking.server.identity import RequestContext, Role +from openviking.session.compressor import SessionCompressor +from openviking.session.memory_deduplicator import ( + DedupDecision, + DedupResult, + ExistingMemoryAction, + MemoryActionDecision, + MemoryDeduplicator, +) +from openviking.session.memory_extractor import ( + CandidateMemory, + MemoryCategory, + MemoryExtractor, + MergedMemoryPayload, +) +from openviking_cli.session.user_id import UserIdentifier + + +class _DummyVikingDB: + def __init__(self): + self._embedder = None + + def get_embedder(self): + return self._embedder + + +class _DummyEmbedResult: + def __init__(self, dense_vector): + self.dense_vector = dense_vector + + +class _DummyEmbedder: + def embed(self, _text): + return _DummyEmbedResult([0.1, 0.2, 0.3]) + + +def _make_user() -> UserIdentifier: + return UserIdentifier("acc1", "test_user", "test_agent") + + +def _make_ctx() -> RequestContext: + return RequestContext(user=_make_user(), role=Role.USER) + + +def _make_candidate() -> CandidateMemory: + return CandidateMemory( + category=MemoryCategory.PREFERENCES, + abstract="User prefers concise summaries", + overview="User asks for concise answers frequently.", + content="The user prefers concise summaries over long explanations.", + source_session="session_test", + user=_make_user(), + language="en", + ) + + +def _make_existing(uri_suffix: str = "existing.md") -> Context: + user_space = _make_user().user_space_name() + return Context( + uri=f"viking://user/{user_space}/memories/preferences/{uri_suffix}", + parent_uri=f"viking://user/{user_space}/memories/preferences", + is_leaf=True, + abstract="Existing preference memory", + context_type="memory", + category="preferences", + ) + + +class TestMemoryDeduplicatorPayload: + def test_create_with_empty_list_is_valid(self): + dedup = MemoryDeduplicator(vikingdb=_DummyVikingDB()) + existing = [_make_existing("a.md")] + + decision, _, actions = dedup._parse_decision_payload( + {"decision": "create", "reason": "new memory", "list": []}, + existing, + ) + + assert decision == DedupDecision.CREATE + assert actions == [] + + def test_create_with_merge_is_normalized_to_none(self): + dedup = MemoryDeduplicator(vikingdb=_DummyVikingDB()) + existing = [_make_existing("b.md")] + + decision, _, actions = dedup._parse_decision_payload( + { + "decision": "create", + "list": [{"uri": existing[0].uri, "decide": "merge"}], + }, + existing, + ) + + assert decision == DedupDecision.NONE + assert len(actions) == 1 + assert actions[0].decision == MemoryActionDecision.MERGE + + def test_skip_drops_list_actions(self): + dedup = MemoryDeduplicator(vikingdb=_DummyVikingDB()) + existing = [_make_existing("c.md")] + + decision, _, actions = dedup._parse_decision_payload( + { + "decision": "skip", + "list": [{"uri": existing[0].uri, "decide": "delete"}], + }, + existing, + ) + + assert decision == DedupDecision.SKIP + assert actions == [] + + def test_cross_facet_delete_actions_are_kept(self): + dedup = MemoryDeduplicator(vikingdb=_DummyVikingDB()) + food = _make_existing("food.md") + food.abstract = "饮食偏好: 喜欢吃苹果和草莓" + routine = _make_existing("routine.md") + routine.abstract = "作息习惯: 每天早上7点起床" + existing = [food, routine] + candidate = _make_candidate() + candidate.abstract = "饮食偏好: 不再喜欢吃水果" + candidate.content = "用户不再喜欢吃水果,需要作废过去的水果偏好。" + + decision, _, actions = dedup._parse_decision_payload( + { + "decision": "create", + "list": [ + {"uri": food.uri, "decide": "delete"}, + {"uri": routine.uri, "decide": "delete"}, + ], + }, + existing, + candidate, + ) + + assert decision == DedupDecision.CREATE + assert len(actions) == 2 + assert {a.memory.uri for a in actions} == {food.uri, routine.uri} + assert all(a.decision == MemoryActionDecision.DELETE for a in actions) + + @pytest.mark.asyncio + async def test_find_similar_memories_uses_path_must_filter_and__score(self): + existing = _make_existing("pref_hit.md") + + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = _DummyEmbedder() + vikingdb.search_similar_memories = AsyncMock( + return_value=[ + { + "id": "uri_pref_hit", + "uri": existing.uri, + "context_type": "memory", + "level": 2, + "account_id": "acc1", + "owner_space": _make_user().user_space_name(), + "abstract": existing.abstract, + "category": "preferences", + "_score": 0.82, + } + ] + ) + dedup = MemoryDeduplicator(vikingdb=vikingdb) + candidate = _make_candidate() + + similar = await dedup._find_similar_memories(candidate) + + assert len(similar) == 1 + assert similar[0].uri == existing.uri + call = vikingdb.search_similar_memories.await_args.kwargs + assert call["account_id"] == "acc1" + assert call["owner_space"] == _make_user().user_space_name() + assert call["category_uri_prefix"] == ( + f"viking://user/{_make_user().user_space_name()}/memories/preferences/" + ) + assert call["limit"] == 5 + + @pytest.mark.asyncio + async def test_find_similar_memories_accepts_low_score_when_threshold_is_zero(self): + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = _DummyEmbedder() + vikingdb.search_similar_memories = AsyncMock( + return_value=[ + { + "id": "uri_low", + "uri": f"viking://user/{_make_user().user_space_name()}/memories/preferences/low.md", + "context_type": "memory", + "level": 2, + "account_id": "acc1", + "owner_space": _make_user().user_space_name(), + "abstract": "low", + "_score": 0.68, + } + ] + ) + dedup = MemoryDeduplicator(vikingdb=vikingdb) + + similar = await dedup._find_similar_memories(_make_candidate()) + + assert len(similar) == 1 + + @pytest.mark.asyncio + async def test_llm_decision_formats_up_to_five_similar_memories(self): + dedup = MemoryDeduplicator(vikingdb=_DummyVikingDB()) + similar = [_make_existing(f"m_{i}.md") for i in range(6)] + captured = {} + + def _fake_render_prompt(_template_id, variables): + captured.update(variables) + return "prompt" + + class _DummyVLM: + def is_available(self): + return True + + async def get_completion_async(self, _prompt): + return '{"decision":"skip","reason":"dup"}' + + class _DummyConfig: + vlm = _DummyVLM() + + with ( + patch( + "openviking.session.memory_deduplicator.get_openviking_config", + return_value=_DummyConfig(), + ), + patch( + "openviking.session.memory_deduplicator.render_prompt", + side_effect=_fake_render_prompt, + ), + ): + decision, _, _ = await dedup._llm_decision(_make_candidate(), similar) + + assert decision == DedupDecision.SKIP + existing_text = captured["existing_memories"] + assert existing_text.count("uri=") == 5 + assert similar[0].abstract in existing_text + assert "facet=" in existing_text + assert similar[4].uri in existing_text + assert similar[5].uri not in existing_text + + +@pytest.mark.asyncio +class TestMemoryMergeBundle: + async def test_merge_memory_bundle_parses_structured_response(self): + extractor = MemoryExtractor() + + class _DummyVLM: + def is_available(self): + return True + + async def get_completion_async(self, _prompt): + return ( + '{"decision":"merge","abstract":"Tool preference: Use clang","overview":"## ' + 'Preference Domain","content":"Use clang for C++.","reason":"updated"}' + ) + + class _DummyConfig: + vlm = _DummyVLM() + + with patch( + "openviking.session.memory_extractor.get_openviking_config", + return_value=_DummyConfig(), + ): + payload = await extractor._merge_memory_bundle( + existing_abstract="old", + existing_overview="", + existing_content="old content", + new_abstract="new", + new_overview="", + new_content="new content", + category="preferences", + output_language="en", + ) + + assert payload is not None + assert payload.abstract == "Tool preference: Use clang" + assert payload.content == "Use clang for C++." + + async def test_merge_memory_bundle_rejects_missing_required_fields(self): + extractor = MemoryExtractor() + + class _DummyVLM: + def is_available(self): + return True + + async def get_completion_async(self, _prompt): + return '{"decision":"merge","abstract":"","overview":"o","content":"","reason":"r"}' + + class _DummyConfig: + vlm = _DummyVLM() + + with patch( + "openviking.session.memory_extractor.get_openviking_config", + return_value=_DummyConfig(), + ): + payload = await extractor._merge_memory_bundle( + existing_abstract="old", + existing_overview="", + existing_content="old content", + new_abstract="new", + new_overview="", + new_content="new content", + category="preferences", + output_language="en", + ) + + assert payload is None + + +@pytest.mark.asyncio +class TestProfileMergeSafety: + async def test_profile_merge_failure_keeps_existing_content(self): + extractor = MemoryExtractor() + extractor._merge_memory_bundle = AsyncMock(return_value=None) + candidate = CandidateMemory( + category=MemoryCategory.PROFILE, + abstract="User basic info: lives in NYC", + overview="## Background", + content="User currently lives in NYC.", + source_session="session_test", + user="test_user", + language="en", + ) + + fs = MagicMock() + fs.read_file = AsyncMock(return_value="existing profile content") + fs.write_file = AsyncMock() + + payload = await extractor._append_to_profile(candidate, fs, ctx=_make_ctx()) + + assert payload is None + fs.write_file.assert_not_called() + + async def test_create_memory_skips_profile_index_payload_when_merge_fails(self): + extractor = MemoryExtractor() + candidate = CandidateMemory( + category=MemoryCategory.PROFILE, + abstract="User basic info: lives in NYC", + overview="## Background", + content="User currently lives in NYC.", + source_session="session_test", + user="test_user", + language="en", + ) + extractor._append_to_profile = AsyncMock(return_value=None) + + with patch("openviking.session.memory_extractor.get_viking_fs", return_value=MagicMock()): + memory = await extractor.create_memory( + candidate, + user=_make_user(), + session_id="s1", + ctx=_make_ctx(), + ) + + assert memory is None + + +@pytest.mark.asyncio +class TestSessionCompressorDedupActions: + async def test_create_with_empty_list_only_creates_new_memory(self): + candidate = _make_candidate() + new_memory = _make_existing("created.md") + + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = None + vikingdb.delete_uris = AsyncMock(return_value=None) + vikingdb.enqueue_embedding_msg = AsyncMock() + + compressor = SessionCompressor(vikingdb=vikingdb) + compressor.extractor.extract = AsyncMock(return_value=[candidate]) + compressor.extractor.create_memory = AsyncMock(return_value=new_memory) + compressor.deduplicator.deduplicate = AsyncMock( + return_value=DedupResult( + decision=DedupDecision.CREATE, + candidate=candidate, + similar_memories=[], + actions=[], + ) + ) + compressor._index_memory = AsyncMock(return_value=True) + + fs = MagicMock() + fs.rm = AsyncMock() + + with patch("openviking.session.compressor.get_viking_fs", return_value=fs): + memories = await compressor.extract_long_term_memories( + [Message.create_user("test message")], + user=_make_user(), + session_id="session_test", + ctx=_make_ctx(), + ) + + assert len(memories) == 1 + assert memories[0].uri == new_memory.uri + fs.rm.assert_not_called() + compressor.extractor.create_memory.assert_awaited_once() + + async def test_create_with_merge_is_executed_as_none(self): + candidate = _make_candidate() + target = _make_existing("merge_target.md") + + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = None + vikingdb.delete_uris = AsyncMock(return_value=None) + vikingdb.enqueue_embedding_msg = AsyncMock() + + compressor = SessionCompressor(vikingdb=vikingdb) + compressor.extractor.extract = AsyncMock(return_value=[candidate]) + compressor.extractor.create_memory = AsyncMock(return_value=_make_existing("never.md")) + compressor.extractor._merge_memory_bundle = AsyncMock( + return_value=MergedMemoryPayload( + abstract="merged abstract", + overview="merged overview", + content="merged memory content", + reason="merged", + ) + ) + compressor.deduplicator.deduplicate = AsyncMock( + return_value=DedupResult( + decision=DedupDecision.CREATE, + candidate=candidate, + similar_memories=[target], + actions=[ + ExistingMemoryAction( + memory=target, + decision=MemoryActionDecision.MERGE, + ) + ], + ) + ) + compressor._index_memory = AsyncMock(return_value=True) + + fs = MagicMock() + fs.read_file = AsyncMock(return_value="old memory content") + fs.write_file = AsyncMock() + fs.rm = AsyncMock() + + with patch("openviking.session.compressor.get_viking_fs", return_value=fs): + memories = await compressor.extract_long_term_memories( + [Message.create_user("test message")], + user=_make_user(), + session_id="session_test", + ctx=_make_ctx(), + ) + + assert memories == [] + compressor.extractor.create_memory.assert_not_called() + fs.write_file.assert_awaited_once_with(target.uri, "merged memory content", ctx=_make_ctx()) + assert target.abstract == "merged abstract" + assert target.meta["overview"] == "merged overview" + compressor._index_memory.assert_awaited_once() + + async def test_merge_bundle_failure_is_skipped_without_fallback(self): + candidate = _make_candidate() + target = _make_existing("merge_target_fail.md") + + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = None + vikingdb.delete_uris = AsyncMock(return_value=None) + vikingdb.enqueue_embedding_msg = AsyncMock() + + compressor = SessionCompressor(vikingdb=vikingdb) + compressor.extractor.extract = AsyncMock(return_value=[candidate]) + compressor.extractor._merge_memory_bundle = AsyncMock(return_value=None) + compressor.deduplicator.deduplicate = AsyncMock( + return_value=DedupResult( + decision=DedupDecision.NONE, + candidate=candidate, + similar_memories=[target], + actions=[ + ExistingMemoryAction( + memory=target, + decision=MemoryActionDecision.MERGE, + ) + ], + ) + ) + compressor._index_memory = AsyncMock(return_value=True) + + fs = MagicMock() + fs.read_file = AsyncMock(return_value="old memory content") + fs.write_file = AsyncMock() + fs.rm = AsyncMock() + + with patch("openviking.session.compressor.get_viking_fs", return_value=fs): + memories = await compressor.extract_long_term_memories( + [Message.create_user("test message")], + user=_make_user(), + session_id="session_test", + ctx=_make_ctx(), + ) + + assert memories == [] + fs.write_file.assert_not_called() + compressor._index_memory.assert_not_called() + + async def test_create_with_delete_runs_delete_before_create(self): + candidate = _make_candidate() + target = _make_existing("to_delete.md") + new_memory = _make_existing("created_after_delete.md") + call_order = [] + + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = None + vikingdb.delete_uris = AsyncMock(return_value=None) + vikingdb.enqueue_embedding_msg = AsyncMock() + + compressor = SessionCompressor(vikingdb=vikingdb) + compressor.extractor.extract = AsyncMock(return_value=[candidate]) + compressor.deduplicator.deduplicate = AsyncMock( + return_value=DedupResult( + decision=DedupDecision.CREATE, + candidate=candidate, + similar_memories=[target], + actions=[ + ExistingMemoryAction( + memory=target, + decision=MemoryActionDecision.DELETE, + ) + ], + ) + ) + + async def _create_memory(*_args, **_kwargs): + call_order.append("create") + return new_memory + + compressor.extractor.create_memory = AsyncMock(side_effect=_create_memory) + compressor._index_memory = AsyncMock(return_value=True) + + fs = MagicMock() + + async def _rm(*_args, **_kwargs): + call_order.append("delete") + return {} + + fs.rm = AsyncMock(side_effect=_rm) + + with patch("openviking.session.compressor.get_viking_fs", return_value=fs): + memories = await compressor.extract_long_term_memories( + [Message.create_user("test message")], + user=_make_user(), + session_id="session_test", + ctx=_make_ctx(), + ) + + assert [m.uri for m in memories] == [new_memory.uri] + assert call_order == ["delete", "create"] + vikingdb.delete_uris.assert_awaited_once_with(_make_ctx(), [target.uri]) diff --git a/tests/session/test_session_compressor_vikingdb.py b/tests/session/test_session_compressor_vikingdb.py new file mode 100644 index 00000000..71e22533 --- /dev/null +++ b/tests/session/test_session_compressor_vikingdb.py @@ -0,0 +1,26 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from openviking.server.identity import RequestContext, Role +from openviking.session.compressor import SessionCompressor +from openviking_cli.session.user_id import UserIdentifier + + +@pytest.mark.asyncio +async def test_delete_existing_memory_uses_vikingdb_manager(): + compressor = SessionCompressor.__new__(SessionCompressor) + compressor.vikingdb = AsyncMock() + viking_fs = AsyncMock() + memory = SimpleNamespace(uri="viking://user/user1/memories/events/e1") + ctx = RequestContext(user=UserIdentifier("acc1", "user1", "agent1"), role=Role.USER) + + ok = await SessionCompressor._delete_existing_memory(compressor, memory, viking_fs, ctx) + + assert ok is True + viking_fs.rm.assert_awaited_once_with(memory.uri, recursive=False, ctx=ctx) + compressor.vikingdb.delete_uris.assert_awaited_once_with(ctx, [memory.uri]) diff --git a/tests/storage/test_semantic_dag_stats.py b/tests/storage/test_semantic_dag_stats.py new file mode 100644 index 00000000..10f06c22 --- /dev/null +++ b/tests/storage/test_semantic_dag_stats.py @@ -0,0 +1,85 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from openviking.server.identity import RequestContext, Role +from openviking.storage.queuefs.semantic_dag import DagStats, SemanticDagExecutor +from openviking_cli.session.user_id import UserIdentifier + + +class _FakeVikingFS: + def __init__(self, tree): + self._tree = tree + self.writes = [] + + async def ls(self, uri, ctx=None): + return self._tree.get(uri, []) + + async def write_file(self, path, content, ctx=None): + self.writes.append((path, content)) + + +class _FakeProcessor: + def __init__(self): + self.vectorized_dirs = [] + self.vectorized_files = [] + + async def _generate_single_file_summary(self, file_path, llm_sem=None, ctx=None): + return {"name": file_path.split("/")[-1], "summary": "summary"} + + async def _generate_overview(self, dir_uri, file_summaries, children_abstracts): + return "overview" + + def _extract_abstract_from_overview(self, overview): + return "abstract" + + async def _vectorize_directory_simple(self, uri, context_type, abstract, overview, ctx=None): + self.vectorized_dirs.append(uri) + + async def _vectorize_single_file( + self, parent_uri, context_type, file_path, summary_dict, ctx=None + ): + self.vectorized_files.append(file_path) + + +@pytest.mark.asyncio +async def test_semantic_dag_stats_collects_nodes(monkeypatch): + root_uri = "viking://resources/root" + tree = { + root_uri: [ + {"name": "a.txt", "isDir": False}, + {"name": "b.txt", "isDir": False}, + {"name": "child", "isDir": True}, + ], + f"{root_uri}/child": [ + {"name": "c.txt", "isDir": False}, + ], + } + fake_fs = _FakeVikingFS(tree) + monkeypatch.setattr("openviking.storage.queuefs.semantic_dag.get_viking_fs", lambda: fake_fs) + + processor = _FakeProcessor() + ctx = RequestContext(user=UserIdentifier("acc1", "user1", "agent1"), role=Role.USER) + executor = SemanticDagExecutor( + processor=processor, + context_type="resource", + max_concurrent_llm=2, + ctx=ctx, + ) + await executor.run(root_uri) + + stats = executor.get_stats() + assert isinstance(stats, DagStats) + assert stats.total_nodes == 5 # 2 dirs + 3 files + assert stats.pending_nodes == 0 + assert stats.done_nodes == 5 + assert stats.in_progress_nodes == 0 + assert processor.vectorized_dirs == [f"{root_uri}/child", root_uri] + assert sorted(processor.vectorized_files) == sorted( + [f"{root_uri}/a.txt", f"{root_uri}/b.txt", f"{root_uri}/child/c.txt"] + ) + + +if __name__ == "__main__": + pytest.main([__file__])