Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/common/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import csv
import json

from langchain_community.vectorstores.utils import DistanceStrategy

import oracledb
import requests

Expand Down Expand Up @@ -190,7 +192,7 @@ def parse_vs_comment(comment: str) -> dict:
"model": parsed.get("model"),
"chunk_size": parsed.get("chunk_size"),
"chunk_overlap": parsed.get("chunk_overlap"),
"distance_metric": parsed.get("distance_metric"),
"distance_metric": parsed.get("distance_metric") or parsed.get("distance_strategy"),
"index_type": parsed.get("index_type"),
"parse_status": "success",
}
Expand All @@ -200,6 +202,25 @@ def parse_vs_comment(comment: str) -> dict:
return default_result


_DISTANCE_STRATEGY_MAP = {
"COSINE": DistanceStrategy.COSINE,
"EUCLIDEAN_DISTANCE": DistanceStrategy.EUCLIDEAN_DISTANCE,
"DOT_PRODUCT": DistanceStrategy.DOT_PRODUCT,
}


def to_distance_strategy(metric: str) -> DistanceStrategy:
"""Convert a distance metric string to a DistanceStrategy enum."""
if not metric:
raise ValueError("Distance metric is required but was None/empty")
strategy = _DISTANCE_STRATEGY_MAP.get(metric.upper())
if strategy is None:
raise ValueError(
f"Unrecognized distance metric: '{metric}'. Expected one of: {list(_DISTANCE_STRATEGY_MAP.keys())}"
)
return strategy


def is_sql_accessible(db_conn: str, query: str) -> tuple[bool, str]:
"""Check if the DB connection and SQL is working one field."""

Expand Down
5 changes: 3 additions & 2 deletions src/server/api/utils/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import server.api.utils.oci as utils_oci

from common import schema, functions
from common.functions import to_distance_strategy


LOGGER = logging.getLogger("api.utils.embed")
Expand Down Expand Up @@ -368,7 +369,7 @@ def _create_temp_vector_store(
client=db_conn,
embedding_function=embed_client,
table_name=vector_store_tmp.vector_store,
distance_strategy=vector_store.distance_metric,
distance_strategy=to_distance_strategy(vector_store.distance_metric),
query="AI Optimizer for Apps - Powered by Oracle",
)
return vs_tmp, vector_store_tmp
Expand Down Expand Up @@ -401,7 +402,7 @@ def _merge_and_index_vector_store(
client=db_conn,
embedding_function=embed_client,
table_name=vector_store.vector_store,
distance_strategy=vector_store.distance_metric,
distance_strategy=to_distance_strategy(vector_store.distance_metric),
query="AI Optimizer for Apps - Powered by Oracle",
)

Expand Down
21 changes: 15 additions & 6 deletions src/server/mcp/tools/vs_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from pydantic import BaseModel

from langchain_community.vectorstores.oraclevs import OracleVS

from litellm import completion

from common.functions import to_distance_strategy


import server.api.utils.settings as utils_settings
import server.api.utils.databases as utils_databases
import server.api.utils.models as utils_models
Expand Down Expand Up @@ -196,10 +198,17 @@ def _deduplicate_documents(documents: List) -> List:

def _search_table(table_name, question, db_conn, embed_client, vector_search, table_distance_metric):
"""Search a single vector table and return documents with metadata"""
LOGGER.info("Searching table: %s with distance metric: %s", table_name, table_distance_metric)
# Normalize distance metric for consistent use in both OracleVS construction and score conversion
distance_strategy = to_distance_strategy(table_distance_metric)
LOGGER.info(
"Searching table: %s with distance metric: %s (resolved: %s)",
table_name,
table_distance_metric,
distance_strategy.value,
)

# Initialize Vector Store for this table using its specific distance metric
vectorstores = OracleVS(db_conn, embed_client, table_name, table_distance_metric)
vectorstores = OracleVS(db_conn, embed_client, table_name, distance_strategy)

# For Similarity searches, call with_score to preserve scores
if vector_search.search_type == "Similarity":
Expand All @@ -212,12 +221,12 @@ def _search_table(table_name, question, db_conn, embed_client, vector_search, ta
for doc, score in docs_and_scores:
# Convert distance to similarity score (for COSINE: similarity = 1 - distance)
# For COSINE metric, distance is in [0, 2] range, similarity in [0, 1]
if table_distance_metric == "COSINE":
if distance_strategy.value == "COSINE":
similarity = 1.0 - (score / 2.0)
elif table_distance_metric == "DOT":
elif distance_strategy.value == "DOT_PRODUCT":
# For DOT product, higher is better (already a similarity)
similarity = score
else: # EUCLIDEAN or EUCLIDEAN_SQUARED
else: # EUCLIDEAN_DISTANCE
# For Euclidean, lower distance = higher similarity
# Use inverse: similarity = 1 / (1 + distance)
similarity = 1.0 / (1.0 + score)
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/common/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

Tests utility functions for URL checking, vector store operations, and SQL operations.
"""
# pylint: disable=import-outside-toplevel

import json
import os
import tempfile
from unittest.mock import patch, MagicMock
import requests
import oracledb
from langchain_community.vectorstores.utils import DistanceStrategy

from common import functions

Expand Down Expand Up @@ -465,3 +467,35 @@ def test_successful_query_creates_csv(self, mock_connect):
content = f.read()
assert "COL1,COL2" in content
assert "val1,val2" in content


class TestToDistanceStrategy:
"""Tests for to_distance_strategy function."""

def test_cosine_string(self):
"""Test COSINE string maps to DistanceStrategy.COSINE."""
assert functions.to_distance_strategy("COSINE") == DistanceStrategy.COSINE

def test_euclidean_distance_string(self):
"""Test EUCLIDEAN_DISTANCE string maps to DistanceStrategy.EUCLIDEAN_DISTANCE."""
assert functions.to_distance_strategy("EUCLIDEAN_DISTANCE") == DistanceStrategy.EUCLIDEAN_DISTANCE

def test_dot_product_string(self):
"""Test DOT_PRODUCT string maps to DistanceStrategy.DOT_PRODUCT."""
assert functions.to_distance_strategy("DOT_PRODUCT") == DistanceStrategy.DOT_PRODUCT

def test_case_insensitive(self):
"""Test that metric string matching is case-insensitive."""
assert functions.to_distance_strategy("cosine") == DistanceStrategy.COSINE

def test_none_raises_error(self):
"""Test that None metric raises ValueError."""
import pytest
with pytest.raises(ValueError, match="required"):
functions.to_distance_strategy(None)

def test_unrecognized_raises_error(self):
"""Test that unrecognized metric raises ValueError."""
import pytest
with pytest.raises(ValueError, match="Unrecognized"):
functions.to_distance_strategy("UNKNOWN")
2 changes: 2 additions & 0 deletions tests/unit/server/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,15 @@ def _make_vector_store(
model: str = "text-embedding-3-small",
chunk_size: int = 1000,
chunk_overlap: int = 200,
distance_metric: str = "COSINE",
**kwargs,
) -> DatabaseVectorStorage:
return DatabaseVectorStorage(
vector_store=vector_store,
model=model,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
distance_metric=distance_metric,
**kwargs,
)

Expand Down
68 changes: 66 additions & 2 deletions tests/unit/server/mcp/test_vs_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from unittest.mock import MagicMock, patch
from langchain_core.documents import Document
from langchain_community.vectorstores.utils import DistanceStrategy

from common.schema import VectorSearchSettings

Expand Down Expand Up @@ -86,7 +87,7 @@ def test_dot_product_similarity(self):
db_conn=MagicMock(),
embed_client=MagicMock(),
vector_search=vector_search,
table_distance_metric="DOT",
table_distance_metric="DOT_PRODUCT",
)

# DOT scores should pass through unchanged
Expand Down Expand Up @@ -125,7 +126,7 @@ def test_euclidean_distance_to_similarity(self):
db_conn=MagicMock(),
embed_client=MagicMock(),
vector_search=vector_search,
table_distance_metric="EUCLIDEAN",
table_distance_metric="EUCLIDEAN_DISTANCE",
)

# Verify similarity scores
Expand Down Expand Up @@ -421,3 +422,66 @@ def test_similarity_score_rounded_to_three_decimals(self):

# Should be rounded to 3 decimals
assert documents[0].metadata["similarity_score"] == 0.833


class TestOracleVSDistanceStrategyType:
"""Tests that OracleVS receives a DistanceStrategy enum, not a raw string."""

def test_oraclevs_receives_distance_strategy_enum(self):
"""OracleVS must be called with a DistanceStrategy enum, not a string.

This test would have caught the langchain-core 1.2.22 breakage where
_get_distance_function raises ValueError on non-enum values.
"""
mock_vectorstore = MagicMock()
mock_vectorstore.similarity_search_with_score.return_value = [
(Document(page_content="Test", metadata={}), 0.2),
]

from server.mcp.tools.vs_retriever import _search_table

vector_search = VectorSearchSettings(
search_type="Similarity",
top_k=1,
score_threshold=0.0,
)

with patch("server.mcp.tools.vs_retriever.OracleVS", return_value=mock_vectorstore) as mock_cls:
_search_table(
table_name="TEST_TABLE",
question="test query",
db_conn=MagicMock(),
embed_client=MagicMock(),
vector_search=vector_search,
table_distance_metric="COSINE",
)

# The 4th positional arg to OracleVS() must be a DistanceStrategy enum
args = mock_cls.call_args[0]
assert len(args) >= 4, "OracleVS should receive at least 4 positional args"
assert isinstance(args[3], DistanceStrategy), (
f"Expected DistanceStrategy enum, got {type(args[3]).__name__}: {args[3]}"
)
assert args[3] == DistanceStrategy.COSINE

def test_oraclevs_raises_for_none_metric(self):
"""When distance_metric is None, _search_table should raise ValueError."""
import pytest

from server.mcp.tools.vs_retriever import _search_table

vector_search = VectorSearchSettings(
search_type="Similarity",
top_k=1,
score_threshold=0.0,
)

with pytest.raises(ValueError, match="required"):
_search_table(
table_name="TEST_TABLE",
question="test query",
db_conn=MagicMock(),
embed_client=MagicMock(),
vector_search=vector_search,
table_distance_metric=None,
)
Loading