diff --git a/src/common/functions.py b/src/common/functions.py index b79e2639..06ad42b4 100644 --- a/src/common/functions.py +++ b/src/common/functions.py @@ -13,6 +13,8 @@ import csv import json +from langchain_community.vectorstores.utils import DistanceStrategy + import oracledb import requests @@ -190,7 +192,7 @@ def parse_vs_comment(comment: str) -> dict: "model": parsed.get("model"), "chunk_size": parsed.get("chunk_size"), "chunk_overlap": parsed.get("chunk_overlap"), - "distance_metric": parsed.get("distance_metric"), + "distance_metric": parsed.get("distance_metric") or parsed.get("distance_strategy"), "index_type": parsed.get("index_type"), "parse_status": "success", } @@ -200,6 +202,25 @@ def parse_vs_comment(comment: str) -> dict: return default_result +_DISTANCE_STRATEGY_MAP = { + "COSINE": DistanceStrategy.COSINE, + "EUCLIDEAN_DISTANCE": DistanceStrategy.EUCLIDEAN_DISTANCE, + "DOT_PRODUCT": DistanceStrategy.DOT_PRODUCT, +} + + +def to_distance_strategy(metric: str) -> DistanceStrategy: + """Convert a distance metric string to a DistanceStrategy enum.""" + if not metric: + raise ValueError("Distance metric is required but was None/empty") + strategy = _DISTANCE_STRATEGY_MAP.get(metric.upper()) + if strategy is None: + raise ValueError( + f"Unrecognized distance metric: '{metric}'. Expected one of: {list(_DISTANCE_STRATEGY_MAP.keys())}" + ) + return strategy + + def is_sql_accessible(db_conn: str, query: str) -> tuple[bool, str]: """Check if the DB connection and SQL is working one field.""" diff --git a/src/server/api/utils/embed.py b/src/server/api/utils/embed.py index b27dd0f1..b61f5f6f 100644 --- a/src/server/api/utils/embed.py +++ b/src/server/api/utils/embed.py @@ -31,6 +31,7 @@ import server.api.utils.oci as utils_oci from common import schema, functions +from common.functions import to_distance_strategy LOGGER = logging.getLogger("api.utils.embed") @@ -368,7 +369,7 @@ def _create_temp_vector_store( client=db_conn, embedding_function=embed_client, table_name=vector_store_tmp.vector_store, - distance_strategy=vector_store.distance_metric, + distance_strategy=to_distance_strategy(vector_store.distance_metric), query="AI Optimizer for Apps - Powered by Oracle", ) return vs_tmp, vector_store_tmp @@ -401,7 +402,7 @@ def _merge_and_index_vector_store( client=db_conn, embedding_function=embed_client, table_name=vector_store.vector_store, - distance_strategy=vector_store.distance_metric, + distance_strategy=to_distance_strategy(vector_store.distance_metric), query="AI Optimizer for Apps - Powered by Oracle", ) diff --git a/src/server/mcp/tools/vs_retriever.py b/src/server/mcp/tools/vs_retriever.py index 54bd6fc4..156e2c3c 100644 --- a/src/server/mcp/tools/vs_retriever.py +++ b/src/server/mcp/tools/vs_retriever.py @@ -12,9 +12,11 @@ from pydantic import BaseModel from langchain_community.vectorstores.oraclevs import OracleVS - from litellm import completion +from common.functions import to_distance_strategy + + import server.api.utils.settings as utils_settings import server.api.utils.databases as utils_databases import server.api.utils.models as utils_models @@ -196,10 +198,17 @@ def _deduplicate_documents(documents: List) -> List: def _search_table(table_name, question, db_conn, embed_client, vector_search, table_distance_metric): """Search a single vector table and return documents with metadata""" - LOGGER.info("Searching table: %s with distance metric: %s", table_name, table_distance_metric) + # Normalize distance metric for consistent use in both OracleVS construction and score conversion + distance_strategy = to_distance_strategy(table_distance_metric) + LOGGER.info( + "Searching table: %s with distance metric: %s (resolved: %s)", + table_name, + table_distance_metric, + distance_strategy.value, + ) # Initialize Vector Store for this table using its specific distance metric - vectorstores = OracleVS(db_conn, embed_client, table_name, table_distance_metric) + vectorstores = OracleVS(db_conn, embed_client, table_name, distance_strategy) # For Similarity searches, call with_score to preserve scores if vector_search.search_type == "Similarity": @@ -212,12 +221,12 @@ def _search_table(table_name, question, db_conn, embed_client, vector_search, ta for doc, score in docs_and_scores: # Convert distance to similarity score (for COSINE: similarity = 1 - distance) # For COSINE metric, distance is in [0, 2] range, similarity in [0, 1] - if table_distance_metric == "COSINE": + if distance_strategy.value == "COSINE": similarity = 1.0 - (score / 2.0) - elif table_distance_metric == "DOT": + elif distance_strategy.value == "DOT_PRODUCT": # For DOT product, higher is better (already a similarity) similarity = score - else: # EUCLIDEAN or EUCLIDEAN_SQUARED + else: # EUCLIDEAN_DISTANCE # For Euclidean, lower distance = higher similarity # Use inverse: similarity = 1 / (1 + distance) similarity = 1.0 / (1.0 + score) diff --git a/tests/unit/common/test_functions.py b/tests/unit/common/test_functions.py index e724ef48..f198e7f1 100644 --- a/tests/unit/common/test_functions.py +++ b/tests/unit/common/test_functions.py @@ -6,6 +6,7 @@ Tests utility functions for URL checking, vector store operations, and SQL operations. """ +# pylint: disable=import-outside-toplevel import json import os @@ -13,6 +14,7 @@ from unittest.mock import patch, MagicMock import requests import oracledb +from langchain_community.vectorstores.utils import DistanceStrategy from common import functions @@ -465,3 +467,35 @@ def test_successful_query_creates_csv(self, mock_connect): content = f.read() assert "COL1,COL2" in content assert "val1,val2" in content + + +class TestToDistanceStrategy: + """Tests for to_distance_strategy function.""" + + def test_cosine_string(self): + """Test COSINE string maps to DistanceStrategy.COSINE.""" + assert functions.to_distance_strategy("COSINE") == DistanceStrategy.COSINE + + def test_euclidean_distance_string(self): + """Test EUCLIDEAN_DISTANCE string maps to DistanceStrategy.EUCLIDEAN_DISTANCE.""" + assert functions.to_distance_strategy("EUCLIDEAN_DISTANCE") == DistanceStrategy.EUCLIDEAN_DISTANCE + + def test_dot_product_string(self): + """Test DOT_PRODUCT string maps to DistanceStrategy.DOT_PRODUCT.""" + assert functions.to_distance_strategy("DOT_PRODUCT") == DistanceStrategy.DOT_PRODUCT + + def test_case_insensitive(self): + """Test that metric string matching is case-insensitive.""" + assert functions.to_distance_strategy("cosine") == DistanceStrategy.COSINE + + def test_none_raises_error(self): + """Test that None metric raises ValueError.""" + import pytest + with pytest.raises(ValueError, match="required"): + functions.to_distance_strategy(None) + + def test_unrecognized_raises_error(self): + """Test that unrecognized metric raises ValueError.""" + import pytest + with pytest.raises(ValueError, match="Unrecognized"): + functions.to_distance_strategy("UNKNOWN") diff --git a/tests/unit/server/api/conftest.py b/tests/unit/server/api/conftest.py index 3306d455..59416d9e 100644 --- a/tests/unit/server/api/conftest.py +++ b/tests/unit/server/api/conftest.py @@ -55,6 +55,7 @@ def _make_vector_store( model: str = "text-embedding-3-small", chunk_size: int = 1000, chunk_overlap: int = 200, + distance_metric: str = "COSINE", **kwargs, ) -> DatabaseVectorStorage: return DatabaseVectorStorage( @@ -62,6 +63,7 @@ def _make_vector_store( model=model, chunk_size=chunk_size, chunk_overlap=chunk_overlap, + distance_metric=distance_metric, **kwargs, ) diff --git a/tests/unit/server/mcp/test_vs_retriever.py b/tests/unit/server/mcp/test_vs_retriever.py index 9a051c9e..9d44ce04 100644 --- a/tests/unit/server/mcp/test_vs_retriever.py +++ b/tests/unit/server/mcp/test_vs_retriever.py @@ -10,6 +10,7 @@ from unittest.mock import MagicMock, patch from langchain_core.documents import Document +from langchain_community.vectorstores.utils import DistanceStrategy from common.schema import VectorSearchSettings @@ -86,7 +87,7 @@ def test_dot_product_similarity(self): db_conn=MagicMock(), embed_client=MagicMock(), vector_search=vector_search, - table_distance_metric="DOT", + table_distance_metric="DOT_PRODUCT", ) # DOT scores should pass through unchanged @@ -125,7 +126,7 @@ def test_euclidean_distance_to_similarity(self): db_conn=MagicMock(), embed_client=MagicMock(), vector_search=vector_search, - table_distance_metric="EUCLIDEAN", + table_distance_metric="EUCLIDEAN_DISTANCE", ) # Verify similarity scores @@ -421,3 +422,66 @@ def test_similarity_score_rounded_to_three_decimals(self): # Should be rounded to 3 decimals assert documents[0].metadata["similarity_score"] == 0.833 + + +class TestOracleVSDistanceStrategyType: + """Tests that OracleVS receives a DistanceStrategy enum, not a raw string.""" + + def test_oraclevs_receives_distance_strategy_enum(self): + """OracleVS must be called with a DistanceStrategy enum, not a string. + + This test would have caught the langchain-core 1.2.22 breakage where + _get_distance_function raises ValueError on non-enum values. + """ + mock_vectorstore = MagicMock() + mock_vectorstore.similarity_search_with_score.return_value = [ + (Document(page_content="Test", metadata={}), 0.2), + ] + + from server.mcp.tools.vs_retriever import _search_table + + vector_search = VectorSearchSettings( + search_type="Similarity", + top_k=1, + score_threshold=0.0, + ) + + with patch("server.mcp.tools.vs_retriever.OracleVS", return_value=mock_vectorstore) as mock_cls: + _search_table( + table_name="TEST_TABLE", + question="test query", + db_conn=MagicMock(), + embed_client=MagicMock(), + vector_search=vector_search, + table_distance_metric="COSINE", + ) + + # The 4th positional arg to OracleVS() must be a DistanceStrategy enum + args = mock_cls.call_args[0] + assert len(args) >= 4, "OracleVS should receive at least 4 positional args" + assert isinstance(args[3], DistanceStrategy), ( + f"Expected DistanceStrategy enum, got {type(args[3]).__name__}: {args[3]}" + ) + assert args[3] == DistanceStrategy.COSINE + + def test_oraclevs_raises_for_none_metric(self): + """When distance_metric is None, _search_table should raise ValueError.""" + import pytest + + from server.mcp.tools.vs_retriever import _search_table + + vector_search = VectorSearchSettings( + search_type="Similarity", + top_k=1, + score_threshold=0.0, + ) + + with pytest.raises(ValueError, match="required"): + _search_table( + table_name="TEST_TABLE", + question="test query", + db_conn=MagicMock(), + embed_client=MagicMock(), + vector_search=vector_search, + table_distance_metric=None, + )