From 3bc7744d9203e5f0252eafe5b2de3fd4996a5a6a Mon Sep 17 00:00:00 2001 From: "zhoujiahui.01" Date: Thu, 26 Feb 2026 11:50:05 +0800 Subject: [PATCH 1/7] refactor: route vector access through semantic gateway --- openviking/async_client.py | 10 +- openviking/core/directories.py | 12 +- openviking/retrieve/hierarchical_retriever.py | 129 ++-- openviking/server/routers/admin.py | 9 +- openviking/session/compressor.py | 4 +- openviking/session/memory_deduplicator.py | 31 +- openviking/session/session.py | 41 +- openviking/storage/__init__.py | 2 + .../storage/context_semantic_gateway.py | 306 ++++++++++ openviking/storage/viking_fs.py | 92 ++- ...test_hierarchical_retriever_target_dirs.py | 21 +- tests/session/test_memory_dedup_actions.py | 568 ++++++++++++++++++ ...est_session_compressor_semantic_gateway.py | 26 + .../storage/test_context_semantic_gateway.py | 49 ++ tests/storage/test_semantic_dag_stats.py | 85 +++ 15 files changed, 1167 insertions(+), 218 deletions(-) create mode 100644 openviking/storage/context_semantic_gateway.py create mode 100644 tests/session/test_memory_dedup_actions.py create mode 100644 tests/session/test_session_compressor_semantic_gateway.py create mode 100644 tests/storage/test_context_semantic_gateway.py create mode 100644 tests/storage/test_semantic_dag_stats.py diff --git a/openviking/async_client.py b/openviking/async_client.py index f351a6cc..1ba053d2 100644 --- a/openviking/async_client.py +++ b/openviking/async_client.py @@ -58,11 +58,13 @@ def __init__( self.user = UserIdentifier.the_default_user() self._initialized = False - self._singleton_initialized = True + # Mark initialized only after LocalClient is successfully constructed. + self._singleton_initialized = False self._client: BaseClient = LocalClient( path=path, ) + self._singleton_initialized = True # ============= Lifecycle methods ============= @@ -78,7 +80,9 @@ async def _ensure_initialized(self): async def close(self) -> None: """Close OpenViking and release resources.""" - await self._client.close() + client = getattr(self, "_client", None) + if client is not None: + await client.close() self._initialized = False self._singleton_initialized = False @@ -88,8 +92,6 @@ async def reset(cls) -> None: with cls._lock: if cls._instance is not None: await cls._instance.close() - cls._instance._initialized = False - cls._instance._singleton_initialized = False cls._instance = None # ============= Session methods ============= diff --git a/openviking/core/directories.py b/openviking/core/directories.py index 2033a29e..0770fbb8 100644 --- a/openviking/core/directories.py +++ b/openviking/core/directories.py @@ -12,6 +12,7 @@ from openviking.core.context import Context, ContextType, Vectorize from openviking.server.identity import RequestContext +from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway from openviking.storage.queuefs.embedding_msg_converter import EmbeddingMsgConverter if TYPE_CHECKING: @@ -145,6 +146,7 @@ def __init__( vikingdb: "VikingDBManager", ): self.vikingdb = vikingdb + self.semantic_gateway = ContextSemanticSearchGateway.from_storage(vikingdb) async def initialize_account_directories(self, ctx: RequestContext) -> int: """Initialize account-shared scope roots.""" @@ -228,13 +230,9 @@ async def _ensure_directory( logger.debug(f"[VikingFS] Directory {uri} already exists") # 2. Ensure record exists in vector storage - from openviking_cli.utils.config import get_openviking_config - - config = get_openviking_config() - - existing = await self.vikingdb.filter( - collection=config.storage.vectordb.name, - filter={"op": "must", "field": "uri", "conds": [uri]}, + existing = await self.semantic_gateway.get_context_by_uri( + account_id=ctx.account_id, + uri=uri, limit=1, ) if not existing: diff --git a/openviking/retrieve/hierarchical_retriever.py b/openviking/retrieve/hierarchical_retriever.py index 7f92e1df..670f659b 100644 --- a/openviking/retrieve/hierarchical_retriever.py +++ b/openviking/retrieve/hierarchical_retriever.py @@ -14,7 +14,7 @@ from openviking.models.embedder.base import EmbedResult from openviking.retrieve.memory_lifecycle import hotness_score from openviking.server.identity import RequestContext, Role -from openviking.storage import VikingDBInterface +from openviking.storage import ContextSemanticSearchGateway, VikingDBInterface from openviking.storage.viking_fs import get_viking_fs from openviking_cli.retrieve.types import ( ContextType, @@ -58,6 +58,7 @@ def __init__( rerank_config: Rerank configuration (optional, will fallback to vector search only) """ self.storage = storage + self.semantic_gateway = ContextSemanticSearchGateway.from_storage(storage) self.embedder = embedder self.rerank_config = rerank_config @@ -85,7 +86,7 @@ async def retrieve( mode: RetrieverMode = RetrieverMode.THINKING, score_threshold: Optional[float] = None, score_gte: bool = False, - metadata_filter: Optional[Dict[str, Any]] = None, + scope_dsl: Optional[Dict[str, Any]] = None, ) -> QueryResult: """ Execute hierarchical retrieval. @@ -95,44 +96,19 @@ async def retrieve( score_threshold: Custom score threshold (overrides config) score_gte: True uses >=, False uses > grep_patterns: Keyword match pattern list - metadata_filter: Additional metadata filter conditions + scope_dsl: Additional scope constraints passed from public find/search filter """ # Use custom threshold or default threshold effective_threshold = score_threshold if score_threshold is not None else self.threshold - collection = "context" - target_dirs = [d for d in (query.target_directories or []) if d] - # Create context_type filter (skip when context_type is None = search all types) - filters_to_merge = [] - if query.context_type is not None: - type_filter = { - "op": "must", - "field": "context_type", - "conds": [query.context_type.value], - } - filters_to_merge.append(type_filter) - tenant_filter = self._build_tenant_filter(ctx, context_type=query.context_type) - if tenant_filter: - filters_to_merge.append(tenant_filter) - if target_dirs: - target_filter = { - "op": "or", - "conds": [ - {"op": "must", "field": "uri", "conds": [target_dir]} - for target_dir in target_dirs - ], - } - filters_to_merge.append(target_filter) - if metadata_filter: - filters_to_merge.append(metadata_filter) - - final_metadata_filter = {"op": "and", "conds": filters_to_merge} - - if not await self.storage.collection_exists(collection): - logger.warning(f"[RecursiveSearch] Collection {collection} does not exist") + if not await self.semantic_gateway.collection_exists_bound(): + logger.warning( + "[RecursiveSearch] Collection %s does not exist", + self.semantic_gateway.collection_name, + ) return QueryResult( query=query, matched_contexts=[], @@ -155,11 +131,13 @@ async def retrieve( # Step 2: Global vector search to supplement starting points global_results = await self._global_vector_search( - collection=collection, + ctx=ctx, query_vector=query_vector, sparse_query_vector=sparse_query_vector, + context_type=query.context_type.value if query.context_type else None, + target_dirs=target_dirs, + scope_dsl=scope_dsl, limit=self.GLOBAL_SEARCH_TOPK, - filter=final_metadata_filter, ) # Step 3: Merge starting points @@ -168,7 +146,7 @@ async def retrieve( # Step 4: Recursive search candidates = await self._recursive_search( query=query.query, - collection=collection, + ctx=ctx, query_vector=query_vector, sparse_query_vector=sparse_query_vector, starting_points=starting_points, @@ -176,7 +154,9 @@ async def retrieve( mode=mode, threshold=effective_threshold, score_gte=score_gte, - metadata_filter=final_metadata_filter, + context_type=query.context_type.value if query.context_type else None, + target_dirs=target_dirs, + scope_dsl=scope_dsl, ) # Step 6: Convert results @@ -188,56 +168,24 @@ async def retrieve( searched_directories=root_uris, ) - def _build_tenant_filter( - self, ctx: RequestContext, context_type: Optional[ContextType] = None - ) -> Optional[Dict[str, Any]]: - """Build tenant visibility filter by role. - - Args: - ctx: Request context with role and user info. - context_type: When RESOURCE, allow owner_space="" so shared - resources are visible to USER role. - """ - if ctx.role == Role.ROOT: - return None - - owner_spaces = [ctx.user.user_space_name(), ctx.user.agent_space_name()] - if context_type == ContextType.RESOURCE: - owner_spaces.append("") - return { - "op": "and", - "conds": [ - {"op": "must", "field": "account_id", "conds": [ctx.account_id]}, - { - "op": "must", - "field": "owner_space", - "conds": owner_spaces, - }, - ], - } - async def _global_vector_search( self, - collection: str, + ctx: RequestContext, query_vector: Optional[List[float]], sparse_query_vector: Optional[Dict[str, float]], + context_type: Optional[str], + target_dirs: List[str], + scope_dsl: Optional[Dict[str, Any]], limit: int, - filter: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """Global vector search to locate initial directories.""" - if not query_vector: - return [] - sparse_query_vector = sparse_query_vector or {} - - global_filter = { - "op": "and", - "conds": [filter, {"op": "must", "field": "level", "conds": [0, 1]}], - } - results = await self.storage.search( - collection=collection, + results = await self.semantic_gateway.search_global_roots_in_tenant( + ctx=ctx, query_vector=query_vector, sparse_query_vector=sparse_query_vector, - filter=global_filter, + context_type=context_type, + target_directories=target_dirs, + extra_filter_dsl=scope_dsl, limit=limit, ) return results @@ -283,7 +231,7 @@ def _merge_starting_points( async def _recursive_search( self, query: str, - collection: str, + ctx: RequestContext, query_vector: Optional[List[float]], sparse_query_vector: Optional[Dict[str, float]], starting_points: List[Tuple[str, float]], @@ -291,7 +239,9 @@ async def _recursive_search( mode: str, threshold: Optional[float] = None, score_gte: bool = False, - metadata_filter: Optional[Dict[str, Any]] = None, + context_type: Optional[str] = None, + target_dirs: Optional[List[str]] = None, + scope_dsl: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """ Recursive search with directory priority return and score propagation. @@ -300,7 +250,7 @@ async def _recursive_search( threshold: Score threshold score_gte: True uses >=, False uses > grep_patterns: Keyword match patterns - metadata_filter: Additional metadata filter conditions + scope_dsl: Additional scope constraints from public find/search filter """ # Use passed threshold or default threshold effective_threshold = threshold if threshold is not None else self.threshold @@ -311,12 +261,6 @@ def passes_threshold(score: float) -> bool: return score >= effective_threshold return score > effective_threshold - def merge_filter(base_filter: Dict, extra_filter: Optional[Dict]) -> Dict: - """Merge filter conditions.""" - if not extra_filter: - return base_filter - return {"op": "and", "conds": [base_filter, extra_filter]} - sparse_query_vector = sparse_query_vector or None collected: List[Dict[str, Any]] = [] # Collected results (directories and leaves) @@ -341,13 +285,14 @@ def merge_filter(base_filter: Dict, extra_filter: Optional[Dict]) -> Dict: pre_filter_limit = max(limit * 2, 20) - results = await self.storage.search( - collection=collection, + results = await self.semantic_gateway.search_children_in_tenant( + ctx=ctx, + parent_uri=current_uri, query_vector=query_vector, sparse_query_vector=sparse_query_vector, # Pass sparse vector - filter=merge_filter( - {"op": "must", "field": "parent_uri", "conds": [current_uri]}, metadata_filter - ), + context_type=context_type, + target_directories=target_dirs, + extra_filter_dsl=scope_dsl, limit=pre_filter_limit, ) diff --git a/openviking/server/routers/admin.py b/openviking/server/routers/admin.py index 681dbe5c..ae963aa5 100644 --- a/openviking/server/routers/admin.py +++ b/openviking/server/routers/admin.py @@ -9,6 +9,7 @@ from openviking.server.dependencies import get_service from openviking.server.identity import RequestContext, Role from openviking.server.models import Response +from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway from openviking.storage.viking_fs import get_viking_fs from openviking_cli.exceptions import PermissionDeniedError from openviking_cli.session.user_id import UserIdentifier @@ -120,12 +121,8 @@ async def delete_account( try: storage = viking_fs._get_vector_store() if storage: - account_filter = { - "op": "must", - "field": "account_id", - "conds": [account_id], - } - deleted = await storage.batch_delete("context", account_filter) + gateway = ContextSemanticSearchGateway.from_storage(storage) + deleted = await gateway.delete_account_data(account_id) logger.info(f"VectorDB cascade delete for account {account_id}: {deleted} records") except Exception as e: logger.warning(f"VectorDB cleanup for account {account_id}: {e}") diff --git a/openviking/session/compressor.py b/openviking/session/compressor.py index 4007ecc0..32120fd1 100644 --- a/openviking/session/compressor.py +++ b/openviking/session/compressor.py @@ -14,6 +14,7 @@ from openviking.message import Message from openviking.server.identity import RequestContext from openviking.storage import VikingDBManager +from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway from openviking.storage.viking_fs import get_viking_fs from openviking_cli.session.user_id import UserIdentifier from openviking_cli.utils import get_logger @@ -53,6 +54,7 @@ def __init__( ): """Initialize session compressor.""" self.vikingdb = vikingdb + self.semantic_gateway = ContextSemanticSearchGateway.from_storage(vikingdb) self.extractor = MemoryExtractor() self.deduplicator = MemoryDeduplicator(vikingdb=vikingdb) @@ -113,7 +115,7 @@ async def _delete_existing_memory( try: # rm() already syncs vector deletion in most cases; keep this as a safe fallback. - await self.vikingdb.remove_by_uri("context", memory.uri) + await self.semantic_gateway.delete_uris(ctx, [memory.uri]) except Exception as e: logger.warning(f"Failed to remove vector record for {memory.uri}: {e}") return True diff --git a/openviking/session/memory_deduplicator.py b/openviking/session/memory_deduplicator.py index 67d8ea85..325899d8 100644 --- a/openviking/session/memory_deduplicator.py +++ b/openviking/session/memory_deduplicator.py @@ -16,6 +16,7 @@ from openviking.models.embedder.base import EmbedResult from openviking.prompts import render_prompt from openviking.storage import VikingDBManager +from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway from openviking_cli.utils import get_logger from openviking_cli.utils.config import get_openviking_config @@ -83,6 +84,7 @@ def __init__( ): """Initialize deduplicator.""" self.vikingdb = vikingdb + self.semantic_gateway = ContextSemanticSearchGateway.from_storage(vikingdb) self.embedder = vikingdb.get_embedder() async def deduplicate( @@ -127,42 +129,33 @@ async def _find_similar_memories( embed_result: EmbedResult = self.embedder.embed(query_text) query_vector = embed_result.dense_vector - # Determine collection and filter based on category - collection = "context" - category_uri_prefix = self._category_uri_prefix(candidate.category.value, candidate.user) - # Build filter by memory scope + uri prefix (schema does not have category field yet). - filter_conds = [ - {"field": "context_type", "op": "must", "conds": ["memory"]}, - {"field": "level", "op": "must", "conds": [2]}, - ] owner = candidate.user - if hasattr(owner, "account_id"): - filter_conds.append({"field": "account_id", "op": "must", "conds": [owner.account_id]}) + account_id = owner.account_id if hasattr(owner, "account_id") else "default" + owner_space = None if owner and hasattr(owner, "user_space_name"): owner_space = ( owner.agent_space_name() if candidate.category.value in {"cases", "patterns"} else owner.user_space_name() ) - filter_conds.append({"field": "owner_space", "op": "must", "conds": [owner_space]}) - if category_uri_prefix: - filter_conds.append({"field": "uri", "op": "must", "conds": [category_uri_prefix]}) - dedup_filter = {"op": "and", "conds": filter_conds} logger.debug( - "Dedup prefilter candidate category=%s filter=%s", + "Dedup prefilter candidate category=%s account=%s owner_space=%s uri_prefix=%s", candidate.category.value, - dedup_filter, + account_id, + owner_space, + category_uri_prefix, ) try: # Search with memory-scope filter. - results = await self.vikingdb.search( - collection=collection, + results = await self.semantic_gateway.search_similar_memories( + account_id=account_id, + owner_space=owner_space, + category_uri_prefix=category_uri_prefix, query_vector=query_vector, limit=5, - filter=dedup_filter, ) # Filter by similarity threshold diff --git a/openviking/session/session.py b/openviking/session/session.py index d46c5276..d6db1f91 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -5,7 +5,6 @@ Session as Context: Sessions integrated into L0/L1/L2 system. """ -import hashlib import json import re from dataclasses import dataclass, field @@ -15,6 +14,7 @@ from openviking.message import Message, Part from openviking.server.identity import RequestContext, Role +from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway from openviking.utils.time_utils import get_current_timestamp from openviking_cli.session.user_id import UserIdentifier from openviking_cli.utils import get_logger, run_async @@ -78,6 +78,11 @@ def __init__( ): self._viking_fs = viking_fs self._vikingdb_manager = vikingdb_manager + self._semantic_gateway = ( + ContextSemanticSearchGateway.from_storage(vikingdb_manager) + if vikingdb_manager + else None + ) self._session_compressor = session_compressor self.user = user or UserIdentifier.the_default_user() self.ctx = ctx or RequestContext(user=self.user, role=Role.ROOT) @@ -297,35 +302,15 @@ def commit(self) -> Dict[str, Any]: def _update_active_counts(self) -> int: """Update active_count for used contexts/skills.""" - if not self._vikingdb_manager: + if not self._semantic_gateway: return 0 - updated = 0 - storage = self._vikingdb_manager - - for usage in self._usage_records: - try: - # Compute the record ID from the URI directly. - # collection_schemas.py assigns id = md5(uri) for every context - # record, so storage.get() gives us a precise single-record lookup - # without the subtree-matching side-effect of fetch_by_uri() on - # path-type fields. - record_id = hashlib.md5(usage.uri.encode("utf-8")).hexdigest() - records = run_async(storage.get(collection="context", ids=[record_id])) - if not records: - logger.debug(f"Record not found for URI: {usage.uri}") - continue - current_count = records[0].get("active_count") or 0 - run_async( - storage.update( - collection="context", - id=record_id, - data={"active_count": current_count + 1}, - ) - ) - updated += 1 - except Exception as e: - logger.debug(f"Could not update active_count for {usage.uri}: {e}") + uris = [usage.uri for usage in self._usage_records if usage.uri] + try: + updated = run_async(self._semantic_gateway.increment_active_count(self.ctx, uris)) + except Exception as e: + logger.debug(f"Could not update active_count for usage URIs: {e}") + updated = 0 if updated > 0: logger.info(f"Updated active_count for {updated} contexts/skills") diff --git a/openviking/storage/__init__.py b/openviking/storage/__init__.py index 63620722..6d5e45b2 100644 --- a/openviking/storage/__init__.py +++ b/openviking/storage/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """Storage layer interfaces and implementations.""" +from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway from openviking.storage.observers import BaseObserver, QueueObserver from openviking.storage.queuefs import QueueManager, get_queue_manager, init_queue_manager from openviking.storage.viking_fs import VikingFS, get_viking_fs, init_viking_fs @@ -30,6 +31,7 @@ # Backend "VikingVectorIndexBackend", "VikingDBManager", + "ContextSemanticSearchGateway", # QueueFS "QueueManager", "init_queue_manager", diff --git a/openviking/storage/context_semantic_gateway.py b/openviking/storage/context_semantic_gateway.py new file mode 100644 index 00000000..7252eb52 --- /dev/null +++ b/openviking/storage/context_semantic_gateway.py @@ -0,0 +1,306 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +""" +Semantic vector gateway for OpenViking business flows. + +This module keeps raw filter DSL usage inside storage integration code so +business modules can call intent-based methods. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from openviking.server.identity import RequestContext, Role +from openviking.storage.vikingdb_interface import VikingDBInterface +from openviking_cli.utils.config import get_openviking_config + + +class ContextSemanticSearchGateway: + """Semantic methods over the bound context collection.""" + + def __init__(self, storage: VikingDBInterface, collection_name: str): + self._storage = storage + self._collection_name = collection_name + + @classmethod + def from_storage( + cls, storage: VikingDBInterface, collection_name: Optional[str] = None + ) -> "ContextSemanticSearchGateway": + if collection_name: + bound_collection = collection_name + else: + try: + bound_collection = get_openviking_config().storage.vectordb.name + except Exception: + # Keep simple tests and lightweight call sites usable. + bound_collection = "context" + return cls(storage=storage, collection_name=bound_collection) + + @property + def collection_name(self) -> str: + return self._collection_name + + async def collection_exists_bound(self) -> bool: + return await self._storage.collection_exists(self._collection_name) + + async def search_in_tenant( + self, + ctx: RequestContext, + query_vector: Optional[List[float]], + sparse_query_vector: Optional[Dict[str, float]] = None, + context_type: Optional[str] = None, + target_directories: Optional[List[str]] = None, + extra_filter_dsl: Optional[Dict[str, Any]] = None, + limit: int = 10, + offset: int = 0, + ) -> List[Dict[str, Any]]: + scope_filter = self._build_scope_filter( + ctx=ctx, + context_type=context_type, + target_directories=target_directories, + extra_filter_dsl=extra_filter_dsl, + ) + return await self._storage.search( + collection=self._collection_name, + query_vector=query_vector, + sparse_query_vector=sparse_query_vector, + filter=scope_filter, + limit=limit, + offset=offset, + ) + + async def search_global_roots_in_tenant( + self, + ctx: RequestContext, + query_vector: Optional[List[float]], + sparse_query_vector: Optional[Dict[str, float]] = None, + context_type: Optional[str] = None, + target_directories: Optional[List[str]] = None, + extra_filter_dsl: Optional[Dict[str, Any]] = None, + limit: int = 10, + ) -> List[Dict[str, Any]]: + if not query_vector: + return [] + + merged_filter = self._merge_filters( + self._build_scope_filter( + ctx=ctx, + context_type=context_type, + target_directories=target_directories, + extra_filter_dsl=extra_filter_dsl, + ), + {"op": "must", "field": "level", "conds": [0, 1]}, + ) + return await self._storage.search( + collection=self._collection_name, + query_vector=query_vector, + sparse_query_vector=sparse_query_vector, + filter=merged_filter, + limit=limit, + ) + + async def search_children_in_tenant( + self, + ctx: RequestContext, + parent_uri: str, + query_vector: Optional[List[float]], + sparse_query_vector: Optional[Dict[str, float]] = None, + context_type: Optional[str] = None, + target_directories: Optional[List[str]] = None, + extra_filter_dsl: Optional[Dict[str, Any]] = None, + limit: int = 10, + ) -> List[Dict[str, Any]]: + merged_filter = self._merge_filters( + {"op": "must", "field": "parent_uri", "conds": [parent_uri]}, + self._build_scope_filter( + ctx=ctx, + context_type=context_type, + target_directories=target_directories, + extra_filter_dsl=extra_filter_dsl, + ), + ) + return await self._storage.search( + collection=self._collection_name, + query_vector=query_vector, + sparse_query_vector=sparse_query_vector, + filter=merged_filter, + limit=limit, + ) + + async def search_similar_memories( + self, + account_id: str, + owner_space: Optional[str], + category_uri_prefix: str, + query_vector: List[float], + limit: int = 5, + ) -> List[Dict[str, Any]]: + conds: List[Dict[str, Any]] = [ + {"op": "must", "field": "context_type", "conds": ["memory"]}, + {"op": "must", "field": "level", "conds": [2]}, + {"op": "must", "field": "account_id", "conds": [account_id]}, + ] + if owner_space: + conds.append({"op": "must", "field": "owner_space", "conds": [owner_space]}) + if category_uri_prefix: + conds.append({"op": "must", "field": "uri", "conds": [category_uri_prefix]}) + + return await self._storage.search( + collection=self._collection_name, + query_vector=query_vector, + filter={"op": "and", "conds": conds}, + limit=limit, + ) + + async def get_context_by_uri( + self, + account_id: str, + uri: str, + owner_space: Optional[str] = None, + limit: int = 1, + ) -> List[Dict[str, Any]]: + conds: List[Dict[str, Any]] = [ + {"op": "must", "field": "uri", "conds": [uri]}, + {"op": "must", "field": "account_id", "conds": [account_id]}, + ] + if owner_space: + conds.append({"op": "must", "field": "owner_space", "conds": [owner_space]}) + + return await self._storage.filter( + collection=self._collection_name, + filter={"op": "and", "conds": conds}, + limit=limit, + ) + + async def delete_account_data(self, account_id: str) -> int: + return await self._storage.batch_delete( + self._collection_name, + {"op": "must", "field": "account_id", "conds": [account_id]}, + ) + + async def delete_uris(self, ctx: RequestContext, uris: List[str]) -> None: + for uri in uris: + conds: List[Dict[str, Any]] = [ + {"op": "must", "field": "account_id", "conds": [ctx.account_id]}, + { + "op": "or", + "conds": [ + {"op": "must", "field": "uri", "conds": [uri]}, + {"op": "must", "field": "uri", "conds": [f"{uri}/"]}, + ], + }, + ] + if ctx.role == Role.USER and uri.startswith(("viking://user/", "viking://agent/")): + owner_space = ( + ctx.user.user_space_name() + if uri.startswith("viking://user/") + else ctx.user.agent_space_name() + ) + conds.append({"op": "must", "field": "owner_space", "conds": [owner_space]}) + await self._storage.batch_delete( + self._collection_name, + {"op": "and", "conds": conds}, + ) + + async def update_uri_mapping( + self, + ctx: RequestContext, + uri: str, + new_uri: str, + new_parent_uri: str, + ) -> bool: + records = await self._storage.filter( + collection=self._collection_name, + filter={ + "op": "and", + "conds": [ + {"op": "must", "field": "uri", "conds": [uri]}, + {"op": "must", "field": "account_id", "conds": [ctx.account_id]}, + ], + }, + limit=1, + ) + if not records or "id" not in records[0]: + return False + + return await self._storage.update( + self._collection_name, + records[0]["id"], + {"uri": new_uri, "parent_uri": new_parent_uri}, + ) + + async def increment_active_count(self, ctx: RequestContext, uris: List[str]) -> int: + updated = 0 + for uri in uris: + records = await self.get_context_by_uri(account_id=ctx.account_id, uri=uri, limit=1) + if not records: + continue + record = records[0] + record_id = record.get("id") + if not record_id: + continue + current = int(record.get("active_count", 0) or 0) + if await self._storage.update( + self._collection_name, + record_id, + {"active_count": current + 1}, + ): + updated += 1 + return updated + + def _build_scope_filter( + self, + ctx: RequestContext, + context_type: Optional[str], + target_directories: Optional[List[str]], + extra_filter_dsl: Optional[Dict[str, Any]], + ) -> Optional[Dict[str, Any]]: + filters: List[Dict[str, Any]] = [] + if context_type: + filters.append({"op": "must", "field": "context_type", "conds": [context_type]}) + + tenant_filter = self._tenant_filter(ctx, context_type=context_type) + if tenant_filter: + filters.append(tenant_filter) + + if target_directories: + uri_conds = [ + {"op": "must", "field": "uri", "conds": [target_dir]} + for target_dir in target_directories + if target_dir + ] + if uri_conds: + filters.append({"op": "or", "conds": uri_conds}) + + if extra_filter_dsl: + filters.append(extra_filter_dsl) + + return self._merge_filters(*filters) + + @staticmethod + def _tenant_filter( + ctx: RequestContext, context_type: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + if ctx.role == Role.ROOT: + return None + + owner_spaces = [ctx.user.user_space_name(), ctx.user.agent_space_name()] + if context_type == "resource": + owner_spaces.append("") + return { + "op": "and", + "conds": [ + {"op": "must", "field": "account_id", "conds": [ctx.account_id]}, + {"op": "must", "field": "owner_space", "conds": owner_spaces}, + ], + } + + @staticmethod + def _merge_filters(*filters: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + non_empty = [f for f in filters if f] + if not non_empty: + return None + if len(non_empty) == 1: + return non_empty[0] + return {"op": "and", "conds": non_empty} diff --git a/openviking/storage/viking_fs.py b/openviking/storage/viking_fs.py index deef27db..78dbecc6 100644 --- a/openviking/storage/viking_fs.py +++ b/openviking/storage/viking_fs.py @@ -25,6 +25,7 @@ from pyagfs.exceptions import AGFSHTTPError from openviking.server.identity import RequestContext, Role +from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway from openviking.storage.vikingdb_interface import VikingDBInterface from openviking.utils.time_utils import format_simplified, get_current_timestamp, parse_iso_datetime from openviking_cli.session.user_id import UserIdentifier @@ -168,6 +169,9 @@ def __init__( self.query_embedder = query_embedder self.rerank_config = rerank_config self.vector_store = vector_store + self._context_semantic_gateway: Optional[ContextSemanticSearchGateway] = ( + ContextSemanticSearchGateway.from_storage(vector_store) if vector_store else None + ) self._bound_ctx: contextvars.ContextVar[Optional[RequestContext]] = contextvars.ContextVar( "vikingfs_bound_ctx", default=None ) @@ -604,7 +608,7 @@ async def find( ctx=self._ctx_or_default(ctx), limit=limit, score_threshold=score_threshold, - metadata_filter=filter, + scope_dsl=filter, ) # Convert QueryResult to FindResult @@ -721,7 +725,7 @@ async def _execute(tq: TypedQuery): ctx=self._ctx_or_default(ctx), limit=limit, score_threshold=score_threshold, - metadata_filter=filter, + scope_dsl=filter, ) query_results = await asyncio.gather(*[_execute(tq) for tq in typed_queries]) @@ -1045,40 +1049,21 @@ async def _delete_from_vector_store( ) -> None: """Delete records with specified URIs from vector store. - Uses storage.remove_by_uri method, which implements recursive deletion of child nodes. + Uses semantic gateway to apply tenant-safe URI deletion semantics. """ - storage = self._get_vector_store() - if not storage: + if not self._get_vector_store(): + return + gateway = self._get_context_semantic_gateway() + if not gateway: return real_ctx = self._ctx_or_default(ctx) - for uri in uris: - try: - filter_conds: List[Dict[str, Any]] = [ - {"op": "must", "field": "account_id", "conds": [real_ctx.account_id]}, - { - "op": "or", - "conds": [ - {"op": "must", "field": "uri", "conds": [uri]}, - {"op": "must", "field": "uri", "conds": [f"{uri}/"]}, - ], - }, - ] - if real_ctx.role == Role.USER and uri.startswith( - ("viking://user/", "viking://agent/") - ): - owner_space = ( - real_ctx.user.user_space_name() - if uri.startswith("viking://user/") - else real_ctx.user.agent_space_name() - ) - filter_conds.append( - {"op": "must", "field": "owner_space", "conds": [owner_space]} - ) - await storage.batch_delete("context", {"op": "and", "conds": filter_conds}) + try: + await gateway.delete_uris(real_ctx, uris) + for uri in uris: logger.info(f"[VikingFS] Deleted from vector store: {uri}") - except Exception as e: - logger.warning(f"[VikingFS] Failed to delete {uri} from vector store: {e}") + except Exception as e: + logger.warning(f"[VikingFS] Failed to delete from vector store: {e}") async def _update_vector_store_uris( self, @@ -1091,8 +1076,10 @@ async def _update_vector_store_uris( Preserves vector data, only updates uri and parent_uri fields, no need to regenerate embeddings. """ - storage = self._get_vector_store() - if not storage: + if not self._get_vector_store(): + return + gateway = self._get_context_semantic_gateway() + if not gateway: return old_base_uri = self._path_to_uri(old_base, ctx=ctx) @@ -1100,19 +1087,9 @@ async def _update_vector_store_uris( for uri in uris: try: - records = await storage.filter( - collection="context", - filter={ - "op": "and", - "conds": [ - {"op": "must", "field": "uri", "conds": [uri]}, - { - "op": "must", - "field": "account_id", - "conds": [self._ctx_or_default(ctx).account_id], - }, - ], - }, + records = await gateway.get_context_by_uri( + account_id=self._ctx_or_default(ctx).account_id, + uri=uri, limit=1, ) @@ -1120,7 +1097,6 @@ async def _update_vector_store_uris( continue record = records[0] - record_id = record["id"] new_uri = uri.replace(old_base_uri, new_base_uri, 1) @@ -1129,13 +1105,11 @@ async def _update_vector_store_uris( old_parent_uri.replace(old_base_uri, new_base_uri, 1) if old_parent_uri else "" ) - await storage.update( - "context", - record_id, - { - "uri": new_uri, - "parent_uri": new_parent_uri, - }, + await gateway.update_uri_mapping( + ctx=self._ctx_or_default(ctx), + uri=uri, + new_uri=new_uri, + new_parent_uri=new_parent_uri, ) logger.info(f"[VikingFS] Updated URI: {uri} -> {new_uri}") except Exception as e: @@ -1145,6 +1119,16 @@ def _get_vector_store(self) -> Optional["VikingDBInterface"]: """Get vector store instance.""" return self.vector_store + def _get_context_semantic_gateway(self) -> Optional[ContextSemanticSearchGateway]: + """Get semantic vector gateway bound to configured collection.""" + storage = self._get_vector_store() + if not storage: + self._context_semantic_gateway = None + return None + if not self._context_semantic_gateway: + self._context_semantic_gateway = ContextSemanticSearchGateway.from_storage(storage) + return self._context_semantic_gateway + def _get_embedder(self) -> Any: """Get embedder instance.""" return self.query_embedder diff --git a/tests/retrieve/test_hierarchical_retriever_target_dirs.py b/tests/retrieve/test_hierarchical_retriever_target_dirs.py index a3f53d3d..31c8165a 100644 --- a/tests/retrieve/test_hierarchical_retriever_target_dirs.py +++ b/tests/retrieve/test_hierarchical_retriever_target_dirs.py @@ -6,7 +6,9 @@ import pytest from openviking.retrieve.hierarchical_retriever import HierarchicalRetriever +from openviking.server.identity import RequestContext, Role from openviking_cli.retrieve.types import ContextType, TypedQuery +from openviking_cli.session.user_id import UserIdentifier class DummyStorage: @@ -40,21 +42,26 @@ async def search( return [] -def _contains_prefix_filter(obj, prefix: str) -> bool: +def _contains_uri_scope_filter(obj, target_uri: str) -> bool: if isinstance(obj, dict): - if obj.get("op") == "prefix" and obj.get("field") == "uri" and obj.get("prefix") == prefix: + if ( + obj.get("op") == "must" + and obj.get("field") == "uri" + and target_uri in obj.get("conds", []) + ): return True - return any(_contains_prefix_filter(v, prefix) for v in obj.values()) + return any(_contains_uri_scope_filter(v, target_uri) for v in obj.values()) if isinstance(obj, list): - return any(_contains_prefix_filter(v, prefix) for v in obj) + return any(_contains_uri_scope_filter(v, target_uri) for v in obj) return False @pytest.mark.asyncio -async def test_retrieve_honors_target_directories_prefix_filter(): +async def test_retrieve_honors_target_directories_scope_filter(): target_uri = "viking://resources/foo" storage = DummyStorage() retriever = HierarchicalRetriever(storage=storage, embedder=None, rerank_config=None) + ctx = RequestContext(user=UserIdentifier("acc1", "user1", "agent1"), role=Role.USER) query = TypedQuery( query="test", @@ -63,8 +70,8 @@ async def test_retrieve_honors_target_directories_prefix_filter(): target_directories=[target_uri], ) - result = await retriever.retrieve(query, limit=3) + result = await retriever.retrieve(query, ctx=ctx, limit=3) assert result.searched_directories == [target_uri] assert storage.search_calls - assert _contains_prefix_filter(storage.search_calls[0]["filter"], target_uri) + assert _contains_uri_scope_filter(storage.search_calls[0]["filter"], target_uri) diff --git a/tests/session/test_memory_dedup_actions.py b/tests/session/test_memory_dedup_actions.py new file mode 100644 index 00000000..bd88306f --- /dev/null +++ b/tests/session/test_memory_dedup_actions.py @@ -0,0 +1,568 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openviking.core.context import Context +from openviking.message import Message +from openviking.server.identity import RequestContext, Role +from openviking.session.compressor import SessionCompressor +from openviking.session.memory_deduplicator import ( + DedupDecision, + DedupResult, + ExistingMemoryAction, + MemoryActionDecision, + MemoryDeduplicator, +) +from openviking.session.memory_extractor import ( + CandidateMemory, + MemoryCategory, + MemoryExtractor, + MergedMemoryPayload, +) +from openviking_cli.session.user_id import UserIdentifier + + +class _DummyVikingDB: + def __init__(self): + self._embedder = None + + def get_embedder(self): + return self._embedder + + +class _DummyEmbedResult: + def __init__(self, dense_vector): + self.dense_vector = dense_vector + + +class _DummyEmbedder: + def embed(self, _text): + return _DummyEmbedResult([0.1, 0.2, 0.3]) + + +def _make_user() -> UserIdentifier: + return UserIdentifier("acc1", "test_user", "test_agent") + + +def _make_ctx() -> RequestContext: + return RequestContext(user=_make_user(), role=Role.USER) + + +def _make_candidate() -> CandidateMemory: + return CandidateMemory( + category=MemoryCategory.PREFERENCES, + abstract="User prefers concise summaries", + overview="User asks for concise answers frequently.", + content="The user prefers concise summaries over long explanations.", + source_session="session_test", + user=_make_user(), + language="en", + ) + + +def _make_existing(uri_suffix: str = "existing.md") -> Context: + user_space = _make_user().user_space_name() + return Context( + uri=f"viking://user/{user_space}/memories/preferences/{uri_suffix}", + parent_uri=f"viking://user/{user_space}/memories/preferences", + is_leaf=True, + abstract="Existing preference memory", + context_type="memory", + category="preferences", + ) + + +class TestMemoryDeduplicatorPayload: + def test_create_with_empty_list_is_valid(self): + dedup = MemoryDeduplicator(vikingdb=_DummyVikingDB()) + existing = [_make_existing("a.md")] + + decision, _, actions = dedup._parse_decision_payload( + {"decision": "create", "reason": "new memory", "list": []}, + existing, + ) + + assert decision == DedupDecision.CREATE + assert actions == [] + + def test_create_with_merge_is_normalized_to_none(self): + dedup = MemoryDeduplicator(vikingdb=_DummyVikingDB()) + existing = [_make_existing("b.md")] + + decision, _, actions = dedup._parse_decision_payload( + { + "decision": "create", + "list": [{"uri": existing[0].uri, "decide": "merge"}], + }, + existing, + ) + + assert decision == DedupDecision.NONE + assert len(actions) == 1 + assert actions[0].decision == MemoryActionDecision.MERGE + + def test_skip_drops_list_actions(self): + dedup = MemoryDeduplicator(vikingdb=_DummyVikingDB()) + existing = [_make_existing("c.md")] + + decision, _, actions = dedup._parse_decision_payload( + { + "decision": "skip", + "list": [{"uri": existing[0].uri, "decide": "delete"}], + }, + existing, + ) + + assert decision == DedupDecision.SKIP + assert actions == [] + + def test_cross_facet_delete_actions_are_kept(self): + dedup = MemoryDeduplicator(vikingdb=_DummyVikingDB()) + food = _make_existing("food.md") + food.abstract = "饮食偏好: 喜欢吃苹果和草莓" + routine = _make_existing("routine.md") + routine.abstract = "作息习惯: 每天早上7点起床" + existing = [food, routine] + candidate = _make_candidate() + candidate.abstract = "饮食偏好: 不再喜欢吃水果" + candidate.content = "用户不再喜欢吃水果,需要作废过去的水果偏好。" + + decision, _, actions = dedup._parse_decision_payload( + { + "decision": "create", + "list": [ + {"uri": food.uri, "decide": "delete"}, + {"uri": routine.uri, "decide": "delete"}, + ], + }, + existing, + candidate, + ) + + assert decision == DedupDecision.CREATE + assert len(actions) == 2 + assert {a.memory.uri for a in actions} == {food.uri, routine.uri} + assert all(a.decision == MemoryActionDecision.DELETE for a in actions) + + @pytest.mark.asyncio + async def test_find_similar_memories_uses_path_must_filter_and__score(self): + existing = _make_existing("pref_hit.md") + + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = _DummyEmbedder() + vikingdb.search = AsyncMock( + return_value=[ + { + "id": "uri_pref_hit", + "uri": existing.uri, + "context_type": "memory", + "level": 2, + "account_id": "acc1", + "owner_space": _make_user().user_space_name(), + "abstract": existing.abstract, + "category": "preferences", + "_score": 0.82, + } + ] + ) + dedup = MemoryDeduplicator(vikingdb=vikingdb) + candidate = _make_candidate() + + similar = await dedup._find_similar_memories(candidate) + + assert len(similar) == 1 + assert similar[0].uri == existing.uri + call = vikingdb.search.await_args.kwargs + assert call["filter"]["op"] == "and" + assert {"field": "context_type", "op": "must", "conds": ["memory"]} in call["filter"][ + "conds" + ] + assert {"field": "level", "op": "must", "conds": [2]} in call["filter"]["conds"] + assert {"field": "account_id", "op": "must", "conds": ["acc1"]} in call["filter"]["conds"] + assert { + "field": "owner_space", + "op": "must", + "conds": [_make_user().user_space_name()], + } in call["filter"]["conds"] + assert { + "field": "uri", + "op": "must", + "conds": [f"viking://user/{_make_user().user_space_name()}/memories/preferences/"], + } in call["filter"]["conds"] + + @pytest.mark.asyncio + async def test_find_similar_memories_accepts_low_score_when_threshold_is_zero(self): + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = _DummyEmbedder() + vikingdb.search = AsyncMock( + return_value=[ + { + "id": "uri_low", + "uri": f"viking://user/{_make_user().user_space_name()}/memories/preferences/low.md", + "context_type": "memory", + "level": 2, + "account_id": "acc1", + "owner_space": _make_user().user_space_name(), + "abstract": "low", + "_score": 0.68, + } + ] + ) + dedup = MemoryDeduplicator(vikingdb=vikingdb) + + similar = await dedup._find_similar_memories(_make_candidate()) + + assert len(similar) == 1 + + @pytest.mark.asyncio + async def test_llm_decision_formats_up_to_five_similar_memories(self): + dedup = MemoryDeduplicator(vikingdb=_DummyVikingDB()) + similar = [_make_existing(f"m_{i}.md") for i in range(6)] + captured = {} + + def _fake_render_prompt(_template_id, variables): + captured.update(variables) + return "prompt" + + class _DummyVLM: + def is_available(self): + return True + + async def get_completion_async(self, _prompt): + return '{"decision":"skip","reason":"dup"}' + + class _DummyConfig: + vlm = _DummyVLM() + + with ( + patch( + "openviking.session.memory_deduplicator.get_openviking_config", + return_value=_DummyConfig(), + ), + patch( + "openviking.session.memory_deduplicator.render_prompt", + side_effect=_fake_render_prompt, + ), + ): + decision, _, _ = await dedup._llm_decision(_make_candidate(), similar) + + assert decision == DedupDecision.SKIP + existing_text = captured["existing_memories"] + assert existing_text.count("uri=") == 5 + assert similar[0].abstract in existing_text + assert "facet=" in existing_text + assert similar[4].uri in existing_text + assert similar[5].uri not in existing_text + + +@pytest.mark.asyncio +class TestMemoryMergeBundle: + async def test_merge_memory_bundle_parses_structured_response(self): + extractor = MemoryExtractor() + + class _DummyVLM: + def is_available(self): + return True + + async def get_completion_async(self, _prompt): + return ( + '{"decision":"merge","abstract":"Tool preference: Use clang","overview":"## ' + 'Preference Domain","content":"Use clang for C++.","reason":"updated"}' + ) + + class _DummyConfig: + vlm = _DummyVLM() + + with patch( + "openviking.session.memory_extractor.get_openviking_config", + return_value=_DummyConfig(), + ): + payload = await extractor._merge_memory_bundle( + existing_abstract="old", + existing_overview="", + existing_content="old content", + new_abstract="new", + new_overview="", + new_content="new content", + category="preferences", + output_language="en", + ) + + assert payload is not None + assert payload.abstract == "Tool preference: Use clang" + assert payload.content == "Use clang for C++." + + async def test_merge_memory_bundle_rejects_missing_required_fields(self): + extractor = MemoryExtractor() + + class _DummyVLM: + def is_available(self): + return True + + async def get_completion_async(self, _prompt): + return '{"decision":"merge","abstract":"","overview":"o","content":"","reason":"r"}' + + class _DummyConfig: + vlm = _DummyVLM() + + with patch( + "openviking.session.memory_extractor.get_openviking_config", + return_value=_DummyConfig(), + ): + payload = await extractor._merge_memory_bundle( + existing_abstract="old", + existing_overview="", + existing_content="old content", + new_abstract="new", + new_overview="", + new_content="new content", + category="preferences", + output_language="en", + ) + + assert payload is None + + +@pytest.mark.asyncio +class TestProfileMergeSafety: + async def test_profile_merge_failure_keeps_existing_content(self): + extractor = MemoryExtractor() + extractor._merge_memory_bundle = AsyncMock(return_value=None) + candidate = CandidateMemory( + category=MemoryCategory.PROFILE, + abstract="User basic info: lives in NYC", + overview="## Background", + content="User currently lives in NYC.", + source_session="session_test", + user="test_user", + language="en", + ) + + fs = MagicMock() + fs.read_file = AsyncMock(return_value="existing profile content") + fs.write_file = AsyncMock() + + payload = await extractor._append_to_profile(candidate, fs, ctx=_make_ctx()) + + assert payload is None + fs.write_file.assert_not_called() + + async def test_create_memory_skips_profile_index_payload_when_merge_fails(self): + extractor = MemoryExtractor() + candidate = CandidateMemory( + category=MemoryCategory.PROFILE, + abstract="User basic info: lives in NYC", + overview="## Background", + content="User currently lives in NYC.", + source_session="session_test", + user="test_user", + language="en", + ) + extractor._append_to_profile = AsyncMock(return_value=None) + + with patch("openviking.session.memory_extractor.get_viking_fs", return_value=MagicMock()): + memory = await extractor.create_memory( + candidate, + user=_make_user(), + session_id="s1", + ctx=_make_ctx(), + ) + + assert memory is None + + +@pytest.mark.asyncio +class TestSessionCompressorDedupActions: + async def test_create_with_empty_list_only_creates_new_memory(self): + candidate = _make_candidate() + new_memory = _make_existing("created.md") + + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = None + vikingdb.batch_delete = AsyncMock(return_value=1) + vikingdb.enqueue_embedding_msg = AsyncMock() + + compressor = SessionCompressor(vikingdb=vikingdb) + compressor.extractor.extract = AsyncMock(return_value=[candidate]) + compressor.extractor.create_memory = AsyncMock(return_value=new_memory) + compressor.deduplicator.deduplicate = AsyncMock( + return_value=DedupResult( + decision=DedupDecision.CREATE, + candidate=candidate, + similar_memories=[], + actions=[], + ) + ) + compressor._index_memory = AsyncMock(return_value=True) + + fs = MagicMock() + fs.rm = AsyncMock() + + with patch("openviking.session.compressor.get_viking_fs", return_value=fs): + memories = await compressor.extract_long_term_memories( + [Message.create_user("test message")], + user=_make_user(), + session_id="session_test", + ctx=_make_ctx(), + ) + + assert len(memories) == 1 + assert memories[0].uri == new_memory.uri + fs.rm.assert_not_called() + compressor.extractor.create_memory.assert_awaited_once() + + async def test_create_with_merge_is_executed_as_none(self): + candidate = _make_candidate() + target = _make_existing("merge_target.md") + + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = None + vikingdb.batch_delete = AsyncMock(return_value=1) + vikingdb.enqueue_embedding_msg = AsyncMock() + + compressor = SessionCompressor(vikingdb=vikingdb) + compressor.extractor.extract = AsyncMock(return_value=[candidate]) + compressor.extractor.create_memory = AsyncMock(return_value=_make_existing("never.md")) + compressor.extractor._merge_memory_bundle = AsyncMock( + return_value=MergedMemoryPayload( + abstract="merged abstract", + overview="merged overview", + content="merged memory content", + reason="merged", + ) + ) + compressor.deduplicator.deduplicate = AsyncMock( + return_value=DedupResult( + decision=DedupDecision.CREATE, + candidate=candidate, + similar_memories=[target], + actions=[ + ExistingMemoryAction( + memory=target, + decision=MemoryActionDecision.MERGE, + ) + ], + ) + ) + compressor._index_memory = AsyncMock(return_value=True) + + fs = MagicMock() + fs.read_file = AsyncMock(return_value="old memory content") + fs.write_file = AsyncMock() + fs.rm = AsyncMock() + + with patch("openviking.session.compressor.get_viking_fs", return_value=fs): + memories = await compressor.extract_long_term_memories( + [Message.create_user("test message")], + user=_make_user(), + session_id="session_test", + ctx=_make_ctx(), + ) + + assert memories == [] + compressor.extractor.create_memory.assert_not_called() + fs.write_file.assert_awaited_once_with(target.uri, "merged memory content", ctx=_make_ctx()) + assert target.abstract == "merged abstract" + assert target.meta["overview"] == "merged overview" + compressor._index_memory.assert_awaited_once() + + async def test_merge_bundle_failure_is_skipped_without_fallback(self): + candidate = _make_candidate() + target = _make_existing("merge_target_fail.md") + + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = None + vikingdb.batch_delete = AsyncMock(return_value=1) + vikingdb.enqueue_embedding_msg = AsyncMock() + + compressor = SessionCompressor(vikingdb=vikingdb) + compressor.extractor.extract = AsyncMock(return_value=[candidate]) + compressor.extractor._merge_memory_bundle = AsyncMock(return_value=None) + compressor.deduplicator.deduplicate = AsyncMock( + return_value=DedupResult( + decision=DedupDecision.NONE, + candidate=candidate, + similar_memories=[target], + actions=[ + ExistingMemoryAction( + memory=target, + decision=MemoryActionDecision.MERGE, + ) + ], + ) + ) + compressor._index_memory = AsyncMock(return_value=True) + + fs = MagicMock() + fs.read_file = AsyncMock(return_value="old memory content") + fs.write_file = AsyncMock() + fs.rm = AsyncMock() + + with patch("openviking.session.compressor.get_viking_fs", return_value=fs): + memories = await compressor.extract_long_term_memories( + [Message.create_user("test message")], + user=_make_user(), + session_id="session_test", + ctx=_make_ctx(), + ) + + assert memories == [] + fs.write_file.assert_not_called() + compressor._index_memory.assert_not_called() + + async def test_create_with_delete_runs_delete_before_create(self): + candidate = _make_candidate() + target = _make_existing("to_delete.md") + new_memory = _make_existing("created_after_delete.md") + call_order = [] + + vikingdb = MagicMock() + vikingdb.get_embedder.return_value = None + vikingdb.batch_delete = AsyncMock(return_value=1) + vikingdb.enqueue_embedding_msg = AsyncMock() + + compressor = SessionCompressor(vikingdb=vikingdb) + compressor.extractor.extract = AsyncMock(return_value=[candidate]) + compressor.deduplicator.deduplicate = AsyncMock( + return_value=DedupResult( + decision=DedupDecision.CREATE, + candidate=candidate, + similar_memories=[target], + actions=[ + ExistingMemoryAction( + memory=target, + decision=MemoryActionDecision.DELETE, + ) + ], + ) + ) + + async def _create_memory(*_args, **_kwargs): + call_order.append("create") + return new_memory + + compressor.extractor.create_memory = AsyncMock(side_effect=_create_memory) + compressor._index_memory = AsyncMock(return_value=True) + + fs = MagicMock() + + async def _rm(*_args, **_kwargs): + call_order.append("delete") + return {} + + fs.rm = AsyncMock(side_effect=_rm) + + with patch("openviking.session.compressor.get_viking_fs", return_value=fs): + memories = await compressor.extract_long_term_memories( + [Message.create_user("test message")], + user=_make_user(), + session_id="session_test", + ctx=_make_ctx(), + ) + + assert [m.uri for m in memories] == [new_memory.uri] + assert call_order == ["delete", "create"] + vikingdb.batch_delete.assert_awaited_once() diff --git a/tests/session/test_session_compressor_semantic_gateway.py b/tests/session/test_session_compressor_semantic_gateway.py new file mode 100644 index 00000000..d2a1c7d7 --- /dev/null +++ b/tests/session/test_session_compressor_semantic_gateway.py @@ -0,0 +1,26 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from openviking.server.identity import RequestContext, Role +from openviking.session.compressor import SessionCompressor +from openviking_cli.session.user_id import UserIdentifier + + +@pytest.mark.asyncio +async def test_delete_existing_memory_uses_semantic_gateway(): + compressor = SessionCompressor.__new__(SessionCompressor) + compressor.semantic_gateway = AsyncMock() + viking_fs = AsyncMock() + memory = SimpleNamespace(uri="viking://user/user1/memories/events/e1") + ctx = RequestContext(user=UserIdentifier("acc1", "user1", "agent1"), role=Role.USER) + + ok = await SessionCompressor._delete_existing_memory(compressor, memory, viking_fs, ctx) + + assert ok is True + viking_fs.rm.assert_awaited_once_with(memory.uri, recursive=False, ctx=ctx) + compressor.semantic_gateway.delete_uris.assert_awaited_once_with(ctx, [memory.uri]) diff --git a/tests/storage/test_context_semantic_gateway.py b/tests/storage/test_context_semantic_gateway.py new file mode 100644 index 00000000..0cfe53fb --- /dev/null +++ b/tests/storage/test_context_semantic_gateway.py @@ -0,0 +1,49 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import AsyncMock + +import pytest + +from openviking.server.identity import RequestContext, Role +from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway +from openviking_cli.session.user_id import UserIdentifier + + +def _make_ctx() -> RequestContext: + return RequestContext(user=UserIdentifier("acc1", "user1", "agent1"), role=Role.USER) + + +@pytest.mark.asyncio +async def test_search_in_tenant_uses_bound_collection_and_tenant_scope(): + storage = AsyncMock() + storage.search.return_value = [] + gateway = ContextSemanticSearchGateway.from_storage(storage, collection_name="ctx_custom") + + await gateway.search_in_tenant( + ctx=_make_ctx(), + query_vector=[0.1], + context_type="resource", + target_directories=["viking://resources/foo"], + limit=2, + ) + + call = storage.search.await_args.kwargs + assert call["collection"] == "ctx_custom" + assert call["filter"]["op"] == "and" + + +@pytest.mark.asyncio +async def test_increment_active_count_updates_by_uri(): + storage = AsyncMock() + storage.filter.return_value = [{"id": "r1", "active_count": 3}] + storage.update.return_value = True + gateway = ContextSemanticSearchGateway.from_storage(storage, collection_name="ctx_custom") + + updated = await gateway.increment_active_count(_make_ctx(), ["viking://resources/foo"]) + + assert updated == 1 + update_call = storage.update.await_args + assert update_call.args[0] == "ctx_custom" + assert update_call.args[1] == "r1" + assert update_call.args[2]["active_count"] == 4 diff --git a/tests/storage/test_semantic_dag_stats.py b/tests/storage/test_semantic_dag_stats.py new file mode 100644 index 00000000..10f06c22 --- /dev/null +++ b/tests/storage/test_semantic_dag_stats.py @@ -0,0 +1,85 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from openviking.server.identity import RequestContext, Role +from openviking.storage.queuefs.semantic_dag import DagStats, SemanticDagExecutor +from openviking_cli.session.user_id import UserIdentifier + + +class _FakeVikingFS: + def __init__(self, tree): + self._tree = tree + self.writes = [] + + async def ls(self, uri, ctx=None): + return self._tree.get(uri, []) + + async def write_file(self, path, content, ctx=None): + self.writes.append((path, content)) + + +class _FakeProcessor: + def __init__(self): + self.vectorized_dirs = [] + self.vectorized_files = [] + + async def _generate_single_file_summary(self, file_path, llm_sem=None, ctx=None): + return {"name": file_path.split("/")[-1], "summary": "summary"} + + async def _generate_overview(self, dir_uri, file_summaries, children_abstracts): + return "overview" + + def _extract_abstract_from_overview(self, overview): + return "abstract" + + async def _vectorize_directory_simple(self, uri, context_type, abstract, overview, ctx=None): + self.vectorized_dirs.append(uri) + + async def _vectorize_single_file( + self, parent_uri, context_type, file_path, summary_dict, ctx=None + ): + self.vectorized_files.append(file_path) + + +@pytest.mark.asyncio +async def test_semantic_dag_stats_collects_nodes(monkeypatch): + root_uri = "viking://resources/root" + tree = { + root_uri: [ + {"name": "a.txt", "isDir": False}, + {"name": "b.txt", "isDir": False}, + {"name": "child", "isDir": True}, + ], + f"{root_uri}/child": [ + {"name": "c.txt", "isDir": False}, + ], + } + fake_fs = _FakeVikingFS(tree) + monkeypatch.setattr("openviking.storage.queuefs.semantic_dag.get_viking_fs", lambda: fake_fs) + + processor = _FakeProcessor() + ctx = RequestContext(user=UserIdentifier("acc1", "user1", "agent1"), role=Role.USER) + executor = SemanticDagExecutor( + processor=processor, + context_type="resource", + max_concurrent_llm=2, + ctx=ctx, + ) + await executor.run(root_uri) + + stats = executor.get_stats() + assert isinstance(stats, DagStats) + assert stats.total_nodes == 5 # 2 dirs + 3 files + assert stats.pending_nodes == 0 + assert stats.done_nodes == 5 + assert stats.in_progress_nodes == 0 + assert processor.vectorized_dirs == [f"{root_uri}/child", root_uri] + assert sorted(processor.vectorized_files) == sorted( + [f"{root_uri}/a.txt", f"{root_uri}/b.txt", f"{root_uri}/child/c.txt"] + ) + + +if __name__ == "__main__": + pytest.main([__file__]) From 4f0567392ceee7e17891f4a197d32d67a2c7a7c1 Mon Sep 17 00:00:00 2001 From: "zhoujiahui.01" Date: Thu, 26 Feb 2026 14:05:33 +0800 Subject: [PATCH 2/7] refactor vector storage to collection-bound drivers --- openviking/core/directories.py | 4 +- openviking/retrieve/hierarchical_retriever.py | 8 +- openviking/server/routers/admin.py | 4 +- openviking/session/compressor.py | 4 +- openviking/session/memory_deduplicator.py | 4 +- openviking/session/session.py | 6 +- openviking/storage/__init__.py | 4 +- ...c_gateway.py => context_vector_gateway.py} | 116 ++++++------ .../storage/observers/vikingdb_observer.py | 15 +- openviking/storage/vector_store/__init__.py | 43 +++++ openviking/storage/vector_store/driver.py | 124 +++++++++++++ .../storage/vector_store/drivers/__init__.py | 15 ++ .../storage/vector_store/drivers/common.py | 33 ++++ .../vector_store/drivers/http_driver.py | 119 ++++++++++++ .../vector_store/drivers/local_driver.py | 91 +++++++++ .../vector_store/drivers/vikingdb_driver.py | 116 ++++++++++++ .../vector_store/drivers/volcengine_driver.py | 130 +++++++++++++ openviking/storage/vector_store/expr.py | 73 ++++++++ openviking/storage/vector_store/factory.py | 17 ++ openviking/storage/vector_store/registry.py | 35 ++++ openviking/storage/viking_fs.py | 20 +- .../storage/viking_vector_index_backend.py | 174 +++++------------- openviking/storage/vikingdb_interface.py | 24 ++- ...test_hierarchical_retriever_target_dirs.py | 5 + tests/session/test_memory_dedup_actions.py | 30 ++- ...eway.py => test_context_vector_gateway.py} | 9 +- 26 files changed, 970 insertions(+), 253 deletions(-) rename openviking/storage/{context_semantic_gateway.py => context_vector_gateway.py} (71%) create mode 100644 openviking/storage/vector_store/__init__.py create mode 100644 openviking/storage/vector_store/driver.py create mode 100644 openviking/storage/vector_store/drivers/__init__.py create mode 100644 openviking/storage/vector_store/drivers/common.py create mode 100644 openviking/storage/vector_store/drivers/http_driver.py create mode 100644 openviking/storage/vector_store/drivers/local_driver.py create mode 100644 openviking/storage/vector_store/drivers/vikingdb_driver.py create mode 100644 openviking/storage/vector_store/drivers/volcengine_driver.py create mode 100644 openviking/storage/vector_store/expr.py create mode 100644 openviking/storage/vector_store/factory.py create mode 100644 openviking/storage/vector_store/registry.py rename tests/storage/{test_context_semantic_gateway.py => test_context_vector_gateway.py} (79%) diff --git a/openviking/core/directories.py b/openviking/core/directories.py index 0770fbb8..c2b07302 100644 --- a/openviking/core/directories.py +++ b/openviking/core/directories.py @@ -12,7 +12,7 @@ from openviking.core.context import Context, ContextType, Vectorize from openviking.server.identity import RequestContext -from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway +from openviking.storage.context_vector_gateway import ContextVectorGateway from openviking.storage.queuefs.embedding_msg_converter import EmbeddingMsgConverter if TYPE_CHECKING: @@ -146,7 +146,7 @@ def __init__( vikingdb: "VikingDBManager", ): self.vikingdb = vikingdb - self.semantic_gateway = ContextSemanticSearchGateway.from_storage(vikingdb) + self.semantic_gateway = ContextVectorGateway.from_storage(vikingdb) async def initialize_account_directories(self, ctx: RequestContext) -> int: """Initialize account-shared scope roots.""" diff --git a/openviking/retrieve/hierarchical_retriever.py b/openviking/retrieve/hierarchical_retriever.py index 670f659b..f4eb4896 100644 --- a/openviking/retrieve/hierarchical_retriever.py +++ b/openviking/retrieve/hierarchical_retriever.py @@ -14,7 +14,7 @@ from openviking.models.embedder.base import EmbedResult from openviking.retrieve.memory_lifecycle import hotness_score from openviking.server.identity import RequestContext, Role -from openviking.storage import ContextSemanticSearchGateway, VikingDBInterface +from openviking.storage import ContextVectorGateway, VikingDBInterface from openviking.storage.viking_fs import get_viking_fs from openviking_cli.retrieve.types import ( ContextType, @@ -58,7 +58,7 @@ def __init__( rerank_config: Rerank configuration (optional, will fallback to vector search only) """ self.storage = storage - self.semantic_gateway = ContextSemanticSearchGateway.from_storage(storage) + self.semantic_gateway = ContextVectorGateway.from_storage(storage) self.embedder = embedder self.rerank_config = rerank_config @@ -185,7 +185,7 @@ async def _global_vector_search( sparse_query_vector=sparse_query_vector, context_type=context_type, target_directories=target_dirs, - extra_filter_dsl=scope_dsl, + extra_filter=scope_dsl, limit=limit, ) return results @@ -292,7 +292,7 @@ def passes_threshold(score: float) -> bool: sparse_query_vector=sparse_query_vector, # Pass sparse vector context_type=context_type, target_directories=target_dirs, - extra_filter_dsl=scope_dsl, + extra_filter=scope_dsl, limit=pre_filter_limit, ) diff --git a/openviking/server/routers/admin.py b/openviking/server/routers/admin.py index ae963aa5..e68f0626 100644 --- a/openviking/server/routers/admin.py +++ b/openviking/server/routers/admin.py @@ -9,7 +9,7 @@ from openviking.server.dependencies import get_service from openviking.server.identity import RequestContext, Role from openviking.server.models import Response -from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway +from openviking.storage.context_vector_gateway import ContextVectorGateway from openviking.storage.viking_fs import get_viking_fs from openviking_cli.exceptions import PermissionDeniedError from openviking_cli.session.user_id import UserIdentifier @@ -121,7 +121,7 @@ async def delete_account( try: storage = viking_fs._get_vector_store() if storage: - gateway = ContextSemanticSearchGateway.from_storage(storage) + gateway = ContextVectorGateway.from_storage(storage) deleted = await gateway.delete_account_data(account_id) logger.info(f"VectorDB cascade delete for account {account_id}: {deleted} records") except Exception as e: diff --git a/openviking/session/compressor.py b/openviking/session/compressor.py index 32120fd1..139c6acb 100644 --- a/openviking/session/compressor.py +++ b/openviking/session/compressor.py @@ -14,7 +14,7 @@ from openviking.message import Message from openviking.server.identity import RequestContext from openviking.storage import VikingDBManager -from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway +from openviking.storage.context_vector_gateway import ContextVectorGateway from openviking.storage.viking_fs import get_viking_fs from openviking_cli.session.user_id import UserIdentifier from openviking_cli.utils import get_logger @@ -54,7 +54,7 @@ def __init__( ): """Initialize session compressor.""" self.vikingdb = vikingdb - self.semantic_gateway = ContextSemanticSearchGateway.from_storage(vikingdb) + self.semantic_gateway = ContextVectorGateway.from_storage(vikingdb) self.extractor = MemoryExtractor() self.deduplicator = MemoryDeduplicator(vikingdb=vikingdb) diff --git a/openviking/session/memory_deduplicator.py b/openviking/session/memory_deduplicator.py index 325899d8..ceb207ca 100644 --- a/openviking/session/memory_deduplicator.py +++ b/openviking/session/memory_deduplicator.py @@ -16,7 +16,7 @@ from openviking.models.embedder.base import EmbedResult from openviking.prompts import render_prompt from openviking.storage import VikingDBManager -from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway +from openviking.storage.context_vector_gateway import ContextVectorGateway from openviking_cli.utils import get_logger from openviking_cli.utils.config import get_openviking_config @@ -84,7 +84,7 @@ def __init__( ): """Initialize deduplicator.""" self.vikingdb = vikingdb - self.semantic_gateway = ContextSemanticSearchGateway.from_storage(vikingdb) + self.semantic_gateway = ContextVectorGateway.from_storage(vikingdb) self.embedder = vikingdb.get_embedder() async def deduplicate( diff --git a/openviking/session/session.py b/openviking/session/session.py index d6db1f91..f26d7750 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -14,7 +14,7 @@ from openviking.message import Message, Part from openviking.server.identity import RequestContext, Role -from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway +from openviking.storage.context_vector_gateway import ContextVectorGateway from openviking.utils.time_utils import get_current_timestamp from openviking_cli.session.user_id import UserIdentifier from openviking_cli.utils import get_logger, run_async @@ -79,9 +79,7 @@ def __init__( self._viking_fs = viking_fs self._vikingdb_manager = vikingdb_manager self._semantic_gateway = ( - ContextSemanticSearchGateway.from_storage(vikingdb_manager) - if vikingdb_manager - else None + ContextVectorGateway.from_storage(vikingdb_manager) if vikingdb_manager else None ) self._session_compressor = session_compressor self.user = user or UserIdentifier.the_default_user() diff --git a/openviking/storage/__init__.py b/openviking/storage/__init__.py index 6d5e45b2..b6a0abfc 100644 --- a/openviking/storage/__init__.py +++ b/openviking/storage/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """Storage layer interfaces and implementations.""" -from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway +from openviking.storage.context_vector_gateway import ContextVectorGateway from openviking.storage.observers import BaseObserver, QueueObserver from openviking.storage.queuefs import QueueManager, get_queue_manager, init_queue_manager from openviking.storage.viking_fs import VikingFS, get_viking_fs, init_viking_fs @@ -31,7 +31,7 @@ # Backend "VikingVectorIndexBackend", "VikingDBManager", - "ContextSemanticSearchGateway", + "ContextVectorGateway", # QueueFS "QueueManager", "init_queue_manager", diff --git a/openviking/storage/context_semantic_gateway.py b/openviking/storage/context_vector_gateway.py similarity index 71% rename from openviking/storage/context_semantic_gateway.py rename to openviking/storage/context_vector_gateway.py index 7252eb52..73376a41 100644 --- a/openviking/storage/context_semantic_gateway.py +++ b/openviking/storage/context_vector_gateway.py @@ -12,11 +12,22 @@ from typing import Any, Dict, List, Optional from openviking.server.identity import RequestContext, Role +from openviking.storage.vector_store.expr import ( + And, + Eq, + FilterExpr, + In, + Or, + Prefix, + RawDSL, +) from openviking.storage.vikingdb_interface import VikingDBInterface from openviking_cli.utils.config import get_openviking_config +WhereExpr = FilterExpr | Dict[str, Any] -class ContextSemanticSearchGateway: + +class ContextVectorGateway: """Semantic methods over the bound context collection.""" def __init__(self, storage: VikingDBInterface, collection_name: str): @@ -26,7 +37,7 @@ def __init__(self, storage: VikingDBInterface, collection_name: str): @classmethod def from_storage( cls, storage: VikingDBInterface, collection_name: Optional[str] = None - ) -> "ContextSemanticSearchGateway": + ) -> "ContextVectorGateway": if collection_name: bound_collection = collection_name else: @@ -51,7 +62,7 @@ async def search_in_tenant( sparse_query_vector: Optional[Dict[str, float]] = None, context_type: Optional[str] = None, target_directories: Optional[List[str]] = None, - extra_filter_dsl: Optional[Dict[str, Any]] = None, + extra_filter: Optional[WhereExpr] = None, limit: int = 10, offset: int = 0, ) -> List[Dict[str, Any]]: @@ -59,7 +70,7 @@ async def search_in_tenant( ctx=ctx, context_type=context_type, target_directories=target_directories, - extra_filter_dsl=extra_filter_dsl, + extra_filter=extra_filter, ) return await self._storage.search( collection=self._collection_name, @@ -77,7 +88,7 @@ async def search_global_roots_in_tenant( sparse_query_vector: Optional[Dict[str, float]] = None, context_type: Optional[str] = None, target_directories: Optional[List[str]] = None, - extra_filter_dsl: Optional[Dict[str, Any]] = None, + extra_filter: Optional[WhereExpr] = None, limit: int = 10, ) -> List[Dict[str, Any]]: if not query_vector: @@ -88,9 +99,9 @@ async def search_global_roots_in_tenant( ctx=ctx, context_type=context_type, target_directories=target_directories, - extra_filter_dsl=extra_filter_dsl, + extra_filter=extra_filter, ), - {"op": "must", "field": "level", "conds": [0, 1]}, + In("level", [0, 1]), ) return await self._storage.search( collection=self._collection_name, @@ -108,16 +119,16 @@ async def search_children_in_tenant( sparse_query_vector: Optional[Dict[str, float]] = None, context_type: Optional[str] = None, target_directories: Optional[List[str]] = None, - extra_filter_dsl: Optional[Dict[str, Any]] = None, + extra_filter: Optional[WhereExpr] = None, limit: int = 10, ) -> List[Dict[str, Any]]: merged_filter = self._merge_filters( - {"op": "must", "field": "parent_uri", "conds": [parent_uri]}, + Eq("parent_uri", parent_uri), self._build_scope_filter( ctx=ctx, context_type=context_type, target_directories=target_directories, - extra_filter_dsl=extra_filter_dsl, + extra_filter=extra_filter, ), ) return await self._storage.search( @@ -136,20 +147,20 @@ async def search_similar_memories( query_vector: List[float], limit: int = 5, ) -> List[Dict[str, Any]]: - conds: List[Dict[str, Any]] = [ - {"op": "must", "field": "context_type", "conds": ["memory"]}, - {"op": "must", "field": "level", "conds": [2]}, - {"op": "must", "field": "account_id", "conds": [account_id]}, + conds: List[FilterExpr] = [ + Eq("context_type", "memory"), + Eq("level", 2), + Eq("account_id", account_id), ] if owner_space: - conds.append({"op": "must", "field": "owner_space", "conds": [owner_space]}) + conds.append(Eq("owner_space", owner_space)) if category_uri_prefix: - conds.append({"op": "must", "field": "uri", "conds": [category_uri_prefix]}) + conds.append(Prefix("uri", category_uri_prefix)) return await self._storage.search( collection=self._collection_name, query_vector=query_vector, - filter={"op": "and", "conds": conds}, + filter=And(conds), limit=limit, ) @@ -160,36 +171,30 @@ async def get_context_by_uri( owner_space: Optional[str] = None, limit: int = 1, ) -> List[Dict[str, Any]]: - conds: List[Dict[str, Any]] = [ - {"op": "must", "field": "uri", "conds": [uri]}, - {"op": "must", "field": "account_id", "conds": [account_id]}, + conds: List[FilterExpr] = [ + Eq("uri", uri), + Eq("account_id", account_id), ] if owner_space: - conds.append({"op": "must", "field": "owner_space", "conds": [owner_space]}) + conds.append(Eq("owner_space", owner_space)) return await self._storage.filter( collection=self._collection_name, - filter={"op": "and", "conds": conds}, + filter=And(conds), limit=limit, ) async def delete_account_data(self, account_id: str) -> int: return await self._storage.batch_delete( self._collection_name, - {"op": "must", "field": "account_id", "conds": [account_id]}, + Eq("account_id", account_id), ) async def delete_uris(self, ctx: RequestContext, uris: List[str]) -> None: for uri in uris: - conds: List[Dict[str, Any]] = [ - {"op": "must", "field": "account_id", "conds": [ctx.account_id]}, - { - "op": "or", - "conds": [ - {"op": "must", "field": "uri", "conds": [uri]}, - {"op": "must", "field": "uri", "conds": [f"{uri}/"]}, - ], - }, + conds: List[FilterExpr] = [ + Eq("account_id", ctx.account_id), + Or([Eq("uri", uri), Prefix("uri", f"{uri}/")]), ] if ctx.role == Role.USER and uri.startswith(("viking://user/", "viking://agent/")): owner_space = ( @@ -197,10 +202,10 @@ async def delete_uris(self, ctx: RequestContext, uris: List[str]) -> None: if uri.startswith("viking://user/") else ctx.user.agent_space_name() ) - conds.append({"op": "must", "field": "owner_space", "conds": [owner_space]}) + conds.append(Eq("owner_space", owner_space)) await self._storage.batch_delete( self._collection_name, - {"op": "and", "conds": conds}, + And(conds), ) async def update_uri_mapping( @@ -212,13 +217,7 @@ async def update_uri_mapping( ) -> bool: records = await self._storage.filter( collection=self._collection_name, - filter={ - "op": "and", - "conds": [ - {"op": "must", "field": "uri", "conds": [uri]}, - {"op": "must", "field": "account_id", "conds": [ctx.account_id]}, - ], - }, + filter=And([Eq("uri", uri), Eq("account_id", ctx.account_id)]), limit=1, ) if not records or "id" not in records[0]: @@ -254,11 +253,11 @@ def _build_scope_filter( ctx: RequestContext, context_type: Optional[str], target_directories: Optional[List[str]], - extra_filter_dsl: Optional[Dict[str, Any]], - ) -> Optional[Dict[str, Any]]: - filters: List[Dict[str, Any]] = [] + extra_filter: Optional[WhereExpr], + ) -> Optional[FilterExpr]: + filters: List[FilterExpr] = [] if context_type: - filters.append({"op": "must", "field": "context_type", "conds": [context_type]}) + filters.append(Eq("context_type", context_type)) tenant_filter = self._tenant_filter(ctx, context_type=context_type) if tenant_filter: @@ -266,41 +265,36 @@ def _build_scope_filter( if target_directories: uri_conds = [ - {"op": "must", "field": "uri", "conds": [target_dir]} - for target_dir in target_directories - if target_dir + Prefix("uri", target_dir) for target_dir in target_directories if target_dir ] if uri_conds: - filters.append({"op": "or", "conds": uri_conds}) + filters.append(Or(uri_conds)) - if extra_filter_dsl: - filters.append(extra_filter_dsl) + if extra_filter: + if isinstance(extra_filter, dict): + filters.append(RawDSL(extra_filter)) + else: + filters.append(extra_filter) return self._merge_filters(*filters) @staticmethod def _tenant_filter( ctx: RequestContext, context_type: Optional[str] = None - ) -> Optional[Dict[str, Any]]: + ) -> Optional[FilterExpr]: if ctx.role == Role.ROOT: return None owner_spaces = [ctx.user.user_space_name(), ctx.user.agent_space_name()] if context_type == "resource": owner_spaces.append("") - return { - "op": "and", - "conds": [ - {"op": "must", "field": "account_id", "conds": [ctx.account_id]}, - {"op": "must", "field": "owner_space", "conds": owner_spaces}, - ], - } + return And([Eq("account_id", ctx.account_id), In("owner_space", owner_spaces)]) @staticmethod - def _merge_filters(*filters: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + def _merge_filters(*filters: Optional[FilterExpr]) -> Optional[FilterExpr]: non_empty = [f for f in filters if f] if not non_empty: return None if len(non_empty) == 1: return non_empty[0] - return {"op": "and", "conds": non_empty} + return And(non_empty) diff --git a/openviking/storage/observers/vikingdb_observer.py b/openviking/storage/observers/vikingdb_observer.py index dc8f6e35..4b25fb0c 100644 --- a/openviking/storage/observers/vikingdb_observer.py +++ b/openviking/storage/observers/vikingdb_observer.py @@ -49,19 +49,12 @@ async def _get_collection_statuses(self, collection_names: list) -> Dict[str, Di for name in collection_names: try: - if not self._vikingdb_manager.project.has_collection(name): + if not await self._vikingdb_manager.collection_exists(name): continue - collection = self._vikingdb_manager.project.get_collection(name) - if not collection: - continue - - index_count = len(collection.list_indexes()) - - count_result = collection.aggregate_data( - index_name=self._vikingdb_manager.DEFAULT_INDEX_NAME, op="count" - ) - vector_count = count_result.agg.get("_total", 0) + # Current OpenViking flow uses one managed default index per collection. + index_count = 1 + vector_count = await self._vikingdb_manager.count(name) statuses[name] = { "index_count": index_count, diff --git a/openviking/storage/vector_store/__init__.py b/openviking/storage/vector_store/__init__.py new file mode 100644 index 00000000..8820f2ac --- /dev/null +++ b/openviking/storage/vector_store/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Vector store driver architecture.""" + +from openviking.storage.vector_store.driver import VectorStoreDriver +from openviking.storage.vector_store.expr import ( + And, + Contains, + Eq, + FilterExpr, + In, + Or, + Prefix, + Range, + RawDSL, + Regex, + TimeRange, +) +from openviking.storage.vector_store.factory import create_driver +from openviking.storage.vector_store.registry import ( + get_driver_class, + list_registered_backends, + register_driver, +) + +__all__ = [ + "VectorStoreDriver", + "FilterExpr", + "And", + "Or", + "Eq", + "In", + "Prefix", + "Range", + "Contains", + "Regex", + "TimeRange", + "RawDSL", + "create_driver", + "register_driver", + "get_driver_class", + "list_registered_backends", +] diff --git a/openviking/storage/vector_store/driver.py b/openviking/storage/vector_store/driver.py new file mode 100644 index 00000000..c22afc55 --- /dev/null +++ b/openviking/storage/vector_store/driver.py @@ -0,0 +1,124 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Base driver contracts for vector store backends.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict + +from openviking.storage.vector_store.expr import ( + And, + Contains, + Eq, + FilterExpr, + In, + Or, + Prefix, + Range, + RawDSL, + Regex, + TimeRange, +) + + +class VectorStoreDriver(ABC): + """Backend-specific adapter for collection operations + filter AST compilation.""" + + mode: str + + @classmethod + @abstractmethod + def from_config(cls, config: Any) -> "VectorStoreDriver": + """Create a driver instance from VectorDB backend config.""" + + @abstractmethod + def has_collection(self, name: str) -> bool: + """Return whether collection exists.""" + + @abstractmethod + def get_collection(self, name: str) -> Any: + """Return backend collection handle.""" + + @abstractmethod + def create_collection(self, name: str, meta: Dict[str, Any]) -> Any: + """Create a collection and return backend collection handle.""" + + @abstractmethod + def drop_collection(self, name: str) -> None: + """Drop collection.""" + + @abstractmethod + def list_collections(self) -> list[str]: + """List all collections.""" + + def close(self) -> None: + """Release backend resources.""" + + def compile_expr(self, expr: FilterExpr | None) -> Dict[str, Any]: + """Compile a filter AST node to vectordb DSL.""" + if expr is None: + return {} + + if isinstance(expr, RawDSL): + return expr.payload + + if isinstance(expr, And): + conds = [self.compile_expr(c) for c in expr.conds if c is not None] + conds = [c for c in conds if c] + if not conds: + return {} + if len(conds) == 1: + return conds[0] + return {"op": "and", "conds": conds} + + if isinstance(expr, Or): + conds = [self.compile_expr(c) for c in expr.conds if c is not None] + conds = [c for c in conds if c] + if not conds: + return {} + if len(conds) == 1: + return conds[0] + return {"op": "or", "conds": conds} + + if isinstance(expr, Eq): + return {"op": "must", "field": expr.field, "conds": [expr.value]} + + if isinstance(expr, In): + return {"op": "must", "field": expr.field, "conds": list(expr.values)} + + if isinstance(expr, Prefix): + # For path fields the current vectordb implementation uses `must` semantics. + return {"op": "must", "field": expr.field, "conds": [expr.prefix]} + + if isinstance(expr, Range): + payload: Dict[str, Any] = {"op": "range", "field": expr.field} + if expr.gte is not None: + payload["gte"] = expr.gte + if expr.gt is not None: + payload["gt"] = expr.gt + if expr.lte is not None: + payload["lte"] = expr.lte + if expr.lt is not None: + payload["lt"] = expr.lt + return payload + + if isinstance(expr, Contains): + return { + "op": "contains", + "field": expr.field, + "substring": expr.substring, + } + + if isinstance(expr, Regex): + return {"op": "regex", "field": expr.field, "pattern": expr.pattern} + + if isinstance(expr, TimeRange): + payload: Dict[str, Any] = {"op": "range", "field": expr.field} + if expr.start is not None: + payload["gte"] = expr.start + if expr.end is not None: + payload["lt"] = expr.end + return payload + + raise TypeError(f"Unsupported filter expr type: {type(expr)!r}") diff --git a/openviking/storage/vector_store/drivers/__init__.py b/openviking/storage/vector_store/drivers/__init__.py new file mode 100644 index 00000000..7b02773b --- /dev/null +++ b/openviking/storage/vector_store/drivers/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Driver module imports for static registration side effects.""" + +from openviking.storage.vector_store.drivers.http_driver import HttpVectorDriver +from openviking.storage.vector_store.drivers.local_driver import LocalVectorDriver +from openviking.storage.vector_store.drivers.vikingdb_driver import VikingDBPrivateDriver +from openviking.storage.vector_store.drivers.volcengine_driver import VolcengineVectorDriver + +__all__ = [ + "LocalVectorDriver", + "HttpVectorDriver", + "VolcengineVectorDriver", + "VikingDBPrivateDriver", +] diff --git a/openviking/storage/vector_store/drivers/common.py b/openviking/storage/vector_store/drivers/common.py new file mode 100644 index 00000000..00856d2b --- /dev/null +++ b/openviking/storage/vector_store/drivers/common.py @@ -0,0 +1,33 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Shared helpers for vector backend drivers.""" + +from __future__ import annotations + +from typing import Any, Iterable +from urllib.parse import urlparse + + +def parse_url(url: str) -> tuple[str, int]: + """Parse backend URL to host/port pair.""" + normalized = url + if not normalized.startswith(("http://", "https://")): + normalized = f"http://{normalized}" + + parsed = urlparse(normalized) + host = parsed.hostname or "127.0.0.1" + port = parsed.port or 5000 + return host, port + + +def normalize_collection_names(raw_collections: Iterable[Any]) -> list[str]: + """Normalize collection listing results to plain collection-name strings.""" + names: list[str] = [] + for item in raw_collections: + if isinstance(item, str): + names.append(item) + elif isinstance(item, dict): + name = item.get("CollectionName") or item.get("collection_name") or item.get("name") + if isinstance(name, str): + names.append(name) + return names diff --git a/openviking/storage/vector_store/drivers/http_driver.py b/openviking/storage/vector_store/drivers/http_driver.py new file mode 100644 index 00000000..18002785 --- /dev/null +++ b/openviking/storage/vector_store/drivers/http_driver.py @@ -0,0 +1,119 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Remote HTTP vector backend driver.""" + +from __future__ import annotations + +from openviking.storage.vector_store.driver import VectorStoreDriver +from openviking.storage.vector_store.drivers.common import normalize_collection_names, parse_url +from openviking.storage.vector_store.registry import register_driver +from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.http_collection import ( + HttpCollection, + get_or_create_http_collection, + list_vikingdb_collections, +) + + +@register_driver("http") +class HttpVectorDriver(VectorStoreDriver): + """Driver for remote HTTP vectordb project.""" + + def __init__(self, host: str, port: int, project_name: str, collection_name: str): + self.mode = "http" + self._host = host + self._port = port + self._project_name = project_name + self._collection_name = collection_name + self._collection = None + + @classmethod + def from_config(cls, config): + if not config.url: + raise ValueError("HTTP backend requires a valid URL") + + host, port = parse_url(config.url) + collection_name = config.name or "context" + project_name = config.project_name or "default" + return cls( + host=host, + port=port, + project_name=project_name, + collection_name=collection_name, + ) + + def _match(self, name: str) -> bool: + return name == self._collection_name + + def _meta(self) -> dict: + return { + "ProjectName": self._project_name, + "CollectionName": self._collection_name, + } + + def _remote_has_collection(self) -> bool: + raw = list_vikingdb_collections( + host=self._host, + port=self._port, + project_name=self._project_name, + ) + names = normalize_collection_names(raw) + return self._collection_name in names + + def _ensure_collection_handle(self) -> None: + if self._collection is not None: + return + if not self._remote_has_collection(): + return + self._collection = Collection( + HttpCollection( + ip=self._host, + port=self._port, + meta_data=self._meta(), + ) + ) + + def has_collection(self, name: str) -> bool: + if not self._match(name): + return False + exists = self._remote_has_collection() + if exists: + self._ensure_collection_handle() + return exists + + def get_collection(self, name: str): + if not self._match(name): + return None + self._ensure_collection_handle() + return self._collection + + def create_collection(self, name: str, meta): + if not self._match(name): + raise ValueError( + f"http backend is bound to collection '{self._collection_name}', got '{name}'" + ) + payload = dict(meta) + payload.update(self._meta()) + self._collection = get_or_create_http_collection( + host=self._host, + port=self._port, + meta_data=payload, + ) + return self._collection + + def drop_collection(self, name: str) -> None: + if not self._match(name): + return + coll = self.get_collection(name) + if coll is None: + return + coll.drop() + self._collection = None + + def list_collections(self) -> list[str]: + return [self._collection_name] if self.has_collection(self._collection_name) else [] + + def close(self) -> None: + if self._collection is not None: + self._collection.close() + self._collection = None diff --git a/openviking/storage/vector_store/drivers/local_driver.py b/openviking/storage/vector_store/drivers/local_driver.py new file mode 100644 index 00000000..c7ce1975 --- /dev/null +++ b/openviking/storage/vector_store/drivers/local_driver.py @@ -0,0 +1,91 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Local persistent vector backend driver.""" + +from __future__ import annotations + +import os +from pathlib import Path + +from openviking.storage.vector_store.driver import VectorStoreDriver +from openviking.storage.vector_store.registry import register_driver +from openviking.storage.vectordb.collection.local_collection import ( + get_or_create_local_collection, +) + + +@register_driver("local") +class LocalVectorDriver(VectorStoreDriver): + """Driver for local embedded vectordb backend.""" + + DEFAULT_LOCAL_PROJECT_NAME = "vectordb" + + def __init__(self, collection_name: str, collection_path: str): + self.mode = "local" + self._collection_name = collection_name + self._collection_path = collection_path + self._collection = None + + @classmethod + def from_config(cls, config): + collection_name = config.name or "context" + if config.path: + project_path = Path(config.path) / cls.DEFAULT_LOCAL_PROJECT_NAME + collection_path = str(project_path / collection_name) + else: + collection_path = "" + return cls(collection_name=collection_name, collection_path=collection_path) + + def _match(self, name: str) -> bool: + return name == self._collection_name + + def _load_existing_collection_if_needed(self) -> None: + if self._collection is not None: + return + if not self._collection_path: + return + meta_path = os.path.join(self._collection_path, "collection_meta.json") + if os.path.exists(meta_path): + self._collection = get_or_create_local_collection(path=self._collection_path) + + def has_collection(self, name: str) -> bool: + if not self._match(name): + return False + self._load_existing_collection_if_needed() + return self._collection is not None + + def get_collection(self, name: str): + if not self._match(name): + return None + self._load_existing_collection_if_needed() + return self._collection + + def create_collection(self, name: str, meta): + if not self._match(name): + raise ValueError( + f"local backend is bound to collection '{self._collection_name}', got '{name}'" + ) + if self._collection is not None: + return self._collection + if self._collection_path: + os.makedirs(self._collection_path, exist_ok=True) + self._collection = get_or_create_local_collection( + meta_data=meta, + path=self._collection_path, + ) + return self._collection + + def drop_collection(self, name: str) -> None: + if not self.has_collection(name): + return + assert self._collection is not None + self._collection.drop() + self._collection = None + + def list_collections(self) -> list[str]: + return [self._collection_name] if self.has_collection(self._collection_name) else [] + + def close(self) -> None: + if self._collection is not None: + self._collection.close() + self._collection = None diff --git a/openviking/storage/vector_store/drivers/vikingdb_driver.py b/openviking/storage/vector_store/drivers/vikingdb_driver.py new file mode 100644 index 00000000..4e21c91d --- /dev/null +++ b/openviking/storage/vector_store/drivers/vikingdb_driver.py @@ -0,0 +1,116 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Private VikingDB deployment backend driver.""" + +from __future__ import annotations + +from openviking.storage.vector_store.driver import VectorStoreDriver +from openviking.storage.vector_store.drivers.common import normalize_collection_names +from openviking.storage.vector_store.registry import register_driver +from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.vikingdb_clients import ( + VIKINGDB_APIS, + VikingDBClient, +) +from openviking.storage.vectordb.collection.vikingdb_collection import VikingDBCollection + + +@register_driver("vikingdb") +class VikingDBPrivateDriver(VectorStoreDriver): + """Driver for private VikingDB deployment.""" + + def __init__( + self, + *, + host: str, + headers: dict | None, + project_name: str, + collection_name: str, + ): + self.mode = "vikingdb" + self._host = host + self._headers = headers + self._project_name = project_name + self._collection_name = collection_name + self._collection = None + + @classmethod + def from_config(cls, config): + if not config.vikingdb or not config.vikingdb.host: + raise ValueError("VikingDB backend requires a valid host") + + return cls( + host=config.vikingdb.host, + headers=config.vikingdb.headers, + project_name=config.project_name or "default", + collection_name=config.name or "context", + ) + + def _match(self, name: str) -> bool: + return name == self._collection_name + + def _client(self) -> VikingDBClient: + return VikingDBClient(self._host, self._headers) + + def _fetch_collection_meta(self): + path, method = VIKINGDB_APIS["GetVikingdbCollection"] + req = { + "ProjectName": self._project_name, + "CollectionName": self._collection_name, + } + response = self._client().do_req(method, path=path, req_body=req) + if response.status_code != 200: + return None + result = response.json() + meta = result.get("Result", {}) + return meta or None + + def has_collection(self, name: str) -> bool: + if not self._match(name): + return False + return self._fetch_collection_meta() is not None + + def get_collection(self, name: str): + if not self._match(name): + return None + if self._collection is not None: + return self._collection + meta = self._fetch_collection_meta() + if meta is None: + return None + self._collection = Collection( + VikingDBCollection( + host=self._host, + headers=self._headers, + meta_data=meta, + ) + ) + return self._collection + + def create_collection(self, name: str, meta): + raise NotImplementedError("private vikingdb collection should be pre-created") + + def drop_collection(self, name: str) -> None: + if not self._match(name): + return + coll = self.get_collection(name) + if coll is None: + return + coll.drop() + self._collection = None + + def list_collections(self) -> list[str]: + path, method = VIKINGDB_APIS["ListVikingdbCollection"] + req = {"ProjectName": self._project_name} + response = self._client().do_req(method, path=path, req_body=req) + if response.status_code != 200: + return [] + result = response.json() + raw = result.get("Result", {}).get("Collections", []) + names = normalize_collection_names(raw) + return [n for n in names if n == self._collection_name] + + def close(self) -> None: + if self._collection is not None: + self._collection.close() + self._collection = None diff --git a/openviking/storage/vector_store/drivers/volcengine_driver.py b/openviking/storage/vector_store/drivers/volcengine_driver.py new file mode 100644 index 00000000..d09cdd6d --- /dev/null +++ b/openviking/storage/vector_store/drivers/volcengine_driver.py @@ -0,0 +1,130 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Volcengine VikingDB backend driver.""" + +from __future__ import annotations + +from openviking.storage.vector_store.driver import VectorStoreDriver +from openviking.storage.vector_store.registry import register_driver +from openviking.storage.vectordb.collection.volcengine_collection import ( + VolcengineCollection, + get_or_create_volcengine_collection, +) + + +@register_driver("volcengine") +class VolcengineVectorDriver(VectorStoreDriver): + """Driver for Volcengine-hosted VikingDB.""" + + def __init__( + self, + *, + ak: str, + sk: str, + region: str, + host: str, + project_name: str, + collection_name: str, + ): + self.mode = "volcengine" + self._ak = ak + self._sk = sk + self._region = region + self._host = host + self._project_name = project_name + self._collection_name = collection_name + self._collection = None + + @classmethod + def from_config(cls, config): + if not ( + config.volcengine + and config.volcengine.ak + and config.volcengine.sk + and config.volcengine.region + ): + raise ValueError("Volcengine backend requires AK, SK, and Region configuration") + + return cls( + ak=config.volcengine.ak, + sk=config.volcengine.sk, + region=config.volcengine.region, + host=config.volcengine.host or "", + project_name=config.project_name or "default", + collection_name=config.name or "context", + ) + + def _match(self, name: str) -> bool: + return name == self._collection_name + + def _meta(self) -> dict: + return { + "ProjectName": self._project_name, + "CollectionName": self._collection_name, + } + + def _config(self) -> dict: + return { + "AK": self._ak, + "SK": self._sk, + "Region": self._region, + "Host": self._host, + } + + def _new_collection_handle(self) -> VolcengineCollection: + return VolcengineCollection( + ak=self._ak, + sk=self._sk, + region=self._region, + host=self._host, + meta_data=self._meta(), + ) + + def has_collection(self, name: str) -> bool: + if not self._match(name): + return False + candidate = self._collection or self._new_collection_handle() + meta = candidate.get_meta_data() or {} + exists = bool(meta and meta.get("CollectionName")) + if exists and self._collection is None: + self._collection = candidate + return exists + + def get_collection(self, name: str): + if not self._match(name): + return None + if self._collection is not None: + return self._collection + if self.has_collection(name): + return self._collection + return None + + def create_collection(self, name: str, meta): + if not self._match(name): + raise ValueError( + f"volcengine backend is bound to collection '{self._collection_name}', got '{name}'" + ) + payload = dict(meta) + payload.update(self._meta()) + self._collection = get_or_create_volcengine_collection( + config=self._config(), + meta_data=payload, + ) + return self._collection + + def drop_collection(self, name: str) -> None: + if not self._match(name): + return + coll = self.get_collection(name) + if coll is None: + return + coll.drop() + self._collection = None + + def list_collections(self) -> list[str]: + return [self._collection_name] if self.has_collection(self._collection_name) else [] + + def close(self) -> None: + if self._collection is not None: + self._collection.close() + self._collection = None diff --git a/openviking/storage/vector_store/expr.py b/openviking/storage/vector_store/expr.py new file mode 100644 index 00000000..9966213a --- /dev/null +++ b/openviking/storage/vector_store/expr.py @@ -0,0 +1,73 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Filter expression AST for vector store queries.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Dict, List, Union + + +@dataclass(frozen=True) +class And: + conds: List["FilterExpr"] + + +@dataclass(frozen=True) +class Or: + conds: List["FilterExpr"] + + +@dataclass(frozen=True) +class Eq: + field: str + value: Any + + +@dataclass(frozen=True) +class In: + field: str + values: List[Any] + + +@dataclass(frozen=True) +class Prefix: + field: str + prefix: str + + +@dataclass(frozen=True) +class Range: + field: str + gte: Any | None = None + gt: Any | None = None + lte: Any | None = None + lt: Any | None = None + + +@dataclass(frozen=True) +class Contains: + field: str + substring: str + + +@dataclass(frozen=True) +class Regex: + field: str + pattern: str + + +@dataclass(frozen=True) +class TimeRange: + field: str + start: datetime | str | None = None + end: datetime | str | None = None + + +@dataclass(frozen=True) +class RawDSL: + payload: Dict[str, Any] + + +FilterExpr = Union[And, Or, Eq, In, Prefix, Range, Contains, Regex, TimeRange, RawDSL] diff --git a/openviking/storage/vector_store/factory.py b/openviking/storage/vector_store/factory.py new file mode 100644 index 00000000..fb670ff0 --- /dev/null +++ b/openviking/storage/vector_store/factory.py @@ -0,0 +1,17 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Factory for vector backend drivers.""" + +from __future__ import annotations + +from openviking.storage.vector_store.driver import VectorStoreDriver +from openviking.storage.vector_store.registry import get_driver_class + + +def create_driver(config) -> VectorStoreDriver: + """Create backend driver from `VectorDBBackendConfig` without backend if/else.""" + # Ensure all static registrations are loaded. + import openviking.storage.vector_store.drivers # noqa: F401 + + driver_cls = get_driver_class(config.backend) + return driver_cls.from_config(config) diff --git a/openviking/storage/vector_store/registry.py b/openviking/storage/vector_store/registry.py new file mode 100644 index 00000000..a80952cc --- /dev/null +++ b/openviking/storage/vector_store/registry.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Static registry for vector backend drivers.""" + +from __future__ import annotations + +from typing import Callable, Dict, Type + +from openviking.storage.vector_store.driver import VectorStoreDriver + +_DRIVER_REGISTRY: Dict[str, Type[VectorStoreDriver]] = {} + + +def register_driver(name: str) -> Callable[[Type[VectorStoreDriver]], Type[VectorStoreDriver]]: + """Register a vector backend driver class by backend name.""" + + def decorator(cls: Type[VectorStoreDriver]) -> Type[VectorStoreDriver]: + _DRIVER_REGISTRY[name] = cls + return cls + + return decorator + + +def get_driver_class(name: str) -> Type[VectorStoreDriver]: + """Resolve registered driver class for backend name.""" + if name not in _DRIVER_REGISTRY: + raise ValueError( + f"Vector backend {name} is not registered. " + f"Available backends: {sorted(_DRIVER_REGISTRY)}" + ) + return _DRIVER_REGISTRY[name] + + +def list_registered_backends() -> list[str]: + return sorted(_DRIVER_REGISTRY) diff --git a/openviking/storage/viking_fs.py b/openviking/storage/viking_fs.py index 78dbecc6..b432f78d 100644 --- a/openviking/storage/viking_fs.py +++ b/openviking/storage/viking_fs.py @@ -25,7 +25,7 @@ from pyagfs.exceptions import AGFSHTTPError from openviking.server.identity import RequestContext, Role -from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway +from openviking.storage.context_vector_gateway import ContextVectorGateway from openviking.storage.vikingdb_interface import VikingDBInterface from openviking.utils.time_utils import format_simplified, get_current_timestamp, parse_iso_datetime from openviking_cli.session.user_id import UserIdentifier @@ -169,8 +169,8 @@ def __init__( self.query_embedder = query_embedder self.rerank_config = rerank_config self.vector_store = vector_store - self._context_semantic_gateway: Optional[ContextSemanticSearchGateway] = ( - ContextSemanticSearchGateway.from_storage(vector_store) if vector_store else None + self._context_vector_gateway: Optional[ContextVectorGateway] = ( + ContextVectorGateway.from_storage(vector_store) if vector_store else None ) self._bound_ctx: contextvars.ContextVar[Optional[RequestContext]] = contextvars.ContextVar( "vikingfs_bound_ctx", default=None @@ -1053,7 +1053,7 @@ async def _delete_from_vector_store( """ if not self._get_vector_store(): return - gateway = self._get_context_semantic_gateway() + gateway = self._get_context_vector_gateway() if not gateway: return real_ctx = self._ctx_or_default(ctx) @@ -1078,7 +1078,7 @@ async def _update_vector_store_uris( """ if not self._get_vector_store(): return - gateway = self._get_context_semantic_gateway() + gateway = self._get_context_vector_gateway() if not gateway: return @@ -1119,15 +1119,15 @@ def _get_vector_store(self) -> Optional["VikingDBInterface"]: """Get vector store instance.""" return self.vector_store - def _get_context_semantic_gateway(self) -> Optional[ContextSemanticSearchGateway]: + def _get_context_vector_gateway(self) -> Optional[ContextVectorGateway]: """Get semantic vector gateway bound to configured collection.""" storage = self._get_vector_store() if not storage: - self._context_semantic_gateway = None + self._context_vector_gateway = None return None - if not self._context_semantic_gateway: - self._context_semantic_gateway = ContextSemanticSearchGateway.from_storage(storage) - return self._context_semantic_gateway + if not self._context_vector_gateway: + self._context_vector_gateway = ContextVectorGateway.from_storage(storage) + return self._context_vector_gateway def _get_embedder(self) -> Any: """Get embedder instance.""" diff --git a/openviking/storage/viking_vector_index_backend.py b/openviking/storage/viking_vector_index_backend.py index 7c579759..c9f14688 100644 --- a/openviking/storage/viking_vector_index_backend.py +++ b/openviking/storage/viking_vector_index_backend.py @@ -8,9 +8,10 @@ """ import uuid -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional +from openviking.storage.vector_store import FilterExpr, create_driver +from openviking.storage.vector_store.expr import RawDSL from openviking.storage.vectordb.collection.collection import Collection from openviking.storage.vectordb.collection.result import FetchDataInCollectionResult from openviking.storage.vectordb.utils.logging_init import init_cpp_logging @@ -34,9 +35,8 @@ class VikingVectorIndexBackend(VikingDBInterface): VikingDBManager is derived by VikingVectorIndexBackend. """ - # Default project and index names + # Default index name DEFAULT_INDEX_NAME = "default" - DEFAULT_LOCAL_PROJECT_NAME = "vectordb" def __init__( self, @@ -75,117 +75,44 @@ def __init__( ) backend = VikingVectorIndexBackend(config=config) """ + if config is None: + raise ValueError("VectorDB backend config is required") + init_cpp_logging() self.vector_dim = config.dimension self.distance_metric = config.distance_metric self.sparse_weight = config.sparse_weight - if config.backend == "volcengine": - if not ( - config.volcengine - and config.volcengine.ak - and config.volcengine.sk - and config.volcengine.region - ): - raise ValueError("Volcengine backend requires AK, SK, and Region configuration") - - # Volcengine VikingDB mode - self._mode = config.backend - # Convert lowercase keys to uppercase for consistency with volcengine_collection - volc_config = { - "AK": config.volcengine.ak, - "SK": config.volcengine.sk, - "Region": config.volcengine.region, - } - - from openviking.storage.vectordb.project.volcengine_project import ( - get_or_create_volcengine_project, - ) - - self.project = get_or_create_volcengine_project( - project_name=config.project_name, config=volc_config - ) - logger.info( - f"VectorDB backend initialized in Volcengine mode: region={volc_config['Region']}" - ) - elif config.backend == "vikingdb": - if not config.vikingdb.host: - raise ValueError("VikingDB backend requires a valid host") - # VikingDB private deployment mode - self._mode = config.backend - viking_config = { - "Host": config.vikingdb.host, - "Headers": config.vikingdb.headers, - } - - from openviking.storage.vectordb.project.vikingdb_project import ( - get_or_create_vikingdb_project, - ) - - self.project = get_or_create_vikingdb_project( - project_name=config.project_name, config=viking_config - ) - logger.info(f"VikingDB backend initialized in private mode: {config.vikingdb.host}") - elif config.backend == "http": - if not config.url: - raise ValueError("HTTP backend requires a valid URL") - # Remote mode: parse URL and create HTTP project - self._mode = config.backend - self.host, self.port = self._parse_url(config.url) - - from openviking.storage.vectordb.project.http_project import get_or_create_http_project - - self.project = get_or_create_http_project( - host=self.host, port=self.port, project_name=config.project_name - ) - logger.info(f"VikingDB backend initialized in remote mode: {config.url}") - elif config.backend == "local": - # Local persistent mode - self._mode = config.backend - from openviking.storage.vectordb.project.local_project import ( - get_or_create_local_project, - ) + # Backend selection is delegated to static driver registry. + self._driver = create_driver(config) + self._mode = self._driver.mode - project_path = ( - Path(config.path) / self.DEFAULT_LOCAL_PROJECT_NAME if config.path else "" - ) - self.project = get_or_create_local_project(path=str(project_path)) - logger.info(f"VikingDB backend initialized with local storage: {project_path}") - else: - raise ValueError(f"Unsupported VikingDB backend type: {config.type}") + logger.info( + "VikingDB backend initialized via driver '%s' (mode=%s)", + type(self._driver).__name__, + self._mode, + ) self._collection_configs: Dict[str, Dict[str, Any]] = {} # Cache meta_data at collection level to avoid repeated remote calls self._meta_data_cache: Dict[str, Dict[str, Any]] = {} - def _parse_url(self, url: str) -> Tuple[str, int]: - """ - Parse VikingVectorIndex service URL to extract host and port. - - Args: - url: Service URL (e.g., "http://localhost:5000" or "localhost:5000") - - Returns: - Tuple of (host, port) - """ - from urllib.parse import urlparse - - # Add scheme if not present - if not url.startswith(("http://", "https://")): - url = f"http://{url}" - - parsed = urlparse(url) - host = parsed.hostname or "127.0.0.1" - port = parsed.port or 5000 - - return host, port + def _compile_filter(self, filter_expr: Optional[FilterExpr | Dict[str, Any]]) -> Dict[str, Any]: + """Compile AST filters via driver; allow raw DSL passthrough.""" + if filter_expr is None: + return {} + if isinstance(filter_expr, dict): + return filter_expr + if isinstance(filter_expr, RawDSL): + return filter_expr.payload + return self._driver.compile_expr(filter_expr) def _get_collection(self, name: str) -> Collection: """Get collection object or raise error if not found.""" - if not self.project.has_collection(name): + if not self._driver.has_collection(name): raise CollectionNotFoundError(f"Collection '{name}' does not exist") - return self.project.get_collection(name) + return self._driver.get_collection(name) def _get_meta_data(self, collection_name: str, coll: Collection) -> Dict[str, Any]: """Get meta_data with collection-level caching to avoid repeated remote calls.""" @@ -240,7 +167,7 @@ async def create_collection(self, name: str, schema: Dict[str, Any]) -> bool: True if created successfully, False if already exists """ try: - if self.project.has_collection(name): + if self._driver.has_collection(name): logger.debug(f"Collection '{name}' already exists") return False @@ -264,8 +191,8 @@ async def create_collection(self, name: str, schema: Dict[str, Any]) -> bool: logger.info(f"Creating collection mode={self._mode} with meta: {collection_meta}") - # Create collection using vectordb project - collection = self.project.create_collection(name, collection_meta) + # Create collection via backend-specific collection driver + collection = self._driver.create_collection(name, collection_meta) # Filter date_time fields for volcengine and vikingdb backends if self._mode in ["volcengine", "vikingdb"]: @@ -323,11 +250,11 @@ async def create_collection(self, name: str, schema: Dict[str, Any]) -> bool: async def drop_collection(self, name: str) -> bool: """Drop a collection.""" try: - if not self.project.has_collection(name): + if not self._driver.has_collection(name): logger.warning(f"Collection '{name}' does not exist") return False - self.project.drop_collection(name) + self._driver.drop_collection(name) self._collection_configs.pop(name, None) # Clear cached meta_data when dropping collection self._meta_data_cache.pop(name, None) @@ -340,16 +267,16 @@ async def drop_collection(self, name: str) -> bool: async def collection_exists(self, name: str) -> bool: """Check if a collection exists.""" - return self.project.has_collection(name) + return self._driver.has_collection(name) async def list_collections(self) -> List[str]: """List all collection names.""" - return self.project.list_collections() + return self._driver.list_collections() async def get_collection_info(self, name: str) -> Optional[Dict[str, Any]]: """Get collection metadata and statistics.""" try: - if not self.project.has_collection(name): + if not self._driver.has_collection(name): return None config = self._collection_configs.get(name, {}) @@ -563,7 +490,7 @@ async def batch_upsert(self, collection: str, data: List[Dict[str, Any]]) -> Lis logger.error(f"Error batch upserting records: {e}") raise - async def batch_delete(self, collection: str, filter: Dict[str, Any]) -> int: + async def batch_delete(self, collection: str, filter: Dict[str, Any] | FilterExpr) -> int: """Delete records matching filter conditions.""" try: # First, find matching records @@ -645,7 +572,7 @@ async def search( collection: str, query_vector: Optional[List[float]] = None, sparse_query_vector: Optional[Dict[str, float]] = None, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Dict[str, Any] | FilterExpr] = None, limit: int = 10, offset: int = 0, output_fields: Optional[List[str]] = None, @@ -669,8 +596,7 @@ async def search( coll = self._get_collection(collection) try: - # Filter is already in vectordb DSL format - vectordb_filter = filter if filter else {} + vectordb_filter = self._compile_filter(filter) if query_vector or sparse_query_vector: # Vector search (dense, sparse, or hybrid) with optional filtering @@ -715,7 +641,7 @@ async def search( async def filter( self, collection: str, - filter: Dict[str, Any], + filter: Dict[str, Any] | FilterExpr, limit: int = 10, offset: int = 0, output_fields: Optional[List[str]] = None, @@ -726,8 +652,7 @@ async def filter( coll = self._get_collection(collection) try: - # Filter is already in vectordb DSL format - vectordb_filter = filter if filter else {} + vectordb_filter = self._compile_filter(filter) if order_by: # Use search_by_scalar for sorting @@ -770,11 +695,11 @@ async def filter( async def scroll( self, collection: str, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Dict[str, Any] | FilterExpr] = None, limit: int = 100, cursor: Optional[str] = None, output_fields: Optional[List[str]] = None, - ) -> Tuple[List[Dict[str, Any]], Optional[str]]: + ) -> tuple[List[Dict[str, Any]], Optional[str]]: """Scroll through large result sets efficiently.""" # vectordb doesn't natively support scroll, so we simulate it offset = int(cursor) if cursor else 0 @@ -796,12 +721,16 @@ async def scroll( # Aggregation Operations # ========================================================================= - async def count(self, collection: str, filter: Optional[Dict[str, Any]] = None) -> int: + async def count( + self, collection: str, filter: Optional[Dict[str, Any] | FilterExpr] = None + ) -> int: """Count records matching filter.""" try: coll = self._get_collection(collection) result = coll.aggregate_data( - index_name=self.DEFAULT_INDEX_NAME, op="count", filters=filter + index_name=self.DEFAULT_INDEX_NAME, + op="count", + filters=self._compile_filter(filter), ) return result.agg.get("_total", 0) except Exception as e: @@ -868,8 +797,7 @@ async def optimize(self, collection: str) -> bool: async def close(self) -> None: """Close storage connection and release resources.""" try: - if self.project: - self.project.close() + self._driver.close() self._collection_configs.clear() logger.info("VikingDB backend closed") @@ -883,8 +811,8 @@ async def close(self) -> None: async def health_check(self) -> bool: """Check if storage backend is healthy and accessible.""" try: - # Simple check: verify we can access the project - self.project.list_collections() + # Simple check: verify we can access collections metadata. + self._driver.list_collections() return True except Exception: return False @@ -892,7 +820,7 @@ async def health_check(self) -> bool: async def get_stats(self) -> Dict[str, Any]: """Get storage statistics.""" try: - collections = self.project.list_collections() + collections = self._driver.list_collections() # Count total records across all collections using aggregate_data total_records = 0 diff --git a/openviking/storage/vikingdb_interface.py b/openviking/storage/vikingdb_interface.py index f11e8116..a3f6ec67 100644 --- a/openviking/storage/vikingdb_interface.py +++ b/openviking/storage/vikingdb_interface.py @@ -10,6 +10,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple +from openviking.storage.vector_store.expr import FilterExpr + class VikingDBInterface(ABC): """ @@ -240,13 +242,13 @@ async def batch_upsert(self, collection: str, data: List[Dict[str, Any]]) -> Lis pass @abstractmethod - async def batch_delete(self, collection: str, filter: Dict[str, Any]) -> int: + async def batch_delete(self, collection: str, filter: Dict[str, Any] | FilterExpr) -> int: """ Delete records matching filter conditions. Args: collection: Collection name - filter: Filter conditions + filter: Filter conditions (AST or backend DSL dict) Returns: Number of records deleted @@ -291,7 +293,7 @@ async def search( collection: str, query_vector: Optional[List[float]] = None, sparse_query_vector: Optional[Dict[str, float]] = None, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Dict[str, Any] | FilterExpr] = None, limit: int = 10, offset: int = 0, output_fields: Optional[List[str]] = None, @@ -304,7 +306,7 @@ async def search( collection: Collection name query_vector: Dense query vector for similarity search (optional) sparse_query_vector: Sparse query vector for term matching (optional, Dict[str, float]) - filter: Scalar filter conditions (optional) + filter: Scalar filter conditions (optional, AST or backend DSL dict) limit: Maximum number of results offset: Offset for pagination output_fields: Fields to return (None for all) @@ -351,7 +353,7 @@ async def search( async def filter( self, collection: str, - filter: Dict[str, Any], + filter: Dict[str, Any] | FilterExpr, limit: int = 10, offset: int = 0, output_fields: Optional[List[str]] = None, @@ -363,7 +365,7 @@ async def filter( Args: collection: Collection name - filter: Filter conditions + filter: Filter conditions (AST or backend DSL dict) limit: Maximum number of results offset: Offset for pagination output_fields: Fields to return @@ -379,7 +381,7 @@ async def filter( async def scroll( self, collection: str, - filter: Optional[Dict[str, Any]] = None, + filter: Optional[Dict[str, Any] | FilterExpr] = None, limit: int = 100, cursor: Optional[str] = None, output_fields: Optional[List[str]] = None, @@ -389,7 +391,7 @@ async def scroll( Args: collection: Collection name - filter: Optional filter conditions + filter: Optional filter conditions (AST or backend DSL dict) limit: Batch size cursor: Cursor from previous scroll (None for first batch) output_fields: Fields to return @@ -414,13 +416,15 @@ async def scroll( # ========================================================================= @abstractmethod - async def count(self, collection: str, filter: Optional[Dict[str, Any]] = None) -> int: + async def count( + self, collection: str, filter: Optional[Dict[str, Any] | FilterExpr] = None + ) -> int: """ Count records matching filter. Args: collection: Collection name - filter: Optional filter conditions + filter: Optional filter conditions (AST or backend DSL dict) Returns: Number of matching records diff --git a/tests/retrieve/test_hierarchical_retriever_target_dirs.py b/tests/retrieve/test_hierarchical_retriever_target_dirs.py index 31c8165a..ba54a24a 100644 --- a/tests/retrieve/test_hierarchical_retriever_target_dirs.py +++ b/tests/retrieve/test_hierarchical_retriever_target_dirs.py @@ -7,6 +7,7 @@ from openviking.retrieve.hierarchical_retriever import HierarchicalRetriever from openviking.server.identity import RequestContext, Role +from openviking.storage.vector_store.expr import Prefix from openviking_cli.retrieve.types import ContextType, TypedQuery from openviking_cli.session.user_id import UserIdentifier @@ -43,6 +44,8 @@ async def search( def _contains_uri_scope_filter(obj, target_uri: str) -> bool: + if isinstance(obj, Prefix): + return obj.field == "uri" and obj.prefix == target_uri if isinstance(obj, dict): if ( obj.get("op") == "must" @@ -53,6 +56,8 @@ def _contains_uri_scope_filter(obj, target_uri: str) -> bool: return any(_contains_uri_scope_filter(v, target_uri) for v in obj.values()) if isinstance(obj, list): return any(_contains_uri_scope_filter(v, target_uri) for v in obj) + if hasattr(obj, "__dict__"): + return any(_contains_uri_scope_filter(v, target_uri) for v in vars(obj).values()) return False diff --git a/tests/session/test_memory_dedup_actions.py b/tests/session/test_memory_dedup_actions.py index bd88306f..6b94a674 100644 --- a/tests/session/test_memory_dedup_actions.py +++ b/tests/session/test_memory_dedup_actions.py @@ -22,6 +22,7 @@ MemoryExtractor, MergedMemoryPayload, ) +from openviking.storage.vector_store.expr import And, Eq, Prefix from openviking_cli.session.user_id import UserIdentifier @@ -176,22 +177,19 @@ async def test_find_similar_memories_uses_path_must_filter_and__score(self): assert len(similar) == 1 assert similar[0].uri == existing.uri call = vikingdb.search.await_args.kwargs - assert call["filter"]["op"] == "and" - assert {"field": "context_type", "op": "must", "conds": ["memory"]} in call["filter"][ - "conds" - ] - assert {"field": "level", "op": "must", "conds": [2]} in call["filter"]["conds"] - assert {"field": "account_id", "op": "must", "conds": ["acc1"]} in call["filter"]["conds"] - assert { - "field": "owner_space", - "op": "must", - "conds": [_make_user().user_space_name()], - } in call["filter"]["conds"] - assert { - "field": "uri", - "op": "must", - "conds": [f"viking://user/{_make_user().user_space_name()}/memories/preferences/"], - } in call["filter"]["conds"] + assert isinstance(call["filter"], And) + conds = call["filter"].conds + assert Eq("context_type", "memory") in conds + assert Eq("level", 2) in conds + assert Eq("account_id", "acc1") in conds + assert Eq("owner_space", _make_user().user_space_name()) in conds + assert ( + Prefix( + "uri", + f"viking://user/{_make_user().user_space_name()}/memories/preferences/", + ) + in conds + ) @pytest.mark.asyncio async def test_find_similar_memories_accepts_low_score_when_threshold_is_zero(self): diff --git a/tests/storage/test_context_semantic_gateway.py b/tests/storage/test_context_vector_gateway.py similarity index 79% rename from tests/storage/test_context_semantic_gateway.py rename to tests/storage/test_context_vector_gateway.py index 0cfe53fb..30b054f9 100644 --- a/tests/storage/test_context_semantic_gateway.py +++ b/tests/storage/test_context_vector_gateway.py @@ -6,7 +6,8 @@ import pytest from openviking.server.identity import RequestContext, Role -from openviking.storage.context_semantic_gateway import ContextSemanticSearchGateway +from openviking.storage.context_vector_gateway import ContextVectorGateway +from openviking.storage.vector_store.expr import And from openviking_cli.session.user_id import UserIdentifier @@ -18,7 +19,7 @@ def _make_ctx() -> RequestContext: async def test_search_in_tenant_uses_bound_collection_and_tenant_scope(): storage = AsyncMock() storage.search.return_value = [] - gateway = ContextSemanticSearchGateway.from_storage(storage, collection_name="ctx_custom") + gateway = ContextVectorGateway.from_storage(storage, collection_name="ctx_custom") await gateway.search_in_tenant( ctx=_make_ctx(), @@ -30,7 +31,7 @@ async def test_search_in_tenant_uses_bound_collection_and_tenant_scope(): call = storage.search.await_args.kwargs assert call["collection"] == "ctx_custom" - assert call["filter"]["op"] == "and" + assert isinstance(call["filter"], And) @pytest.mark.asyncio @@ -38,7 +39,7 @@ async def test_increment_active_count_updates_by_uri(): storage = AsyncMock() storage.filter.return_value = [{"id": "r1", "active_count": 3}] storage.update.return_value = True - gateway = ContextSemanticSearchGateway.from_storage(storage, collection_name="ctx_custom") + gateway = ContextVectorGateway.from_storage(storage, collection_name="ctx_custom") updated = await gateway.increment_active_count(_make_ctx(), ["viking://resources/foo"]) From 73ecb526bd9e20131137bc2dcf21a3821d703b09 Mon Sep 17 00:00:00 2001 From: "zhoujiahui.01" Date: Thu, 26 Feb 2026 18:36:19 +0800 Subject: [PATCH 3/7] refactor(storage): collapse gateway/interface into single-collection backend --- openviking/core/directories.py | 3 +- openviking/eval/ragas/playback.py | 73 ++- openviking/eval/recorder/wrapper.py | 66 +- openviking/retrieve/hierarchical_retriever.py | 17 +- openviking/server/routers/admin.py | 4 +- openviking/session/compressor.py | 4 +- openviking/session/memory_deduplicator.py | 6 +- openviking/session/session.py | 5 +- openviking/storage/__init__.py | 15 +- openviking/storage/collection_schemas.py | 9 +- openviking/storage/context_vector_gateway.py | 300 --------- openviking/storage/errors.py | 31 + .../storage/observers/vikingdb_observer.py | 2 +- openviking/storage/queuefs/queue_manager.py | 6 +- openviking/storage/vector_store/driver.py | 37 ++ .../vector_store/drivers/vikingdb_driver.py | 43 ++ .../vector_store/drivers/volcengine_driver.py | 43 ++ openviking/storage/viking_fs.py | 44 +- .../storage/viking_vector_index_backend.py | 448 +++++++++---- openviking/storage/vikingdb_interface.py | 593 ------------------ 20 files changed, 636 insertions(+), 1113 deletions(-) delete mode 100644 openviking/storage/context_vector_gateway.py create mode 100644 openviking/storage/errors.py delete mode 100644 openviking/storage/vikingdb_interface.py diff --git a/openviking/core/directories.py b/openviking/core/directories.py index c2b07302..4c8c598f 100644 --- a/openviking/core/directories.py +++ b/openviking/core/directories.py @@ -12,7 +12,6 @@ from openviking.core.context import Context, ContextType, Vectorize from openviking.server.identity import RequestContext -from openviking.storage.context_vector_gateway import ContextVectorGateway from openviking.storage.queuefs.embedding_msg_converter import EmbeddingMsgConverter if TYPE_CHECKING: @@ -146,7 +145,7 @@ def __init__( vikingdb: "VikingDBManager", ): self.vikingdb = vikingdb - self.semantic_gateway = ContextVectorGateway.from_storage(vikingdb) + self.semantic_gateway = vikingdb async def initialize_account_directories(self, ctx: RequestContext) -> int: """Initialize account-shared scope roots.""" diff --git a/openviking/eval/ragas/playback.py b/openviking/eval/ragas/playback.py index 6ec9e10c..8f48d820 100644 --- a/openviking/eval/ragas/playback.py +++ b/openviking/eval/ragas/playback.py @@ -415,21 +415,78 @@ async def _play_vikingdb_operation(self, record: IORecord) -> PlaybackResult: kwargs = request.get("kwargs", {}) if operation == "insert": - await self._vector_store.insert(*args, **kwargs) + if args: + payload = args[-1] + else: + payload = kwargs.get("data", request.get("data", {})) + await self._vector_store.insert(payload) elif operation == "update": - await self._vector_store.update(*args, **kwargs) + if len(args) >= 3: + record_id = args[-2] + payload = args[-1] + elif len(args) == 2: + record_id = args[0] + payload = args[1] + else: + record_id = kwargs.get("id", request.get("id")) + payload = kwargs.get("data", request.get("data", {})) + await self._vector_store.update(record_id, payload) elif operation == "upsert": - await self._vector_store.upsert(*args, **kwargs) + if args: + payload = args[-1] + else: + payload = kwargs.get("data", request.get("data", {})) + await self._vector_store.upsert(payload) elif operation == "delete": - await self._vector_store.delete(*args, **kwargs) + if args: + ids = args[-1] + else: + ids = kwargs.get("ids", request.get("ids", [])) + await self._vector_store.delete(ids) elif operation == "get": - await self._vector_store.get(*args, **kwargs) + if args: + ids = args[-1] + else: + ids = kwargs.get("ids", request.get("ids", [])) + await self._vector_store.get(ids) elif operation == "exists": - await self._vector_store.exists(*args, **kwargs) + if len(args) >= 2: + record_id = args[-1] + elif len(args) == 1: + record_id = args[0] + else: + record_id = kwargs.get("id", request.get("id")) + await self._vector_store.exists(record_id) elif operation == "search": - await self._vector_store.search(*args, **kwargs) + if len(args) >= 4: + query_vector = args[1] + limit = args[2] + where = args[3] + elif args: + query_vector = args[0] + limit = kwargs.get("top_k", kwargs.get("limit", 10)) + where = kwargs.get("filter") + else: + query_vector = kwargs.get("vector", kwargs.get("query_vector")) + limit = kwargs.get("top_k", kwargs.get("limit", request.get("top_k", 10))) + where = kwargs.get("filter", request.get("filter")) + await self._vector_store.search( + query_vector=query_vector, filter=where, limit=limit + ) elif operation == "filter": - await self._vector_store.filter(*args, **kwargs) + if len(args) >= 4: + where = args[1] + limit = args[2] + offset = args[3] + elif args: + where = args[0] + limit = kwargs.get("limit", 100) + offset = kwargs.get("offset", 0) + else: + where = kwargs.get("filter", request.get("filter", {})) + limit = kwargs.get("limit", request.get("limit", 100)) + offset = kwargs.get("offset", request.get("offset", 0)) + await self._vector_store.filter(filter=where, limit=limit, offset=offset) elif operation == "create_collection": await self._vector_store.create_collection(*args, **kwargs) elif operation == "drop_collection": diff --git a/openviking/eval/recorder/wrapper.py b/openviking/eval/recorder/wrapper.py index e8fb2686..cfb13793 100644 --- a/openviking/eval/recorder/wrapper.py +++ b/openviking/eval/recorder/wrapper.py @@ -105,12 +105,35 @@ def __getattr__(self, name: str) -> Any: if not callable(original_attr) or name.startswith("_"): return original_attr # viking_fs文件操作 - if name not in ("ls", "mkdir", "stat", "rm", "mv", "read", "write", "grep", "glob", "tree", - "abstract", "overview", "relations", "link", "unlink", - "write_file", "read_file", "read_file_bytes", "write_file_bytes", "append_file", "move_file", - "delete_temp", "write_context", "get_relations", "get_relations_with_content", - "find", "search", - ): + if name not in ( + "ls", + "mkdir", + "stat", + "rm", + "mv", + "read", + "write", + "grep", + "glob", + "tree", + "abstract", + "overview", + "relations", + "link", + "unlink", + "write_file", + "read_file", + "read_file_bytes", + "write_file_bytes", + "append_file", + "move_file", + "delete_temp", + "write_context", + "get_relations", + "get_relations_with_content", + "find", + "search", + ): return original_attr async def wrapped_async(*args, **kwargs): @@ -179,6 +202,7 @@ def wrapped_sync(*args, **kwargs): raise import inspect + if inspect.iscoroutinefunction(original_attr) or name.startswith("_"): return wrapped_async @@ -201,6 +225,7 @@ def _build_request(self, name: str, args: tuple, kwargs: dict) -> Dict[str, Any] param_names = [] try: import inspect + original_attr = getattr(self._fs, name, None) if original_attr and callable(original_attr): sig = inspect.signature(original_attr) @@ -223,7 +248,7 @@ def _build_request(self, name: str, args: tuple, kwargs: dict) -> Dict[str, Any] class RecordingVikingDB: """ - Wrapper for VikingDBInterface that records all operations. + Wrapper for vector store instances that records all operations. Usage: from openviking.eval.recorder import init_recorder @@ -239,7 +264,7 @@ def __init__(self, viking_db: Any, recorder: Optional[IORecorder] = None): Initialize wrapper. Args: - viking_db: VikingDBInterface instance to wrap + viking_db: Vector store instance to wrap recorder: IORecorder instance (uses global if None) """ self._db = viking_db @@ -269,7 +294,7 @@ async def insert(self, collection: str, data: Dict[str, Any]) -> str: request = {"collection": collection, "data": data} start_time = time.time() try: - result = await self._db.insert(collection, data) + result = await self._db.insert(data) latency_ms = (time.time() - start_time) * 1000 self._record("insert", request, result, latency_ms) return result @@ -283,7 +308,7 @@ async def update(self, collection: str, id: str, data: Dict[str, Any]) -> bool: request = {"collection": collection, "id": id, "data": data} start_time = time.time() try: - result = await self._db.update(collection, id, data) + result = await self._db.update(id, data) latency_ms = (time.time() - start_time) * 1000 self._record("update", request, result, latency_ms) return result @@ -297,7 +322,7 @@ async def upsert(self, collection: str, data: Dict[str, Any]) -> str: request = {"collection": collection, "data": data} start_time = time.time() try: - result = await self._db.upsert(collection, data) + result = await self._db.upsert(data) latency_ms = (time.time() - start_time) * 1000 self._record("upsert", request, result, latency_ms) return result @@ -311,7 +336,7 @@ async def delete(self, collection: str, ids: List[str]) -> int: request = {"collection": collection, "ids": ids} start_time = time.time() try: - result = await self._db.delete(collection, ids) + result = await self._db.delete(ids) latency_ms = (time.time() - start_time) * 1000 self._record("delete", request, result, latency_ms) return result @@ -325,7 +350,7 @@ async def get(self, collection: str, ids: List[str]) -> List[Dict[str, Any]]: request = {"collection": collection, "ids": ids} start_time = time.time() try: - result = await self._db.get(collection, ids) + result = await self._db.get(ids) latency_ms = (time.time() - start_time) * 1000 self._record("get", request, result, latency_ms) return result @@ -339,7 +364,7 @@ async def exists(self, collection: str, id: str) -> bool: request = {"collection": collection, "id": id} start_time = time.time() try: - result = await self._db.exists(collection, id) + result = await self._db.exists(id) latency_ms = (time.time() - start_time) * 1000 self._record("exists", request, result, latency_ms) return result @@ -359,7 +384,11 @@ async def search( request = {"collection": collection, "vector": vector, "top_k": top_k, "filter": filter} start_time = time.time() try: - result = await self._db.search(collection, vector, top_k, filter) + result = await self._db.search( + query_vector=vector, + filter=filter, + limit=top_k, + ) latency_ms = (time.time() - start_time) * 1000 self._record("search", request, result, latency_ms) return result @@ -379,7 +408,11 @@ async def filter( request = {"collection": collection, "filter": filter, "limit": limit, "offset": offset} start_time = time.time() try: - result = await self._db.filter(collection, filter, limit, offset) + result = await self._db.filter( + filter=filter, + limit=limit, + offset=offset, + ) latency_ms = (time.time() - start_time) * 1000 self._record("filter", request, result, latency_ms) return result @@ -447,4 +480,3 @@ async def list_collections(self) -> List[str]: def __getattr__(self, name: str) -> Any: """Pass through any other attributes to the wrapped db.""" return getattr(self._db, name) - diff --git a/openviking/retrieve/hierarchical_retriever.py b/openviking/retrieve/hierarchical_retriever.py index f4eb4896..d342c3d9 100644 --- a/openviking/retrieve/hierarchical_retriever.py +++ b/openviking/retrieve/hierarchical_retriever.py @@ -14,7 +14,7 @@ from openviking.models.embedder.base import EmbedResult from openviking.retrieve.memory_lifecycle import hotness_score from openviking.server.identity import RequestContext, Role -from openviking.storage import ContextVectorGateway, VikingDBInterface +from openviking.storage import VikingVectorIndexBackend from openviking.storage.viking_fs import get_viking_fs from openviking_cli.retrieve.types import ( ContextType, @@ -46,19 +46,18 @@ class HierarchicalRetriever: def __init__( self, - storage: VikingDBInterface, + storage: VikingVectorIndexBackend, embedder: Optional[Any], rerank_config: Optional[RerankConfig] = None, ): """Initialize hierarchical retriever with rerank_config. Args: - storage: VikingDBInterface instance + storage: VikingVectorIndexBackend instance embedder: Embedder instance (supports dense/sparse/hybrid) rerank_config: Rerank configuration (optional, will fallback to vector search only) """ - self.storage = storage - self.semantic_gateway = ContextVectorGateway.from_storage(storage) + self.vector_store = storage self.embedder = embedder self.rerank_config = rerank_config @@ -104,10 +103,10 @@ async def retrieve( target_dirs = [d for d in (query.target_directories or []) if d] - if not await self.semantic_gateway.collection_exists_bound(): + if not await self.vector_store.collection_exists_bound(): logger.warning( "[RecursiveSearch] Collection %s does not exist", - self.semantic_gateway.collection_name, + self.vector_store.collection_name, ) return QueryResult( query=query, @@ -179,7 +178,7 @@ async def _global_vector_search( limit: int, ) -> List[Dict[str, Any]]: """Global vector search to locate initial directories.""" - results = await self.semantic_gateway.search_global_roots_in_tenant( + results = await self.vector_store.search_global_roots_in_tenant( ctx=ctx, query_vector=query_vector, sparse_query_vector=sparse_query_vector, @@ -285,7 +284,7 @@ def passes_threshold(score: float) -> bool: pre_filter_limit = max(limit * 2, 20) - results = await self.semantic_gateway.search_children_in_tenant( + results = await self.vector_store.search_children_in_tenant( ctx=ctx, parent_uri=current_uri, query_vector=query_vector, diff --git a/openviking/server/routers/admin.py b/openviking/server/routers/admin.py index e68f0626..b768fe69 100644 --- a/openviking/server/routers/admin.py +++ b/openviking/server/routers/admin.py @@ -9,7 +9,6 @@ from openviking.server.dependencies import get_service from openviking.server.identity import RequestContext, Role from openviking.server.models import Response -from openviking.storage.context_vector_gateway import ContextVectorGateway from openviking.storage.viking_fs import get_viking_fs from openviking_cli.exceptions import PermissionDeniedError from openviking_cli.session.user_id import UserIdentifier @@ -121,8 +120,7 @@ async def delete_account( try: storage = viking_fs._get_vector_store() if storage: - gateway = ContextVectorGateway.from_storage(storage) - deleted = await gateway.delete_account_data(account_id) + deleted = await storage.delete_account_data(account_id) logger.info(f"VectorDB cascade delete for account {account_id}: {deleted} records") except Exception as e: logger.warning(f"VectorDB cleanup for account {account_id}: {e}") diff --git a/openviking/session/compressor.py b/openviking/session/compressor.py index 139c6acb..5f1b6c8f 100644 --- a/openviking/session/compressor.py +++ b/openviking/session/compressor.py @@ -14,7 +14,6 @@ from openviking.message import Message from openviking.server.identity import RequestContext from openviking.storage import VikingDBManager -from openviking.storage.context_vector_gateway import ContextVectorGateway from openviking.storage.viking_fs import get_viking_fs from openviking_cli.session.user_id import UserIdentifier from openviking_cli.utils import get_logger @@ -54,7 +53,6 @@ def __init__( ): """Initialize session compressor.""" self.vikingdb = vikingdb - self.semantic_gateway = ContextVectorGateway.from_storage(vikingdb) self.extractor = MemoryExtractor() self.deduplicator = MemoryDeduplicator(vikingdb=vikingdb) @@ -115,7 +113,7 @@ async def _delete_existing_memory( try: # rm() already syncs vector deletion in most cases; keep this as a safe fallback. - await self.semantic_gateway.delete_uris(ctx, [memory.uri]) + await self.vikingdb.delete_uris(ctx, [memory.uri]) except Exception as e: logger.warning(f"Failed to remove vector record for {memory.uri}: {e}") return True diff --git a/openviking/session/memory_deduplicator.py b/openviking/session/memory_deduplicator.py index ceb207ca..cb038c19 100644 --- a/openviking/session/memory_deduplicator.py +++ b/openviking/session/memory_deduplicator.py @@ -16,7 +16,6 @@ from openviking.models.embedder.base import EmbedResult from openviking.prompts import render_prompt from openviking.storage import VikingDBManager -from openviking.storage.context_vector_gateway import ContextVectorGateway from openviking_cli.utils import get_logger from openviking_cli.utils.config import get_openviking_config @@ -84,8 +83,7 @@ def __init__( ): """Initialize deduplicator.""" self.vikingdb = vikingdb - self.semantic_gateway = ContextVectorGateway.from_storage(vikingdb) - self.embedder = vikingdb.get_embedder() + self.embedder = self.vikingdb.get_embedder() async def deduplicate( self, @@ -150,7 +148,7 @@ async def _find_similar_memories( try: # Search with memory-scope filter. - results = await self.semantic_gateway.search_similar_memories( + results = await self.vikingdb.search_similar_memories( account_id=account_id, owner_space=owner_space, category_uri_prefix=category_uri_prefix, diff --git a/openviking/session/session.py b/openviking/session/session.py index f26d7750..07cb5fde 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -14,7 +14,6 @@ from openviking.message import Message, Part from openviking.server.identity import RequestContext, Role -from openviking.storage.context_vector_gateway import ContextVectorGateway from openviking.utils.time_utils import get_current_timestamp from openviking_cli.session.user_id import UserIdentifier from openviking_cli.utils import get_logger, run_async @@ -78,9 +77,7 @@ def __init__( ): self._viking_fs = viking_fs self._vikingdb_manager = vikingdb_manager - self._semantic_gateway = ( - ContextVectorGateway.from_storage(vikingdb_manager) if vikingdb_manager else None - ) + self._semantic_gateway = vikingdb_manager self._session_compressor = session_compressor self.user = user or UserIdentifier.the_default_user() self.ctx = ctx or RequestContext(user=self.user, role=Role.ROOT) diff --git a/openviking/storage/__init__.py b/openviking/storage/__init__.py index b6a0abfc..ac4cd8d5 100644 --- a/openviking/storage/__init__.py +++ b/openviking/storage/__init__.py @@ -2,25 +2,21 @@ # SPDX-License-Identifier: Apache-2.0 """Storage layer interfaces and implementations.""" -from openviking.storage.context_vector_gateway import ContextVectorGateway -from openviking.storage.observers import BaseObserver, QueueObserver -from openviking.storage.queuefs import QueueManager, get_queue_manager, init_queue_manager -from openviking.storage.viking_fs import VikingFS, get_viking_fs, init_viking_fs -from openviking.storage.viking_vector_index_backend import VikingVectorIndexBackend -from openviking.storage.vikingdb_interface import ( +from openviking.storage.errors import ( CollectionNotFoundError, ConnectionError, DuplicateKeyError, RecordNotFoundError, SchemaError, StorageException, - VikingDBInterface, ) +from openviking.storage.observers import BaseObserver, QueueObserver +from openviking.storage.queuefs import QueueManager, get_queue_manager, init_queue_manager +from openviking.storage.viking_fs import VikingFS, get_viking_fs, init_viking_fs +from openviking.storage.viking_vector_index_backend import VikingVectorIndexBackend from openviking.storage.vikingdb_manager import VikingDBManager __all__ = [ - # Interface - "VikingDBInterface", # Exceptions "StorageException", "CollectionNotFoundError", @@ -31,7 +27,6 @@ # Backend "VikingVectorIndexBackend", "VikingDBManager", - "ContextVectorGateway", # QueueFS "QueueManager", "init_queue_manager", diff --git a/openviking/storage/collection_schemas.py b/openviking/storage/collection_schemas.py index 56b59cd9..f5086d1a 100644 --- a/openviking/storage/collection_schemas.py +++ b/openviking/storage/collection_schemas.py @@ -13,9 +13,10 @@ from typing import Any, Dict, Optional from openviking.models.embedder.base import EmbedResult +from openviking.storage.errors import CollectionNotFoundError from openviking.storage.queuefs.embedding_msg import EmbeddingMsg from openviking.storage.queuefs.named_queue import DequeueHandlerBase -from openviking.storage.vikingdb_interface import CollectionNotFoundError, VikingDBInterface +from openviking.storage.viking_vector_index_backend import VikingVectorIndexBackend from openviking_cli.utils import get_logger from openviking_cli.utils.config.open_viking_config import OpenVikingConfig @@ -126,11 +127,11 @@ class TextEmbeddingHandler(DequeueHandlerBase): Supports both dense and sparse embeddings based on configuration. """ - def __init__(self, vikingdb: VikingDBInterface): + def __init__(self, vikingdb: VikingVectorIndexBackend): """Initialize the text embedding handler. Args: - vikingdb: VikingDBInterface instance for writing to vector database + vikingdb: VikingVectorIndexBackend instance for writing to vector database """ from openviking_cli.utils.config import get_openviking_config @@ -207,7 +208,7 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, id_seed = f"{account_id}:{owner_space}:{uri}" inserted_data["id"] = hashlib.md5(id_seed.encode("utf-8")).hexdigest() - record_id = await self._vikingdb.insert(self._collection_name, inserted_data) + record_id = await self._vikingdb.insert(inserted_data) if record_id: logger.debug( f"Successfully wrote embedding to database: {record_id} abstract {inserted_data['abstract']} vector {inserted_data['vector'][:5]}" diff --git a/openviking/storage/context_vector_gateway.py b/openviking/storage/context_vector_gateway.py deleted file mode 100644 index 73376a41..00000000 --- a/openviking/storage/context_vector_gateway.py +++ /dev/null @@ -1,300 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -""" -Semantic vector gateway for OpenViking business flows. - -This module keeps raw filter DSL usage inside storage integration code so -business modules can call intent-based methods. -""" - -from __future__ import annotations - -from typing import Any, Dict, List, Optional - -from openviking.server.identity import RequestContext, Role -from openviking.storage.vector_store.expr import ( - And, - Eq, - FilterExpr, - In, - Or, - Prefix, - RawDSL, -) -from openviking.storage.vikingdb_interface import VikingDBInterface -from openviking_cli.utils.config import get_openviking_config - -WhereExpr = FilterExpr | Dict[str, Any] - - -class ContextVectorGateway: - """Semantic methods over the bound context collection.""" - - def __init__(self, storage: VikingDBInterface, collection_name: str): - self._storage = storage - self._collection_name = collection_name - - @classmethod - def from_storage( - cls, storage: VikingDBInterface, collection_name: Optional[str] = None - ) -> "ContextVectorGateway": - if collection_name: - bound_collection = collection_name - else: - try: - bound_collection = get_openviking_config().storage.vectordb.name - except Exception: - # Keep simple tests and lightweight call sites usable. - bound_collection = "context" - return cls(storage=storage, collection_name=bound_collection) - - @property - def collection_name(self) -> str: - return self._collection_name - - async def collection_exists_bound(self) -> bool: - return await self._storage.collection_exists(self._collection_name) - - async def search_in_tenant( - self, - ctx: RequestContext, - query_vector: Optional[List[float]], - sparse_query_vector: Optional[Dict[str, float]] = None, - context_type: Optional[str] = None, - target_directories: Optional[List[str]] = None, - extra_filter: Optional[WhereExpr] = None, - limit: int = 10, - offset: int = 0, - ) -> List[Dict[str, Any]]: - scope_filter = self._build_scope_filter( - ctx=ctx, - context_type=context_type, - target_directories=target_directories, - extra_filter=extra_filter, - ) - return await self._storage.search( - collection=self._collection_name, - query_vector=query_vector, - sparse_query_vector=sparse_query_vector, - filter=scope_filter, - limit=limit, - offset=offset, - ) - - async def search_global_roots_in_tenant( - self, - ctx: RequestContext, - query_vector: Optional[List[float]], - sparse_query_vector: Optional[Dict[str, float]] = None, - context_type: Optional[str] = None, - target_directories: Optional[List[str]] = None, - extra_filter: Optional[WhereExpr] = None, - limit: int = 10, - ) -> List[Dict[str, Any]]: - if not query_vector: - return [] - - merged_filter = self._merge_filters( - self._build_scope_filter( - ctx=ctx, - context_type=context_type, - target_directories=target_directories, - extra_filter=extra_filter, - ), - In("level", [0, 1]), - ) - return await self._storage.search( - collection=self._collection_name, - query_vector=query_vector, - sparse_query_vector=sparse_query_vector, - filter=merged_filter, - limit=limit, - ) - - async def search_children_in_tenant( - self, - ctx: RequestContext, - parent_uri: str, - query_vector: Optional[List[float]], - sparse_query_vector: Optional[Dict[str, float]] = None, - context_type: Optional[str] = None, - target_directories: Optional[List[str]] = None, - extra_filter: Optional[WhereExpr] = None, - limit: int = 10, - ) -> List[Dict[str, Any]]: - merged_filter = self._merge_filters( - Eq("parent_uri", parent_uri), - self._build_scope_filter( - ctx=ctx, - context_type=context_type, - target_directories=target_directories, - extra_filter=extra_filter, - ), - ) - return await self._storage.search( - collection=self._collection_name, - query_vector=query_vector, - sparse_query_vector=sparse_query_vector, - filter=merged_filter, - limit=limit, - ) - - async def search_similar_memories( - self, - account_id: str, - owner_space: Optional[str], - category_uri_prefix: str, - query_vector: List[float], - limit: int = 5, - ) -> List[Dict[str, Any]]: - conds: List[FilterExpr] = [ - Eq("context_type", "memory"), - Eq("level", 2), - Eq("account_id", account_id), - ] - if owner_space: - conds.append(Eq("owner_space", owner_space)) - if category_uri_prefix: - conds.append(Prefix("uri", category_uri_prefix)) - - return await self._storage.search( - collection=self._collection_name, - query_vector=query_vector, - filter=And(conds), - limit=limit, - ) - - async def get_context_by_uri( - self, - account_id: str, - uri: str, - owner_space: Optional[str] = None, - limit: int = 1, - ) -> List[Dict[str, Any]]: - conds: List[FilterExpr] = [ - Eq("uri", uri), - Eq("account_id", account_id), - ] - if owner_space: - conds.append(Eq("owner_space", owner_space)) - - return await self._storage.filter( - collection=self._collection_name, - filter=And(conds), - limit=limit, - ) - - async def delete_account_data(self, account_id: str) -> int: - return await self._storage.batch_delete( - self._collection_name, - Eq("account_id", account_id), - ) - - async def delete_uris(self, ctx: RequestContext, uris: List[str]) -> None: - for uri in uris: - conds: List[FilterExpr] = [ - Eq("account_id", ctx.account_id), - Or([Eq("uri", uri), Prefix("uri", f"{uri}/")]), - ] - if ctx.role == Role.USER and uri.startswith(("viking://user/", "viking://agent/")): - owner_space = ( - ctx.user.user_space_name() - if uri.startswith("viking://user/") - else ctx.user.agent_space_name() - ) - conds.append(Eq("owner_space", owner_space)) - await self._storage.batch_delete( - self._collection_name, - And(conds), - ) - - async def update_uri_mapping( - self, - ctx: RequestContext, - uri: str, - new_uri: str, - new_parent_uri: str, - ) -> bool: - records = await self._storage.filter( - collection=self._collection_name, - filter=And([Eq("uri", uri), Eq("account_id", ctx.account_id)]), - limit=1, - ) - if not records or "id" not in records[0]: - return False - - return await self._storage.update( - self._collection_name, - records[0]["id"], - {"uri": new_uri, "parent_uri": new_parent_uri}, - ) - - async def increment_active_count(self, ctx: RequestContext, uris: List[str]) -> int: - updated = 0 - for uri in uris: - records = await self.get_context_by_uri(account_id=ctx.account_id, uri=uri, limit=1) - if not records: - continue - record = records[0] - record_id = record.get("id") - if not record_id: - continue - current = int(record.get("active_count", 0) or 0) - if await self._storage.update( - self._collection_name, - record_id, - {"active_count": current + 1}, - ): - updated += 1 - return updated - - def _build_scope_filter( - self, - ctx: RequestContext, - context_type: Optional[str], - target_directories: Optional[List[str]], - extra_filter: Optional[WhereExpr], - ) -> Optional[FilterExpr]: - filters: List[FilterExpr] = [] - if context_type: - filters.append(Eq("context_type", context_type)) - - tenant_filter = self._tenant_filter(ctx, context_type=context_type) - if tenant_filter: - filters.append(tenant_filter) - - if target_directories: - uri_conds = [ - Prefix("uri", target_dir) for target_dir in target_directories if target_dir - ] - if uri_conds: - filters.append(Or(uri_conds)) - - if extra_filter: - if isinstance(extra_filter, dict): - filters.append(RawDSL(extra_filter)) - else: - filters.append(extra_filter) - - return self._merge_filters(*filters) - - @staticmethod - def _tenant_filter( - ctx: RequestContext, context_type: Optional[str] = None - ) -> Optional[FilterExpr]: - if ctx.role == Role.ROOT: - return None - - owner_spaces = [ctx.user.user_space_name(), ctx.user.agent_space_name()] - if context_type == "resource": - owner_spaces.append("") - return And([Eq("account_id", ctx.account_id), In("owner_space", owner_spaces)]) - - @staticmethod - def _merge_filters(*filters: Optional[FilterExpr]) -> Optional[FilterExpr]: - non_empty = [f for f in filters if f] - if not non_empty: - return None - if len(non_empty) == 1: - return non_empty[0] - return And(non_empty) diff --git a/openviking/storage/errors.py b/openviking/storage/errors.py new file mode 100644 index 00000000..bc3e36be --- /dev/null +++ b/openviking/storage/errors.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Storage-layer exceptions.""" + + +class VikingDBException(Exception): + """Base exception for vector-store operations.""" + + +class StorageException(VikingDBException): + """Legacy alias for VikingDBException for backward compatibility.""" + + +class CollectionNotFoundError(StorageException): + """Raised when a collection does not exist.""" + + +class RecordNotFoundError(StorageException): + """Raised when a record does not exist.""" + + +class DuplicateKeyError(StorageException): + """Raised when trying to insert a duplicate key.""" + + +class ConnectionError(StorageException): + """Raised when storage connection fails.""" + + +class SchemaError(StorageException): + """Raised when schema validation fails.""" diff --git a/openviking/storage/observers/vikingdb_observer.py b/openviking/storage/observers/vikingdb_observer.py index 4b25fb0c..8f2fca3b 100644 --- a/openviking/storage/observers/vikingdb_observer.py +++ b/openviking/storage/observers/vikingdb_observer.py @@ -54,7 +54,7 @@ async def _get_collection_statuses(self, collection_names: list) -> Dict[str, Di # Current OpenViking flow uses one managed default index per collection. index_count = 1 - vector_count = await self._vikingdb_manager.count(name) + vector_count = await self._vikingdb_manager.count() statuses[name] = { "index_count": index_count, diff --git a/openviking/storage/queuefs/queue_manager.py b/openviking/storage/queuefs/queue_manager.py index dc9aeb57..95e9aeb2 100644 --- a/openviking/storage/queuefs/queue_manager.py +++ b/openviking/storage/queuefs/queue_manager.py @@ -107,7 +107,7 @@ def start(self) -> None: logger.info("[QueueManager] Started") - def setup_standard_queues(self, vikingdb_interface: Any) -> None: + def setup_standard_queues(self, vector_store: Any) -> None: """ Setup standard queues (Embedding and Semantic) with their handlers. @@ -116,14 +116,14 @@ def setup_standard_queues(self, vikingdb_interface: Any) -> None: queue manager is started. Args: - vikingdb_interface: VikingDBInterface instance for handlers to write results. + vector_store: Vector store instance for handlers to write results. """ # Import handlers here to avoid circular dependencies from openviking.storage.collection_schemas import TextEmbeddingHandler from openviking.storage.queuefs import SemanticProcessor # Embedding Queue - embedding_handler = TextEmbeddingHandler(vikingdb_interface) + embedding_handler = TextEmbeddingHandler(vector_store) self.get_queue( self.EMBEDDING, dequeue_handler=embedding_handler, diff --git a/openviking/storage/vector_store/driver.py b/openviking/storage/vector_store/driver.py index c22afc55..97f30cbf 100644 --- a/openviking/storage/vector_store/driver.py +++ b/openviking/storage/vector_store/driver.py @@ -55,6 +55,43 @@ def list_collections(self) -> list[str]: def close(self) -> None: """Release backend resources.""" + def sanitize_scalar_index_fields( + self, + scalar_index_fields: list[str], + fields_meta: list[dict[str, Any]], + ) -> list[str]: + """Normalize scalar index fields for backend-specific constraints.""" + return scalar_index_fields + + def build_default_index_meta( + self, + *, + index_name: str, + distance: str, + use_sparse: bool, + sparse_weight: float, + scalar_index_fields: list[str], + ) -> Dict[str, Any]: + """Build default index meta payload for this backend.""" + index_type = "flat_hybrid" if use_sparse else "flat" + index_meta: Dict[str, Any] = { + "IndexName": index_name, + "VectorIndex": { + "IndexType": index_type, + "Distance": distance, + "Quant": "int8", + }, + "ScalarIndex": scalar_index_fields, + } + if use_sparse: + index_meta["VectorIndex"]["EnableSparse"] = True + index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight + return index_meta + + def normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: + """Normalize record fields after reading from backend.""" + return record + def compile_expr(self, expr: FilterExpr | None) -> Dict[str, Any]: """Compile a filter AST node to vectordb DSL.""" if expr is None: diff --git a/openviking/storage/vector_store/drivers/vikingdb_driver.py b/openviking/storage/vector_store/drivers/vikingdb_driver.py index 4e21c91d..c09fad69 100644 --- a/openviking/storage/vector_store/drivers/vikingdb_driver.py +++ b/openviking/storage/vector_store/drivers/vikingdb_driver.py @@ -114,3 +114,46 @@ def close(self) -> None: if self._collection is not None: self._collection.close() self._collection = None + + def sanitize_scalar_index_fields( + self, + scalar_index_fields: list[str], + fields_meta: list[dict], + ) -> list[str]: + date_time_fields = { + field.get("FieldName") for field in fields_meta if field.get("FieldType") == "date_time" + } + return [field for field in scalar_index_fields if field not in date_time_fields] + + def build_default_index_meta( + self, + *, + index_name: str, + distance: str, + use_sparse: bool, + sparse_weight: float, + scalar_index_fields: list[str], + ) -> dict: + index_type = "hnsw_hybrid" if use_sparse else "hnsw" + index_meta = { + "IndexName": index_name, + "VectorIndex": { + "IndexType": index_type, + "Distance": distance, + "Quant": "int8", + }, + "ScalarIndex": scalar_index_fields, + } + if use_sparse: + index_meta["VectorIndex"]["EnableSparse"] = True + index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight + return index_meta + + def normalize_record_for_read(self, record: dict) -> dict: + for key in ("uri", "parent_uri"): + value = record.get(key) + if isinstance(value, str) and not value.startswith("viking://"): + stripped = value.strip("/") + if stripped: + record[key] = f"viking://{stripped}" + return record diff --git a/openviking/storage/vector_store/drivers/volcengine_driver.py b/openviking/storage/vector_store/drivers/volcengine_driver.py index d09cdd6d..d42f849c 100644 --- a/openviking/storage/vector_store/drivers/volcengine_driver.py +++ b/openviking/storage/vector_store/drivers/volcengine_driver.py @@ -128,3 +128,46 @@ def close(self) -> None: if self._collection is not None: self._collection.close() self._collection = None + + def sanitize_scalar_index_fields( + self, + scalar_index_fields: list[str], + fields_meta: list[dict], + ) -> list[str]: + date_time_fields = { + field.get("FieldName") for field in fields_meta if field.get("FieldType") == "date_time" + } + return [field for field in scalar_index_fields if field not in date_time_fields] + + def build_default_index_meta( + self, + *, + index_name: str, + distance: str, + use_sparse: bool, + sparse_weight: float, + scalar_index_fields: list[str], + ) -> dict: + index_type = "hnsw_hybrid" if use_sparse else "hnsw" + index_meta = { + "IndexName": index_name, + "VectorIndex": { + "IndexType": index_type, + "Distance": distance, + "Quant": "int8", + }, + "ScalarIndex": scalar_index_fields, + } + if use_sparse: + index_meta["VectorIndex"]["EnableSparse"] = True + index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight + return index_meta + + def normalize_record_for_read(self, record: dict) -> dict: + for key in ("uri", "parent_uri"): + value = record.get(key) + if isinstance(value, str) and not value.startswith("viking://"): + stripped = value.strip("/") + if stripped: + record[key] = f"viking://{stripped}" + return record diff --git a/openviking/storage/viking_fs.py b/openviking/storage/viking_fs.py index b432f78d..ae09efe6 100644 --- a/openviking/storage/viking_fs.py +++ b/openviking/storage/viking_fs.py @@ -25,14 +25,13 @@ from pyagfs.exceptions import AGFSHTTPError from openviking.server.identity import RequestContext, Role -from openviking.storage.context_vector_gateway import ContextVectorGateway -from openviking.storage.vikingdb_interface import VikingDBInterface from openviking.utils.time_utils import format_simplified, get_current_timestamp, parse_iso_datetime from openviking_cli.session.user_id import UserIdentifier from openviking_cli.utils.logger import get_logger from openviking_cli.utils.uri import VikingURI if TYPE_CHECKING: + from openviking.storage.viking_vector_index_backend import VikingVectorIndexBackend from openviking_cli.utils.config import RerankConfig logger = get_logger(__name__) @@ -72,7 +71,8 @@ def init_viking_fs( agfs: Any, query_embedder: Optional[Any] = None, rerank_config: Optional["RerankConfig"] = None, - vector_store: Optional["VikingDBInterface"] = None, + vector_store: Optional["VikingVectorIndexBackend"] = None, + timeout: int = 10, enable_recorder: bool = False, ) -> "VikingFS": """Initialize VikingFS singleton. @@ -163,15 +163,13 @@ def __init__( agfs: Any, query_embedder: Optional[Any] = None, rerank_config: Optional["RerankConfig"] = None, - vector_store: Optional["VikingDBInterface"] = None, + vector_store: Optional["VikingVectorIndexBackend"] = None, + timeout: int = 10, ): self.agfs = agfs self.query_embedder = query_embedder self.rerank_config = rerank_config self.vector_store = vector_store - self._context_vector_gateway: Optional[ContextVectorGateway] = ( - ContextVectorGateway.from_storage(vector_store) if vector_store else None - ) self._bound_ctx: contextvars.ContextVar[Optional[RequestContext]] = contextvars.ContextVar( "vikingfs_bound_ctx", default=None ) @@ -1049,17 +1047,15 @@ async def _delete_from_vector_store( ) -> None: """Delete records with specified URIs from vector store. - Uses semantic gateway to apply tenant-safe URI deletion semantics. + Uses tenant-safe URI deletion semantics from vector store. """ - if not self._get_vector_store(): - return - gateway = self._get_context_vector_gateway() - if not gateway: + vector_store = self._get_vector_store() + if not vector_store: return real_ctx = self._ctx_or_default(ctx) try: - await gateway.delete_uris(real_ctx, uris) + await vector_store.delete_uris(real_ctx, uris) for uri in uris: logger.info(f"[VikingFS] Deleted from vector store: {uri}") except Exception as e: @@ -1076,10 +1072,8 @@ async def _update_vector_store_uris( Preserves vector data, only updates uri and parent_uri fields, no need to regenerate embeddings. """ - if not self._get_vector_store(): - return - gateway = self._get_context_vector_gateway() - if not gateway: + vector_store = self._get_vector_store() + if not vector_store: return old_base_uri = self._path_to_uri(old_base, ctx=ctx) @@ -1087,7 +1081,7 @@ async def _update_vector_store_uris( for uri in uris: try: - records = await gateway.get_context_by_uri( + records = await vector_store.get_context_by_uri( account_id=self._ctx_or_default(ctx).account_id, uri=uri, limit=1, @@ -1105,7 +1099,7 @@ async def _update_vector_store_uris( old_parent_uri.replace(old_base_uri, new_base_uri, 1) if old_parent_uri else "" ) - await gateway.update_uri_mapping( + await vector_store.update_uri_mapping( ctx=self._ctx_or_default(ctx), uri=uri, new_uri=new_uri, @@ -1115,20 +1109,10 @@ async def _update_vector_store_uris( except Exception as e: logger.warning(f"[VikingFS] Failed to update {uri} in vector store: {e}") - def _get_vector_store(self) -> Optional["VikingDBInterface"]: + def _get_vector_store(self) -> Optional["VikingVectorIndexBackend"]: """Get vector store instance.""" return self.vector_store - def _get_context_vector_gateway(self) -> Optional[ContextVectorGateway]: - """Get semantic vector gateway bound to configured collection.""" - storage = self._get_vector_store() - if not storage: - self._context_vector_gateway = None - return None - if not self._context_vector_gateway: - self._context_vector_gateway = ContextVectorGateway.from_storage(storage) - return self._context_vector_gateway - def _get_embedder(self) -> Any: """Get embedder instance.""" return self.query_embedder diff --git a/openviking/storage/viking_vector_index_backend.py b/openviking/storage/viking_vector_index_backend.py index c9f14688..3a1fc1db 100644 --- a/openviking/storage/viking_vector_index_backend.py +++ b/openviking/storage/viking_vector_index_backend.py @@ -3,26 +3,26 @@ """ VikingDB storage backend for OpenViking. -Implements the VikingDBInterface using the custom vectordb implementation. Supports both in-memory and local persistent storage modes. """ import uuid from typing import Any, Dict, List, Optional +from openviking.server.identity import RequestContext, Role +from openviking.storage.errors import CollectionNotFoundError from openviking.storage.vector_store import FilterExpr, create_driver -from openviking.storage.vector_store.expr import RawDSL +from openviking.storage.vector_store.expr import And, Eq, In, Or, RawDSL from openviking.storage.vectordb.collection.collection import Collection from openviking.storage.vectordb.collection.result import FetchDataInCollectionResult from openviking.storage.vectordb.utils.logging_init import init_cpp_logging -from openviking.storage.vikingdb_interface import CollectionNotFoundError, VikingDBInterface from openviking_cli.utils import get_logger from openviking_cli.utils.config.vectordb_config import VectorDBBackendConfig logger = get_logger(__name__) -class VikingVectorIndexBackend(VikingDBInterface): +class VikingVectorIndexBackend: """ VikingDB storage backend implementation. @@ -83,6 +83,7 @@ def __init__( self.vector_dim = config.dimension self.distance_metric = config.distance_metric self.sparse_weight = config.sparse_weight + self._collection_name = config.name or "context" # Backend selection is delegated to static driver registry. self._driver = create_driver(config) @@ -125,22 +126,14 @@ def _update_meta_data_cache(self, collection_name: str, coll: Collection): meta_data = coll.get_meta_data() self._meta_data_cache[collection_name] = meta_data - @staticmethod - def _restore_uri_fields(record: Dict[str, Any]) -> Dict[str, Any]: - """Restore viking:// prefix on uri/parent_uri fields read from VikingDB. + @property + def collection_name(self) -> str: + """Return bound collection name for this store instance.""" + return self._collection_name - The volcengine backend sanitizes URIs to /path/ format on write; - this reverses that transformation so the rest of the system sees - the canonical viking:// scheme. Idempotent for values that - already carry the prefix (local/http backends). - """ - for key in ("uri", "parent_uri"): - val = record.get(key) - if isinstance(val, str) and not val.startswith("viking://"): - restored = val.strip("/") - if restored: - record[key] = f"viking://{restored}" - return record + def _resolve_collection_name(self, collection_name: Optional[str] = None) -> str: + """Resolve collection name with bound default.""" + return collection_name or self._collection_name # ========================================================================= # Collection/Table Management @@ -194,35 +187,20 @@ async def create_collection(self, name: str, schema: Dict[str, Any]) -> bool: # Create collection via backend-specific collection driver collection = self._driver.create_collection(name, collection_meta) - # Filter date_time fields for volcengine and vikingdb backends - if self._mode in ["volcengine", "vikingdb"]: - date_time_fields = { - field.get("FieldName") - for field in collection_meta.get("Fields", []) - if field.get("FieldType") == "date_time" - } - scalar_index_fields = [ - field for field in scalar_index_fields if field not in date_time_fields - ] + scalar_index_fields = self._driver.sanitize_scalar_index_fields( + scalar_index_fields=scalar_index_fields, + fields_meta=collection_meta.get("Fields", []), + ) # Create default index for the collection use_sparse = self.sparse_weight > 0.0 - index_type = "flat_hybrid" if use_sparse else "flat" - if self._mode in ["volcengine", "vikingdb"]: - index_type = "hnsw_hybrid" if use_sparse else "hnsw" - - index_meta = { - "IndexName": self.DEFAULT_INDEX_NAME, - "VectorIndex": { - "IndexType": index_type, - "Distance": distance, - "Quant": "int8", - }, - "ScalarIndex": scalar_index_fields, - } - if use_sparse: - index_meta["VectorIndex"]["EnableSparse"] = True - index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = self.sparse_weight + index_meta = self._driver.build_default_index_meta( + index_name=self.DEFAULT_INDEX_NAME, + distance=distance, + use_sparse=use_sparse, + sparse_weight=self.sparse_weight, + scalar_index_fields=scalar_index_fields, + ) logger.info(f"Creating index with meta: {index_meta}") collection.create_index(self.DEFAULT_INDEX_NAME, index_meta) @@ -236,6 +214,7 @@ async def create_collection(self, name: str, schema: Dict[str, Any]) -> bool: "distance": distance, "schema": schema, } + self._collection_name = name logger.info(f"Created VikingDB collection: {name} (dim={vector_dim})") return True @@ -265,9 +244,9 @@ async def drop_collection(self, name: str) -> bool: logger.error(f"Error dropping collection '{name}': {e}") return False - async def collection_exists(self, name: str) -> bool: + async def collection_exists(self, name: Optional[str] = None) -> bool: """Check if a collection exists.""" - return self._driver.has_collection(name) + return self._driver.has_collection(self._resolve_collection_name(name)) async def list_collections(self) -> List[str]: """List all collection names.""" @@ -291,13 +270,17 @@ async def get_collection_info(self, name: str) -> Optional[Dict[str, Any]]: logger.error(f"Error getting collection info for '{name}': {e}") return None + async def collection_exists_bound(self) -> bool: + """Check whether the bound collection exists.""" + return await self.collection_exists(self._collection_name) + # ========================================================================= # CRUD Operations - Single Record # ========================================================================= - async def insert(self, collection: str, data: Dict[str, Any]) -> str: - """Insert a single record.""" - coll = self._get_collection(collection) + async def insert(self, data: Dict[str, Any]) -> str: + """Insert a single record into the bound collection.""" + coll = self._get_collection(self._collection_name) # Ensure ID exists record_id = data.get("id") @@ -306,21 +289,20 @@ async def insert(self, collection: str, data: Dict[str, Any]) -> str: data = {**data, "id": record_id} # Validate context_type for context collection - if collection == "context": - context_type = data.get("context_type") - if context_type not in ["resource", "skill", "memory"]: - logger.warning( - f"Invalid context_type: {context_type}. " - f"Must be one of ['resource', 'skill', 'memory'], Ignore" - ) - return "" + context_type = data.get("context_type") + if context_type not in ["resource", "skill", "memory"]: + logger.warning( + f"Invalid context_type: {context_type}. " + f"Must be one of ['resource', 'skill', 'memory'], Ignore" + ) + return "" - fields = self._get_meta_data(collection, coll).get("Fields", []) + fields = self._get_meta_data(self._collection_name, coll).get("Fields", []) fields_dict = {item["FieldName"]: item for item in fields} new_data = {} - for k in data: - if k in fields_dict and data[k] is not None: - new_data[k] = data[k] + for key in data: + if key in fields_dict and data[key] is not None: + new_data[key] = data[key] try: coll.upsert_data([new_data]) @@ -329,13 +311,13 @@ async def insert(self, collection: str, data: Dict[str, Any]) -> str: logger.error(f"Error inserting record: {e}") raise - async def update(self, collection: str, id: str, data: Dict[str, Any]) -> bool: - """Update a record by ID.""" - coll = self._get_collection(collection) + async def update(self, id: str, data: Dict[str, Any]) -> bool: + """Update a record by ID in the bound collection.""" + coll = self._get_collection(self._collection_name) try: # Fetch existing record - existing = await self.get(collection, [id]) + existing = await self.get([id]) if not existing: return False @@ -350,9 +332,9 @@ async def update(self, collection: str, id: str, data: Dict[str, Any]) -> bool: logger.error(f"Error updating record '{id}': {e}") return False - async def upsert(self, collection: str, data: Dict[str, Any]) -> str: - """Insert or update a record.""" - coll = self._get_collection(collection) + async def upsert(self, data: Dict[str, Any]) -> str: + """Insert or update a record in the bound collection.""" + coll = self._get_collection(self._collection_name) record_id = data.get("id") if not record_id: @@ -366,9 +348,9 @@ async def upsert(self, collection: str, data: Dict[str, Any]) -> str: logger.error(f"Error upserting record: {e}") raise - async def delete(self, collection: str, ids: List[str]) -> int: - """Delete records by IDs.""" - coll = self._get_collection(collection) + async def delete(self, ids: List[str]) -> int: + """Delete records by IDs from the bound collection.""" + coll = self._get_collection(self._collection_name) try: coll.delete_data(ids) @@ -377,9 +359,9 @@ async def delete(self, collection: str, ids: List[str]) -> int: logger.error(f"Error deleting records: {e}") return 0 - async def get(self, collection: str, ids: List[str]) -> List[Dict[str, Any]]: - """Get records by IDs.""" - coll = self._get_collection(collection) + async def get(self, ids: List[str]) -> List[Dict[str, Any]]: + """Get records by IDs from the bound collection.""" + coll = self._get_collection(self._collection_name) try: result = coll.fetch_data(ids) @@ -389,7 +371,7 @@ async def get(self, collection: str, ids: List[str]) -> List[Dict[str, Any]]: for item in result.items: record = dict(item.fields) if item.fields else {} record["id"] = item.id - self._restore_uri_fields(record) + self._driver.normalize_record_for_read(record) records.append(record) return records elif isinstance(result, dict): @@ -399,7 +381,7 @@ async def get(self, collection: str, ids: List[str]) -> List[Dict[str, Any]]: record = dict(item.get("fields", {})) if item.get("fields") else {} record["id"] = item.get("id") if record["id"]: - self._restore_uri_fields(record) + self._driver.normalize_record_for_read(record) records.append(record) return records else: @@ -409,9 +391,9 @@ async def get(self, collection: str, ids: List[str]) -> List[Dict[str, Any]]: logger.error(f"Error getting records: {e}") return [] - async def fetch_by_uri(self, collection: str, uri: str) -> Optional[Dict[str, Any]]: + async def fetch_by_uri(self, uri: str) -> Optional[Dict[str, Any]]: """Fetch a record by URI.""" - coll = self._get_collection(collection) + coll = self._get_collection(self._collection_name) try: result = coll.search_by_random( index_name=self.DEFAULT_INDEX_NAME, @@ -422,7 +404,7 @@ async def fetch_by_uri(self, collection: str, uri: str) -> Optional[Dict[str, An for item in result.data: record = dict(item.fields) if item.fields else {} record["id"] = item.id - self._restore_uri_fields(record) + self._driver.normalize_record_for_read(record) records.append(record) if len(records) > 1: raise ValueError(f"Duplicate records found for URI: {uri}") @@ -433,10 +415,10 @@ async def fetch_by_uri(self, collection: str, uri: str) -> Optional[Dict[str, An logger.error(f"Error fetching record by URI '{uri}': {e}") return None - async def exists(self, collection: str, id: str) -> bool: + async def exists(self, id: str) -> bool: """Check if a record exists.""" try: - results = await self.get(collection, [id]) + results = await self.get([id]) return len(results) > 0 except Exception: return False @@ -445,9 +427,9 @@ async def exists(self, collection: str, id: str) -> bool: # CRUD Operations - Batch # ========================================================================= - async def batch_insert(self, collection: str, data: List[Dict[str, Any]]) -> List[str]: - """Batch insert multiple records.""" - coll = self._get_collection(collection) + async def batch_insert(self, data: List[Dict[str, Any]]) -> List[str]: + """Batch insert multiple records into the bound collection.""" + coll = self._get_collection(self._collection_name) # Ensure all records have IDs ids = [] @@ -468,9 +450,9 @@ async def batch_insert(self, collection: str, data: List[Dict[str, Any]]) -> Lis logger.error(f"Error batch inserting records: {e}") raise - async def batch_upsert(self, collection: str, data: List[Dict[str, Any]]) -> List[str]: - """Batch insert or update multiple records.""" - coll = self._get_collection(collection) + async def batch_upsert(self, data: List[Dict[str, Any]]) -> List[str]: + """Batch insert or update multiple records in the bound collection.""" + coll = self._get_collection(self._collection_name) ids = [] records_with_ids = [] @@ -490,28 +472,27 @@ async def batch_upsert(self, collection: str, data: List[Dict[str, Any]]) -> Lis logger.error(f"Error batch upserting records: {e}") raise - async def batch_delete(self, collection: str, filter: Dict[str, Any] | FilterExpr) -> int: + async def batch_delete(self, filter: Dict[str, Any] | FilterExpr) -> int: """Delete records matching filter conditions.""" try: # First, find matching records - matching_records = await self.filter(collection, filter, limit=10000) + matching_records = await self.filter(filter, limit=10000) if not matching_records: return 0 # Extract IDs and delete ids = [record["id"] for record in matching_records if "id" in record] - return await self.delete(collection, ids) + return await self.delete(ids) except Exception as e: logger.error(f"Error batch deleting records: {e}") return 0 - async def remove_by_uri(self, collection: str, uri: str) -> int: + async def remove_by_uri(self, uri: str) -> int: """Remove resource(s) by URI.""" try: target_records = await self.filter( - collection=collection, - filter={"op": "must", "field": "uri", "conds": [uri]}, + {"op": "must", "field": "uri", "conds": [uri]}, limit=10, ) @@ -522,12 +503,12 @@ async def remove_by_uri(self, collection: str, uri: str) -> int: # If any record indicates this URI is a directory node, remove descendants first. if any(r.get("level") in [0, 1] for r in target_records): - descendant_count = await self._remove_descendants(collection, uri) + descendant_count = await self._remove_descendants(parent_uri=uri) total_deleted += descendant_count ids = [r.get("id") for r in target_records if r.get("id")] if ids: - total_deleted += await self.delete(collection, ids) + total_deleted += await self.delete(ids) logger.info(f"Removed {total_deleted} record(s) for URI: {uri}") return total_deleted @@ -536,14 +517,13 @@ async def remove_by_uri(self, collection: str, uri: str) -> int: logger.error(f"Error removing URI '{uri}': {e}") return 0 - async def _remove_descendants(self, collection: str, parent_uri: str) -> int: + async def _remove_descendants(self, parent_uri: str) -> int: """Recursively remove all descendants of a parent URI.""" total_deleted = 0 # Find direct children children = await self.filter( - collection=collection, - filter={"op": "must", "field": "parent_uri", "conds": [parent_uri]}, + {"op": "must", "field": "parent_uri", "conds": [parent_uri]}, limit=10000, ) @@ -553,12 +533,12 @@ async def _remove_descendants(self, collection: str, parent_uri: str) -> int: # Recursively delete if child is also an intermediate directory if level in [0, 1] and child_uri: - descendant_count = await self._remove_descendants(collection, child_uri) + descendant_count = await self._remove_descendants(parent_uri=child_uri) total_deleted += descendant_count # Delete the child if "id" in child: - await self.delete(collection, [child["id"]]) + await self.delete([child["id"]]) total_deleted += 1 return total_deleted @@ -569,7 +549,6 @@ async def _remove_descendants(self, collection: str, parent_uri: str) -> int: async def search( self, - collection: str, query_vector: Optional[List[float]] = None, sparse_query_vector: Optional[Dict[str, float]] = None, filter: Optional[Dict[str, Any] | FilterExpr] = None, @@ -581,7 +560,6 @@ async def search( """Hybrid search: vector similarity (dense/sparse/hybrid) + scalar filtering. Args: - collection: Collection name, by default it should be "context" query_vector: Dense query vector (optional) sparse_query_vector: Sparse query vector as {term: weight} dict (optional) filter: Scalar filter conditions @@ -593,7 +571,7 @@ async def search( Returns: List of matching records with scores """ - coll = self._get_collection(collection) + coll = self._get_collection(self._collection_name) try: vectordb_filter = self._compile_filter(filter) @@ -616,7 +594,7 @@ async def search( record = dict(item.fields) if item.fields else {} record["id"] = item.id record["_score"] = item.score if item.score is not None else 0.0 - self._restore_uri_fields(record) + self._driver.normalize_record_for_read(record) if not with_vector: if "vector" in record: @@ -629,10 +607,15 @@ async def search( return records else: # Pure filtering without vector search - return await self.filter(collection, filter or {}, limit, offset, output_fields) + return await self.filter( + filter or {}, + limit=limit, + offset=offset, + output_fields=output_fields, + ) except Exception as e: - logger.error(f"Error searching collection '{collection}': {e}") + logger.error(f"Error searching collection '{self._collection_name}': {e}") import traceback traceback.print_exc() @@ -640,7 +623,6 @@ async def search( async def filter( self, - collection: str, filter: Dict[str, Any] | FilterExpr, limit: int = 10, offset: int = 0, @@ -649,7 +631,7 @@ async def filter( order_desc: bool = False, ) -> List[Dict[str, Any]]: """Pure scalar filtering without vector search.""" - coll = self._get_collection(collection) + coll = self._get_collection(self._collection_name) try: vectordb_filter = self._compile_filter(filter) @@ -680,21 +662,244 @@ async def filter( for item in result.data: record = dict(item.fields) if item.fields else {} record["id"] = item.id - self._restore_uri_fields(record) + self._driver.normalize_record_for_read(record) records.append(record) return records except Exception as e: - logger.error(f"Error filtering collection '{collection}': {e}") + logger.error(f"Error filtering collection '{self._collection_name}': {e}") import traceback traceback.print_exc() return [] + # ========================================================================= + # Semantic Context Operations (Tenant-Aware) + # ========================================================================= + + async def search_in_tenant( + self, + ctx: RequestContext, + query_vector: Optional[List[float]], + sparse_query_vector: Optional[Dict[str, float]] = None, + context_type: Optional[str] = None, + target_directories: Optional[List[str]] = None, + extra_filter: Optional[FilterExpr | Dict[str, Any]] = None, + limit: int = 10, + offset: int = 0, + ) -> List[Dict[str, Any]]: + scope_filter = self._build_scope_filter( + ctx=ctx, + context_type=context_type, + target_directories=target_directories, + extra_filter=extra_filter, + ) + return await self.search( + query_vector=query_vector, + sparse_query_vector=sparse_query_vector, + filter=scope_filter, + limit=limit, + offset=offset, + ) + + async def search_global_roots_in_tenant( + self, + ctx: RequestContext, + query_vector: Optional[List[float]], + sparse_query_vector: Optional[Dict[str, float]] = None, + context_type: Optional[str] = None, + target_directories: Optional[List[str]] = None, + extra_filter: Optional[FilterExpr | Dict[str, Any]] = None, + limit: int = 10, + ) -> List[Dict[str, Any]]: + if not query_vector: + return [] + + merged_filter = self._merge_filters( + self._build_scope_filter( + ctx=ctx, + context_type=context_type, + target_directories=target_directories, + extra_filter=extra_filter, + ), + In("level", [0, 1]), + ) + return await self.search( + query_vector=query_vector, + sparse_query_vector=sparse_query_vector, + filter=merged_filter, + limit=limit, + ) + + async def search_children_in_tenant( + self, + ctx: RequestContext, + parent_uri: str, + query_vector: Optional[List[float]], + sparse_query_vector: Optional[Dict[str, float]] = None, + context_type: Optional[str] = None, + target_directories: Optional[List[str]] = None, + extra_filter: Optional[FilterExpr | Dict[str, Any]] = None, + limit: int = 10, + ) -> List[Dict[str, Any]]: + merged_filter = self._merge_filters( + Eq("parent_uri", parent_uri), + self._build_scope_filter( + ctx=ctx, + context_type=context_type, + target_directories=target_directories, + extra_filter=extra_filter, + ), + ) + return await self.search( + query_vector=query_vector, + sparse_query_vector=sparse_query_vector, + filter=merged_filter, + limit=limit, + ) + + async def search_similar_memories( + self, + account_id: str, + owner_space: Optional[str], + category_uri_prefix: str, + query_vector: List[float], + limit: int = 5, + ) -> List[Dict[str, Any]]: + conds: List[FilterExpr] = [ + Eq("context_type", "memory"), + Eq("level", 2), + Eq("account_id", account_id), + ] + if owner_space: + conds.append(Eq("owner_space", owner_space)) + if category_uri_prefix: + conds.append(In("uri", [category_uri_prefix])) + + return await self.search( + query_vector=query_vector, + filter=And(conds), + limit=limit, + ) + + async def get_context_by_uri( + self, + account_id: str, + uri: str, + owner_space: Optional[str] = None, + limit: int = 1, + ) -> List[Dict[str, Any]]: + conds: List[FilterExpr] = [ + Eq("uri", uri), + Eq("account_id", account_id), + ] + if owner_space: + conds.append(Eq("owner_space", owner_space)) + return await self.filter( + filter=And(conds), + limit=limit, + ) + + async def delete_account_data(self, account_id: str) -> int: + return await self.batch_delete(Eq("account_id", account_id)) + + async def delete_uris(self, ctx: RequestContext, uris: List[str]) -> None: + for uri in uris: + conds: List[FilterExpr] = [ + Eq("account_id", ctx.account_id), + Or([Eq("uri", uri), In("uri", [f"{uri}/"])]), + ] + if ctx.role == Role.USER and uri.startswith(("viking://user/", "viking://agent/")): + owner_space = ( + ctx.user.user_space_name() + if uri.startswith("viking://user/") + else ctx.user.agent_space_name() + ) + conds.append(Eq("owner_space", owner_space)) + await self.batch_delete(And(conds)) + + async def update_uri_mapping( + self, + ctx: RequestContext, + uri: str, + new_uri: str, + new_parent_uri: str, + ) -> bool: + records = await self.filter( + filter=And([Eq("uri", uri), Eq("account_id", ctx.account_id)]), + limit=1, + ) + if not records or "id" not in records[0]: + return False + return await self.update(records[0]["id"], {"uri": new_uri, "parent_uri": new_parent_uri}) + + async def increment_active_count(self, ctx: RequestContext, uris: List[str]) -> int: + updated = 0 + for uri in uris: + records = await self.get_context_by_uri(account_id=ctx.account_id, uri=uri, limit=1) + if not records: + continue + record = records[0] + record_id = record.get("id") + if not record_id: + continue + current = int(record.get("active_count", 0) or 0) + if await self.update(record_id, {"active_count": current + 1}): + updated += 1 + return updated + + def _build_scope_filter( + self, + ctx: RequestContext, + context_type: Optional[str], + target_directories: Optional[List[str]], + extra_filter: Optional[FilterExpr | Dict[str, Any]], + ) -> Optional[FilterExpr]: + filters: List[FilterExpr] = [] + if context_type: + filters.append(Eq("context_type", context_type)) + + tenant_filter = self._tenant_filter(ctx, context_type=context_type) + if tenant_filter: + filters.append(tenant_filter) + + if target_directories: + uri_conds = [In("uri", [target_dir]) for target_dir in target_directories if target_dir] + if uri_conds: + filters.append(Or(uri_conds)) + + if extra_filter: + if isinstance(extra_filter, dict): + filters.append(RawDSL(extra_filter)) + else: + filters.append(extra_filter) + + return self._merge_filters(*filters) + + @staticmethod + def _tenant_filter( + ctx: RequestContext, context_type: Optional[str] = None + ) -> Optional[FilterExpr]: + if ctx.role == Role.ROOT: + return None + + owner_spaces = [ctx.user.user_space_name(), ctx.user.agent_space_name()] + if context_type == "resource": + owner_spaces.append("") + return And([Eq("account_id", ctx.account_id), In("owner_space", owner_spaces)]) + + @staticmethod + def _merge_filters(*filters: Optional[FilterExpr]) -> Optional[FilterExpr]: + non_empty = [f for f in filters if f] + if not non_empty: + return None + if len(non_empty) == 1: + return non_empty[0] + return And(non_empty) + async def scroll( self, - collection: str, filter: Optional[Dict[str, Any] | FilterExpr] = None, limit: int = 100, cursor: Optional[str] = None, @@ -705,7 +910,6 @@ async def scroll( offset = int(cursor) if cursor else 0 records = await self.filter( - collection=collection, filter=filter or {}, limit=limit, offset=offset, @@ -722,11 +926,12 @@ async def scroll( # ========================================================================= async def count( - self, collection: str, filter: Optional[Dict[str, Any] | FilterExpr] = None + self, + filter: Optional[Dict[str, Any] | FilterExpr] = None, ) -> int: """Count records matching filter.""" try: - coll = self._get_collection(collection) + coll = self._get_collection(self._collection_name) result = coll.aggregate_data( index_name=self.DEFAULT_INDEX_NAME, op="count", @@ -743,7 +948,6 @@ async def count( async def create_index( self, - collection: str, field: str, index_type: str, **kwargs, @@ -758,7 +962,7 @@ async def create_index( logger.error(f"Error creating index on '{field}': {e}") return False - async def drop_index(self, collection: str, field: str) -> bool: + async def drop_index(self, field: str) -> bool: """Drop an index on a field.""" try: # vectordb manages indexes internally @@ -772,23 +976,23 @@ async def drop_index(self, collection: str, field: str) -> bool: # Lifecycle Operations # ========================================================================= - async def clear(self, collection: str) -> bool: + async def clear(self) -> bool: """Clear all data in a collection.""" - coll = self._get_collection(collection) + coll = self._get_collection(self._collection_name) try: coll.delete_all_data() - logger.info(f"Cleared all data in collection: {collection}") + logger.info(f"Cleared all data in collection: {self._collection_name}") return True except Exception as e: logger.error(f"Error clearing collection: {e}") return False - async def optimize(self, collection: str) -> bool: + async def optimize(self) -> bool: """Optimize collection for better performance.""" try: # vectordb handles optimization internally via index rebuilding - logger.info(f"Optimization requested for collection: {collection}") + logger.info("Optimization requested for collection: %s", self._collection_name) return True except Exception as e: logger.error(f"Error optimizing collection: {e}") diff --git a/openviking/storage/vikingdb_interface.py b/openviking/storage/vikingdb_interface.py deleted file mode 100644 index a3f6ec67..00000000 --- a/openviking/storage/vikingdb_interface.py +++ /dev/null @@ -1,593 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -""" -Storage interface for OpenViking. - -Defines the abstract storage interface inspired by vector database designs -(Milvus/Qdrant). All storage backends must implement this interface. -""" - -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple - -from openviking.storage.vector_store.expr import FilterExpr - - -class VikingDBInterface(ABC): - """ - Abstract vector indexing interface for OpenViking. - - This interface defines all vector indexing capabilities required by OpenViking. - New vector indexing backends should implement this interface to ensure compatibility. - - Capabilities: - - Collection management - - CRUD operations (single and batch) - - Vector similarity search - - Scalar filtering - - Index management - - Lifecycle management - """ - - # ========================================================================= - # Collection Management - # ========================================================================= - - @abstractmethod - async def create_collection(self, name: str, schema: Dict[str, Any]) -> bool: - """ - Create a new collection.` - - Args: - name: Collection name (e.g., "memory", "resource", "skill") - schema: Schema definition including: - - vector_dim: int - Vector dimension (default: 2048) - - distance: str - Distance metric ("cosine", "euclid", "dot") - - fields: List[dict] - Field definitions with name, type, indexed - - Returns: - True if created successfully, False if already exists - - Example: - schema = { - "vector_dim": 2048, - "distance": "cosine", - "fields": [ - {"name": "uri", "type": "string", "indexed": True}, - {"name": "abstract", "type": "text"}, - {"name": "active_count", "type": "integer"}, - ] - } - """ - pass - - @abstractmethod - async def drop_collection(self, name: str) -> bool: - """ - Drop a collection/table. - - Args: - name: Collection name - - Returns: - True if dropped successfully, False otherwise - """ - pass - - @abstractmethod - async def collection_exists(self, name: str) -> bool: - """ - Check if a collection exists. - - Args: - name: Collection name - - Returns: - True if exists, False otherwise - """ - pass - - @abstractmethod - async def list_collections(self) -> List[str]: - """ - List all collection names. - - Returns: - List of collection names - """ - pass - - @abstractmethod - async def get_collection_info(self, name: str) -> Optional[Dict[str, Any]]: - """ - Get collection metadata and statistics. - - Args: - name: Collection name - - Returns: - Dictionary with collection info: - - name: str - - vector_dim: int - - count: int - - status: str - Returns None if collection doesn't exist - """ - pass - - # ========================================================================= - # CRUD Operations - Single Record - # ========================================================================= - - @abstractmethod - async def insert(self, collection: str, data: Dict[str, Any]) -> str: - """ - Insert a single record. - - Args: - collection: Collection name - data: Record data. Must include: - - id: str (optional, auto-generated if not provided) - - vector: List[float] (optional) - - Other payload fields - - Returns: - ID of the inserted record - """ - pass - - @abstractmethod - async def update(self, collection: str, id: str, data: Dict[str, Any]) -> bool: - """ - Update a record by ID. - - Args: - collection: Collection name - id: Record ID - data: Fields to update (can include vector) - - Returns: - True if updated successfully, False if not found - """ - pass - - @abstractmethod - async def upsert(self, collection: str, data: Dict[str, Any]) -> str: - """ - Insert or update a record. - - If record with same ID exists, update it. Otherwise insert new record. - - Args: - collection: Collection name - data: Record data with id field - - Returns: - ID of the upserted record - """ - pass - - @abstractmethod - async def delete(self, collection: str, ids: List[str]) -> int: - """ - Delete records by IDs. - - Args: - collection: Collection name - ids: List of record IDs to delete - - Returns: - Number of records deleted - """ - pass - - @abstractmethod - async def get(self, collection: str, ids: List[str]) -> List[Dict[str, Any]]: - """ - Get records by IDs. - - Args: - collection: Collection name - ids: List of record IDs - - Returns: - List of records (may be fewer than requested if some IDs not found) - """ - pass - - @abstractmethod - async def exists(self, collection: str, id: str) -> bool: - """ - Check if a record exists. - - Args: - collection: Collection name - id: Record ID - - Returns: - True if exists, False otherwise - """ - pass - - # ========================================================================= - # CRUD Operations - Batch - # ========================================================================= - - @abstractmethod - async def batch_insert(self, collection: str, data: List[Dict[str, Any]]) -> List[str]: - """ - Batch insert multiple records. - - Args: - collection: Collection name - data: List of records - - Returns: - List of IDs of inserted records - """ - pass - - @abstractmethod - async def batch_upsert(self, collection: str, data: List[Dict[str, Any]]) -> List[str]: - """ - Batch insert or update multiple records. - - Args: - collection: Collection name - data: List of records with id fields - - Returns: - List of IDs of upserted records - """ - pass - - @abstractmethod - async def batch_delete(self, collection: str, filter: Dict[str, Any] | FilterExpr) -> int: - """ - Delete records matching filter conditions. - - Args: - collection: Collection name - filter: Filter conditions (AST or backend DSL dict) - - Returns: - Number of records deleted - """ - pass - - @abstractmethod - async def remove_by_uri( - self, - collection: str, - uri: str, - ) -> int: - """ - Remove resource(s) by URI. - - If the URI points to a directory, removes all descendants first, - then removes the directory itself. - - Args: - collection: Collection name - uri: URI to remove (e.g., "viking://resources/references/doc_name") - - Returns: - Number of records removed - - Example: - # Remove a single context - await storage.remove_by_uri("resource", "viking://resources/ref/doc/section1") - - # Remove entire document tree (directory + all children) - await storage.remove_by_uri("resource", "viking://resources/ref/doc_name") - """ - pass - - # ========================================================================= - # Search Operations - # ========================================================================= - - @abstractmethod - async def search( - self, - collection: str, - query_vector: Optional[List[float]] = None, - sparse_query_vector: Optional[Dict[str, float]] = None, - filter: Optional[Dict[str, Any] | FilterExpr] = None, - limit: int = 10, - offset: int = 0, - output_fields: Optional[List[str]] = None, - with_vector: bool = False, - ) -> List[Dict[str, Any]]: - """ - Hybrid search: vector similarity + scalar filtering + sparse vector matching. - - Args: - collection: Collection name - query_vector: Dense query vector for similarity search (optional) - sparse_query_vector: Sparse query vector for term matching (optional, Dict[str, float]) - filter: Scalar filter conditions (optional, AST or backend DSL dict) - limit: Maximum number of results - offset: Offset for pagination - output_fields: Fields to return (None for all) - with_vector: Include vector in results - - Returns: - List of matching records. If query_vector provided, includes _score field. - - Notes: - - If both query_vector and sparse_query_vector are provided, performs hybrid search - - If only query_vector is provided, performs dense vector search - - If only sparse_query_vector is provided, performs sparse search - - If neither is provided, performs filter-only search - - Filter format (VikingVectorIndex DSL): - { - "op": "and" | "or", - "conds": [ - {"op": "must", "field": "name", "conds": [value]}, - {"op": "range", "field": "age", "gte": 18, "lt": 65}, - {"op": "must", "field": "uri", "conds": ["viking://"]}, - {"op": "contains", "field": "desc", "substring": "hello"} - ] - } - - Example: - # Dense search - results = await storage.search( - collection="context", - query_vector=embedding, - filter={ - "op": "and", - "conds": [ - {"op": "must", "field": "uri", "conds": ["viking://user"]}, - {"op": "range", "field": "active_count", "gte": 1} - ] - }, - limit=10 - ) - """ - pass - - @abstractmethod - async def filter( - self, - collection: str, - filter: Dict[str, Any] | FilterExpr, - limit: int = 10, - offset: int = 0, - output_fields: Optional[List[str]] = None, - order_by: Optional[str] = None, - order_desc: bool = False, - ) -> List[Dict[str, Any]]: - """ - Pure scalar filtering without vector search. - - Args: - collection: Collection name - filter: Filter conditions (AST or backend DSL dict) - limit: Maximum number of results - offset: Offset for pagination - output_fields: Fields to return - order_by: Field to sort by (optional) - order_desc: Sort descending if True - - Returns: - List of matching records - """ - pass - - @abstractmethod - async def scroll( - self, - collection: str, - filter: Optional[Dict[str, Any] | FilterExpr] = None, - limit: int = 100, - cursor: Optional[str] = None, - output_fields: Optional[List[str]] = None, - ) -> Tuple[List[Dict[str, Any]], Optional[str]]: - """ - Scroll through large result sets efficiently. - - Args: - collection: Collection name - filter: Optional filter conditions (AST or backend DSL dict) - limit: Batch size - cursor: Cursor from previous scroll (None for first batch) - output_fields: Fields to return - - Returns: - Tuple of (records, next_cursor). next_cursor is None when exhausted. - - Example: - cursor = None - while True: - records, cursor = await storage.scroll( - "memory", limit=100, cursor=cursor - ) - process(records) - if cursor is None: - break - """ - pass - - # ========================================================================= - # Aggregation Operations - # ========================================================================= - - @abstractmethod - async def count( - self, collection: str, filter: Optional[Dict[str, Any] | FilterExpr] = None - ) -> int: - """ - Count records matching filter. - - Args: - collection: Collection name - filter: Optional filter conditions (AST or backend DSL dict) - - Returns: - Number of matching records - """ - pass - - # ========================================================================= - # Index Operations - # ========================================================================= - - @abstractmethod - async def create_index( - self, - collection: str, - field: str, - index_type: str, - **kwargs, - ) -> bool: - """ - Create an index on a field. - - Args: - collection: Collection name - field: Field name to index - index_type: Index type: - - "keyword": Exact match index - - "text": Full-text search index - - "integer": Numeric range index - - "float": Numeric range index - - "bool": Boolean index - **kwargs: Additional index parameters - - Returns: - True if created successfully - """ - pass - - @abstractmethod - async def drop_index(self, collection: str, field: str) -> bool: - """ - Drop an index on a field. - - Args: - collection: Collection name - field: Field name - - Returns: - True if dropped successfully - """ - pass - - # ========================================================================= - # Lifecycle Operations - # ========================================================================= - - @abstractmethod - async def clear(self, collection: str) -> bool: - """ - Clear all data in a collection (keep schema). - - Args: - collection: Collection name - - Returns: - True if cleared successfully - """ - pass - - @abstractmethod - async def optimize(self, collection: str) -> bool: - """ - Optimize collection for better performance. - - Triggers index optimization, compaction, etc. - - Args: - collection: Collection name - - Returns: - True if optimization completed - """ - pass - - @abstractmethod - async def close(self) -> None: - """ - Close storage connection and release resources. - - Should be called when storage is no longer needed. - """ - pass - - # ========================================================================= - # Health & Status - # ========================================================================= - - @abstractmethod - async def health_check(self) -> bool: - """ - Check if storage backend is healthy and accessible. - - Returns: - True if healthy, False otherwise - """ - pass - - @abstractmethod - async def get_stats(self) -> Dict[str, Any]: - """ - Get storage statistics. - - Returns: - Dictionary with stats: - - collections: int - Number of collections - - total_records: int - Total record count - - storage_size: int - Storage size in bytes (if available) - - backend: str - Backend type identifier - """ - pass - - -# ============================================================================= -# Exceptions -# ============================================================================= - - -class VikingDBException(Exception): - """Base exception for VikingDB operations.""" - - pass - - -class StorageException(VikingDBException): - """Legacy alias for VikingDBException for backward compatibility.""" - - pass - - -class CollectionNotFoundError(StorageException): - """Raised when a collection does not exist.""" - - pass - - -class RecordNotFoundError(StorageException): - """Raised when a record does not exist.""" - - pass - - -class DuplicateKeyError(StorageException): - """Raised when trying to insert a duplicate key.""" - - pass - - -class ConnectionError(StorageException): - """Raised when storage connection fails.""" - - pass - - -class SchemaError(StorageException): - """Raised when schema validation fails.""" - - pass From f003c3319573ab10e433bca6aa89bff0c5419139 Mon Sep 17 00:00:00 2001 From: "zhoujiahui.01" Date: Thu, 26 Feb 2026 21:19:39 +0800 Subject: [PATCH 4/7] refactor vectordb backend to single-collection adapter model --- .../vectordb-gateway-collection-refactor.md | 328 +++++++ openviking/eval/ragas/playback.py | 13 +- openviking/eval/recorder/wrapper.py | 35 +- openviking/storage/collection_adapter.py | 729 +++++++++++++++ openviking/storage/collection_schemas.py | 2 +- .../storage/observers/vikingdb_observer.py | 8 +- openviking/storage/vector_store/__init__.py | 18 +- openviking/storage/vector_store/driver.py | 161 ---- .../storage/vector_store/drivers/__init__.py | 15 - .../storage/vector_store/drivers/common.py | 33 - .../vector_store/drivers/http_driver.py | 119 --- .../vector_store/drivers/local_driver.py | 91 -- .../vector_store/drivers/vikingdb_driver.py | 159 ---- .../vector_store/drivers/volcengine_driver.py | 173 ---- openviking/storage/vector_store/expr.py | 14 +- openviking/storage/vector_store/factory.py | 17 - openviking/storage/vector_store/registry.py | 35 - .../storage/viking_vector_index_backend.py | 851 ++++-------------- ...test_hierarchical_retriever_target_dirs.py | 77 +- tests/session/test_memory_dedup_actions.py | 34 +- 20 files changed, 1331 insertions(+), 1581 deletions(-) create mode 100644 docs/design/storage/vectordb-gateway-collection-refactor.md create mode 100644 openviking/storage/collection_adapter.py delete mode 100644 openviking/storage/vector_store/driver.py delete mode 100644 openviking/storage/vector_store/drivers/__init__.py delete mode 100644 openviking/storage/vector_store/drivers/common.py delete mode 100644 openviking/storage/vector_store/drivers/http_driver.py delete mode 100644 openviking/storage/vector_store/drivers/local_driver.py delete mode 100644 openviking/storage/vector_store/drivers/vikingdb_driver.py delete mode 100644 openviking/storage/vector_store/drivers/volcengine_driver.py delete mode 100644 openviking/storage/vector_store/factory.py delete mode 100644 openviking/storage/vector_store/registry.py diff --git a/docs/design/storage/vectordb-gateway-collection-refactor.md b/docs/design/storage/vectordb-gateway-collection-refactor.md new file mode 100644 index 00000000..f99398f5 --- /dev/null +++ b/docs/design/storage/vectordb-gateway-collection-refactor.md @@ -0,0 +1,328 @@ +# OpenViking 向量存储分层重构设计(Gateway / Collection / Filter) + +> 日期:2026-02-26 +> 状态:Draft(可直接进入实施) +> 范围:`openviking/storage`、`openviking/retrieve`、`openviking/eval` 的向量存储接入层 + +--- + +## 1. 背景与目标 + +当前向量存储路径中,`VikingVectorIndexBackend`、`vector_store.driver`、`collection_adapter`、`vectordb.collection` 在职责上有重叠: + +- 过滤表达式编译(AST -> DSL)分散在多层。 +- 后端差异(URL 适配、索引参数、读写字段规范化)未完全下沉。 +- `collection` 语义与“单 collection 绑定”的运行模型不完全一致。 + +本次重构目标: + +1. **通用业务逻辑上收**到 gateway/backend 层。 +2. **后端差异下沉**到 collection 层。 +3. `compile_filter` **可扩展但简单**(默认实现 + 子类 override)。 +4. backend/store **单 collection 绑定**(除 `create_collection(name, ...)` 外,不再传 collection)。 +5. 过滤语义统一:`uri`/`parent_uri` 等 path 字段统一使用 `must` 语义。 +6. 移除 `Prefix` / `Regex` expr AST 能力。 + +--- + +## 2. 已确认决策(来自前序讨论) + +### 2.1 分层决策 + +- `CollectionAdapter` 与旧 `driver` 职责高度重叠,最终应移除。 +- `compile_filter` 不放在独立 driver 层,放到 `ICollection`,由 `Collection` 在调用查询接口前触发。 +- 新增向量库时,只需在对应 collection 实现重写 `compile_filter`(默认可不重写)。 + +### 2.2 单 collection 约束 + +- `VikingVectorIndexBackend`(及其上层管理者)内部只持有一个当前 collection。 +- `create_collection(name, ...)` 是唯一需要显式 `name` 的入口。 +- 后续 CRUD / search / filter 等都操作绑定 collection,不再重复传 collection 参数。 + +### 2.3 filter 语义决策 + +- 输入兼容:`FilterExpr | dict | None`。 +- path 字段过滤不使用 prefix op,统一映射到 `must`。 +- 不引入 `_path_must` 之类额外包装层。 + +### 2.4 命名与代码风格决策 + +- 业务代码中原有 `self.vikingdb` 命名保持,不做无收益替换为 `self.vector_store`。 +- gateway 命名应更语义化(见第 8 节迁移建议)。 + +### 2.5 表达式能力决策 + +- `Prefix` / `Regex` 两个 expr AST 能力移除。 +- 若需要后端特有复杂语法,使用 `RawDSL` 或 backend-specific collection override。 + +--- + +## 3. 当前代码现状(实施前基线) + +### 3.1 主要组件 + +- `openviking/storage/viking_vector_index_backend.py` + - 当前承担大量业务逻辑、collection 管理、filter 编译调用。 +- `openviking/storage/vector_store/driver.py` + `drivers/*` + - 当前仍存在后端差异封装与 `compile_expr`。 +- `openviking/storage/collection_adapter.py` + - 与 driver 层重复(工厂 + backend 分发 + filter 编译 + normalize)。 +- `openviking/storage/vectordb/collection/collection.py` + - `ICollection` / `Collection` 封装,尚未成为 filter 编译单一入口。 + +### 3.2 现状问题清单 + +1. **重复抽象**:driver 与 collection_adapter 并存,维护成本高。 +2. **职责漂移**:backend 层仍携带后端差异处理逻辑。 +3. **扩展成本高**:新增后端需改多层(factory/driver/backend)。 +4. **接口噪音**:单 collection 场景下仍出现 collection 参数概念。 + +--- + +## 4. 目标架构 + +```text +[Business Callers] + | + v +[Semantic Gateway / VikingDBManager] + - 租户作用域/业务语义 + - 单 collection 生命周期 + - 通用查询编排 + | + v +[Collection (wrapper)] + - 所有含 filters 的调用前统一 compile_filter + - 统一结果包装 + | + v +[ICollection implementations] + - LocalCollection + - HttpCollection + - VolcengineCollection + - VikingDBCollection + - backend-specific compile_filter override(可选) +``` + +核心原则: + +- **只保留一条“filter 编译路径”**:`ICollection.compile_filter`。 +- **后端差异只出现在具体 collection 子类**。 +- **gateway 不持有 backend 语法细节**。 + +--- + +## 5. `compile_filter` 设计规范 + +### 5.1 接口定义 + +在 `ICollection` 增加默认实现: + +```python +def compile_filter(self, filter_expr: FilterExpr | dict | None) -> dict: + ... +``` + +### 5.2 默认行为 + +- `None` -> `{}` +- `dict` -> 原样透传 +- `RawDSL` -> 透传 payload +- `Eq/In` -> `{"op": "must", "field": ..., "conds": [...]}` +- `And/Or/Range/Contains/TimeRange` -> 按统一 DSL 映射 + +### 5.3 可扩展机制 + +- 新后端如语法不同:只在该后端 collection 中重写 `compile_filter`。 +- 未重写时自动使用默认实现,保证接入门槛低。 + +### 5.4 复杂度控制 + +- 不引入额外 compiler 注册中心。 +- 不新增 driver 级 compiler 层。 +- 优先“默认实现 + 最小 override”。 + +--- + +## 6. FilterExpr 能力边界(重构后) + +### 6.1 保留能力 + +- `And` +- `Or` +- `Eq` +- `In` +- `Range` +- `Contains` +- `TimeRange` +- `RawDSL` + +### 6.2 移除能力 + +- `Prefix` +- `Regex` + +### 6.3 path 字段规则 + +对 `uri` / `parent_uri` / 其他路径字段,统一使用: + +```json +{"op":"must","field":"uri","conds":["..."]} +``` + +不允许在 AST 层保留 prefix 语义入口。 + +--- + +## 7. 单 Collection 运行模型 + +### 7.1 约束 + +- backend 实例内部仅绑定一个 active collection。 +- 除 `create_collection(name, schema)` 外,其余操作基于绑定对象。 + +### 7.2 操作模型 + +- 创建流程: + - `create_collection(name, schema)` + - 建立 `self._collection` 绑定 + - 更新 meta cache +- 数据流程: + - `insert/update/upsert/delete/get/search/filter/count/...` 均操作 `self._collection` + +### 7.3 错误模型 + +- 未绑定 collection 时抛 `CollectionNotFoundError` 或统一运行时错误。 +- 不再依赖每次调用传 collection_name 做防御。 + +--- + +## 8. 命名与接口整理建议 + +### 8.1 gateway 命名 + +建议把“语义检索网关”命名统一为更语义化名称(例如 `SemanticGateway` / `SemanticContextGateway`)。 + +> 兼容策略:保留原类名 alias 一段时间,避免一次性大面积改动。 + +### 8.2 变量命名一致性 + +- 业务模块中已存在 `self.vikingdb` 的位置保持不变。 +- 不做“仅换名不换义”的全局改名(避免噪音 diff)。 + +### 8.3 说明 + +本设计文档不处理 CRUD private 收口;该议题后续单独评估。 + +--- + +## 9. 分阶段迁移计划(实施顺序) + +### Phase A:FilterExpr 与语义基线 + +1. 移除 `Prefix` / `Regex` AST 定义与导出。 +2. 清理编译分支中对应逻辑。 +3. 修复测试中对 `Prefix`/`Regex` 的 AST 断言。 + +**验收**:无 `Prefix`/`Regex` 类型引用;编译与静态检查通过。 + +### Phase B:compile_filter 下沉到 Collection + +1. `ICollection` 增加默认 `compile_filter`。 +2. `Collection` wrapper 在查询前统一调用 `compile_filter`。 +3. backend 中重复的 filter 编译逻辑删除。 + +**验收**:调用方仍可传 AST/dict,行为一致。 + +### Phase C:去除 driver / adapter 重复层 + +1. backend 不再依赖 `create_driver` / `VectorStoreDriver`。 +2. 删除(或停用)`collection_adapter.py` 与 `vector_store/driver*` 路径。 +3. 后端差异迁移到具体 `ICollection` 子类。 + +**验收**:新增 backend 仅需实现 collection + optional compile_filter override。 + +### Phase D:收口与稳定 + +1. 清理遗留 import/export。 +2. 更新设计文档与开发文档。 +3. 完成回归测试矩阵。 + +**验收**:无 dead adapter/driver 引用;主链路稳定。 + +--- + +## 10. 影响面与兼容性 + +### 10.1 影响模块 + +- `openviking/storage/*` +- `openviking/retrieve/*` +- `openviking/session/*` +- `openviking/eval/recorder/*` +- `openviking/eval/ragas/*` + +### 10.2 兼容策略 + +- dict filter 调用保持兼容。 +- AST 精简(去 Prefix/Regex)属于显式破坏性变更,依赖方需改为 `In` 或 `RawDSL`。 +- eval/recorder 现状可继续使用通用 CRUD,不在本文档范围做 private 收口。 + +--- + +## 11. 测试与验收标准 + +### 11.1 静态与构建 + +- `ruff check` 通过 +- `python -m compileall openviking` 通过 + +### 11.2 功能测试矩阵 + +1. **filter 编译**:AST / dict / RawDSL / None +2. **路径过滤**:`uri`/`parent_uri` 使用 `must` 语义 +3. **单 collection**:create 后无需传 collection 参数 +4. **后端差异**:至少验证一个后端 override `compile_filter` 生效 +5. **回归链路**:检索、去重、目录初始化、URI 更新映射 + +### 11.3 代码检索验收 + +- 无 `Prefix` / `Regex` expr 定义与引用。 +- 无 `collection_adapter` / `VectorStoreDriver` 生产路径引用(完成 Phase C 后)。 + +--- + +## 12. 风险与缓解 + +### 风险 1:迁移期双实现并存导致行为不一致 + +- **缓解**:以 `ICollection.compile_filter` 为唯一真源,旧分支尽快删除。 + +### 风险 2:测试桩接口与真实接口漂移 + +- **缓解**:统一测试 stub 最小接口契约,优先修复 `collection_exists_bound` 等缺失。 + +### 风险 3:后端特化语法回归 + +- **缓解**:在对应 collection override 中增加最小单测覆盖。 + +--- + +## 13. Out of Scope(本轮明确不做) + +1. CRUD public/private 收口策略。 +2. eval/recorder 能力边界重定义。 +3. 非向量存储模块(FS/Parser/Client)的结构性重构。 + +--- + +## 14. 实施完成定义(DoD) + +满足以下条件可认为本重构完成: + +1. filter 编译链路单一(Collection 入口)。 +2. backend 单 collection 绑定模式稳定。 +3. driver/adapter 重复层移除。 +4. `Prefix`/`Regex` expr 能力移除且无残留调用。 +5. 主流程回归通过并补齐设计文档。 + diff --git a/openviking/eval/ragas/playback.py b/openviking/eval/ragas/playback.py index 8f48d820..3a9b75e7 100644 --- a/openviking/eval/ragas/playback.py +++ b/openviking/eval/ragas/playback.py @@ -419,7 +419,7 @@ async def _play_vikingdb_operation(self, record: IORecord) -> PlaybackResult: payload = args[-1] else: payload = kwargs.get("data", request.get("data", {})) - await self._vector_store.insert(payload) + await self._vector_store.upsert(payload) elif operation == "update": if len(args) >= 3: record_id = args[-2] @@ -430,7 +430,10 @@ async def _play_vikingdb_operation(self, record: IORecord) -> PlaybackResult: else: record_id = kwargs.get("id", request.get("id")) payload = kwargs.get("data", request.get("data", {})) - await self._vector_store.update(record_id, payload) + existing = await self._vector_store.get([record_id]) + if existing: + merged = {**existing[0], **payload, "id": record_id} + await self._vector_store.upsert(merged) elif operation == "upsert": if args: payload = args[-1] @@ -490,11 +493,9 @@ async def _play_vikingdb_operation(self, record: IORecord) -> PlaybackResult: elif operation == "create_collection": await self._vector_store.create_collection(*args, **kwargs) elif operation == "drop_collection": - await self._vector_store.drop_collection(*args, **kwargs) + await self._vector_store.drop_collection() elif operation == "collection_exists": - await self._vector_store.collection_exists(*args, **kwargs) - elif operation == "list_collections": - await self._vector_store.list_collections(*args, **kwargs) + await self._vector_store.collection_exists() else: raise ValueError(f"Unknown VikingDB operation: {operation}") diff --git a/openviking/eval/recorder/wrapper.py b/openviking/eval/recorder/wrapper.py index cfb13793..685bfb3b 100644 --- a/openviking/eval/recorder/wrapper.py +++ b/openviking/eval/recorder/wrapper.py @@ -294,7 +294,7 @@ async def insert(self, collection: str, data: Dict[str, Any]) -> str: request = {"collection": collection, "data": data} start_time = time.time() try: - result = await self._db.insert(data) + result = await self._db.upsert(data) latency_ms = (time.time() - start_time) * 1000 self._record("insert", request, result, latency_ms) return result @@ -308,7 +308,12 @@ async def update(self, collection: str, id: str, data: Dict[str, Any]) -> bool: request = {"collection": collection, "id": id, "data": data} start_time = time.time() try: - result = await self._db.update(id, data) + existing = await self._db.get([id]) + if not existing: + result = False + else: + payload = {**existing[0], **data, "id": id} + result = bool(await self._db.upsert(payload)) latency_ms = (time.time() - start_time) * 1000 self._record("update", request, result, latency_ms) return result @@ -435,12 +440,12 @@ async def create_collection(self, name: str, schema: Dict[str, Any]) -> bool: self._record("create_collection", request, None, latency_ms, False, str(e)) raise - async def drop_collection(self, name: str) -> bool: + async def drop_collection(self) -> bool: """Drop collection with recording.""" - request = {"name": name} + request = {} start_time = time.time() try: - result = await self._db.drop_collection(name) + result = await self._db.drop_collection() latency_ms = (time.time() - start_time) * 1000 self._record("drop_collection", request, result, latency_ms) return result @@ -449,12 +454,12 @@ async def drop_collection(self, name: str) -> bool: self._record("drop_collection", request, None, latency_ms, False, str(e)) raise - async def collection_exists(self, name: str) -> bool: + async def collection_exists(self) -> bool: """Check collection exists with recording.""" - request = {"name": name} + request = {} start_time = time.time() try: - result = await self._db.collection_exists(name) + result = await self._db.collection_exists() latency_ms = (time.time() - start_time) * 1000 self._record("collection_exists", request, result, latency_ms) return result @@ -463,20 +468,6 @@ async def collection_exists(self, name: str) -> bool: self._record("collection_exists", request, None, latency_ms, False, str(e)) raise - async def list_collections(self) -> List[str]: - """List collections with recording.""" - request = {} - start_time = time.time() - try: - result = await self._db.list_collections() - latency_ms = (time.time() - start_time) * 1000 - self._record("list_collections", request, result, latency_ms) - return result - except Exception as e: - latency_ms = (time.time() - start_time) * 1000 - self._record("list_collections", request, None, latency_ms, False, str(e)) - raise - def __getattr__(self, name: str) -> Any: """Pass through any other attributes to the wrapped db.""" return getattr(self._db, name) diff --git a/openviking/storage/collection_adapter.py b/openviking/storage/collection_adapter.py new file mode 100644 index 00000000..5e621ba0 --- /dev/null +++ b/openviking/storage/collection_adapter.py @@ -0,0 +1,729 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Collection adapter layer for backend-specific storage integration.""" + +from __future__ import annotations + +import os +import uuid +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Dict, Iterable, Optional +from urllib.parse import urlparse + +from openviking.storage.errors import CollectionNotFoundError +from openviking.storage.vector_store.expr import ( + And, + Contains, + Eq, + FilterExpr, + In, + Or, + Range, + RawDSL, + TimeRange, +) +from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.http_collection import ( + HttpCollection, + get_or_create_http_collection, + list_vikingdb_collections, +) +from openviking.storage.vectordb.collection.local_collection import get_or_create_local_collection +from openviking.storage.vectordb.collection.result import FetchDataInCollectionResult +from openviking.storage.vectordb.collection.vikingdb_clients import VIKINGDB_APIS, VikingDBClient +from openviking.storage.vectordb.collection.vikingdb_collection import VikingDBCollection +from openviking.storage.vectordb.collection.volcengine_collection import ( + VolcengineCollection, + get_or_create_volcengine_collection, +) +from openviking_cli.utils import get_logger + +logger = get_logger(__name__) + + +def _parse_url(url: str) -> tuple[str, int]: + normalized = url + if not normalized.startswith(("http://", "https://")): + normalized = f"http://{normalized}" + parsed = urlparse(normalized) + host = parsed.hostname or "127.0.0.1" + port = parsed.port or 5000 + return host, port + + +def _normalize_collection_names(raw_collections: Iterable[Any]) -> list[str]: + names: list[str] = [] + for item in raw_collections: + if isinstance(item, str): + names.append(item) + elif isinstance(item, dict): + name = item.get("CollectionName") or item.get("collection_name") or item.get("name") + if isinstance(name, str): + names.append(name) + return names + + +class CollectionAdapter(ABC): + """Backend-specific adapter for single-collection operations.""" + + mode: str + + def __init__(self, collection_name: str): + self._collection_name = collection_name + self._collection: Optional[Collection] = None + + @property + def collection_name(self) -> str: + return self._collection_name + + @classmethod + @abstractmethod + def from_config(cls, config: Any) -> "CollectionAdapter": + """Create an adapter instance from VectorDB backend config.""" + + @abstractmethod + def _load_existing_collection_if_needed(self) -> None: + """Load existing bound collection handle when possible.""" + + @abstractmethod + def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: + """Create backend collection handle for bound collection.""" + + def collection_exists(self) -> bool: + self._load_existing_collection_if_needed() + return self._collection is not None + + def get_collection(self) -> Collection: + self._load_existing_collection_if_needed() + if self._collection is None: + raise CollectionNotFoundError(f"Collection {self._collection_name} does not exist") + return self._collection + + def create_collection( + self, + name: str, + schema: Dict[str, Any], + *, + distance: str, + sparse_weight: float, + index_name: str, + ) -> bool: + if self.collection_exists(): + return False + + self._collection_name = name + collection_meta = dict(schema) + scalar_index_fields = collection_meta.pop("ScalarIndex", []) + if "CollectionName" not in collection_meta: + collection_meta["CollectionName"] = name + + self._collection = self._create_backend_collection(collection_meta) + + scalar_index_fields = self.sanitize_scalar_index_fields( + scalar_index_fields=scalar_index_fields, + fields_meta=collection_meta.get("Fields", []), + ) + index_meta = self.build_default_index_meta( + index_name=index_name, + distance=distance, + use_sparse=sparse_weight > 0.0, + sparse_weight=sparse_weight, + scalar_index_fields=scalar_index_fields, + ) + self._collection.create_index(index_name, index_meta) + return True + + def drop_collection(self) -> bool: + if not self.collection_exists(): + return False + + coll = self.get_collection() + + # Drop indexes first so index lifecycle remains internal to adapter. + try: + for index_name in coll.list_indexes() or []: + try: + coll.drop_index(index_name) + except Exception as e: + logger.warning("Failed to drop index %s: %s", index_name, e) + except Exception as e: + logger.warning("Failed to list indexes before dropping collection: %s", e) + + try: + coll.drop() + except NotImplementedError: + logger.warning("Collection drop is not supported by backend mode=%s", self.mode) + return False + finally: + self._collection = None + + return True + + def close(self) -> None: + if self._collection is not None: + self._collection.close() + self._collection = None + + def get_collection_info(self) -> Optional[Dict[str, Any]]: + if not self.collection_exists(): + return None + return self.get_collection().get_meta_data() + + def sanitize_scalar_index_fields( + self, + scalar_index_fields: list[str], + fields_meta: list[dict[str, Any]], + ) -> list[str]: + return scalar_index_fields + + def build_default_index_meta( + self, + *, + index_name: str, + distance: str, + use_sparse: bool, + sparse_weight: float, + scalar_index_fields: list[str], + ) -> Dict[str, Any]: + index_type = "flat_hybrid" if use_sparse else "flat" + index_meta: Dict[str, Any] = { + "IndexName": index_name, + "VectorIndex": { + "IndexType": index_type, + "Distance": distance, + "Quant": "int8", + }, + "ScalarIndex": scalar_index_fields, + } + if use_sparse: + index_meta["VectorIndex"]["EnableSparse"] = True + index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight + return index_meta + + def normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: + return record + + def compile_filter(self, expr: FilterExpr | Dict[str, Any] | None) -> Dict[str, Any]: + if expr is None: + return {} + if isinstance(expr, dict): + return expr + if isinstance(expr, RawDSL): + return expr.payload + if isinstance(expr, And): + conds = [self.compile_filter(c) for c in expr.conds if c is not None] + conds = [c for c in conds if c] + if not conds: + return {} + if len(conds) == 1: + return conds[0] + return {"op": "and", "conds": conds} + if isinstance(expr, Or): + conds = [self.compile_filter(c) for c in expr.conds if c is not None] + conds = [c for c in conds if c] + if not conds: + return {} + if len(conds) == 1: + return conds[0] + return {"op": "or", "conds": conds} + if isinstance(expr, Eq): + return {"op": "must", "field": expr.field, "conds": [expr.value]} + if isinstance(expr, In): + return {"op": "must", "field": expr.field, "conds": list(expr.values)} + if isinstance(expr, Range): + payload: Dict[str, Any] = {"op": "range", "field": expr.field} + if expr.gte is not None: + payload["gte"] = expr.gte + if expr.gt is not None: + payload["gt"] = expr.gt + if expr.lte is not None: + payload["lte"] = expr.lte + if expr.lt is not None: + payload["lt"] = expr.lt + return payload + if isinstance(expr, Contains): + return { + "op": "contains", + "field": expr.field, + "substring": expr.substring, + } + if isinstance(expr, TimeRange): + payload: Dict[str, Any] = {"op": "range", "field": expr.field} + if expr.start is not None: + payload["gte"] = expr.start + if expr.end is not None: + payload["lt"] = expr.end + return payload + raise TypeError(f"Unsupported filter expr type: {type(expr)!r}") + + def upsert(self, data: Dict[str, Any] | list[Dict[str, Any]]) -> list[str]: + coll = self.get_collection() + records = [data] if isinstance(data, dict) else data + normalized: list[Dict[str, Any]] = [] + ids: list[str] = [] + for item in records: + record = dict(item) + record_id = record.get("id") or str(uuid.uuid4()) + record["id"] = record_id + ids.append(record_id) + normalized.append(record) + coll.upsert_data(normalized) + return ids + + def get(self, ids: list[str]) -> list[Dict[str, Any]]: + coll = self.get_collection() + result = coll.fetch_data(ids) + + records: list[Dict[str, Any]] = [] + if isinstance(result, FetchDataInCollectionResult): + for item in result.items: + record = dict(item.fields) if item.fields else {} + record["id"] = item.id + records.append(self.normalize_record_for_read(record)) + return records + + if isinstance(result, dict) and "fetch" in result: + for item in result.get("fetch", []): + record = dict(item.get("fields", {})) if item.get("fields") else {} + record_id = item.get("id") + if record_id: + record["id"] = record_id + records.append(self.normalize_record_for_read(record)) + return records + + def query( + self, + *, + query_vector: Optional[list[float]] = None, + sparse_query_vector: Optional[Dict[str, float]] = None, + filter: Optional[Dict[str, Any] | FilterExpr] = None, + limit: int = 10, + offset: int = 0, + output_fields: Optional[list[str]] = None, + with_vector: bool = False, + order_by: Optional[str] = None, + order_desc: bool = False, + ) -> list[Dict[str, Any]]: + coll = self.get_collection() + vectordb_filter = self.compile_filter(filter) + + if query_vector or sparse_query_vector: + result = coll.search_by_vector( + index_name="default", + dense_vector=query_vector, + sparse_vector=sparse_query_vector, + limit=limit, + offset=offset, + filters=vectordb_filter, + output_fields=output_fields, + ) + elif order_by: + result = coll.search_by_scalar( + index_name="default", + field=order_by, + order="desc" if order_desc else "asc", + limit=limit, + offset=offset, + filters=vectordb_filter, + output_fields=output_fields, + ) + else: + result = coll.search_by_random( + index_name="default", + limit=limit, + offset=offset, + filters=vectordb_filter, + output_fields=output_fields, + ) + + records: list[Dict[str, Any]] = [] + for item in result.data: + record = dict(item.fields) if item.fields else {} + record["id"] = item.id + record["_score"] = item.score if item.score is not None else 0.0 + record = self.normalize_record_for_read(record) + if not with_vector: + record.pop("vector", None) + record.pop("sparse_vector", None) + records.append(record) + return records + + def delete( + self, + *, + ids: Optional[list[str]] = None, + filter: Optional[Dict[str, Any] | FilterExpr] = None, + limit: int = 100000, + ) -> int: + coll = self.get_collection() + delete_ids = list(ids or []) + if not delete_ids and filter is not None: + matched = self.query(filter=filter, limit=limit, with_vector=True) + delete_ids = [record["id"] for record in matched if record.get("id")] + + if not delete_ids: + return 0 + + coll.delete_data(delete_ids) + return len(delete_ids) + + def count(self, filter: Optional[Dict[str, Any] | FilterExpr] = None) -> int: + coll = self.get_collection() + result = coll.aggregate_data( + index_name="default", + op="count", + filters=self.compile_filter(filter), + ) + return result.agg.get("_total", 0) + + def clear(self) -> bool: + self.get_collection().delete_all_data() + return True + + +class LocalCollectionAdapter(CollectionAdapter): + """Adapter for local embedded vectordb backend.""" + + DEFAULT_LOCAL_PROJECT_NAME = "vectordb" + + def __init__(self, collection_name: str, project_path: str): + super().__init__(collection_name=collection_name) + self.mode = "local" + self._project_path = project_path + + @classmethod + def from_config(cls, config): + project_path = ( + str(Path(config.path) / cls.DEFAULT_LOCAL_PROJECT_NAME) if config.path else "" + ) + return cls(collection_name=config.name or "context", project_path=project_path) + + def _collection_path(self) -> str: + if not self._project_path: + return "" + return str(Path(self._project_path) / self._collection_name) + + def _load_existing_collection_if_needed(self) -> None: + if self._collection is not None: + return + collection_path = self._collection_path() + if not collection_path: + return + meta_path = os.path.join(collection_path, "collection_meta.json") + if os.path.exists(meta_path): + self._collection = get_or_create_local_collection(path=collection_path) + + def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: + collection_path = self._collection_path() + if collection_path: + os.makedirs(collection_path, exist_ok=True) + return get_or_create_local_collection(meta_data=meta, path=collection_path) + + +class HttpCollectionAdapter(CollectionAdapter): + """Adapter for remote HTTP vectordb project.""" + + def __init__(self, host: str, port: int, project_name: str, collection_name: str): + super().__init__(collection_name=collection_name) + self.mode = "http" + self._host = host + self._port = port + self._project_name = project_name + + @classmethod + def from_config(cls, config): + if not config.url: + raise ValueError("HTTP backend requires a valid URL") + host, port = _parse_url(config.url) + return cls( + host=host, + port=port, + project_name=config.project_name or "default", + collection_name=config.name or "context", + ) + + def _meta(self) -> Dict[str, Any]: + return { + "ProjectName": self._project_name, + "CollectionName": self._collection_name, + } + + def _remote_has_collection(self) -> bool: + raw = list_vikingdb_collections( + host=self._host, + port=self._port, + project_name=self._project_name, + ) + return self._collection_name in _normalize_collection_names(raw) + + def _load_existing_collection_if_needed(self) -> None: + if self._collection is not None: + return + if not self._remote_has_collection(): + return + self._collection = Collection( + HttpCollection( + ip=self._host, + port=self._port, + meta_data=self._meta(), + ) + ) + + def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: + payload = dict(meta) + payload.update(self._meta()) + return get_or_create_http_collection( + host=self._host, + port=self._port, + meta_data=payload, + ) + + +class VolcengineCollectionAdapter(CollectionAdapter): + """Adapter for Volcengine-hosted VikingDB.""" + + def __init__( + self, + *, + ak: str, + sk: str, + region: str, + host: str, + project_name: str, + collection_name: str, + ): + super().__init__(collection_name=collection_name) + self.mode = "volcengine" + self._ak = ak + self._sk = sk + self._region = region + self._host = host + self._project_name = project_name + + @classmethod + def from_config(cls, config): + if not ( + config.volcengine + and config.volcengine.ak + and config.volcengine.sk + and config.volcengine.region + ): + raise ValueError("Volcengine backend requires AK, SK, and Region configuration") + return cls( + ak=config.volcengine.ak, + sk=config.volcengine.sk, + region=config.volcengine.region, + host=config.volcengine.host or "", + project_name=config.project_name or "default", + collection_name=config.name or "context", + ) + + def _meta(self) -> Dict[str, Any]: + return { + "ProjectName": self._project_name, + "CollectionName": self._collection_name, + } + + def _config(self) -> Dict[str, Any]: + return { + "AK": self._ak, + "SK": self._sk, + "Region": self._region, + "Host": self._host, + } + + def _new_collection_handle(self) -> VolcengineCollection: + return VolcengineCollection( + ak=self._ak, + sk=self._sk, + region=self._region, + host=self._host, + meta_data=self._meta(), + ) + + def _load_existing_collection_if_needed(self) -> None: + if self._collection is not None: + return + candidate = self._new_collection_handle() + meta = candidate.get_meta_data() or {} + if meta and meta.get("CollectionName"): + self._collection = candidate + + def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: + payload = dict(meta) + payload.update(self._meta()) + return get_or_create_volcengine_collection( + config=self._config(), + meta_data=payload, + ) + + def sanitize_scalar_index_fields( + self, + scalar_index_fields: list[str], + fields_meta: list[dict[str, Any]], + ) -> list[str]: + date_time_fields = { + field.get("FieldName") for field in fields_meta if field.get("FieldType") == "date_time" + } + return [field for field in scalar_index_fields if field not in date_time_fields] + + def build_default_index_meta( + self, + *, + index_name: str, + distance: str, + use_sparse: bool, + sparse_weight: float, + scalar_index_fields: list[str], + ) -> Dict[str, Any]: + index_type = "hnsw_hybrid" if use_sparse else "hnsw" + index_meta: Dict[str, Any] = { + "IndexName": index_name, + "VectorIndex": { + "IndexType": index_type, + "Distance": distance, + "Quant": "int8", + }, + "ScalarIndex": scalar_index_fields, + } + if use_sparse: + index_meta["VectorIndex"]["EnableSparse"] = True + index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight + return index_meta + + def normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: + for key in ("uri", "parent_uri"): + value = record.get(key) + if isinstance(value, str) and not value.startswith("viking://"): + stripped = value.strip("/") + if stripped: + record[key] = f"viking://{stripped}" + return record + + +class VikingDBPrivateCollectionAdapter(CollectionAdapter): + """Adapter for private VikingDB deployment.""" + + def __init__( + self, + *, + host: str, + headers: Optional[dict[str, str]], + project_name: str, + collection_name: str, + ): + super().__init__(collection_name=collection_name) + self.mode = "vikingdb" + self._host = host + self._headers = headers + self._project_name = project_name + + @classmethod + def from_config(cls, config): + if not config.vikingdb or not config.vikingdb.host: + raise ValueError("VikingDB backend requires a valid host") + return cls( + host=config.vikingdb.host, + headers=config.vikingdb.headers, + project_name=config.project_name or "default", + collection_name=config.name or "context", + ) + + def _client(self) -> VikingDBClient: + return VikingDBClient(self._host, self._headers) + + def _fetch_collection_meta(self) -> Optional[Dict[str, Any]]: + path, method = VIKINGDB_APIS["GetVikingdbCollection"] + req = { + "ProjectName": self._project_name, + "CollectionName": self._collection_name, + } + response = self._client().do_req(method, path=path, req_body=req) + if response.status_code != 200: + return None + result = response.json() + meta = result.get("Result", {}) + return meta or None + + def _load_existing_collection_if_needed(self) -> None: + if self._collection is not None: + return + meta = self._fetch_collection_meta() + if meta is None: + return + self._collection = Collection( + VikingDBCollection( + host=self._host, + headers=self._headers, + meta_data=meta, + ) + ) + + def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: + self._load_existing_collection_if_needed() + if self._collection is None: + raise NotImplementedError("private vikingdb collection should be pre-created") + return self._collection + + def sanitize_scalar_index_fields( + self, + scalar_index_fields: list[str], + fields_meta: list[dict[str, Any]], + ) -> list[str]: + date_time_fields = { + field.get("FieldName") for field in fields_meta if field.get("FieldType") == "date_time" + } + return [field for field in scalar_index_fields if field not in date_time_fields] + + def build_default_index_meta( + self, + *, + index_name: str, + distance: str, + use_sparse: bool, + sparse_weight: float, + scalar_index_fields: list[str], + ) -> Dict[str, Any]: + index_type = "hnsw_hybrid" if use_sparse else "hnsw" + index_meta: Dict[str, Any] = { + "IndexName": index_name, + "VectorIndex": { + "IndexType": index_type, + "Distance": distance, + "Quant": "int8", + }, + "ScalarIndex": scalar_index_fields, + } + if use_sparse: + index_meta["VectorIndex"]["EnableSparse"] = True + index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight + return index_meta + + def normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: + for key in ("uri", "parent_uri"): + value = record.get(key) + if isinstance(value, str) and not value.startswith("viking://"): + stripped = value.strip("/") + if stripped: + record[key] = f"viking://{stripped}" + return record + + +_ADAPTER_REGISTRY: dict[str, type[CollectionAdapter]] = { + "local": LocalCollectionAdapter, + "http": HttpCollectionAdapter, + "volcengine": VolcengineCollectionAdapter, + "vikingdb": VikingDBPrivateCollectionAdapter, +} + + +def create_collection_adapter(config) -> CollectionAdapter: + """Unified factory entrypoint for backend-specific collection adapters.""" + adapter_cls = _ADAPTER_REGISTRY.get(config.backend) + if adapter_cls is None: + raise ValueError( + f"Vector backend {config.backend} is not supported. " + f"Available backends: {sorted(_ADAPTER_REGISTRY)}" + ) + return adapter_cls.from_config(config) diff --git a/openviking/storage/collection_schemas.py b/openviking/storage/collection_schemas.py index f5086d1a..0cff8725 100644 --- a/openviking/storage/collection_schemas.py +++ b/openviking/storage/collection_schemas.py @@ -208,7 +208,7 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, id_seed = f"{account_id}:{owner_space}:{uri}" inserted_data["id"] = hashlib.md5(id_seed.encode("utf-8")).hexdigest() - record_id = await self._vikingdb.insert(inserted_data) + record_id = await self._vikingdb.upsert(inserted_data) if record_id: logger.debug( f"Successfully wrote embedding to database: {record_id} abstract {inserted_data['abstract']} vector {inserted_data['vector'][:5]}" diff --git a/openviking/storage/observers/vikingdb_observer.py b/openviking/storage/observers/vikingdb_observer.py index 8f2fca3b..9e6eeab7 100644 --- a/openviking/storage/observers/vikingdb_observer.py +++ b/openviking/storage/observers/vikingdb_observer.py @@ -30,12 +30,10 @@ async def get_status_table_async(self) -> str: if not self._vikingdb_manager: return "VikingDB manager not initialized." - collection_names = await self._vikingdb_manager.list_collections() - - if not collection_names: + if not await self._vikingdb_manager.collection_exists(): return "No collections found." - statuses = await self._get_collection_statuses(collection_names) + statuses = await self._get_collection_statuses([self._vikingdb_manager.collection_name]) return self._format_status_as_table(statuses) def get_status_table(self) -> str: @@ -49,7 +47,7 @@ async def _get_collection_statuses(self, collection_names: list) -> Dict[str, Di for name in collection_names: try: - if not await self._vikingdb_manager.collection_exists(name): + if not await self._vikingdb_manager.collection_exists(): continue # Current OpenViking flow uses one managed default index per collection. diff --git a/openviking/storage/vector_store/__init__.py b/openviking/storage/vector_store/__init__.py index 8820f2ac..05ed11f9 100644 --- a/openviking/storage/vector_store/__init__.py +++ b/openviking/storage/vector_store/__init__.py @@ -1,8 +1,7 @@ # Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. # SPDX-License-Identifier: Apache-2.0 -"""Vector store driver architecture.""" +"""Vector store filter expression types.""" -from openviking.storage.vector_store.driver import VectorStoreDriver from openviking.storage.vector_store.expr import ( And, Contains, @@ -10,34 +9,19 @@ FilterExpr, In, Or, - Prefix, Range, RawDSL, - Regex, TimeRange, ) -from openviking.storage.vector_store.factory import create_driver -from openviking.storage.vector_store.registry import ( - get_driver_class, - list_registered_backends, - register_driver, -) __all__ = [ - "VectorStoreDriver", "FilterExpr", "And", "Or", "Eq", "In", - "Prefix", "Range", "Contains", - "Regex", "TimeRange", "RawDSL", - "create_driver", - "register_driver", - "get_driver_class", - "list_registered_backends", ] diff --git a/openviking/storage/vector_store/driver.py b/openviking/storage/vector_store/driver.py deleted file mode 100644 index 97f30cbf..00000000 --- a/openviking/storage/vector_store/driver.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Base driver contracts for vector store backends.""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Any, Dict - -from openviking.storage.vector_store.expr import ( - And, - Contains, - Eq, - FilterExpr, - In, - Or, - Prefix, - Range, - RawDSL, - Regex, - TimeRange, -) - - -class VectorStoreDriver(ABC): - """Backend-specific adapter for collection operations + filter AST compilation.""" - - mode: str - - @classmethod - @abstractmethod - def from_config(cls, config: Any) -> "VectorStoreDriver": - """Create a driver instance from VectorDB backend config.""" - - @abstractmethod - def has_collection(self, name: str) -> bool: - """Return whether collection exists.""" - - @abstractmethod - def get_collection(self, name: str) -> Any: - """Return backend collection handle.""" - - @abstractmethod - def create_collection(self, name: str, meta: Dict[str, Any]) -> Any: - """Create a collection and return backend collection handle.""" - - @abstractmethod - def drop_collection(self, name: str) -> None: - """Drop collection.""" - - @abstractmethod - def list_collections(self) -> list[str]: - """List all collections.""" - - def close(self) -> None: - """Release backend resources.""" - - def sanitize_scalar_index_fields( - self, - scalar_index_fields: list[str], - fields_meta: list[dict[str, Any]], - ) -> list[str]: - """Normalize scalar index fields for backend-specific constraints.""" - return scalar_index_fields - - def build_default_index_meta( - self, - *, - index_name: str, - distance: str, - use_sparse: bool, - sparse_weight: float, - scalar_index_fields: list[str], - ) -> Dict[str, Any]: - """Build default index meta payload for this backend.""" - index_type = "flat_hybrid" if use_sparse else "flat" - index_meta: Dict[str, Any] = { - "IndexName": index_name, - "VectorIndex": { - "IndexType": index_type, - "Distance": distance, - "Quant": "int8", - }, - "ScalarIndex": scalar_index_fields, - } - if use_sparse: - index_meta["VectorIndex"]["EnableSparse"] = True - index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight - return index_meta - - def normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: - """Normalize record fields after reading from backend.""" - return record - - def compile_expr(self, expr: FilterExpr | None) -> Dict[str, Any]: - """Compile a filter AST node to vectordb DSL.""" - if expr is None: - return {} - - if isinstance(expr, RawDSL): - return expr.payload - - if isinstance(expr, And): - conds = [self.compile_expr(c) for c in expr.conds if c is not None] - conds = [c for c in conds if c] - if not conds: - return {} - if len(conds) == 1: - return conds[0] - return {"op": "and", "conds": conds} - - if isinstance(expr, Or): - conds = [self.compile_expr(c) for c in expr.conds if c is not None] - conds = [c for c in conds if c] - if not conds: - return {} - if len(conds) == 1: - return conds[0] - return {"op": "or", "conds": conds} - - if isinstance(expr, Eq): - return {"op": "must", "field": expr.field, "conds": [expr.value]} - - if isinstance(expr, In): - return {"op": "must", "field": expr.field, "conds": list(expr.values)} - - if isinstance(expr, Prefix): - # For path fields the current vectordb implementation uses `must` semantics. - return {"op": "must", "field": expr.field, "conds": [expr.prefix]} - - if isinstance(expr, Range): - payload: Dict[str, Any] = {"op": "range", "field": expr.field} - if expr.gte is not None: - payload["gte"] = expr.gte - if expr.gt is not None: - payload["gt"] = expr.gt - if expr.lte is not None: - payload["lte"] = expr.lte - if expr.lt is not None: - payload["lt"] = expr.lt - return payload - - if isinstance(expr, Contains): - return { - "op": "contains", - "field": expr.field, - "substring": expr.substring, - } - - if isinstance(expr, Regex): - return {"op": "regex", "field": expr.field, "pattern": expr.pattern} - - if isinstance(expr, TimeRange): - payload: Dict[str, Any] = {"op": "range", "field": expr.field} - if expr.start is not None: - payload["gte"] = expr.start - if expr.end is not None: - payload["lt"] = expr.end - return payload - - raise TypeError(f"Unsupported filter expr type: {type(expr)!r}") diff --git a/openviking/storage/vector_store/drivers/__init__.py b/openviking/storage/vector_store/drivers/__init__.py deleted file mode 100644 index 7b02773b..00000000 --- a/openviking/storage/vector_store/drivers/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Driver module imports for static registration side effects.""" - -from openviking.storage.vector_store.drivers.http_driver import HttpVectorDriver -from openviking.storage.vector_store.drivers.local_driver import LocalVectorDriver -from openviking.storage.vector_store.drivers.vikingdb_driver import VikingDBPrivateDriver -from openviking.storage.vector_store.drivers.volcengine_driver import VolcengineVectorDriver - -__all__ = [ - "LocalVectorDriver", - "HttpVectorDriver", - "VolcengineVectorDriver", - "VikingDBPrivateDriver", -] diff --git a/openviking/storage/vector_store/drivers/common.py b/openviking/storage/vector_store/drivers/common.py deleted file mode 100644 index 00856d2b..00000000 --- a/openviking/storage/vector_store/drivers/common.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Shared helpers for vector backend drivers.""" - -from __future__ import annotations - -from typing import Any, Iterable -from urllib.parse import urlparse - - -def parse_url(url: str) -> tuple[str, int]: - """Parse backend URL to host/port pair.""" - normalized = url - if not normalized.startswith(("http://", "https://")): - normalized = f"http://{normalized}" - - parsed = urlparse(normalized) - host = parsed.hostname or "127.0.0.1" - port = parsed.port or 5000 - return host, port - - -def normalize_collection_names(raw_collections: Iterable[Any]) -> list[str]: - """Normalize collection listing results to plain collection-name strings.""" - names: list[str] = [] - for item in raw_collections: - if isinstance(item, str): - names.append(item) - elif isinstance(item, dict): - name = item.get("CollectionName") or item.get("collection_name") or item.get("name") - if isinstance(name, str): - names.append(name) - return names diff --git a/openviking/storage/vector_store/drivers/http_driver.py b/openviking/storage/vector_store/drivers/http_driver.py deleted file mode 100644 index 18002785..00000000 --- a/openviking/storage/vector_store/drivers/http_driver.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Remote HTTP vector backend driver.""" - -from __future__ import annotations - -from openviking.storage.vector_store.driver import VectorStoreDriver -from openviking.storage.vector_store.drivers.common import normalize_collection_names, parse_url -from openviking.storage.vector_store.registry import register_driver -from openviking.storage.vectordb.collection.collection import Collection -from openviking.storage.vectordb.collection.http_collection import ( - HttpCollection, - get_or_create_http_collection, - list_vikingdb_collections, -) - - -@register_driver("http") -class HttpVectorDriver(VectorStoreDriver): - """Driver for remote HTTP vectordb project.""" - - def __init__(self, host: str, port: int, project_name: str, collection_name: str): - self.mode = "http" - self._host = host - self._port = port - self._project_name = project_name - self._collection_name = collection_name - self._collection = None - - @classmethod - def from_config(cls, config): - if not config.url: - raise ValueError("HTTP backend requires a valid URL") - - host, port = parse_url(config.url) - collection_name = config.name or "context" - project_name = config.project_name or "default" - return cls( - host=host, - port=port, - project_name=project_name, - collection_name=collection_name, - ) - - def _match(self, name: str) -> bool: - return name == self._collection_name - - def _meta(self) -> dict: - return { - "ProjectName": self._project_name, - "CollectionName": self._collection_name, - } - - def _remote_has_collection(self) -> bool: - raw = list_vikingdb_collections( - host=self._host, - port=self._port, - project_name=self._project_name, - ) - names = normalize_collection_names(raw) - return self._collection_name in names - - def _ensure_collection_handle(self) -> None: - if self._collection is not None: - return - if not self._remote_has_collection(): - return - self._collection = Collection( - HttpCollection( - ip=self._host, - port=self._port, - meta_data=self._meta(), - ) - ) - - def has_collection(self, name: str) -> bool: - if not self._match(name): - return False - exists = self._remote_has_collection() - if exists: - self._ensure_collection_handle() - return exists - - def get_collection(self, name: str): - if not self._match(name): - return None - self._ensure_collection_handle() - return self._collection - - def create_collection(self, name: str, meta): - if not self._match(name): - raise ValueError( - f"http backend is bound to collection '{self._collection_name}', got '{name}'" - ) - payload = dict(meta) - payload.update(self._meta()) - self._collection = get_or_create_http_collection( - host=self._host, - port=self._port, - meta_data=payload, - ) - return self._collection - - def drop_collection(self, name: str) -> None: - if not self._match(name): - return - coll = self.get_collection(name) - if coll is None: - return - coll.drop() - self._collection = None - - def list_collections(self) -> list[str]: - return [self._collection_name] if self.has_collection(self._collection_name) else [] - - def close(self) -> None: - if self._collection is not None: - self._collection.close() - self._collection = None diff --git a/openviking/storage/vector_store/drivers/local_driver.py b/openviking/storage/vector_store/drivers/local_driver.py deleted file mode 100644 index c7ce1975..00000000 --- a/openviking/storage/vector_store/drivers/local_driver.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Local persistent vector backend driver.""" - -from __future__ import annotations - -import os -from pathlib import Path - -from openviking.storage.vector_store.driver import VectorStoreDriver -from openviking.storage.vector_store.registry import register_driver -from openviking.storage.vectordb.collection.local_collection import ( - get_or_create_local_collection, -) - - -@register_driver("local") -class LocalVectorDriver(VectorStoreDriver): - """Driver for local embedded vectordb backend.""" - - DEFAULT_LOCAL_PROJECT_NAME = "vectordb" - - def __init__(self, collection_name: str, collection_path: str): - self.mode = "local" - self._collection_name = collection_name - self._collection_path = collection_path - self._collection = None - - @classmethod - def from_config(cls, config): - collection_name = config.name or "context" - if config.path: - project_path = Path(config.path) / cls.DEFAULT_LOCAL_PROJECT_NAME - collection_path = str(project_path / collection_name) - else: - collection_path = "" - return cls(collection_name=collection_name, collection_path=collection_path) - - def _match(self, name: str) -> bool: - return name == self._collection_name - - def _load_existing_collection_if_needed(self) -> None: - if self._collection is not None: - return - if not self._collection_path: - return - meta_path = os.path.join(self._collection_path, "collection_meta.json") - if os.path.exists(meta_path): - self._collection = get_or_create_local_collection(path=self._collection_path) - - def has_collection(self, name: str) -> bool: - if not self._match(name): - return False - self._load_existing_collection_if_needed() - return self._collection is not None - - def get_collection(self, name: str): - if not self._match(name): - return None - self._load_existing_collection_if_needed() - return self._collection - - def create_collection(self, name: str, meta): - if not self._match(name): - raise ValueError( - f"local backend is bound to collection '{self._collection_name}', got '{name}'" - ) - if self._collection is not None: - return self._collection - if self._collection_path: - os.makedirs(self._collection_path, exist_ok=True) - self._collection = get_or_create_local_collection( - meta_data=meta, - path=self._collection_path, - ) - return self._collection - - def drop_collection(self, name: str) -> None: - if not self.has_collection(name): - return - assert self._collection is not None - self._collection.drop() - self._collection = None - - def list_collections(self) -> list[str]: - return [self._collection_name] if self.has_collection(self._collection_name) else [] - - def close(self) -> None: - if self._collection is not None: - self._collection.close() - self._collection = None diff --git a/openviking/storage/vector_store/drivers/vikingdb_driver.py b/openviking/storage/vector_store/drivers/vikingdb_driver.py deleted file mode 100644 index c09fad69..00000000 --- a/openviking/storage/vector_store/drivers/vikingdb_driver.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Private VikingDB deployment backend driver.""" - -from __future__ import annotations - -from openviking.storage.vector_store.driver import VectorStoreDriver -from openviking.storage.vector_store.drivers.common import normalize_collection_names -from openviking.storage.vector_store.registry import register_driver -from openviking.storage.vectordb.collection.collection import Collection -from openviking.storage.vectordb.collection.vikingdb_clients import ( - VIKINGDB_APIS, - VikingDBClient, -) -from openviking.storage.vectordb.collection.vikingdb_collection import VikingDBCollection - - -@register_driver("vikingdb") -class VikingDBPrivateDriver(VectorStoreDriver): - """Driver for private VikingDB deployment.""" - - def __init__( - self, - *, - host: str, - headers: dict | None, - project_name: str, - collection_name: str, - ): - self.mode = "vikingdb" - self._host = host - self._headers = headers - self._project_name = project_name - self._collection_name = collection_name - self._collection = None - - @classmethod - def from_config(cls, config): - if not config.vikingdb or not config.vikingdb.host: - raise ValueError("VikingDB backend requires a valid host") - - return cls( - host=config.vikingdb.host, - headers=config.vikingdb.headers, - project_name=config.project_name or "default", - collection_name=config.name or "context", - ) - - def _match(self, name: str) -> bool: - return name == self._collection_name - - def _client(self) -> VikingDBClient: - return VikingDBClient(self._host, self._headers) - - def _fetch_collection_meta(self): - path, method = VIKINGDB_APIS["GetVikingdbCollection"] - req = { - "ProjectName": self._project_name, - "CollectionName": self._collection_name, - } - response = self._client().do_req(method, path=path, req_body=req) - if response.status_code != 200: - return None - result = response.json() - meta = result.get("Result", {}) - return meta or None - - def has_collection(self, name: str) -> bool: - if not self._match(name): - return False - return self._fetch_collection_meta() is not None - - def get_collection(self, name: str): - if not self._match(name): - return None - if self._collection is not None: - return self._collection - meta = self._fetch_collection_meta() - if meta is None: - return None - self._collection = Collection( - VikingDBCollection( - host=self._host, - headers=self._headers, - meta_data=meta, - ) - ) - return self._collection - - def create_collection(self, name: str, meta): - raise NotImplementedError("private vikingdb collection should be pre-created") - - def drop_collection(self, name: str) -> None: - if not self._match(name): - return - coll = self.get_collection(name) - if coll is None: - return - coll.drop() - self._collection = None - - def list_collections(self) -> list[str]: - path, method = VIKINGDB_APIS["ListVikingdbCollection"] - req = {"ProjectName": self._project_name} - response = self._client().do_req(method, path=path, req_body=req) - if response.status_code != 200: - return [] - result = response.json() - raw = result.get("Result", {}).get("Collections", []) - names = normalize_collection_names(raw) - return [n for n in names if n == self._collection_name] - - def close(self) -> None: - if self._collection is not None: - self._collection.close() - self._collection = None - - def sanitize_scalar_index_fields( - self, - scalar_index_fields: list[str], - fields_meta: list[dict], - ) -> list[str]: - date_time_fields = { - field.get("FieldName") for field in fields_meta if field.get("FieldType") == "date_time" - } - return [field for field in scalar_index_fields if field not in date_time_fields] - - def build_default_index_meta( - self, - *, - index_name: str, - distance: str, - use_sparse: bool, - sparse_weight: float, - scalar_index_fields: list[str], - ) -> dict: - index_type = "hnsw_hybrid" if use_sparse else "hnsw" - index_meta = { - "IndexName": index_name, - "VectorIndex": { - "IndexType": index_type, - "Distance": distance, - "Quant": "int8", - }, - "ScalarIndex": scalar_index_fields, - } - if use_sparse: - index_meta["VectorIndex"]["EnableSparse"] = True - index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight - return index_meta - - def normalize_record_for_read(self, record: dict) -> dict: - for key in ("uri", "parent_uri"): - value = record.get(key) - if isinstance(value, str) and not value.startswith("viking://"): - stripped = value.strip("/") - if stripped: - record[key] = f"viking://{stripped}" - return record diff --git a/openviking/storage/vector_store/drivers/volcengine_driver.py b/openviking/storage/vector_store/drivers/volcengine_driver.py deleted file mode 100644 index d42f849c..00000000 --- a/openviking/storage/vector_store/drivers/volcengine_driver.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Volcengine VikingDB backend driver.""" - -from __future__ import annotations - -from openviking.storage.vector_store.driver import VectorStoreDriver -from openviking.storage.vector_store.registry import register_driver -from openviking.storage.vectordb.collection.volcengine_collection import ( - VolcengineCollection, - get_or_create_volcengine_collection, -) - - -@register_driver("volcengine") -class VolcengineVectorDriver(VectorStoreDriver): - """Driver for Volcengine-hosted VikingDB.""" - - def __init__( - self, - *, - ak: str, - sk: str, - region: str, - host: str, - project_name: str, - collection_name: str, - ): - self.mode = "volcengine" - self._ak = ak - self._sk = sk - self._region = region - self._host = host - self._project_name = project_name - self._collection_name = collection_name - self._collection = None - - @classmethod - def from_config(cls, config): - if not ( - config.volcengine - and config.volcengine.ak - and config.volcengine.sk - and config.volcengine.region - ): - raise ValueError("Volcengine backend requires AK, SK, and Region configuration") - - return cls( - ak=config.volcengine.ak, - sk=config.volcengine.sk, - region=config.volcengine.region, - host=config.volcengine.host or "", - project_name=config.project_name or "default", - collection_name=config.name or "context", - ) - - def _match(self, name: str) -> bool: - return name == self._collection_name - - def _meta(self) -> dict: - return { - "ProjectName": self._project_name, - "CollectionName": self._collection_name, - } - - def _config(self) -> dict: - return { - "AK": self._ak, - "SK": self._sk, - "Region": self._region, - "Host": self._host, - } - - def _new_collection_handle(self) -> VolcengineCollection: - return VolcengineCollection( - ak=self._ak, - sk=self._sk, - region=self._region, - host=self._host, - meta_data=self._meta(), - ) - - def has_collection(self, name: str) -> bool: - if not self._match(name): - return False - candidate = self._collection or self._new_collection_handle() - meta = candidate.get_meta_data() or {} - exists = bool(meta and meta.get("CollectionName")) - if exists and self._collection is None: - self._collection = candidate - return exists - - def get_collection(self, name: str): - if not self._match(name): - return None - if self._collection is not None: - return self._collection - if self.has_collection(name): - return self._collection - return None - - def create_collection(self, name: str, meta): - if not self._match(name): - raise ValueError( - f"volcengine backend is bound to collection '{self._collection_name}', got '{name}'" - ) - payload = dict(meta) - payload.update(self._meta()) - self._collection = get_or_create_volcengine_collection( - config=self._config(), - meta_data=payload, - ) - return self._collection - - def drop_collection(self, name: str) -> None: - if not self._match(name): - return - coll = self.get_collection(name) - if coll is None: - return - coll.drop() - self._collection = None - - def list_collections(self) -> list[str]: - return [self._collection_name] if self.has_collection(self._collection_name) else [] - - def close(self) -> None: - if self._collection is not None: - self._collection.close() - self._collection = None - - def sanitize_scalar_index_fields( - self, - scalar_index_fields: list[str], - fields_meta: list[dict], - ) -> list[str]: - date_time_fields = { - field.get("FieldName") for field in fields_meta if field.get("FieldType") == "date_time" - } - return [field for field in scalar_index_fields if field not in date_time_fields] - - def build_default_index_meta( - self, - *, - index_name: str, - distance: str, - use_sparse: bool, - sparse_weight: float, - scalar_index_fields: list[str], - ) -> dict: - index_type = "hnsw_hybrid" if use_sparse else "hnsw" - index_meta = { - "IndexName": index_name, - "VectorIndex": { - "IndexType": index_type, - "Distance": distance, - "Quant": "int8", - }, - "ScalarIndex": scalar_index_fields, - } - if use_sparse: - index_meta["VectorIndex"]["EnableSparse"] = True - index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight - return index_meta - - def normalize_record_for_read(self, record: dict) -> dict: - for key in ("uri", "parent_uri"): - value = record.get(key) - if isinstance(value, str) and not value.startswith("viking://"): - stripped = value.strip("/") - if stripped: - record[key] = f"viking://{stripped}" - return record diff --git a/openviking/storage/vector_store/expr.py b/openviking/storage/vector_store/expr.py index 9966213a..e19519bd 100644 --- a/openviking/storage/vector_store/expr.py +++ b/openviking/storage/vector_store/expr.py @@ -31,12 +31,6 @@ class In: values: List[Any] -@dataclass(frozen=True) -class Prefix: - field: str - prefix: str - - @dataclass(frozen=True) class Range: field: str @@ -52,12 +46,6 @@ class Contains: substring: str -@dataclass(frozen=True) -class Regex: - field: str - pattern: str - - @dataclass(frozen=True) class TimeRange: field: str @@ -70,4 +58,4 @@ class RawDSL: payload: Dict[str, Any] -FilterExpr = Union[And, Or, Eq, In, Prefix, Range, Contains, Regex, TimeRange, RawDSL] +FilterExpr = Union[And, Or, Eq, In, Range, Contains, TimeRange, RawDSL] diff --git a/openviking/storage/vector_store/factory.py b/openviking/storage/vector_store/factory.py deleted file mode 100644 index fb670ff0..00000000 --- a/openviking/storage/vector_store/factory.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Factory for vector backend drivers.""" - -from __future__ import annotations - -from openviking.storage.vector_store.driver import VectorStoreDriver -from openviking.storage.vector_store.registry import get_driver_class - - -def create_driver(config) -> VectorStoreDriver: - """Create backend driver from `VectorDBBackendConfig` without backend if/else.""" - # Ensure all static registrations are loaded. - import openviking.storage.vector_store.drivers # noqa: F401 - - driver_cls = get_driver_class(config.backend) - return driver_cls.from_config(config) diff --git a/openviking/storage/vector_store/registry.py b/openviking/storage/vector_store/registry.py deleted file mode 100644 index a80952cc..00000000 --- a/openviking/storage/vector_store/registry.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Static registry for vector backend drivers.""" - -from __future__ import annotations - -from typing import Callable, Dict, Type - -from openviking.storage.vector_store.driver import VectorStoreDriver - -_DRIVER_REGISTRY: Dict[str, Type[VectorStoreDriver]] = {} - - -def register_driver(name: str) -> Callable[[Type[VectorStoreDriver]], Type[VectorStoreDriver]]: - """Register a vector backend driver class by backend name.""" - - def decorator(cls: Type[VectorStoreDriver]) -> Type[VectorStoreDriver]: - _DRIVER_REGISTRY[name] = cls - return cls - - return decorator - - -def get_driver_class(name: str) -> Type[VectorStoreDriver]: - """Resolve registered driver class for backend name.""" - if name not in _DRIVER_REGISTRY: - raise ValueError( - f"Vector backend {name} is not registered. " - f"Available backends: {sorted(_DRIVER_REGISTRY)}" - ) - return _DRIVER_REGISTRY[name] - - -def list_registered_backends() -> list[str]: - return sorted(_DRIVER_REGISTRY) diff --git a/openviking/storage/viking_vector_index_backend.py b/openviking/storage/viking_vector_index_backend.py index 3a1fc1db..18044cec 100644 --- a/openviking/storage/viking_vector_index_backend.py +++ b/openviking/storage/viking_vector_index_backend.py @@ -1,20 +1,16 @@ # Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. # SPDX-License-Identifier: Apache-2.0 -""" -VikingDB storage backend for OpenViking. +"""VikingDB storage backend for OpenViking.""" -Supports both in-memory and local persistent storage modes. -""" +from __future__ import annotations import uuid from typing import Any, Dict, List, Optional from openviking.server.identity import RequestContext, Role -from openviking.storage.errors import CollectionNotFoundError -from openviking.storage.vector_store import FilterExpr, create_driver -from openviking.storage.vector_store.expr import And, Eq, In, Or, RawDSL +from openviking.storage.collection_adapter import CollectionAdapter, create_collection_adapter +from openviking.storage.vector_store.expr import And, Eq, FilterExpr, In, Or, RawDSL from openviking.storage.vectordb.collection.collection import Collection -from openviking.storage.vectordb.collection.result import FetchDataInCollectionResult from openviking.storage.vectordb.utils.logging_init import init_cpp_logging from openviking_cli.utils import get_logger from openviking_cli.utils.config.vectordb_config import VectorDBBackendConfig @@ -23,58 +19,12 @@ class VikingVectorIndexBackend: - """ - VikingDB storage backend implementation. + """Single-collection vector backend with adapter-based backend specialization.""" - Features: - - Vector similarity search with BruteForce indexing - - Scalar filtering with support for multiple operators - - Support for local persistent storage, HTTP service, and Volcengine VikingDB - - Auto-managed indexes per collection - - VikingDBManager is derived by VikingVectorIndexBackend. - """ - - # Default index name DEFAULT_INDEX_NAME = "default" + ALLOWED_CONTEXT_TYPES = {"resource", "skill", "memory"} - def __init__( - self, - config: Optional[VectorDBBackendConfig], - ): - """ - Initialize VikingDB backend. - - Args: - config: Configuration object for VectorDB backend. - - Examples: - # 1. Local persistent storage - config = VectorDBBackendConfig( - backend="local", - path="./data/vectordb" - ) - backend = VikingVectorIndexBackend(config=config) - - # 2. Remote HTTP service - config = VectorDBBackendConfig( - backend="http", - url="http://localhost:5000" - ) - backend = VikingVectorIndexBackend(config=config) - - # 3. Volcengine VikingDB - from openviking_cli.utils.config.storage_config import VolcengineConfig - config = VectorDBBackendConfig( - backend="volcengine", - volcengine=VolcengineConfig( - ak="your-ak", - sk="your-sk", - region="cn-beijing" - ) - ) - backend = VikingVectorIndexBackend(config=config) - """ + def __init__(self, config: Optional[VectorDBBackendConfig]): if config is None: raise ValueError("VectorDB backend config is required") @@ -85,595 +35,268 @@ def __init__( self.sparse_weight = config.sparse_weight self._collection_name = config.name or "context" - # Backend selection is delegated to static driver registry. - self._driver = create_driver(config) - self._mode = self._driver.mode + self._adapter: CollectionAdapter = create_collection_adapter(config) + self._mode = self._adapter.mode logger.info( - "VikingDB backend initialized via driver '%s' (mode=%s)", - type(self._driver).__name__, + "VikingDB backend initialized via adapter %s (mode=%s)", + type(self._adapter).__name__, self._mode, ) - self._collection_configs: Dict[str, Dict[str, Any]] = {} - # Cache meta_data at collection level to avoid repeated remote calls - self._meta_data_cache: Dict[str, Dict[str, Any]] = {} - - def _compile_filter(self, filter_expr: Optional[FilterExpr | Dict[str, Any]]) -> Dict[str, Any]: - """Compile AST filters via driver; allow raw DSL passthrough.""" - if filter_expr is None: - return {} - if isinstance(filter_expr, dict): - return filter_expr - if isinstance(filter_expr, RawDSL): - return filter_expr.payload - return self._driver.compile_expr(filter_expr) - - def _get_collection(self, name: str) -> Collection: - """Get collection object or raise error if not found.""" - if not self._driver.has_collection(name): - raise CollectionNotFoundError(f"Collection '{name}' does not exist") - return self._driver.get_collection(name) - - def _get_meta_data(self, collection_name: str, coll: Collection) -> Dict[str, Any]: - """Get meta_data with collection-level caching to avoid repeated remote calls.""" - if collection_name not in self._meta_data_cache: - self._meta_data_cache[collection_name] = coll.get_meta_data() - return self._meta_data_cache[collection_name] - - def _update_meta_data_cache(self, collection_name: str, coll: Collection): - """Update the cached meta_data after modifications.""" - meta_data = coll.get_meta_data() - self._meta_data_cache[collection_name] = meta_data + self._collection_config: Dict[str, Any] = {} + self._meta_data_cache: Dict[str, Any] = {} @property def collection_name(self) -> str: - """Return bound collection name for this store instance.""" return self._collection_name - def _resolve_collection_name(self, collection_name: Optional[str] = None) -> str: - """Resolve collection name with bound default.""" - return collection_name or self._collection_name + def _get_collection(self) -> Collection: + return self._adapter.get_collection() + + def _get_meta_data(self, coll: Collection) -> Dict[str, Any]: + if not self._meta_data_cache: + self._meta_data_cache = coll.get_meta_data() or {} + return self._meta_data_cache + + def _refresh_meta_data(self, coll: Collection) -> None: + self._meta_data_cache = coll.get_meta_data() or {} + + def _filter_known_fields(self, data: Dict[str, Any]) -> Dict[str, Any]: + try: + coll = self._get_collection() + fields = self._get_meta_data(coll).get("Fields", []) + allowed = {item.get("FieldName") for item in fields} + return {k: v for k, v in data.items() if k in allowed and v is not None} + except Exception: + return data # ========================================================================= - # Collection/Table Management + # Collection Management (single collection) # ========================================================================= async def create_collection(self, name: str, schema: Dict[str, Any]) -> bool: - """ - Create a new collection. - - Args: - name: Collection name - schema: VikingVectorIndex collection metadata in the format: - { - "CollectionName": "name", - "Description": "description", - "Fields": [ - {"FieldName": "id", "FieldType": "string", "IsPrimaryKey": True}, - {"FieldName": "vector", "FieldType": "vector", "Dim": 128}, - ... - ] - } - - Returns: - True if created successfully, False if already exists - """ try: - if self._driver.has_collection(name): - logger.debug(f"Collection '{name}' already exists") - return False - - collection_meta = schema.copy() + collection_meta = dict(schema) - scalar_index_fields = [] - if "ScalarIndex" in collection_meta: - scalar_index_fields = collection_meta.pop("ScalarIndex") - - # Ensure CollectionName is set - if "CollectionName" not in collection_meta: - collection_meta["CollectionName"] = name - - # Extract distance metric and vector_dim for config tracking - distance = self.distance_metric + # Track vector dim from schema for info. vector_dim = self.vector_dim for field in collection_meta.get("Fields", []): if field.get("FieldType") == "vector": vector_dim = field.get("Dim", self.vector_dim) break - logger.info(f"Creating collection mode={self._mode} with meta: {collection_meta}") - - # Create collection via backend-specific collection driver - collection = self._driver.create_collection(name, collection_meta) - - scalar_index_fields = self._driver.sanitize_scalar_index_fields( - scalar_index_fields=scalar_index_fields, - fields_meta=collection_meta.get("Fields", []), - ) - - # Create default index for the collection - use_sparse = self.sparse_weight > 0.0 - index_meta = self._driver.build_default_index_meta( - index_name=self.DEFAULT_INDEX_NAME, - distance=distance, - use_sparse=use_sparse, + created = self._adapter.create_collection( + name=name, + schema=collection_meta, + distance=self.distance_metric, sparse_weight=self.sparse_weight, - scalar_index_fields=scalar_index_fields, + index_name=self.DEFAULT_INDEX_NAME, ) + if not created: + return False - logger.info(f"Creating index with meta: {index_meta}") - collection.create_index(self.DEFAULT_INDEX_NAME, index_meta) - - # Update cached meta_data after creating index - self._update_meta_data_cache(name, collection) - - # Store collection config - self._collection_configs[name] = { + self._collection_name = name + self._collection_config = { "vector_dim": vector_dim, - "distance": distance, + "distance": self.distance_metric, "schema": schema, } - self._collection_name = name - - logger.info(f"Created VikingDB collection: {name} (dim={vector_dim})") + self._refresh_meta_data(self._get_collection()) + logger.info("Created VikingDB collection: %s (dim=%s)", name, vector_dim) return True - except Exception as e: - logger.error(f"Error creating collection '{name}': {e}") - import traceback - - traceback.print_exc() + logger.error("Error creating collection %s: %s", name, e) return False - async def drop_collection(self, name: str) -> bool: - """Drop a collection.""" + async def drop_collection(self) -> bool: try: - if not self._driver.has_collection(name): - logger.warning(f"Collection '{name}' does not exist") - return False - - self._driver.drop_collection(name) - self._collection_configs.pop(name, None) - # Clear cached meta_data when dropping collection - self._meta_data_cache.pop(name, None) - - logger.info(f"Dropped collection: {name}") - return True + dropped = self._adapter.drop_collection() + if dropped: + self._collection_config = {} + self._meta_data_cache = {} + return dropped except Exception as e: - logger.error(f"Error dropping collection '{name}': {e}") + logger.error("Error dropping collection %s: %s", self._collection_name, e) return False - async def collection_exists(self, name: Optional[str] = None) -> bool: - """Check if a collection exists.""" - return self._driver.has_collection(self._resolve_collection_name(name)) - - async def list_collections(self) -> List[str]: - """List all collection names.""" - return self._driver.list_collections() - - async def get_collection_info(self, name: str) -> Optional[Dict[str, Any]]: - """Get collection metadata and statistics.""" - try: - if not self._driver.has_collection(name): - return None + async def collection_exists(self) -> bool: + return self._adapter.collection_exists() - config = self._collection_configs.get(name, {}) - - return { - "name": name, - "vector_dim": config.get("vector_dim", self.vector_dim), - "count": 0, # vectordb doesn't easily expose count - "status": "active", - } - except Exception as e: - logger.error(f"Error getting collection info for '{name}': {e}") + async def get_collection_info(self) -> Optional[Dict[str, Any]]: + if not await self.collection_exists(): return None + config = self._collection_config + return { + "name": self._collection_name, + "vector_dim": config.get("vector_dim", self.vector_dim), + "count": await self.count(), + "status": "active", + } async def collection_exists_bound(self) -> bool: - """Check whether the bound collection exists.""" - return await self.collection_exists(self._collection_name) + return await self.collection_exists() # ========================================================================= - # CRUD Operations - Single Record + # Data Operations # ========================================================================= - async def insert(self, data: Dict[str, Any]) -> str: - """Insert a single record into the bound collection.""" - coll = self._get_collection(self._collection_name) - - # Ensure ID exists - record_id = data.get("id") - if not record_id: - record_id = str(uuid.uuid4()) - data = {**data, "id": record_id} - - # Validate context_type for context collection - context_type = data.get("context_type") - if context_type not in ["resource", "skill", "memory"]: + async def upsert(self, data: Dict[str, Any]) -> str: + payload = dict(data) + context_type = payload.get("context_type") + if context_type and context_type not in self.ALLOWED_CONTEXT_TYPES: logger.warning( - f"Invalid context_type: {context_type}. " - f"Must be one of ['resource', 'skill', 'memory'], Ignore" + "Invalid context_type: %s. Must be one of %s", + context_type, + sorted(self.ALLOWED_CONTEXT_TYPES), ) return "" - fields = self._get_meta_data(self._collection_name, coll).get("Fields", []) - fields_dict = {item["FieldName"]: item for item in fields} - new_data = {} - for key in data: - if key in fields_dict and data[key] is not None: - new_data[key] = data[key] + if not payload.get("id"): + payload["id"] = str(uuid.uuid4()) - try: - coll.upsert_data([new_data]) - return record_id - except Exception as e: - logger.error(f"Error inserting record: {e}") - raise - - async def update(self, id: str, data: Dict[str, Any]) -> bool: - """Update a record by ID in the bound collection.""" - coll = self._get_collection(self._collection_name) - - try: - # Fetch existing record - existing = await self.get([id]) - if not existing: - return False - - # Merge data with existing record - updated_data = {**existing[0], **data} - updated_data["id"] = id - - # Upsert the updated record - coll.upsert_data([updated_data]) - return True - except Exception as e: - logger.error(f"Error updating record '{id}': {e}") - return False - - async def upsert(self, data: Dict[str, Any]) -> str: - """Insert or update a record in the bound collection.""" - coll = self._get_collection(self._collection_name) - - record_id = data.get("id") - if not record_id: - record_id = str(uuid.uuid4()) - data = {**data, "id": record_id} - - try: - coll.upsert_data([data]) - return record_id - except Exception as e: - logger.error(f"Error upserting record: {e}") - raise - - async def delete(self, ids: List[str]) -> int: - """Delete records by IDs from the bound collection.""" - coll = self._get_collection(self._collection_name) - - try: - coll.delete_data(ids) - return len(ids) - except Exception as e: - logger.error(f"Error deleting records: {e}") - return 0 + payload = self._filter_known_fields(payload) + ids = self._adapter.upsert(payload) + return ids[0] if ids else "" async def get(self, ids: List[str]) -> List[Dict[str, Any]]: - """Get records by IDs from the bound collection.""" - coll = self._get_collection(self._collection_name) - try: - result = coll.fetch_data(ids) - - if isinstance(result, FetchDataInCollectionResult): - records = [] - for item in result.items: - record = dict(item.fields) if item.fields else {} - record["id"] = item.id - self._driver.normalize_record_for_read(record) - records.append(record) - return records - elif isinstance(result, dict): - records = [] - if "fetch" in result: - for item in result.get("fetch", []): - record = dict(item.get("fields", {})) if item.get("fields") else {} - record["id"] = item.get("id") - if record["id"]: - self._driver.normalize_record_for_read(record) - records.append(record) - return records - else: - logger.warning(f"Unexpected return type from fetch_data: {type(result)}") - return [] + return self._adapter.get(ids) except Exception as e: - logger.error(f"Error getting records: {e}") + logger.error("Error getting records: %s", e) return [] - async def fetch_by_uri(self, uri: str) -> Optional[Dict[str, Any]]: - """Fetch a record by URI.""" - coll = self._get_collection(self._collection_name) + async def delete(self, ids: List[str]) -> int: try: - result = coll.search_by_random( - index_name=self.DEFAULT_INDEX_NAME, - limit=10, - filters={"op": "must", "field": "uri", "conds": [uri]}, - ) - records = [] - for item in result.data: - record = dict(item.fields) if item.fields else {} - record["id"] = item.id - self._driver.normalize_record_for_read(record) - records.append(record) - if len(records) > 1: - raise ValueError(f"Duplicate records found for URI: {uri}") - if len(records) == 0: - raise ValueError(f"Record not found for URI: {uri}") - return records[0] + return self._adapter.delete(ids=ids) except Exception as e: - logger.error(f"Error fetching record by URI '{uri}': {e}") - return None + logger.error("Error deleting records: %s", e) + return 0 async def exists(self, id: str) -> bool: - """Check if a record exists.""" try: - results = await self.get([id]) - return len(results) > 0 + return len(await self.get([id])) > 0 except Exception: return False - # ========================================================================= - # CRUD Operations - Batch - # ========================================================================= - - async def batch_insert(self, data: List[Dict[str, Any]]) -> List[str]: - """Batch insert multiple records into the bound collection.""" - coll = self._get_collection(self._collection_name) - - # Ensure all records have IDs - ids = [] - records_with_ids = [] - for record in data: - if "id" not in record: - record_id = str(uuid.uuid4()) - records_with_ids.append({**record, "id": record_id}) - ids.append(record_id) - else: - records_with_ids.append(record) - ids.append(record["id"]) - + async def fetch_by_uri(self, uri: str) -> Optional[Dict[str, Any]]: try: - coll.upsert_data(records_with_ids) - return ids + records = await self.query( + filter={"op": "must", "field": "uri", "conds": [uri]}, + limit=2, + ) + if len(records) == 1: + return records[0] + return None except Exception as e: - logger.error(f"Error batch inserting records: {e}") - raise - - async def batch_upsert(self, data: List[Dict[str, Any]]) -> List[str]: - """Batch insert or update multiple records in the bound collection.""" - coll = self._get_collection(self._collection_name) - - ids = [] - records_with_ids = [] - for record in data: - if "id" not in record: - record_id = str(uuid.uuid4()) - records_with_ids.append({**record, "id": record_id}) - ids.append(record_id) - else: - records_with_ids.append(record) - ids.append(record["id"]) + logger.error("Error fetching record by URI %s: %s", uri, e) + return None + async def query( + self, + query_vector: Optional[List[float]] = None, + sparse_query_vector: Optional[Dict[str, float]] = None, + filter: Optional[Dict[str, Any] | FilterExpr] = None, + limit: int = 10, + offset: int = 0, + output_fields: Optional[List[str]] = None, + with_vector: bool = False, + order_by: Optional[str] = None, + order_desc: bool = False, + ) -> List[Dict[str, Any]]: try: - coll.upsert_data(records_with_ids) - return ids + return self._adapter.query( + query_vector=query_vector, + sparse_query_vector=sparse_query_vector, + filter=filter, + limit=limit, + offset=offset, + output_fields=output_fields, + with_vector=with_vector, + order_by=order_by, + order_desc=order_desc, + ) except Exception as e: - logger.error(f"Error batch upserting records: {e}") - raise - - async def batch_delete(self, filter: Dict[str, Any] | FilterExpr) -> int: - """Delete records matching filter conditions.""" - try: - # First, find matching records - matching_records = await self.filter(filter, limit=10000) + logger.error("Error querying collection %s: %s", self._collection_name, e) + return [] - if not matching_records: - return 0 + async def search( + self, + query_vector: Optional[List[float]] = None, + sparse_query_vector: Optional[Dict[str, float]] = None, + filter: Optional[Dict[str, Any] | FilterExpr] = None, + limit: int = 10, + offset: int = 0, + output_fields: Optional[List[str]] = None, + with_vector: bool = False, + ) -> List[Dict[str, Any]]: + # Backward-compatible alias for internal call sites. + return await self.query( + query_vector=query_vector, + sparse_query_vector=sparse_query_vector, + filter=filter, + limit=limit, + offset=offset, + output_fields=output_fields, + with_vector=with_vector, + ) - # Extract IDs and delete - ids = [record["id"] for record in matching_records if "id" in record] - return await self.delete(ids) - except Exception as e: - logger.error(f"Error batch deleting records: {e}") - return 0 + async def filter( + self, + filter: Dict[str, Any] | FilterExpr, + limit: int = 10, + offset: int = 0, + output_fields: Optional[List[str]] = None, + order_by: Optional[str] = None, + order_desc: bool = False, + ) -> List[Dict[str, Any]]: + return await self.query( + filter=filter, + limit=limit, + offset=offset, + output_fields=output_fields, + order_by=order_by, + order_desc=order_desc, + ) async def remove_by_uri(self, uri: str) -> int: - """Remove resource(s) by URI.""" try: target_records = await self.filter( {"op": "must", "field": "uri", "conds": [uri]}, limit=10, ) - if not target_records: return 0 total_deleted = 0 - - # If any record indicates this URI is a directory node, remove descendants first. if any(r.get("level") in [0, 1] for r in target_records): - descendant_count = await self._remove_descendants(parent_uri=uri) - total_deleted += descendant_count + total_deleted += await self._remove_descendants(parent_uri=uri) ids = [r.get("id") for r in target_records if r.get("id")] if ids: total_deleted += await self.delete(ids) - - logger.info(f"Removed {total_deleted} record(s) for URI: {uri}") return total_deleted - except Exception as e: - logger.error(f"Error removing URI '{uri}': {e}") + logger.error("Error removing URI %s: %s", uri, e) return 0 async def _remove_descendants(self, parent_uri: str) -> int: - """Recursively remove all descendants of a parent URI.""" total_deleted = 0 - - # Find direct children children = await self.filter( {"op": "must", "field": "parent_uri", "conds": [parent_uri]}, - limit=10000, + limit=100000, ) - for child in children: child_uri = child.get("uri") level = child.get("level", 2) - - # Recursively delete if child is also an intermediate directory if level in [0, 1] and child_uri: - descendant_count = await self._remove_descendants(parent_uri=child_uri) - total_deleted += descendant_count - - # Delete the child - if "id" in child: - await self.delete([child["id"]]) + total_deleted += await self._remove_descendants(parent_uri=child_uri) + child_id = child.get("id") + if child_id: + await self.delete([child_id]) total_deleted += 1 - return total_deleted - # ========================================================================= - # Search Operations - # ========================================================================= - - async def search( - self, - query_vector: Optional[List[float]] = None, - sparse_query_vector: Optional[Dict[str, float]] = None, - filter: Optional[Dict[str, Any] | FilterExpr] = None, - limit: int = 10, - offset: int = 0, - output_fields: Optional[List[str]] = None, - with_vector: bool = False, - ) -> List[Dict[str, Any]]: - """Hybrid search: vector similarity (dense/sparse/hybrid) + scalar filtering. - - Args: - query_vector: Dense query vector (optional) - sparse_query_vector: Sparse query vector as {term: weight} dict (optional) - filter: Scalar filter conditions - limit: Maximum number of results - offset: Offset for pagination - output_fields: Fields to return - with_vector: Whether to include vector field in results - - Returns: - List of matching records with scores - """ - coll = self._get_collection(self._collection_name) - - try: - vectordb_filter = self._compile_filter(filter) - - if query_vector or sparse_query_vector: - # Vector search (dense, sparse, or hybrid) with optional filtering - result = coll.search_by_vector( - index_name=self.DEFAULT_INDEX_NAME, - dense_vector=query_vector, - sparse_vector=sparse_query_vector, - limit=limit, - offset=offset, - filters=vectordb_filter, - output_fields=output_fields, - ) - - # Convert results - records = [] - for item in result.data: - record = dict(item.fields) if item.fields else {} - record["id"] = item.id - record["_score"] = item.score if item.score is not None else 0.0 - self._driver.normalize_record_for_read(record) - - if not with_vector: - if "vector" in record: - record.pop("vector") - if "sparse_vector" in record: - record.pop("sparse_vector") - - records.append(record) - - return records - else: - # Pure filtering without vector search - return await self.filter( - filter or {}, - limit=limit, - offset=offset, - output_fields=output_fields, - ) - - except Exception as e: - logger.error(f"Error searching collection '{self._collection_name}': {e}") - import traceback - - traceback.print_exc() - return [] - - async def filter( - self, - filter: Dict[str, Any] | FilterExpr, - limit: int = 10, - offset: int = 0, - output_fields: Optional[List[str]] = None, - order_by: Optional[str] = None, - order_desc: bool = False, - ) -> List[Dict[str, Any]]: - """Pure scalar filtering without vector search.""" - coll = self._get_collection(self._collection_name) - - try: - vectordb_filter = self._compile_filter(filter) - - if order_by: - # Use search_by_scalar for sorting - result = coll.search_by_scalar( - index_name=self.DEFAULT_INDEX_NAME, - field=order_by, - order="desc" if order_desc else "asc", - limit=limit, - offset=offset, - filters=vectordb_filter, - output_fields=output_fields, - ) - else: - # Use search_by_random for pure filtering - result = coll.search_by_random( - index_name=self.DEFAULT_INDEX_NAME, - limit=limit, - offset=offset, - filters=vectordb_filter, - output_fields=output_fields, - ) - - # Convert results - records = [] - for item in result.data: - record = dict(item.fields) if item.fields else {} - record["id"] = item.id - self._driver.normalize_record_for_read(record) - records.append(record) - - return records - - except Exception as e: - logger.error(f"Error filtering collection '{self._collection_name}': {e}") - import traceback - - traceback.print_exc() - return [] - # ========================================================================= # Semantic Context Operations (Tenant-Aware) # ========================================================================= @@ -790,19 +413,13 @@ async def get_context_by_uri( owner_space: Optional[str] = None, limit: int = 1, ) -> List[Dict[str, Any]]: - conds: List[FilterExpr] = [ - Eq("uri", uri), - Eq("account_id", account_id), - ] + conds: List[FilterExpr] = [Eq("uri", uri), Eq("account_id", account_id)] if owner_space: conds.append(Eq("owner_space", owner_space)) - return await self.filter( - filter=And(conds), - limit=limit, - ) + return await self.filter(filter=And(conds), limit=limit) async def delete_account_data(self, account_id: str) -> int: - return await self.batch_delete(Eq("account_id", account_id)) + return self._adapter.delete(filter=Eq("account_id", account_id)) async def delete_uris(self, ctx: RequestContext, uris: List[str]) -> None: for uri in uris: @@ -817,7 +434,7 @@ async def delete_uris(self, ctx: RequestContext, uris: List[str]) -> None: else ctx.user.agent_space_name() ) conds.append(Eq("owner_space", owner_space)) - await self.batch_delete(And(conds)) + self._adapter.delete(filter=And(conds)) async def update_uri_mapping( self, @@ -832,7 +449,8 @@ async def update_uri_mapping( ) if not records or "id" not in records[0]: return False - return await self.update(records[0]["id"], {"uri": new_uri, "parent_uri": new_parent_uri}) + updated = {**records[0], "uri": new_uri, "parent_uri": new_parent_uri} + return bool(await self.upsert(updated)) async def increment_active_count(self, ctx: RequestContext, uris: List[str]) -> int: updated = 0 @@ -841,11 +459,9 @@ async def increment_active_count(self, ctx: RequestContext, uris: List[str]) -> if not records: continue record = records[0] - record_id = record.get("id") - if not record_id: - continue current = int(record.get("active_count", 0) or 0) - if await self.update(record_id, {"active_count": current + 1}): + record["active_count"] = current + 1 + if await self.upsert(record): updated += 1 return updated @@ -905,148 +521,62 @@ async def scroll( cursor: Optional[str] = None, output_fields: Optional[List[str]] = None, ) -> tuple[List[Dict[str, Any]], Optional[str]]: - """Scroll through large result sets efficiently.""" - # vectordb doesn't natively support scroll, so we simulate it offset = int(cursor) if cursor else 0 - records = await self.filter( filter=filter or {}, limit=limit, offset=offset, output_fields=output_fields, ) - - # Return next cursor if we got a full batch next_cursor = str(offset + limit) if len(records) == limit else None - return records, next_cursor - # ========================================================================= - # Aggregation Operations - # ========================================================================= - - async def count( - self, - filter: Optional[Dict[str, Any] | FilterExpr] = None, - ) -> int: - """Count records matching filter.""" + async def count(self, filter: Optional[Dict[str, Any] | FilterExpr] = None) -> int: try: - coll = self._get_collection(self._collection_name) - result = coll.aggregate_data( - index_name=self.DEFAULT_INDEX_NAME, - op="count", - filters=self._compile_filter(filter), - ) - return result.agg.get("_total", 0) + return self._adapter.count(filter=filter) except Exception as e: - logger.error(f"Error counting records: {e}") + logger.error("Error counting records: %s", e) return 0 - # ========================================================================= - # Index Operations - # ========================================================================= - - async def create_index( - self, - field: str, - index_type: str, - **kwargs, - ) -> bool: - """Create an index on a field.""" - try: - # vectordb manages indexes at collection level - # Indexes are already created with the collection - logger.info(f"Index creation requested for field '{field}' (managed by vectordb)") - return True - except Exception as e: - logger.error(f"Error creating index on '{field}': {e}") - return False - - async def drop_index(self, field: str) -> bool: - """Drop an index on a field.""" - try: - # vectordb manages indexes internally - logger.info(f"Index drop requested for field '{field}' (managed by vectordb)") - return True - except Exception as e: - logger.error(f"Error dropping index on '{field}': {e}") - return False - - # ========================================================================= - # Lifecycle Operations - # ========================================================================= - async def clear(self) -> bool: - """Clear all data in a collection.""" - coll = self._get_collection(self._collection_name) - try: - coll.delete_all_data() - logger.info(f"Cleared all data in collection: {self._collection_name}") - return True + return self._adapter.clear() except Exception as e: - logger.error(f"Error clearing collection: {e}") + logger.error("Error clearing collection: %s", e) return False async def optimize(self) -> bool: - """Optimize collection for better performance.""" - try: - # vectordb handles optimization internally via index rebuilding - logger.info("Optimization requested for collection: %s", self._collection_name) - return True - except Exception as e: - logger.error(f"Error optimizing collection: {e}") - return False + logger.info("Optimization requested for collection: %s", self._collection_name) + return True async def close(self) -> None: - """Close storage connection and release resources.""" try: - self._driver.close() - - self._collection_configs.clear() + self._adapter.close() + self._collection_config = {} + self._meta_data_cache = {} logger.info("VikingDB backend closed") except Exception as e: - logger.error(f"Error closing VikingDB backend: {e}") - - # ========================================================================= - # Health & Status - # ========================================================================= + logger.error("Error closing VikingDB backend: %s", e) async def health_check(self) -> bool: - """Check if storage backend is healthy and accessible.""" try: - # Simple check: verify we can access collections metadata. - self._driver.list_collections() + await self.collection_exists() return True except Exception: return False async def get_stats(self) -> Dict[str, Any]: - """Get storage statistics.""" try: - collections = self._driver.list_collections() - - # Count total records across all collections using aggregate_data - total_records = 0 - for collection_name in collections: - try: - coll = self._get_collection(collection_name) - result = coll.aggregate_data( - index_name=self.DEFAULT_INDEX_NAME, op="count", filters=None - ) - total_records += result.agg.get("_total", 0) - except Exception as e: - logger.warning(f"Error counting records in collection '{collection_name}': {e}") - continue - + exists = await self.collection_exists() + total_records = await self.count() if exists else 0 return { - "collections": len(collections), + "collections": 1 if exists else 0, "total_records": total_records, "backend": "vikingdb", "mode": self._mode, } except Exception as e: - logger.error(f"Error getting stats: {e}") + logger.error("Error getting stats: %s", e) return { "collections": 0, "total_records": 0, @@ -1056,5 +586,4 @@ async def get_stats(self) -> Dict[str, Any]: @property def mode(self) -> str: - """Return the current storage mode.""" return self._mode diff --git a/tests/retrieve/test_hierarchical_retriever_target_dirs.py b/tests/retrieve/test_hierarchical_retriever_target_dirs.py index ba54a24a..019328bb 100644 --- a/tests/retrieve/test_hierarchical_retriever_target_dirs.py +++ b/tests/retrieve/test_hierarchical_retriever_target_dirs.py @@ -7,7 +7,6 @@ from openviking.retrieve.hierarchical_retriever import HierarchicalRetriever from openviking.server.identity import RequestContext, Role -from openviking.storage.vector_store.expr import Prefix from openviking_cli.retrieve.types import ContextType, TypedQuery from openviking_cli.session.user_id import UserIdentifier @@ -16,49 +15,60 @@ class DummyStorage: """Minimal storage stub to capture search filters.""" def __init__(self) -> None: - self.search_calls = [] + self.collection_name = "context" + self.global_search_calls = [] + self.child_search_calls = [] - async def collection_exists(self, _name: str) -> bool: + async def collection_exists_bound(self) -> bool: return True - async def search( + async def search_global_roots_in_tenant( self, - collection: str, + ctx, query_vector=None, sparse_query_vector=None, - filter=None, + context_type=None, + target_directories=None, + extra_filter=None, limit: int = 10, - offset: int = 0, - output_fields=None, - with_vector: bool = False, ): - self.search_calls.append( + self.global_search_calls.append( { - "collection": collection, - "filter": filter, + "ctx": ctx, + "query_vector": query_vector, + "sparse_query_vector": sparse_query_vector, + "context_type": context_type, + "target_directories": target_directories, + "extra_filter": extra_filter, "limit": limit, - "offset": offset, } ) return [] - -def _contains_uri_scope_filter(obj, target_uri: str) -> bool: - if isinstance(obj, Prefix): - return obj.field == "uri" and obj.prefix == target_uri - if isinstance(obj, dict): - if ( - obj.get("op") == "must" - and obj.get("field") == "uri" - and target_uri in obj.get("conds", []) - ): - return True - return any(_contains_uri_scope_filter(v, target_uri) for v in obj.values()) - if isinstance(obj, list): - return any(_contains_uri_scope_filter(v, target_uri) for v in obj) - if hasattr(obj, "__dict__"): - return any(_contains_uri_scope_filter(v, target_uri) for v in vars(obj).values()) - return False + async def search_children_in_tenant( + self, + ctx, + parent_uri: str, + query_vector=None, + sparse_query_vector=None, + context_type=None, + target_directories=None, + extra_filter=None, + limit: int = 10, + ): + self.child_search_calls.append( + { + "ctx": ctx, + "parent_uri": parent_uri, + "query_vector": query_vector, + "sparse_query_vector": sparse_query_vector, + "context_type": context_type, + "target_directories": target_directories, + "extra_filter": extra_filter, + "limit": limit, + } + ) + return [] @pytest.mark.asyncio @@ -78,5 +88,8 @@ async def test_retrieve_honors_target_directories_scope_filter(): result = await retriever.retrieve(query, ctx=ctx, limit=3) assert result.searched_directories == [target_uri] - assert storage.search_calls - assert _contains_uri_scope_filter(storage.search_calls[0]["filter"], target_uri) + assert storage.global_search_calls + assert storage.global_search_calls[0]["target_directories"] == [target_uri] + assert storage.child_search_calls + assert storage.child_search_calls[0]["target_directories"] == [target_uri] + assert storage.child_search_calls[0]["parent_uri"] == target_uri diff --git a/tests/session/test_memory_dedup_actions.py b/tests/session/test_memory_dedup_actions.py index 6b94a674..8b6e886e 100644 --- a/tests/session/test_memory_dedup_actions.py +++ b/tests/session/test_memory_dedup_actions.py @@ -22,7 +22,6 @@ MemoryExtractor, MergedMemoryPayload, ) -from openviking.storage.vector_store.expr import And, Eq, Prefix from openviking_cli.session.user_id import UserIdentifier @@ -154,7 +153,7 @@ async def test_find_similar_memories_uses_path_must_filter_and__score(self): vikingdb = MagicMock() vikingdb.get_embedder.return_value = _DummyEmbedder() - vikingdb.search = AsyncMock( + vikingdb.search_similar_memories = AsyncMock( return_value=[ { "id": "uri_pref_hit", @@ -176,26 +175,19 @@ async def test_find_similar_memories_uses_path_must_filter_and__score(self): assert len(similar) == 1 assert similar[0].uri == existing.uri - call = vikingdb.search.await_args.kwargs - assert isinstance(call["filter"], And) - conds = call["filter"].conds - assert Eq("context_type", "memory") in conds - assert Eq("level", 2) in conds - assert Eq("account_id", "acc1") in conds - assert Eq("owner_space", _make_user().user_space_name()) in conds - assert ( - Prefix( - "uri", - f"viking://user/{_make_user().user_space_name()}/memories/preferences/", - ) - in conds + call = vikingdb.search_similar_memories.await_args.kwargs + assert call["account_id"] == "acc1" + assert call["owner_space"] == _make_user().user_space_name() + assert call["category_uri_prefix"] == ( + f"viking://user/{_make_user().user_space_name()}/memories/preferences/" ) + assert call["limit"] == 5 @pytest.mark.asyncio async def test_find_similar_memories_accepts_low_score_when_threshold_is_zero(self): vikingdb = MagicMock() vikingdb.get_embedder.return_value = _DummyEmbedder() - vikingdb.search = AsyncMock( + vikingdb.search_similar_memories = AsyncMock( return_value=[ { "id": "uri_low", @@ -380,7 +372,7 @@ async def test_create_with_empty_list_only_creates_new_memory(self): vikingdb = MagicMock() vikingdb.get_embedder.return_value = None - vikingdb.batch_delete = AsyncMock(return_value=1) + vikingdb.delete_uris = AsyncMock(return_value=None) vikingdb.enqueue_embedding_msg = AsyncMock() compressor = SessionCompressor(vikingdb=vikingdb) @@ -418,7 +410,7 @@ async def test_create_with_merge_is_executed_as_none(self): vikingdb = MagicMock() vikingdb.get_embedder.return_value = None - vikingdb.batch_delete = AsyncMock(return_value=1) + vikingdb.delete_uris = AsyncMock(return_value=None) vikingdb.enqueue_embedding_msg = AsyncMock() compressor = SessionCompressor(vikingdb=vikingdb) @@ -473,7 +465,7 @@ async def test_merge_bundle_failure_is_skipped_without_fallback(self): vikingdb = MagicMock() vikingdb.get_embedder.return_value = None - vikingdb.batch_delete = AsyncMock(return_value=1) + vikingdb.delete_uris = AsyncMock(return_value=None) vikingdb.enqueue_embedding_msg = AsyncMock() compressor = SessionCompressor(vikingdb=vikingdb) @@ -519,7 +511,7 @@ async def test_create_with_delete_runs_delete_before_create(self): vikingdb = MagicMock() vikingdb.get_embedder.return_value = None - vikingdb.batch_delete = AsyncMock(return_value=1) + vikingdb.delete_uris = AsyncMock(return_value=None) vikingdb.enqueue_embedding_msg = AsyncMock() compressor = SessionCompressor(vikingdb=vikingdb) @@ -563,4 +555,4 @@ async def _rm(*_args, **_kwargs): assert [m.uri for m in memories] == [new_memory.uri] assert call_order == ["delete", "create"] - vikingdb.batch_delete.assert_awaited_once() + vikingdb.delete_uris.assert_awaited_once_with(_make_ctx(), [target.uri]) From 9ff1cce680d6e2d79e524543e34bb221220b863b Mon Sep 17 00:00:00 2001 From: "zhoujiahui.01" Date: Fri, 27 Feb 2026 00:03:39 +0800 Subject: [PATCH 5/7] chore: align naming with vikingdb and rename session test --- .../vectordb-gateway-collection-refactor.md | 328 ------------------ openviking/core/directories.py | 3 +- openviking/session/session.py | 5 +- ...py => test_session_compressor_vikingdb.py} | 6 +- 4 files changed, 6 insertions(+), 336 deletions(-) delete mode 100644 docs/design/storage/vectordb-gateway-collection-refactor.md rename tests/session/{test_session_compressor_semantic_gateway.py => test_session_compressor_vikingdb.py} (81%) diff --git a/docs/design/storage/vectordb-gateway-collection-refactor.md b/docs/design/storage/vectordb-gateway-collection-refactor.md deleted file mode 100644 index f99398f5..00000000 --- a/docs/design/storage/vectordb-gateway-collection-refactor.md +++ /dev/null @@ -1,328 +0,0 @@ -# OpenViking 向量存储分层重构设计(Gateway / Collection / Filter) - -> 日期:2026-02-26 -> 状态:Draft(可直接进入实施) -> 范围:`openviking/storage`、`openviking/retrieve`、`openviking/eval` 的向量存储接入层 - ---- - -## 1. 背景与目标 - -当前向量存储路径中,`VikingVectorIndexBackend`、`vector_store.driver`、`collection_adapter`、`vectordb.collection` 在职责上有重叠: - -- 过滤表达式编译(AST -> DSL)分散在多层。 -- 后端差异(URL 适配、索引参数、读写字段规范化)未完全下沉。 -- `collection` 语义与“单 collection 绑定”的运行模型不完全一致。 - -本次重构目标: - -1. **通用业务逻辑上收**到 gateway/backend 层。 -2. **后端差异下沉**到 collection 层。 -3. `compile_filter` **可扩展但简单**(默认实现 + 子类 override)。 -4. backend/store **单 collection 绑定**(除 `create_collection(name, ...)` 外,不再传 collection)。 -5. 过滤语义统一:`uri`/`parent_uri` 等 path 字段统一使用 `must` 语义。 -6. 移除 `Prefix` / `Regex` expr AST 能力。 - ---- - -## 2. 已确认决策(来自前序讨论) - -### 2.1 分层决策 - -- `CollectionAdapter` 与旧 `driver` 职责高度重叠,最终应移除。 -- `compile_filter` 不放在独立 driver 层,放到 `ICollection`,由 `Collection` 在调用查询接口前触发。 -- 新增向量库时,只需在对应 collection 实现重写 `compile_filter`(默认可不重写)。 - -### 2.2 单 collection 约束 - -- `VikingVectorIndexBackend`(及其上层管理者)内部只持有一个当前 collection。 -- `create_collection(name, ...)` 是唯一需要显式 `name` 的入口。 -- 后续 CRUD / search / filter 等都操作绑定 collection,不再重复传 collection 参数。 - -### 2.3 filter 语义决策 - -- 输入兼容:`FilterExpr | dict | None`。 -- path 字段过滤不使用 prefix op,统一映射到 `must`。 -- 不引入 `_path_must` 之类额外包装层。 - -### 2.4 命名与代码风格决策 - -- 业务代码中原有 `self.vikingdb` 命名保持,不做无收益替换为 `self.vector_store`。 -- gateway 命名应更语义化(见第 8 节迁移建议)。 - -### 2.5 表达式能力决策 - -- `Prefix` / `Regex` 两个 expr AST 能力移除。 -- 若需要后端特有复杂语法,使用 `RawDSL` 或 backend-specific collection override。 - ---- - -## 3. 当前代码现状(实施前基线) - -### 3.1 主要组件 - -- `openviking/storage/viking_vector_index_backend.py` - - 当前承担大量业务逻辑、collection 管理、filter 编译调用。 -- `openviking/storage/vector_store/driver.py` + `drivers/*` - - 当前仍存在后端差异封装与 `compile_expr`。 -- `openviking/storage/collection_adapter.py` - - 与 driver 层重复(工厂 + backend 分发 + filter 编译 + normalize)。 -- `openviking/storage/vectordb/collection/collection.py` - - `ICollection` / `Collection` 封装,尚未成为 filter 编译单一入口。 - -### 3.2 现状问题清单 - -1. **重复抽象**:driver 与 collection_adapter 并存,维护成本高。 -2. **职责漂移**:backend 层仍携带后端差异处理逻辑。 -3. **扩展成本高**:新增后端需改多层(factory/driver/backend)。 -4. **接口噪音**:单 collection 场景下仍出现 collection 参数概念。 - ---- - -## 4. 目标架构 - -```text -[Business Callers] - | - v -[Semantic Gateway / VikingDBManager] - - 租户作用域/业务语义 - - 单 collection 生命周期 - - 通用查询编排 - | - v -[Collection (wrapper)] - - 所有含 filters 的调用前统一 compile_filter - - 统一结果包装 - | - v -[ICollection implementations] - - LocalCollection - - HttpCollection - - VolcengineCollection - - VikingDBCollection - - backend-specific compile_filter override(可选) -``` - -核心原则: - -- **只保留一条“filter 编译路径”**:`ICollection.compile_filter`。 -- **后端差异只出现在具体 collection 子类**。 -- **gateway 不持有 backend 语法细节**。 - ---- - -## 5. `compile_filter` 设计规范 - -### 5.1 接口定义 - -在 `ICollection` 增加默认实现: - -```python -def compile_filter(self, filter_expr: FilterExpr | dict | None) -> dict: - ... -``` - -### 5.2 默认行为 - -- `None` -> `{}` -- `dict` -> 原样透传 -- `RawDSL` -> 透传 payload -- `Eq/In` -> `{"op": "must", "field": ..., "conds": [...]}` -- `And/Or/Range/Contains/TimeRange` -> 按统一 DSL 映射 - -### 5.3 可扩展机制 - -- 新后端如语法不同:只在该后端 collection 中重写 `compile_filter`。 -- 未重写时自动使用默认实现,保证接入门槛低。 - -### 5.4 复杂度控制 - -- 不引入额外 compiler 注册中心。 -- 不新增 driver 级 compiler 层。 -- 优先“默认实现 + 最小 override”。 - ---- - -## 6. FilterExpr 能力边界(重构后) - -### 6.1 保留能力 - -- `And` -- `Or` -- `Eq` -- `In` -- `Range` -- `Contains` -- `TimeRange` -- `RawDSL` - -### 6.2 移除能力 - -- `Prefix` -- `Regex` - -### 6.3 path 字段规则 - -对 `uri` / `parent_uri` / 其他路径字段,统一使用: - -```json -{"op":"must","field":"uri","conds":["..."]} -``` - -不允许在 AST 层保留 prefix 语义入口。 - ---- - -## 7. 单 Collection 运行模型 - -### 7.1 约束 - -- backend 实例内部仅绑定一个 active collection。 -- 除 `create_collection(name, schema)` 外,其余操作基于绑定对象。 - -### 7.2 操作模型 - -- 创建流程: - - `create_collection(name, schema)` - - 建立 `self._collection` 绑定 - - 更新 meta cache -- 数据流程: - - `insert/update/upsert/delete/get/search/filter/count/...` 均操作 `self._collection` - -### 7.3 错误模型 - -- 未绑定 collection 时抛 `CollectionNotFoundError` 或统一运行时错误。 -- 不再依赖每次调用传 collection_name 做防御。 - ---- - -## 8. 命名与接口整理建议 - -### 8.1 gateway 命名 - -建议把“语义检索网关”命名统一为更语义化名称(例如 `SemanticGateway` / `SemanticContextGateway`)。 - -> 兼容策略:保留原类名 alias 一段时间,避免一次性大面积改动。 - -### 8.2 变量命名一致性 - -- 业务模块中已存在 `self.vikingdb` 的位置保持不变。 -- 不做“仅换名不换义”的全局改名(避免噪音 diff)。 - -### 8.3 说明 - -本设计文档不处理 CRUD private 收口;该议题后续单独评估。 - ---- - -## 9. 分阶段迁移计划(实施顺序) - -### Phase A:FilterExpr 与语义基线 - -1. 移除 `Prefix` / `Regex` AST 定义与导出。 -2. 清理编译分支中对应逻辑。 -3. 修复测试中对 `Prefix`/`Regex` 的 AST 断言。 - -**验收**:无 `Prefix`/`Regex` 类型引用;编译与静态检查通过。 - -### Phase B:compile_filter 下沉到 Collection - -1. `ICollection` 增加默认 `compile_filter`。 -2. `Collection` wrapper 在查询前统一调用 `compile_filter`。 -3. backend 中重复的 filter 编译逻辑删除。 - -**验收**:调用方仍可传 AST/dict,行为一致。 - -### Phase C:去除 driver / adapter 重复层 - -1. backend 不再依赖 `create_driver` / `VectorStoreDriver`。 -2. 删除(或停用)`collection_adapter.py` 与 `vector_store/driver*` 路径。 -3. 后端差异迁移到具体 `ICollection` 子类。 - -**验收**:新增 backend 仅需实现 collection + optional compile_filter override。 - -### Phase D:收口与稳定 - -1. 清理遗留 import/export。 -2. 更新设计文档与开发文档。 -3. 完成回归测试矩阵。 - -**验收**:无 dead adapter/driver 引用;主链路稳定。 - ---- - -## 10. 影响面与兼容性 - -### 10.1 影响模块 - -- `openviking/storage/*` -- `openviking/retrieve/*` -- `openviking/session/*` -- `openviking/eval/recorder/*` -- `openviking/eval/ragas/*` - -### 10.2 兼容策略 - -- dict filter 调用保持兼容。 -- AST 精简(去 Prefix/Regex)属于显式破坏性变更,依赖方需改为 `In` 或 `RawDSL`。 -- eval/recorder 现状可继续使用通用 CRUD,不在本文档范围做 private 收口。 - ---- - -## 11. 测试与验收标准 - -### 11.1 静态与构建 - -- `ruff check` 通过 -- `python -m compileall openviking` 通过 - -### 11.2 功能测试矩阵 - -1. **filter 编译**:AST / dict / RawDSL / None -2. **路径过滤**:`uri`/`parent_uri` 使用 `must` 语义 -3. **单 collection**:create 后无需传 collection 参数 -4. **后端差异**:至少验证一个后端 override `compile_filter` 生效 -5. **回归链路**:检索、去重、目录初始化、URI 更新映射 - -### 11.3 代码检索验收 - -- 无 `Prefix` / `Regex` expr 定义与引用。 -- 无 `collection_adapter` / `VectorStoreDriver` 生产路径引用(完成 Phase C 后)。 - ---- - -## 12. 风险与缓解 - -### 风险 1:迁移期双实现并存导致行为不一致 - -- **缓解**:以 `ICollection.compile_filter` 为唯一真源,旧分支尽快删除。 - -### 风险 2:测试桩接口与真实接口漂移 - -- **缓解**:统一测试 stub 最小接口契约,优先修复 `collection_exists_bound` 等缺失。 - -### 风险 3:后端特化语法回归 - -- **缓解**:在对应 collection override 中增加最小单测覆盖。 - ---- - -## 13. Out of Scope(本轮明确不做) - -1. CRUD public/private 收口策略。 -2. eval/recorder 能力边界重定义。 -3. 非向量存储模块(FS/Parser/Client)的结构性重构。 - ---- - -## 14. 实施完成定义(DoD) - -满足以下条件可认为本重构完成: - -1. filter 编译链路单一(Collection 入口)。 -2. backend 单 collection 绑定模式稳定。 -3. driver/adapter 重复层移除。 -4. `Prefix`/`Regex` expr 能力移除且无残留调用。 -5. 主流程回归通过并补齐设计文档。 - diff --git a/openviking/core/directories.py b/openviking/core/directories.py index 4c8c598f..080cf7ca 100644 --- a/openviking/core/directories.py +++ b/openviking/core/directories.py @@ -145,7 +145,6 @@ def __init__( vikingdb: "VikingDBManager", ): self.vikingdb = vikingdb - self.semantic_gateway = vikingdb async def initialize_account_directories(self, ctx: RequestContext) -> int: """Initialize account-shared scope roots.""" @@ -229,7 +228,7 @@ async def _ensure_directory( logger.debug(f"[VikingFS] Directory {uri} already exists") # 2. Ensure record exists in vector storage - existing = await self.semantic_gateway.get_context_by_uri( + existing = await self.vikingdb.get_context_by_uri( account_id=ctx.account_id, uri=uri, limit=1, diff --git a/openviking/session/session.py b/openviking/session/session.py index 07cb5fde..136823e4 100644 --- a/openviking/session/session.py +++ b/openviking/session/session.py @@ -77,7 +77,6 @@ def __init__( ): self._viking_fs = viking_fs self._vikingdb_manager = vikingdb_manager - self._semantic_gateway = vikingdb_manager self._session_compressor = session_compressor self.user = user or UserIdentifier.the_default_user() self.ctx = ctx or RequestContext(user=self.user, role=Role.ROOT) @@ -297,12 +296,12 @@ def commit(self) -> Dict[str, Any]: def _update_active_counts(self) -> int: """Update active_count for used contexts/skills.""" - if not self._semantic_gateway: + if not self._vikingdb_manager: return 0 uris = [usage.uri for usage in self._usage_records if usage.uri] try: - updated = run_async(self._semantic_gateway.increment_active_count(self.ctx, uris)) + updated = run_async(self._vikingdb_manager.increment_active_count(self.ctx, uris)) except Exception as e: logger.debug(f"Could not update active_count for usage URIs: {e}") updated = 0 diff --git a/tests/session/test_session_compressor_semantic_gateway.py b/tests/session/test_session_compressor_vikingdb.py similarity index 81% rename from tests/session/test_session_compressor_semantic_gateway.py rename to tests/session/test_session_compressor_vikingdb.py index d2a1c7d7..71e22533 100644 --- a/tests/session/test_session_compressor_semantic_gateway.py +++ b/tests/session/test_session_compressor_vikingdb.py @@ -12,9 +12,9 @@ @pytest.mark.asyncio -async def test_delete_existing_memory_uses_semantic_gateway(): +async def test_delete_existing_memory_uses_vikingdb_manager(): compressor = SessionCompressor.__new__(SessionCompressor) - compressor.semantic_gateway = AsyncMock() + compressor.vikingdb = AsyncMock() viking_fs = AsyncMock() memory = SimpleNamespace(uri="viking://user/user1/memories/events/e1") ctx = RequestContext(user=UserIdentifier("acc1", "user1", "agent1"), role=Role.USER) @@ -23,4 +23,4 @@ async def test_delete_existing_memory_uses_semantic_gateway(): assert ok is True viking_fs.rm.assert_awaited_once_with(memory.uri, recursive=False, ctx=ctx) - compressor.semantic_gateway.delete_uris.assert_awaited_once_with(ctx, [memory.uri]) + compressor.vikingdb.delete_uris.assert_awaited_once_with(ctx, [memory.uri]) From 7729e89a8d96fe21d3fb8e53f1e2d0364213f91c Mon Sep 17 00:00:00 2001 From: "zhoujiahui.01" Date: Fri, 27 Feb 2026 11:59:32 +0800 Subject: [PATCH 6/7] fix --- openviking/storage/collection_adapter.py | 729 ------------------ openviking/storage/{vector_store => }/expr.py | 0 openviking/storage/vector_store/__init__.py | 27 - .../collection/vikingdb_collection.py | 10 +- .../collection/volcengine_collection.py | 19 +- .../storage/vectordb_adapters/__init__.py | 19 + openviking/storage/vectordb_adapters/base.py | 425 ++++++++++ .../storage/vectordb_adapters/factory.py | 29 + .../storage/vectordb_adapters/http_adapter.py | 75 ++ .../vectordb_adapters/local_adapter.py | 53 ++ .../vikingdb_private_adapter.py | 121 +++ .../vectordb_adapters/volcengine_adapter.py | 132 ++++ .../storage/viking_vector_index_backend.py | 4 +- .../utils/config/vectordb_config.py | 17 +- tests/storage/test_context_vector_gateway.py | 50 -- 15 files changed, 886 insertions(+), 824 deletions(-) delete mode 100644 openviking/storage/collection_adapter.py rename openviking/storage/{vector_store => }/expr.py (100%) delete mode 100644 openviking/storage/vector_store/__init__.py create mode 100644 openviking/storage/vectordb_adapters/__init__.py create mode 100644 openviking/storage/vectordb_adapters/base.py create mode 100644 openviking/storage/vectordb_adapters/factory.py create mode 100644 openviking/storage/vectordb_adapters/http_adapter.py create mode 100644 openviking/storage/vectordb_adapters/local_adapter.py create mode 100644 openviking/storage/vectordb_adapters/vikingdb_private_adapter.py create mode 100644 openviking/storage/vectordb_adapters/volcengine_adapter.py delete mode 100644 tests/storage/test_context_vector_gateway.py diff --git a/openviking/storage/collection_adapter.py b/openviking/storage/collection_adapter.py deleted file mode 100644 index 5e621ba0..00000000 --- a/openviking/storage/collection_adapter.py +++ /dev/null @@ -1,729 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Collection adapter layer for backend-specific storage integration.""" - -from __future__ import annotations - -import os -import uuid -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Any, Dict, Iterable, Optional -from urllib.parse import urlparse - -from openviking.storage.errors import CollectionNotFoundError -from openviking.storage.vector_store.expr import ( - And, - Contains, - Eq, - FilterExpr, - In, - Or, - Range, - RawDSL, - TimeRange, -) -from openviking.storage.vectordb.collection.collection import Collection -from openviking.storage.vectordb.collection.http_collection import ( - HttpCollection, - get_or_create_http_collection, - list_vikingdb_collections, -) -from openviking.storage.vectordb.collection.local_collection import get_or_create_local_collection -from openviking.storage.vectordb.collection.result import FetchDataInCollectionResult -from openviking.storage.vectordb.collection.vikingdb_clients import VIKINGDB_APIS, VikingDBClient -from openviking.storage.vectordb.collection.vikingdb_collection import VikingDBCollection -from openviking.storage.vectordb.collection.volcengine_collection import ( - VolcengineCollection, - get_or_create_volcengine_collection, -) -from openviking_cli.utils import get_logger - -logger = get_logger(__name__) - - -def _parse_url(url: str) -> tuple[str, int]: - normalized = url - if not normalized.startswith(("http://", "https://")): - normalized = f"http://{normalized}" - parsed = urlparse(normalized) - host = parsed.hostname or "127.0.0.1" - port = parsed.port or 5000 - return host, port - - -def _normalize_collection_names(raw_collections: Iterable[Any]) -> list[str]: - names: list[str] = [] - for item in raw_collections: - if isinstance(item, str): - names.append(item) - elif isinstance(item, dict): - name = item.get("CollectionName") or item.get("collection_name") or item.get("name") - if isinstance(name, str): - names.append(name) - return names - - -class CollectionAdapter(ABC): - """Backend-specific adapter for single-collection operations.""" - - mode: str - - def __init__(self, collection_name: str): - self._collection_name = collection_name - self._collection: Optional[Collection] = None - - @property - def collection_name(self) -> str: - return self._collection_name - - @classmethod - @abstractmethod - def from_config(cls, config: Any) -> "CollectionAdapter": - """Create an adapter instance from VectorDB backend config.""" - - @abstractmethod - def _load_existing_collection_if_needed(self) -> None: - """Load existing bound collection handle when possible.""" - - @abstractmethod - def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: - """Create backend collection handle for bound collection.""" - - def collection_exists(self) -> bool: - self._load_existing_collection_if_needed() - return self._collection is not None - - def get_collection(self) -> Collection: - self._load_existing_collection_if_needed() - if self._collection is None: - raise CollectionNotFoundError(f"Collection {self._collection_name} does not exist") - return self._collection - - def create_collection( - self, - name: str, - schema: Dict[str, Any], - *, - distance: str, - sparse_weight: float, - index_name: str, - ) -> bool: - if self.collection_exists(): - return False - - self._collection_name = name - collection_meta = dict(schema) - scalar_index_fields = collection_meta.pop("ScalarIndex", []) - if "CollectionName" not in collection_meta: - collection_meta["CollectionName"] = name - - self._collection = self._create_backend_collection(collection_meta) - - scalar_index_fields = self.sanitize_scalar_index_fields( - scalar_index_fields=scalar_index_fields, - fields_meta=collection_meta.get("Fields", []), - ) - index_meta = self.build_default_index_meta( - index_name=index_name, - distance=distance, - use_sparse=sparse_weight > 0.0, - sparse_weight=sparse_weight, - scalar_index_fields=scalar_index_fields, - ) - self._collection.create_index(index_name, index_meta) - return True - - def drop_collection(self) -> bool: - if not self.collection_exists(): - return False - - coll = self.get_collection() - - # Drop indexes first so index lifecycle remains internal to adapter. - try: - for index_name in coll.list_indexes() or []: - try: - coll.drop_index(index_name) - except Exception as e: - logger.warning("Failed to drop index %s: %s", index_name, e) - except Exception as e: - logger.warning("Failed to list indexes before dropping collection: %s", e) - - try: - coll.drop() - except NotImplementedError: - logger.warning("Collection drop is not supported by backend mode=%s", self.mode) - return False - finally: - self._collection = None - - return True - - def close(self) -> None: - if self._collection is not None: - self._collection.close() - self._collection = None - - def get_collection_info(self) -> Optional[Dict[str, Any]]: - if not self.collection_exists(): - return None - return self.get_collection().get_meta_data() - - def sanitize_scalar_index_fields( - self, - scalar_index_fields: list[str], - fields_meta: list[dict[str, Any]], - ) -> list[str]: - return scalar_index_fields - - def build_default_index_meta( - self, - *, - index_name: str, - distance: str, - use_sparse: bool, - sparse_weight: float, - scalar_index_fields: list[str], - ) -> Dict[str, Any]: - index_type = "flat_hybrid" if use_sparse else "flat" - index_meta: Dict[str, Any] = { - "IndexName": index_name, - "VectorIndex": { - "IndexType": index_type, - "Distance": distance, - "Quant": "int8", - }, - "ScalarIndex": scalar_index_fields, - } - if use_sparse: - index_meta["VectorIndex"]["EnableSparse"] = True - index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight - return index_meta - - def normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: - return record - - def compile_filter(self, expr: FilterExpr | Dict[str, Any] | None) -> Dict[str, Any]: - if expr is None: - return {} - if isinstance(expr, dict): - return expr - if isinstance(expr, RawDSL): - return expr.payload - if isinstance(expr, And): - conds = [self.compile_filter(c) for c in expr.conds if c is not None] - conds = [c for c in conds if c] - if not conds: - return {} - if len(conds) == 1: - return conds[0] - return {"op": "and", "conds": conds} - if isinstance(expr, Or): - conds = [self.compile_filter(c) for c in expr.conds if c is not None] - conds = [c for c in conds if c] - if not conds: - return {} - if len(conds) == 1: - return conds[0] - return {"op": "or", "conds": conds} - if isinstance(expr, Eq): - return {"op": "must", "field": expr.field, "conds": [expr.value]} - if isinstance(expr, In): - return {"op": "must", "field": expr.field, "conds": list(expr.values)} - if isinstance(expr, Range): - payload: Dict[str, Any] = {"op": "range", "field": expr.field} - if expr.gte is not None: - payload["gte"] = expr.gte - if expr.gt is not None: - payload["gt"] = expr.gt - if expr.lte is not None: - payload["lte"] = expr.lte - if expr.lt is not None: - payload["lt"] = expr.lt - return payload - if isinstance(expr, Contains): - return { - "op": "contains", - "field": expr.field, - "substring": expr.substring, - } - if isinstance(expr, TimeRange): - payload: Dict[str, Any] = {"op": "range", "field": expr.field} - if expr.start is not None: - payload["gte"] = expr.start - if expr.end is not None: - payload["lt"] = expr.end - return payload - raise TypeError(f"Unsupported filter expr type: {type(expr)!r}") - - def upsert(self, data: Dict[str, Any] | list[Dict[str, Any]]) -> list[str]: - coll = self.get_collection() - records = [data] if isinstance(data, dict) else data - normalized: list[Dict[str, Any]] = [] - ids: list[str] = [] - for item in records: - record = dict(item) - record_id = record.get("id") or str(uuid.uuid4()) - record["id"] = record_id - ids.append(record_id) - normalized.append(record) - coll.upsert_data(normalized) - return ids - - def get(self, ids: list[str]) -> list[Dict[str, Any]]: - coll = self.get_collection() - result = coll.fetch_data(ids) - - records: list[Dict[str, Any]] = [] - if isinstance(result, FetchDataInCollectionResult): - for item in result.items: - record = dict(item.fields) if item.fields else {} - record["id"] = item.id - records.append(self.normalize_record_for_read(record)) - return records - - if isinstance(result, dict) and "fetch" in result: - for item in result.get("fetch", []): - record = dict(item.get("fields", {})) if item.get("fields") else {} - record_id = item.get("id") - if record_id: - record["id"] = record_id - records.append(self.normalize_record_for_read(record)) - return records - - def query( - self, - *, - query_vector: Optional[list[float]] = None, - sparse_query_vector: Optional[Dict[str, float]] = None, - filter: Optional[Dict[str, Any] | FilterExpr] = None, - limit: int = 10, - offset: int = 0, - output_fields: Optional[list[str]] = None, - with_vector: bool = False, - order_by: Optional[str] = None, - order_desc: bool = False, - ) -> list[Dict[str, Any]]: - coll = self.get_collection() - vectordb_filter = self.compile_filter(filter) - - if query_vector or sparse_query_vector: - result = coll.search_by_vector( - index_name="default", - dense_vector=query_vector, - sparse_vector=sparse_query_vector, - limit=limit, - offset=offset, - filters=vectordb_filter, - output_fields=output_fields, - ) - elif order_by: - result = coll.search_by_scalar( - index_name="default", - field=order_by, - order="desc" if order_desc else "asc", - limit=limit, - offset=offset, - filters=vectordb_filter, - output_fields=output_fields, - ) - else: - result = coll.search_by_random( - index_name="default", - limit=limit, - offset=offset, - filters=vectordb_filter, - output_fields=output_fields, - ) - - records: list[Dict[str, Any]] = [] - for item in result.data: - record = dict(item.fields) if item.fields else {} - record["id"] = item.id - record["_score"] = item.score if item.score is not None else 0.0 - record = self.normalize_record_for_read(record) - if not with_vector: - record.pop("vector", None) - record.pop("sparse_vector", None) - records.append(record) - return records - - def delete( - self, - *, - ids: Optional[list[str]] = None, - filter: Optional[Dict[str, Any] | FilterExpr] = None, - limit: int = 100000, - ) -> int: - coll = self.get_collection() - delete_ids = list(ids or []) - if not delete_ids and filter is not None: - matched = self.query(filter=filter, limit=limit, with_vector=True) - delete_ids = [record["id"] for record in matched if record.get("id")] - - if not delete_ids: - return 0 - - coll.delete_data(delete_ids) - return len(delete_ids) - - def count(self, filter: Optional[Dict[str, Any] | FilterExpr] = None) -> int: - coll = self.get_collection() - result = coll.aggregate_data( - index_name="default", - op="count", - filters=self.compile_filter(filter), - ) - return result.agg.get("_total", 0) - - def clear(self) -> bool: - self.get_collection().delete_all_data() - return True - - -class LocalCollectionAdapter(CollectionAdapter): - """Adapter for local embedded vectordb backend.""" - - DEFAULT_LOCAL_PROJECT_NAME = "vectordb" - - def __init__(self, collection_name: str, project_path: str): - super().__init__(collection_name=collection_name) - self.mode = "local" - self._project_path = project_path - - @classmethod - def from_config(cls, config): - project_path = ( - str(Path(config.path) / cls.DEFAULT_LOCAL_PROJECT_NAME) if config.path else "" - ) - return cls(collection_name=config.name or "context", project_path=project_path) - - def _collection_path(self) -> str: - if not self._project_path: - return "" - return str(Path(self._project_path) / self._collection_name) - - def _load_existing_collection_if_needed(self) -> None: - if self._collection is not None: - return - collection_path = self._collection_path() - if not collection_path: - return - meta_path = os.path.join(collection_path, "collection_meta.json") - if os.path.exists(meta_path): - self._collection = get_or_create_local_collection(path=collection_path) - - def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: - collection_path = self._collection_path() - if collection_path: - os.makedirs(collection_path, exist_ok=True) - return get_or_create_local_collection(meta_data=meta, path=collection_path) - - -class HttpCollectionAdapter(CollectionAdapter): - """Adapter for remote HTTP vectordb project.""" - - def __init__(self, host: str, port: int, project_name: str, collection_name: str): - super().__init__(collection_name=collection_name) - self.mode = "http" - self._host = host - self._port = port - self._project_name = project_name - - @classmethod - def from_config(cls, config): - if not config.url: - raise ValueError("HTTP backend requires a valid URL") - host, port = _parse_url(config.url) - return cls( - host=host, - port=port, - project_name=config.project_name or "default", - collection_name=config.name or "context", - ) - - def _meta(self) -> Dict[str, Any]: - return { - "ProjectName": self._project_name, - "CollectionName": self._collection_name, - } - - def _remote_has_collection(self) -> bool: - raw = list_vikingdb_collections( - host=self._host, - port=self._port, - project_name=self._project_name, - ) - return self._collection_name in _normalize_collection_names(raw) - - def _load_existing_collection_if_needed(self) -> None: - if self._collection is not None: - return - if not self._remote_has_collection(): - return - self._collection = Collection( - HttpCollection( - ip=self._host, - port=self._port, - meta_data=self._meta(), - ) - ) - - def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: - payload = dict(meta) - payload.update(self._meta()) - return get_or_create_http_collection( - host=self._host, - port=self._port, - meta_data=payload, - ) - - -class VolcengineCollectionAdapter(CollectionAdapter): - """Adapter for Volcengine-hosted VikingDB.""" - - def __init__( - self, - *, - ak: str, - sk: str, - region: str, - host: str, - project_name: str, - collection_name: str, - ): - super().__init__(collection_name=collection_name) - self.mode = "volcengine" - self._ak = ak - self._sk = sk - self._region = region - self._host = host - self._project_name = project_name - - @classmethod - def from_config(cls, config): - if not ( - config.volcengine - and config.volcengine.ak - and config.volcengine.sk - and config.volcengine.region - ): - raise ValueError("Volcengine backend requires AK, SK, and Region configuration") - return cls( - ak=config.volcengine.ak, - sk=config.volcengine.sk, - region=config.volcengine.region, - host=config.volcengine.host or "", - project_name=config.project_name or "default", - collection_name=config.name or "context", - ) - - def _meta(self) -> Dict[str, Any]: - return { - "ProjectName": self._project_name, - "CollectionName": self._collection_name, - } - - def _config(self) -> Dict[str, Any]: - return { - "AK": self._ak, - "SK": self._sk, - "Region": self._region, - "Host": self._host, - } - - def _new_collection_handle(self) -> VolcengineCollection: - return VolcengineCollection( - ak=self._ak, - sk=self._sk, - region=self._region, - host=self._host, - meta_data=self._meta(), - ) - - def _load_existing_collection_if_needed(self) -> None: - if self._collection is not None: - return - candidate = self._new_collection_handle() - meta = candidate.get_meta_data() or {} - if meta and meta.get("CollectionName"): - self._collection = candidate - - def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: - payload = dict(meta) - payload.update(self._meta()) - return get_or_create_volcengine_collection( - config=self._config(), - meta_data=payload, - ) - - def sanitize_scalar_index_fields( - self, - scalar_index_fields: list[str], - fields_meta: list[dict[str, Any]], - ) -> list[str]: - date_time_fields = { - field.get("FieldName") for field in fields_meta if field.get("FieldType") == "date_time" - } - return [field for field in scalar_index_fields if field not in date_time_fields] - - def build_default_index_meta( - self, - *, - index_name: str, - distance: str, - use_sparse: bool, - sparse_weight: float, - scalar_index_fields: list[str], - ) -> Dict[str, Any]: - index_type = "hnsw_hybrid" if use_sparse else "hnsw" - index_meta: Dict[str, Any] = { - "IndexName": index_name, - "VectorIndex": { - "IndexType": index_type, - "Distance": distance, - "Quant": "int8", - }, - "ScalarIndex": scalar_index_fields, - } - if use_sparse: - index_meta["VectorIndex"]["EnableSparse"] = True - index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight - return index_meta - - def normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: - for key in ("uri", "parent_uri"): - value = record.get(key) - if isinstance(value, str) and not value.startswith("viking://"): - stripped = value.strip("/") - if stripped: - record[key] = f"viking://{stripped}" - return record - - -class VikingDBPrivateCollectionAdapter(CollectionAdapter): - """Adapter for private VikingDB deployment.""" - - def __init__( - self, - *, - host: str, - headers: Optional[dict[str, str]], - project_name: str, - collection_name: str, - ): - super().__init__(collection_name=collection_name) - self.mode = "vikingdb" - self._host = host - self._headers = headers - self._project_name = project_name - - @classmethod - def from_config(cls, config): - if not config.vikingdb or not config.vikingdb.host: - raise ValueError("VikingDB backend requires a valid host") - return cls( - host=config.vikingdb.host, - headers=config.vikingdb.headers, - project_name=config.project_name or "default", - collection_name=config.name or "context", - ) - - def _client(self) -> VikingDBClient: - return VikingDBClient(self._host, self._headers) - - def _fetch_collection_meta(self) -> Optional[Dict[str, Any]]: - path, method = VIKINGDB_APIS["GetVikingdbCollection"] - req = { - "ProjectName": self._project_name, - "CollectionName": self._collection_name, - } - response = self._client().do_req(method, path=path, req_body=req) - if response.status_code != 200: - return None - result = response.json() - meta = result.get("Result", {}) - return meta or None - - def _load_existing_collection_if_needed(self) -> None: - if self._collection is not None: - return - meta = self._fetch_collection_meta() - if meta is None: - return - self._collection = Collection( - VikingDBCollection( - host=self._host, - headers=self._headers, - meta_data=meta, - ) - ) - - def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: - self._load_existing_collection_if_needed() - if self._collection is None: - raise NotImplementedError("private vikingdb collection should be pre-created") - return self._collection - - def sanitize_scalar_index_fields( - self, - scalar_index_fields: list[str], - fields_meta: list[dict[str, Any]], - ) -> list[str]: - date_time_fields = { - field.get("FieldName") for field in fields_meta if field.get("FieldType") == "date_time" - } - return [field for field in scalar_index_fields if field not in date_time_fields] - - def build_default_index_meta( - self, - *, - index_name: str, - distance: str, - use_sparse: bool, - sparse_weight: float, - scalar_index_fields: list[str], - ) -> Dict[str, Any]: - index_type = "hnsw_hybrid" if use_sparse else "hnsw" - index_meta: Dict[str, Any] = { - "IndexName": index_name, - "VectorIndex": { - "IndexType": index_type, - "Distance": distance, - "Quant": "int8", - }, - "ScalarIndex": scalar_index_fields, - } - if use_sparse: - index_meta["VectorIndex"]["EnableSparse"] = True - index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight - return index_meta - - def normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: - for key in ("uri", "parent_uri"): - value = record.get(key) - if isinstance(value, str) and not value.startswith("viking://"): - stripped = value.strip("/") - if stripped: - record[key] = f"viking://{stripped}" - return record - - -_ADAPTER_REGISTRY: dict[str, type[CollectionAdapter]] = { - "local": LocalCollectionAdapter, - "http": HttpCollectionAdapter, - "volcengine": VolcengineCollectionAdapter, - "vikingdb": VikingDBPrivateCollectionAdapter, -} - - -def create_collection_adapter(config) -> CollectionAdapter: - """Unified factory entrypoint for backend-specific collection adapters.""" - adapter_cls = _ADAPTER_REGISTRY.get(config.backend) - if adapter_cls is None: - raise ValueError( - f"Vector backend {config.backend} is not supported. " - f"Available backends: {sorted(_ADAPTER_REGISTRY)}" - ) - return adapter_cls.from_config(config) diff --git a/openviking/storage/vector_store/expr.py b/openviking/storage/expr.py similarity index 100% rename from openviking/storage/vector_store/expr.py rename to openviking/storage/expr.py diff --git a/openviking/storage/vector_store/__init__.py b/openviking/storage/vector_store/__init__.py deleted file mode 100644 index 05ed11f9..00000000 --- a/openviking/storage/vector_store/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 -"""Vector store filter expression types.""" - -from openviking.storage.vector_store.expr import ( - And, - Contains, - Eq, - FilterExpr, - In, - Or, - Range, - RawDSL, - TimeRange, -) - -__all__ = [ - "FilterExpr", - "And", - "Or", - "Eq", - "In", - "Range", - "Contains", - "TimeRange", - "RawDSL", -] diff --git a/openviking/storage/vectordb/collection/vikingdb_collection.py b/openviking/storage/vectordb/collection/vikingdb_collection.py index 9bfa74cd..b9995a1f 100644 --- a/openviking/storage/vectordb/collection/vikingdb_collection.py +++ b/openviking/storage/vectordb/collection/vikingdb_collection.py @@ -371,17 +371,17 @@ def aggregate_data( filters: Optional[Dict[str, Any]] = None, cond: Optional[Dict[str, Any]] = None, ) -> AggregateResult: - path = "/api/vikingdb/data/aggregate" + path = "/api/vikingdb/data/agg" data = { "project": self.project_name, "collection_name": self.collection_name, "index_name": index_name, - "agg": { - "op": op, - "field": field, - }, + "op": op, + "field": field, "filter": filters, } + if cond is not None: + data["cond"] = cond resp_data = self._data_post(path, data) return self._parse_aggregate_result(resp_data, op, field) diff --git a/openviking/storage/vectordb/collection/volcengine_collection.py b/openviking/storage/vectordb/collection/volcengine_collection.py index 8564c2c0..4acb9445 100644 --- a/openviking/storage/vectordb/collection/volcengine_collection.py +++ b/openviking/storage/vectordb/collection/volcengine_collection.py @@ -35,14 +35,13 @@ def get_or_create_volcengine_collection(config: Dict[str, Any], meta_data: Dict[ ak = config.get("AK") sk = config.get("SK") region = config.get("Region") - host = config.get("Host", "") collection_name = meta_data.get("CollectionName") if not collection_name: raise ValueError("CollectionName is required in config") # Initialize Console client for creating Collection - client = ClientForConsoleApi(ak, sk, region, host) + client = ClientForConsoleApi(ak, sk, region) # Try to create Collection try: @@ -64,10 +63,10 @@ def get_or_create_volcengine_collection(config: Dict[str, Any], meta_data: Dict[ raise e logger.info(f"Collection {collection_name} created successfully") - return VolcengineCollection(ak, sk, region, host, meta_data) + return VolcengineCollection(ak, sk, region, meta_data=meta_data) # Return VolcengineCollection instance - return VolcengineCollection(ak=ak, sk=sk, region=region, host=host, meta_data=meta_data) + return VolcengineCollection(ak=ak, sk=sk, region=region, meta_data=meta_data) class VolcengineCollection(ICollection): @@ -76,7 +75,7 @@ def __init__( ak: str, sk: str, region: str, - host: str, + host: Optional[str] = None, meta_data: Optional[Dict[str, Any]] = None, ): self.console_client = ClientForConsoleApi(ak, sk, region, host) @@ -580,16 +579,16 @@ def aggregate_data( filters: Optional[Dict[str, Any]] = None, cond: Optional[Dict[str, Any]] = None, ) -> AggregateResult: - path = "/api/vikingdb/data/aggregate" + path = "/api/vikingdb/data/agg" data = { "project": self.project_name, "collection_name": self.collection_name, "index_name": index_name, - "agg": { - "op": op, - "field": field, - }, + "op": op, + "field": field, "filter": filters, } + if cond is not None: + data["cond"] = cond resp_data = self._data_post(path, data) return self._parse_aggregate_result(resp_data, op, field) diff --git a/openviking/storage/vectordb_adapters/__init__.py b/openviking/storage/vectordb_adapters/__init__.py new file mode 100644 index 00000000..8446b64a --- /dev/null +++ b/openviking/storage/vectordb_adapters/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""VectorDB backend collection adapter package.""" + +from .base import CollectionAdapter +from .factory import create_collection_adapter +from .http_adapter import HttpCollectionAdapter +from .local_adapter import LocalCollectionAdapter +from .vikingdb_private_adapter import VikingDBPrivateCollectionAdapter +from .volcengine_adapter import VolcengineCollectionAdapter + +__all__ = [ + "CollectionAdapter", + "LocalCollectionAdapter", + "HttpCollectionAdapter", + "VolcengineCollectionAdapter", + "VikingDBPrivateCollectionAdapter", + "create_collection_adapter", +] diff --git a/openviking/storage/vectordb_adapters/base.py b/openviking/storage/vectordb_adapters/base.py new file mode 100644 index 00000000..fc74bb79 --- /dev/null +++ b/openviking/storage/vectordb_adapters/base.py @@ -0,0 +1,425 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Base adapter primitives for backend-specific vector collection operations.""" + +from __future__ import annotations + +import uuid +from abc import ABC, abstractmethod +from typing import Any, Dict, Iterable, Optional +from urllib.parse import urlparse + +from openviking.storage.errors import CollectionNotFoundError +from openviking.storage.expr import ( + And, + Contains, + Eq, + FilterExpr, + In, + Or, + Range, + RawDSL, + TimeRange, +) +from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.result import FetchDataInCollectionResult +from openviking_cli.utils import get_logger + +logger = get_logger(__name__) + + +def _parse_url(url: str) -> tuple[str, int]: + normalized = url + if not normalized.startswith(("http://", "https://")): + normalized = f"http://{normalized}" + parsed = urlparse(normalized) + host = parsed.hostname or "127.0.0.1" + port = parsed.port or 5000 + return host, port + + +def _normalize_collection_names(raw_collections: Iterable[Any]) -> list[str]: + names: list[str] = [] + for item in raw_collections: + if isinstance(item, str): + names.append(item) + elif isinstance(item, dict): + name = item.get("CollectionName") or item.get("collection_name") or item.get("name") + if isinstance(name, str): + names.append(name) + return names + + +class CollectionAdapter(ABC): + """Backend-specific adapter for single-collection operations. + + Public API methods are kept without prefix (create/query/upsert/delete/count...). + Internal extension hooks for subclasses use leading underscore. + """ + + mode: str + + def __init__(self, collection_name: str): + self._collection_name = collection_name + self._collection: Optional[Collection] = None + + @property + def collection_name(self) -> str: + return self._collection_name + + @classmethod + @abstractmethod + def from_config(cls, config: Any) -> "CollectionAdapter": + """Create an adapter instance from VectorDB backend config.""" + + @abstractmethod + def _load_existing_collection_if_needed(self) -> None: + """Load existing bound collection handle when possible.""" + + @abstractmethod + def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: + """Create backend collection handle for bound collection.""" + + def collection_exists(self) -> bool: + self._load_existing_collection_if_needed() + return self._collection is not None + + def get_collection(self) -> Collection: + self._load_existing_collection_if_needed() + if self._collection is None: + raise CollectionNotFoundError(f"Collection {self._collection_name} does not exist") + return self._collection + + def create_collection( + self, + name: str, + schema: Dict[str, Any], + *, + distance: str, + sparse_weight: float, + index_name: str, + ) -> bool: + if self.collection_exists(): + return False + + self._collection_name = name + collection_meta = dict(schema) + scalar_index_fields = collection_meta.pop("ScalarIndex", []) + if "CollectionName" not in collection_meta: + collection_meta["CollectionName"] = name + + self._collection = self._create_backend_collection(collection_meta) + + scalar_index_fields = self._sanitize_scalar_index_fields( + scalar_index_fields=scalar_index_fields, + fields_meta=collection_meta.get("Fields", []), + ) + index_meta = self._build_default_index_meta( + index_name=index_name, + distance=distance, + use_sparse=sparse_weight > 0.0, + sparse_weight=sparse_weight, + scalar_index_fields=scalar_index_fields, + ) + self._collection.create_index(index_name, index_meta) + return True + + def drop_collection(self) -> bool: + if not self.collection_exists(): + return False + + coll = self.get_collection() + + # Drop indexes first so index lifecycle remains internal to adapter. + try: + for index_name in coll.list_indexes() or []: + try: + coll.drop_index(index_name) + except Exception as e: + logger.warning("Failed to drop index %s: %s", index_name, e) + except Exception as e: + logger.warning("Failed to list indexes before dropping collection: %s", e) + + try: + coll.drop() + except NotImplementedError: + logger.warning("Collection drop is not supported by backend mode=%s", self.mode) + return False + finally: + self._collection = None + + return True + + def close(self) -> None: + if self._collection is not None: + self._collection.close() + self._collection = None + + def get_collection_info(self) -> Optional[Dict[str, Any]]: + if not self.collection_exists(): + return None + return self.get_collection().get_meta_data() + + def _sanitize_scalar_index_fields( + self, + scalar_index_fields: list[str], + fields_meta: list[dict[str, Any]], + ) -> list[str]: + return scalar_index_fields + + def _build_default_index_meta( + self, + *, + index_name: str, + distance: str, + use_sparse: bool, + sparse_weight: float, + scalar_index_fields: list[str], + ) -> Dict[str, Any]: + index_type = "flat_hybrid" if use_sparse else "flat" + index_meta: Dict[str, Any] = { + "IndexName": index_name, + "VectorIndex": { + "IndexType": index_type, + "Distance": distance, + "Quant": "int8", + }, + "ScalarIndex": scalar_index_fields, + } + if use_sparse: + index_meta["VectorIndex"]["EnableSparse"] = True + index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight + return index_meta + + def _normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: + return record + + def _compile_filter(self, expr: FilterExpr | Dict[str, Any] | None) -> Dict[str, Any]: + if expr is None: + return {} + if isinstance(expr, dict): + return expr + if isinstance(expr, RawDSL): + return expr.payload + if isinstance(expr, And): + conds = [self._compile_filter(c) for c in expr.conds if c is not None] + conds = [c for c in conds if c] + if not conds: + return {} + if len(conds) == 1: + return conds[0] + return {"op": "and", "conds": conds} + if isinstance(expr, Or): + conds = [self._compile_filter(c) for c in expr.conds if c is not None] + conds = [c for c in conds if c] + if not conds: + return {} + if len(conds) == 1: + return conds[0] + return {"op": "or", "conds": conds} + if isinstance(expr, Eq): + return {"op": "must", "field": expr.field, "conds": [expr.value]} + if isinstance(expr, In): + return {"op": "must", "field": expr.field, "conds": list(expr.values)} + if isinstance(expr, Range): + payload: Dict[str, Any] = {"op": "range", "field": expr.field} + if expr.gte is not None: + payload["gte"] = expr.gte + if expr.gt is not None: + payload["gt"] = expr.gt + if expr.lte is not None: + payload["lte"] = expr.lte + if expr.lt is not None: + payload["lt"] = expr.lt + return payload + if isinstance(expr, Contains): + return { + "op": "contains", + "field": expr.field, + "substring": expr.substring, + } + if isinstance(expr, TimeRange): + payload: Dict[str, Any] = {"op": "range", "field": expr.field} + if expr.start is not None: + payload["gte"] = expr.start + if expr.end is not None: + payload["lt"] = expr.end + return payload + raise TypeError(f"Unsupported filter expr type: {type(expr)!r}") + + # Backward-compatible aliases: keep old non-underscore names callable. + def sanitize_scalar_index_fields( + self, + scalar_index_fields: list[str], + fields_meta: list[dict[str, Any]], + ) -> list[str]: + return self._sanitize_scalar_index_fields( + scalar_index_fields=scalar_index_fields, + fields_meta=fields_meta, + ) + + def build_default_index_meta( + self, + *, + index_name: str, + distance: str, + use_sparse: bool, + sparse_weight: float, + scalar_index_fields: list[str], + ) -> Dict[str, Any]: + return self._build_default_index_meta( + index_name=index_name, + distance=distance, + use_sparse=use_sparse, + sparse_weight=sparse_weight, + scalar_index_fields=scalar_index_fields, + ) + + def normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: + return self._normalize_record_for_read(record) + + def compile_filter(self, expr: FilterExpr | Dict[str, Any] | None) -> Dict[str, Any]: + return self._compile_filter(expr) + + def upsert(self, data: Dict[str, Any] | list[Dict[str, Any]]) -> list[str]: + coll = self.get_collection() + records = [data] if isinstance(data, dict) else data + normalized: list[Dict[str, Any]] = [] + ids: list[str] = [] + for item in records: + record = dict(item) + record_id = record.get("id") or str(uuid.uuid4()) + record["id"] = record_id + ids.append(record_id) + normalized.append(record) + coll.upsert_data(normalized) + return ids + + def get(self, ids: list[str]) -> list[Dict[str, Any]]: + coll = self.get_collection() + result = coll.fetch_data(ids) + + records: list[Dict[str, Any]] = [] + if isinstance(result, FetchDataInCollectionResult): + for item in result.items: + record = dict(item.fields) if item.fields else {} + record["id"] = item.id + records.append(self._normalize_record_for_read(record)) + return records + + if isinstance(result, dict) and "fetch" in result: + for item in result.get("fetch", []): + record = dict(item.get("fields", {})) if item.get("fields") else {} + record_id = item.get("id") + if record_id: + record["id"] = record_id + records.append(self._normalize_record_for_read(record)) + return records + + def query( + self, + *, + query_vector: Optional[list[float]] = None, + sparse_query_vector: Optional[Dict[str, float]] = None, + filter: Optional[Dict[str, Any] | FilterExpr] = None, + limit: int = 10, + offset: int = 0, + output_fields: Optional[list[str]] = None, + with_vector: bool = False, + order_by: Optional[str] = None, + order_desc: bool = False, + ) -> list[Dict[str, Any]]: + coll = self.get_collection() + vectordb_filter = self._compile_filter(filter) + + if query_vector or sparse_query_vector: + result = coll.search_by_vector( + index_name="default", + dense_vector=query_vector, + sparse_vector=sparse_query_vector, + limit=limit, + offset=offset, + filters=vectordb_filter, + output_fields=output_fields, + ) + elif order_by: + result = coll.search_by_scalar( + index_name="default", + field=order_by, + order="desc" if order_desc else "asc", + limit=limit, + offset=offset, + filters=vectordb_filter, + output_fields=output_fields, + ) + else: + result = coll.search_by_random( + index_name="default", + limit=limit, + offset=offset, + filters=vectordb_filter, + output_fields=output_fields, + ) + + records: list[Dict[str, Any]] = [] + for item in result.data: + record = dict(item.fields) if item.fields else {} + record["id"] = item.id + record["_score"] = item.score if item.score is not None else 0.0 + record = self._normalize_record_for_read(record) + if not with_vector: + record.pop("vector", None) + record.pop("sparse_vector", None) + records.append(record) + return records + + def delete( + self, + *, + ids: Optional[list[str]] = None, + filter: Optional[Dict[str, Any] | FilterExpr] = None, + limit: int = 100000, + ) -> int: + coll = self.get_collection() + delete_ids = list(ids or []) + if not delete_ids and filter is not None: + matched = self.query(filter=filter, limit=limit, with_vector=True) + delete_ids = [record["id"] for record in matched if record.get("id")] + + if not delete_ids: + return 0 + + coll.delete_data(delete_ids) + return len(delete_ids) + + @staticmethod + def _coerce_int(value: Any) -> Optional[int]: + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float) and value.is_integer(): + return int(value) + if isinstance(value, str): + stripped = value.strip() + if stripped.isdigit(): + return int(stripped) + return None + + def count(self, filter: Optional[Dict[str, Any] | FilterExpr] = None) -> int: + coll = self.get_collection() + result = coll.aggregate_data( + index_name="default", + op="count", + filters=self._compile_filter(filter), + ) + if "_total" in result.agg: + parsed_total = self._coerce_int(result.agg.get("_total")) + if parsed_total is not None: + return parsed_total + + return 0 + + def clear(self) -> bool: + self.get_collection().delete_all_data() + return True diff --git a/openviking/storage/vectordb_adapters/factory.py b/openviking/storage/vectordb_adapters/factory.py new file mode 100644 index 00000000..21f15279 --- /dev/null +++ b/openviking/storage/vectordb_adapters/factory.py @@ -0,0 +1,29 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Adapter registry and factory entrypoints.""" + +from __future__ import annotations + +from .base import CollectionAdapter +from .http_adapter import HttpCollectionAdapter +from .local_adapter import LocalCollectionAdapter +from .vikingdb_private_adapter import VikingDBPrivateCollectionAdapter +from .volcengine_adapter import VolcengineCollectionAdapter + +_ADAPTER_REGISTRY: dict[str, type[CollectionAdapter]] = { + "local": LocalCollectionAdapter, + "http": HttpCollectionAdapter, + "volcengine": VolcengineCollectionAdapter, + "vikingdb": VikingDBPrivateCollectionAdapter, +} + + +def create_collection_adapter(config) -> CollectionAdapter: + """Unified factory entrypoint for backend-specific collection adapters.""" + adapter_cls = _ADAPTER_REGISTRY.get(config.backend) + if adapter_cls is None: + raise ValueError( + f"Vector backend {config.backend} is not supported. " + f"Available backends: {sorted(_ADAPTER_REGISTRY)}" + ) + return adapter_cls.from_config(config) diff --git a/openviking/storage/vectordb_adapters/http_adapter.py b/openviking/storage/vectordb_adapters/http_adapter.py new file mode 100644 index 00000000..d33b7243 --- /dev/null +++ b/openviking/storage/vectordb_adapters/http_adapter.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""HTTP backend collection adapter.""" + +from __future__ import annotations + +from typing import Any, Dict + +from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.http_collection import ( + HttpCollection, + get_or_create_http_collection, + list_vikingdb_collections, +) + +from .base import CollectionAdapter, _normalize_collection_names, _parse_url + + +class HttpCollectionAdapter(CollectionAdapter): + """Adapter for remote HTTP vectordb project.""" + + def __init__(self, host: str, port: int, project_name: str, collection_name: str): + super().__init__(collection_name=collection_name) + self.mode = "http" + self._host = host + self._port = port + self._project_name = project_name + + @classmethod + def from_config(cls, config: Any): + if not config.url: + raise ValueError("HTTP backend requires a valid URL") + host, port = _parse_url(config.url) + return cls( + host=host, + port=port, + project_name=config.project_name or "default", + collection_name=config.name or "context", + ) + + def _meta(self) -> Dict[str, Any]: + return { + "ProjectName": self._project_name, + "CollectionName": self._collection_name, + } + + def _remote_has_collection(self) -> bool: + raw = list_vikingdb_collections( + host=self._host, + port=self._port, + project_name=self._project_name, + ) + return self._collection_name in _normalize_collection_names(raw) + + def _load_existing_collection_if_needed(self) -> None: + if self._collection is not None: + return + if not self._remote_has_collection(): + return + self._collection = Collection( + HttpCollection( + ip=self._host, + port=self._port, + meta_data=self._meta(), + ) + ) + + def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: + payload = dict(meta) + payload.update(self._meta()) + return get_or_create_http_collection( + host=self._host, + port=self._port, + meta_data=payload, + ) diff --git a/openviking/storage/vectordb_adapters/local_adapter.py b/openviking/storage/vectordb_adapters/local_adapter.py new file mode 100644 index 00000000..29489e44 --- /dev/null +++ b/openviking/storage/vectordb_adapters/local_adapter.py @@ -0,0 +1,53 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Local backend collection adapter.""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any, Dict + +from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.local_collection import get_or_create_local_collection + +from .base import CollectionAdapter + + +class LocalCollectionAdapter(CollectionAdapter): + """Adapter for local embedded vectordb backend.""" + + DEFAULT_LOCAL_PROJECT_NAME = "vectordb" + + def __init__(self, collection_name: str, project_path: str): + super().__init__(collection_name=collection_name) + self.mode = "local" + self._project_path = project_path + + @classmethod + def from_config(cls, config: Any): + project_path = ( + str(Path(config.path) / cls.DEFAULT_LOCAL_PROJECT_NAME) if config.path else "" + ) + return cls(collection_name=config.name or "context", project_path=project_path) + + def _collection_path(self) -> str: + if not self._project_path: + return "" + return str(Path(self._project_path) / self._collection_name) + + def _load_existing_collection_if_needed(self) -> None: + if self._collection is not None: + return + collection_path = self._collection_path() + if not collection_path: + return + meta_path = os.path.join(collection_path, "collection_meta.json") + if os.path.exists(meta_path): + self._collection = get_or_create_local_collection(path=collection_path) + + def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: + collection_path = self._collection_path() + if collection_path: + os.makedirs(collection_path, exist_ok=True) + return get_or_create_local_collection(meta_data=meta, path=collection_path) diff --git a/openviking/storage/vectordb_adapters/vikingdb_private_adapter.py b/openviking/storage/vectordb_adapters/vikingdb_private_adapter.py new file mode 100644 index 00000000..fe15ad74 --- /dev/null +++ b/openviking/storage/vectordb_adapters/vikingdb_private_adapter.py @@ -0,0 +1,121 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Private VikingDB backend collection adapter.""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.vikingdb_clients import VIKINGDB_APIS, VikingDBClient +from openviking.storage.vectordb.collection.vikingdb_collection import VikingDBCollection + +from .base import CollectionAdapter + + +class VikingDBPrivateCollectionAdapter(CollectionAdapter): + """Adapter for private VikingDB deployment.""" + + def __init__( + self, + *, + host: str, + headers: Optional[dict[str, str]], + project_name: str, + collection_name: str, + ): + super().__init__(collection_name=collection_name) + self.mode = "vikingdb" + self._host = host + self._headers = headers + self._project_name = project_name + + @classmethod + def from_config(cls, config: Any): + if not config.vikingdb or not config.vikingdb.host: + raise ValueError("VikingDB backend requires a valid host") + return cls( + host=config.vikingdb.host, + headers=config.vikingdb.headers, + project_name=config.project_name or "default", + collection_name=config.name or "context", + ) + + def _client(self) -> VikingDBClient: + return VikingDBClient(self._host, self._headers) + + def _fetch_collection_meta(self) -> Optional[Dict[str, Any]]: + path, method = VIKINGDB_APIS["GetVikingdbCollection"] + req = { + "ProjectName": self._project_name, + "CollectionName": self._collection_name, + } + response = self._client().do_req(method, path=path, req_body=req) + if response.status_code != 200: + return None + result = response.json() + meta = result.get("Result", {}) + return meta or None + + def _load_existing_collection_if_needed(self) -> None: + if self._collection is not None: + return + meta = self._fetch_collection_meta() + if meta is None: + return + self._collection = Collection( + VikingDBCollection( + host=self._host, + headers=self._headers, + meta_data=meta, + ) + ) + + def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: + self._load_existing_collection_if_needed() + if self._collection is None: + raise NotImplementedError("private vikingdb collection should be pre-created") + return self._collection + + def _sanitize_scalar_index_fields( + self, + scalar_index_fields: list[str], + fields_meta: list[dict[str, Any]], + ) -> list[str]: + date_time_fields = { + field.get("FieldName") for field in fields_meta if field.get("FieldType") == "date_time" + } + return [field for field in scalar_index_fields if field not in date_time_fields] + + def _build_default_index_meta( + self, + *, + index_name: str, + distance: str, + use_sparse: bool, + sparse_weight: float, + scalar_index_fields: list[str], + ) -> Dict[str, Any]: + index_type = "hnsw_hybrid" if use_sparse else "hnsw" + index_meta: Dict[str, Any] = { + "IndexName": index_name, + "VectorIndex": { + "IndexType": index_type, + "Distance": distance, + "Quant": "int8", + }, + "ScalarIndex": scalar_index_fields, + } + if use_sparse: + index_meta["VectorIndex"]["EnableSparse"] = True + index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight + return index_meta + + def _normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: + for key in ("uri", "parent_uri"): + value = record.get(key) + if isinstance(value, str) and not value.startswith("viking://"): + stripped = value.strip("/") + if stripped: + record[key] = f"viking://{stripped}" + return record diff --git a/openviking/storage/vectordb_adapters/volcengine_adapter.py b/openviking/storage/vectordb_adapters/volcengine_adapter.py new file mode 100644 index 00000000..16d5ab08 --- /dev/null +++ b/openviking/storage/vectordb_adapters/volcengine_adapter.py @@ -0,0 +1,132 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Volcengine backend collection adapter.""" + +from __future__ import annotations + +from typing import Any, Dict + +from openviking.storage.vectordb.collection.collection import Collection +from openviking.storage.vectordb.collection.volcengine_collection import ( + VolcengineCollection, + get_or_create_volcengine_collection, +) + +from .base import CollectionAdapter + + +class VolcengineCollectionAdapter(CollectionAdapter): + """Adapter for Volcengine-hosted VikingDB.""" + + def __init__( + self, + *, + ak: str, + sk: str, + region: str, + project_name: str, + collection_name: str, + ): + super().__init__(collection_name=collection_name) + self.mode = "volcengine" + self._ak = ak + self._sk = sk + self._region = region + self._project_name = project_name + + @classmethod + def from_config(cls, config: Any): + if not ( + config.volcengine + and config.volcengine.ak + and config.volcengine.sk + and config.volcengine.region + ): + raise ValueError("Volcengine backend requires AK, SK, and Region configuration") + return cls( + ak=config.volcengine.ak, + sk=config.volcengine.sk, + region=config.volcengine.region, + project_name=config.project_name or "default", + collection_name=config.name or "context", + ) + + def _meta(self) -> Dict[str, Any]: + return { + "ProjectName": self._project_name, + "CollectionName": self._collection_name, + } + + def _config(self) -> Dict[str, Any]: + return { + "AK": self._ak, + "SK": self._sk, + "Region": self._region, + } + + def _new_collection_handle(self) -> VolcengineCollection: + return VolcengineCollection( + ak=self._ak, + sk=self._sk, + region=self._region, + meta_data=self._meta(), + ) + + def _load_existing_collection_if_needed(self) -> None: + if self._collection is not None: + return + candidate = self._new_collection_handle() + meta = candidate.get_meta_data() or {} + if meta and meta.get("CollectionName"): + self._collection = candidate + + def _create_backend_collection(self, meta: Dict[str, Any]) -> Collection: + payload = dict(meta) + payload.update(self._meta()) + return get_or_create_volcengine_collection( + config=self._config(), + meta_data=payload, + ) + + def _sanitize_scalar_index_fields( + self, + scalar_index_fields: list[str], + fields_meta: list[dict[str, Any]], + ) -> list[str]: + date_time_fields = { + field.get("FieldName") for field in fields_meta if field.get("FieldType") == "date_time" + } + return [field for field in scalar_index_fields if field not in date_time_fields] + + def _build_default_index_meta( + self, + *, + index_name: str, + distance: str, + use_sparse: bool, + sparse_weight: float, + scalar_index_fields: list[str], + ) -> Dict[str, Any]: + index_type = "hnsw_hybrid" if use_sparse else "hnsw" + index_meta: Dict[str, Any] = { + "IndexName": index_name, + "VectorIndex": { + "IndexType": index_type, + "Distance": distance, + "Quant": "int8", + }, + "ScalarIndex": scalar_index_fields, + } + if use_sparse: + index_meta["VectorIndex"]["EnableSparse"] = True + index_meta["VectorIndex"]["SearchWithSparseLogitAlpha"] = sparse_weight + return index_meta + + def _normalize_record_for_read(self, record: Dict[str, Any]) -> Dict[str, Any]: + for key in ("uri", "parent_uri"): + value = record.get(key) + if isinstance(value, str) and not value.startswith("viking://"): + stripped = value.strip("/") + if stripped: + record[key] = f"viking://{stripped}" + return record diff --git a/openviking/storage/viking_vector_index_backend.py b/openviking/storage/viking_vector_index_backend.py index 18044cec..7307960a 100644 --- a/openviking/storage/viking_vector_index_backend.py +++ b/openviking/storage/viking_vector_index_backend.py @@ -8,10 +8,10 @@ from typing import Any, Dict, List, Optional from openviking.server.identity import RequestContext, Role -from openviking.storage.collection_adapter import CollectionAdapter, create_collection_adapter -from openviking.storage.vector_store.expr import And, Eq, FilterExpr, In, Or, RawDSL +from openviking.storage.expr import And, Eq, FilterExpr, In, Or, RawDSL from openviking.storage.vectordb.collection.collection import Collection from openviking.storage.vectordb.utils.logging_init import init_cpp_logging +from openviking.storage.vectordb_adapters import CollectionAdapter, create_collection_adapter from openviking_cli.utils import get_logger from openviking_cli.utils.config.vectordb_config import VectorDBBackendConfig diff --git a/openviking_cli/utils/config/vectordb_config.py b/openviking_cli/utils/config/vectordb_config.py index 4052e0c4..0a78877b 100644 --- a/openviking_cli/utils/config/vectordb_config.py +++ b/openviking_cli/utils/config/vectordb_config.py @@ -4,8 +4,11 @@ from pydantic import BaseModel, Field, model_validator +from openviking_cli.utils.logger import get_logger + COLLECTION_NAME = "context" DEFAULT_PROJECT_NAME = "default" +logger = get_logger(__name__) class VolcengineConfig(BaseModel): @@ -16,7 +19,13 @@ class VolcengineConfig(BaseModel): region: Optional[str] = Field( default=None, description="Volcengine region (e.g., 'cn-beijing')" ) - host: Optional[str] = Field(default=None, description="Volcengine VikingDB host (optional)") + host: Optional[str] = Field( + default=None, + description=( + "[Deprecated] Ignored in volcengine mode. " + "Hosts are derived from `region` to route console/data APIs correctly." + ), + ) model_config = {"extra": "forbid"} @@ -112,6 +121,12 @@ def validate_config(self): raise ValueError("VectorDB volcengine backend requires 'ak' and 'sk' to be set") if not self.volcengine.region: raise ValueError("VectorDB volcengine backend requires 'region' to be set") + if self.volcengine.host: + logger.warning( + "VectorDB volcengine backend: 'volcengine.host' is deprecated and ignored. " + "Using region-based console/data hosts for region='%s'.", + self.volcengine.region, + ) elif self.backend == "vikingdb": if not self.vikingdb or not self.vikingdb.host: diff --git a/tests/storage/test_context_vector_gateway.py b/tests/storage/test_context_vector_gateway.py deleted file mode 100644 index 30b054f9..00000000 --- a/tests/storage/test_context_vector_gateway.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 - -from unittest.mock import AsyncMock - -import pytest - -from openviking.server.identity import RequestContext, Role -from openviking.storage.context_vector_gateway import ContextVectorGateway -from openviking.storage.vector_store.expr import And -from openviking_cli.session.user_id import UserIdentifier - - -def _make_ctx() -> RequestContext: - return RequestContext(user=UserIdentifier("acc1", "user1", "agent1"), role=Role.USER) - - -@pytest.mark.asyncio -async def test_search_in_tenant_uses_bound_collection_and_tenant_scope(): - storage = AsyncMock() - storage.search.return_value = [] - gateway = ContextVectorGateway.from_storage(storage, collection_name="ctx_custom") - - await gateway.search_in_tenant( - ctx=_make_ctx(), - query_vector=[0.1], - context_type="resource", - target_directories=["viking://resources/foo"], - limit=2, - ) - - call = storage.search.await_args.kwargs - assert call["collection"] == "ctx_custom" - assert isinstance(call["filter"], And) - - -@pytest.mark.asyncio -async def test_increment_active_count_updates_by_uri(): - storage = AsyncMock() - storage.filter.return_value = [{"id": "r1", "active_count": 3}] - storage.update.return_value = True - gateway = ContextVectorGateway.from_storage(storage, collection_name="ctx_custom") - - updated = await gateway.increment_active_count(_make_ctx(), ["viking://resources/foo"]) - - assert updated == 1 - update_call = storage.update.await_args - assert update_call.args[0] == "ctx_custom" - assert update_call.args[1] == "r1" - assert update_call.args[2]["active_count"] == 4 From 696a5e00cc80ae3cc32e7a07deab039ce6111094 Mon Sep 17 00:00:00 2001 From: "zhoujiahui.01" Date: Fri, 27 Feb 2026 15:29:38 +0800 Subject: [PATCH 7/7] docs: add guide for integrating third-party vectordb adapters --- .../storage/vectordb_adapters/README.md | 211 ++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 openviking/storage/vectordb_adapters/README.md diff --git a/openviking/storage/vectordb_adapters/README.md b/openviking/storage/vectordb_adapters/README.md new file mode 100644 index 00000000..3cb8e065 --- /dev/null +++ b/openviking/storage/vectordb_adapters/README.md @@ -0,0 +1,211 @@ +# VectorDB Adapter 接入指南(新增第三方后端) + +本指南说明如何在 `openviking/storage/vectordb_adapters` 下新增一个第三方向量库后端,并接入 OpenViking 现有检索链路。 + +--- + +## 1. 目标与范围 + +### 目标 +- 以最小改动新增一个向量库后端。 +- 保持上层业务接口不变(`find/search` 等无需改调用方式)。 +- 将后端差异封装在 Adapter 层,不泄漏到业务层。 + +### 非目标 +- 不改上层语义检索策略(租户、目录层级、召回策略)。 +- 不增加新的对外 API 协议。 + +--- + +## 2. 架构位置与职责 + +当前分层职责如下: + +1. **上层语义层(OpenViking 业务)** + 面向语义接口,不关心后端协议差异。 + +2. **通用向量存储层(Store/Backend)** + 提供统一查询、写入、删除、计数能力。 + +3. **Adapter 层(本目录)** + 负责把统一能力映射到具体后端实现(local/http/volcengine/vikingdb/thirdparty)。 + +新增后端时,主要只改第 3 层。 + +--- + +## 3. 接入前提 + +在开始前,请确认: + +- 你已拿到第三方后端的: + - 集合管理 API(查/建/删集合) + - 数据 API(upsert/get/delete/search/aggregate) +- 你已明确该后端的: + - 认证方式(AK/SK、token、header) + - 过滤语法能力(是否支持 must/range/and/or) + - 索引参数约束(dense/sparse、距离度量、索引类型) + +--- + +## 4. 接入步骤 + +## Step 1:新增 Adapter 文件 + +在目录下新增文件,例如: + +- `openviking/storage/vectordb_adapters/thirdparty_adapter.py` + +定义类: + +- `ThirdPartyCollectionAdapter(CollectionAdapter)` + +基类位于: + +- `openviking/storage/vectordb_adapters/base.py` + +--- + +## Step 2:实现最小必需方法 + +你需要实现以下方法: + +1. `from_config(cls, config)` + - 从 `VectorDBBackendConfig` 读取后端配置并构造 adapter。 + - collection 名建议使用 `config.name or "context"`。 + +2. `_load_existing_collection_if_needed(self)` + - 懒加载已存在 collection handle。 + - 若不存在,保持 `_collection is None`。 + +3. `_create_backend_collection(self, meta)` + - 按传入 schema 创建 collection 并返回 handle。 + +--- + +## Step 3:按后端能力补充可选 Hook + +如后端有差异,可重写: + +- `_sanitize_scalar_index_fields(...)` +- `_build_default_index_meta(...)` + +目的:把后端特性差异收敛在 adapter 内。 + +--- + +## Step 4:注册到 Factory + +编辑: + +- `openviking/storage/vectordb_adapters/factory.py` + +在 `_ADAPTER_REGISTRY` 增加映射,例如: + +```python +"thirdparty": ThirdPartyCollectionAdapter +``` + +这样 `create_collection_adapter(config)` 会自动路由到你的实现。 + +--- + +## Step 5:补充配置模型 + +确保配置中可声明新 backend(如 `backend: thirdparty`)及其专属字段(endpoint/auth/region 等)。 + +原则: +- `create_collection` 时使用配置中的 name 绑定 collection。 +- 后续操作默认绑定,不需要每次传 collection_name。 + +--- + +## 5. Filter 与查询兼容规则 + +- Adapter 需要兼容统一过滤表达。 +- 上层传入的过滤表达会经由统一编译流程进入后端查询。 +- 若第三方语法不同,请在 adapter 内做映射,不改上层调用协议。 + +关键原则: +- **后端 DSL 不上浮到业务层**。 +- **业务层不依赖第三方私有查询语法**。 + +--- + +## 6. 最小代码骨架(示例) + +```python +from __future__ import annotations +from typing import Any, Dict + +from openviking.storage.vectordb_adapters.base import CollectionAdapter + +class ThirdPartyCollectionAdapter(CollectionAdapter): + def __init__(self, *, endpoint: str, token: str, collection_name: str): + super().__init__(collection_name=collection_name) + self.mode = "thirdparty" + self._endpoint = endpoint + self._token = token + + @classmethod + def from_config(cls, config: Any): + if not config.thirdparty or not config.thirdparty.endpoint: + raise ValueError("ThirdParty backend requires endpoint") + return cls( + endpoint=config.thirdparty.endpoint, + token=config.thirdparty.token, + collection_name=config.name or "context", + ) + + def _load_existing_collection_if_needed(self) -> None: + if self._collection is not None: + return + # TODO: 查询远端 collection 是否存在,存在则初始化 handle + # self._collection = ... + pass + + def _create_backend_collection(self, meta: Dict[str, Any]): + # TODO: 调后端 create collection,并返回 collection handle + # return ... + raise NotImplementedError +``` + +--- + +## 7. 测试要求(必须) + +至少覆盖以下场景: + +1. backend 工厂路由正确(能创建到新 adapter)。 +2. collection 生命周期可用(exists/create/drop)。 +3. 基础数据链路可用(upsert/get/delete/query)。 +4. count/aggregate 行为正确。 +5. filter 条件可正确生效(含组合条件)。 + +--- + +## 8. 常见问题与排查 + +### Q1:启动时报 backend 不支持 +- 检查 factory 是否注册。 +- 检查配置里的 backend 字符串是否与 registry key 一致。 + +### Q2:集合创建成功但查询为空 +- 检查 collection 绑定名是否一致。 +- 检查索引是否创建成功。 +- 检查 filter 映射是否把条件误转成空条件。 + +### Q3:count 与 query 条数不一致 +- 检查 aggregate API 的字段命名与返回结构解析。 +- 检查 count 使用的 filter 与 query 使用的 filter 是否一致。 + +--- + +## 9. 验收标准 + +当满足以下条件,即可视为接入完成: + +- `backend=thirdparty` 可正常初始化。 +- create 后可完成 upsert/get/query/delete/count 全流程。 +- 不改上层业务调用方式即可参与 `find/search` 检索链路。 +- 后端差异全部封装在 adapter 层。