Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions aixplain/modules/model/index_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Index model module for document indexing and search operations."""

import os
import warnings
from uuid import uuid4
Expand All @@ -14,6 +16,7 @@

DOCLING_MODEL_ID = "677bee6c6eb56331f9192a91"


class IndexFilterOperator(Enum):
"""Enumeration of operators available for filtering index records.

Expand All @@ -30,6 +33,7 @@ class IndexFilterOperator(Enum):
GREATER_THAN_OR_EQUALS (str): Greater than or equal to operator (">=")
LESS_THAN_OR_EQUALS (str): Less than or equal to operator ("<=")
"""

EQUALS = "=="
NOT_EQUALS = "!="
CONTAINS = "in"
Expand Down Expand Up @@ -120,6 +124,8 @@ def __init__(


class IndexModel(Model):
"""A model for indexing and searching documents using vector embeddings."""

def __init__(
self,
id: Text,
Expand Down Expand Up @@ -197,19 +203,24 @@ def to_dict(self) -> Dict:
data["collection_type"] = self.version.split("-", 1)[0]
return data

def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) -> ModelResponse:
"""Search for documents in the index
def search(
self, query: str, top_k: int = 10, filters: List[IndexFilter] = [], score_threshold: float = 0.0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add this a filter object?
Operator Larger, Smaller etc and then a value. Add this as a indexFilder

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ahmetgunduz Coming back to this. IndexFilter is generally meant for metadata filtering, thats why I think, it would make more sense if score_threshold is kept separate? We can give a more meaningful name for this.

) -> ModelResponse:
"""Search for documents in the index.

Args:
query (str): Query to be searched
top_k (int, optional): Number of results to be returned. Defaults to 10.
filters (List[IndexFilter], optional): Filters to be applied. Defaults to [].
score_threshold (float, optional): Minimum score threshold for results. Results with
scores below this threshold will be filtered out. Defaults to 0.0.

Returns:
ModelResponse: Response from the indexing service

Example:
- index_model.search("Hello")
- index_model.search("Hello", score_threshold=0.5)
- index_model.search("", filters=[IndexFilter(field="category", value="animate", operator=IndexFilterOperator.EQUALS)])
"""
from aixplain.factories import FileFactory
Expand All @@ -226,12 +237,12 @@ def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) -
"data": query or uri,
"dataType": value_type,
"filters": [filter.to_dict() for filter in filters],
"payload": {"uri": uri, "value_type": value_type, "top_k": top_k},
"payload": {"uri": uri, "value_type": value_type, "top_k": top_k, "score_threshold": score_threshold},
}
return self.run(data=data)

def upsert(self, documents: Union[List[Record], str], splitter: Optional[Splitter] = None) -> ModelResponse:
"""Upsert documents into the index
"""Upsert documents into the index.

Args:
documents (Union[List[Record], str]): List of documents to be upserted or a file path
Expand Down Expand Up @@ -295,8 +306,7 @@ def count(self) -> int:
raise Exception(f"Failed to count documents: {response.error_message}")

def get_record(self, record_id: Text) -> ModelResponse:
"""
Get a document from the index.
"""Get a document from the index.

Args:
record_id (Text): ID of the document to retrieve.
Expand All @@ -317,8 +327,7 @@ def get_record(self, record_id: Text) -> ModelResponse:
raise Exception(f"Failed to get record: {response.error_message}")

def delete_record(self, record_id: Text) -> ModelResponse:
"""
Delete a document from the index.
"""Delete a document from the index.

Args:
record_id (Text): ID of the document to delete.
Expand Down Expand Up @@ -396,8 +405,7 @@ def parse_file(file_path: str) -> ModelResponse:
raise Exception(f"Failed to parse file: {e}")

def retrieve_records_with_filter(self, filter: IndexFilter) -> ModelResponse:
"""
Retrieve records from the index that match the given filter.
"""Retrieve records from the index that match the given filter.

Args:
filter (IndexFilter): The filter criteria to apply when retrieving records.
Expand All @@ -420,8 +428,7 @@ def retrieve_records_with_filter(self, filter: IndexFilter) -> ModelResponse:
raise Exception(f"Failed to retrieve records with filter: {response.error_message}")

def delete_records_by_date(self, date: float) -> ModelResponse:
"""
Delete records from the index that match the given date.
"""Delete records from the index that match the given date.

Args:
date (float): The date (as a timestamp) to match records for deletion.
Expand Down
14 changes: 12 additions & 2 deletions tests/functional/model/run_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def pytest_generate_tests(metafunc):

def test_llm_run(llm_model):
"""Testing LLMs with history context"""

assert isinstance(llm_model, LLM)
response = llm_model.run(
data="What is my name?",
Expand Down Expand Up @@ -225,7 +224,6 @@ def test_index_model_with_filter(embedding_model, supplier_params):

def test_llm_run_with_file():
"""Testing LLM with local file input containing emoji"""

# Create test file path
test_file_path = Path(__file__).parent / "data" / "test_input.txt"

Expand Down Expand Up @@ -607,3 +605,15 @@ def test_delete_records_by_date(setup_index_with_test_records):
response = index_model.retrieve_records_with_filter(filter_all)
assert response.status == "SUCCESS"
assert len(response.details) == 2


def test_index_model_with_score_threshold(setup_index_with_test_records):
"""Testing Index Model with score threshold"""
index_model = setup_index_with_test_records
response = index_model.search("technology", score_threshold=0.0)
assert response.status == "SUCCESS"
assert len(response.details) == 5

response = index_model.search("technology", score_threshold=1.0)
assert response.status == "SUCCESS"
assert len(response.details) == 0
125 changes: 125 additions & 0 deletions tests/functional/model/test_score_threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Test score threshold filtering for IndexModel.search()"""

import pytest
import time
import uuid
from aixplain.factories import IndexFactory
from aixplain.modules.model.record import Record


@pytest.fixture(scope="module")
def setup_test_index():
"""Create an index with test records for score threshold testing."""
unique_name = f"score_threshold_test_{uuid.uuid4().hex[:8]}"
index_model = IndexFactory.create(name=unique_name, description="Test index for score threshold")

# Wait for index to be ready
time.sleep(5)

# Add test records with varied content
test_records = [
Record(id="1", value="Python programming language tutorial", value_type="text", attributes={"topic": "python"}),
Record(
id="2", value="JavaScript web development basics", value_type="text", attributes={"topic": "javascript"}
),
Record(
id="3", value="Python data science and machine learning", value_type="text", attributes={"topic": "python"}
),
Record(id="4", value="Cloud computing infrastructure", value_type="text", attributes={"topic": "cloud"}),
Record(id="5", value="Python Flask web framework", value_type="text", attributes={"topic": "python"}),
]

index_model.upsert(test_records)
time.sleep(10) # Wait for indexing

yield index_model

# Cleanup
index_model.delete()


def test_score_threshold_zero_returns_all(setup_test_index):
"""With score_threshold=0.0, all matching results should be returned."""
index_model = setup_test_index
response = index_model.search("Python", score_threshold=0.0)

assert response.status == "SUCCESS"
print(f"\nResults with threshold 0.0: {len(response.details)} records")
for detail in response.details:
print(f" - Score: {detail.get('score', 'N/A')}")


def test_score_threshold_high_returns_none(setup_test_index):
"""With score_threshold=1.0, no results should be returned."""
index_model = setup_test_index

response = index_model.search("Python", score_threshold=1.0)
assert response.status == "SUCCESS"
print(f"\nResults with threshold 1.0: {len(response.details)} records")
assert len(response.details) == 0, "No results should match with threshold=1.0"


def test_score_threshold_filters_correctly(setup_test_index):
"""Verify that higher threshold returns fewer or equal results."""
index_model = setup_test_index

# Get all results first
response_all = index_model.search("Python", score_threshold=0.0)
all_count = len(response_all.details)

# Get filtered results with medium threshold
response_filtered = index_model.search("Python", score_threshold=0.5)
filtered_count = len(response_filtered.details)

print(f"\nAll results (threshold=0.0): {all_count}")
print(f"Filtered results (threshold=0.5): {filtered_count}")

# Filtered count should be <= all count
assert filtered_count <= all_count


def test_score_threshold_with_filter(setup_test_index):
"""Test score_threshold combined with filters."""
from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator

index_model = setup_test_index

# Search with filter only (no threshold)
filter_ = IndexFilter(field="topic", value="python", operator=IndexFilterOperator.EQUALS)
response = index_model.search("Python", filters=[filter_], score_threshold=0.0)
assert response.status == "SUCCESS"
filtered_count = len(response.details)
print(f"\nResults with filter only (threshold=0.0): {filtered_count}")

# Search with filter AND high threshold - should return 0
response = index_model.search("Python", filters=[filter_], score_threshold=1.0)
assert response.status == "SUCCESS"
print(f"Results with filter + threshold=1.0: {len(response.details)}")
assert len(response.details) == 0


def test_score_threshold_with_filter_medium(setup_test_index):
"""Test score_threshold with filter at medium threshold."""
from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator

index_model = setup_test_index

filter_ = IndexFilter(field="topic", value="python", operator=IndexFilterOperator.EQUALS)

# Get all filtered results
response_all = index_model.search("Python", filters=[filter_], score_threshold=0.0)
all_count = len(response_all.details)

# Get filtered results with medium threshold
response_filtered = index_model.search("Python", filters=[filter_], score_threshold=0.8)
filtered_count = len(response_filtered.details)

print(f"\nFiltered results (threshold=0.0): {all_count}")
print(f"Filtered results (threshold=0.8): {filtered_count}")

# Higher threshold should return fewer or equal results
assert filtered_count <= all_count


if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
19 changes: 14 additions & 5 deletions tests/unit/index_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,10 @@ def test_validate_record_failure_no_uri(mocker):


def test_validate_record_failure_no_value(mocker):
record = Record(uri="test.jpg", value_type="text", id=0, attributes={})
record = Record(value_type="text", id=0, attributes={})
with pytest.raises(Exception) as e:
record.validate()
assert str(e.value) == "Index Upsert Error: Value is required for text records"
assert str(e.value) == "Index Upsert Error: Either value or uri is required for text records"


def test_record_to_dict():
Expand Down Expand Up @@ -233,15 +233,24 @@ def test_index_factory_create_failure():

with pytest.raises(Exception) as e:
IndexFactory.create(description="test")
assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not"
assert (
str(e.value)
== "Index Factory Exception: name, description, and embedding_model must be provided when params is not"
)

with pytest.raises(Exception) as e:
IndexFactory.create(name="test")
assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not"
assert (
str(e.value)
== "Index Factory Exception: name, description, and embedding_model must be provided when params is not"
)

with pytest.raises(Exception) as e:
IndexFactory.create(name="test", description="test", embedding_model=None)
assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not"
assert (
str(e.value)
== "Index Factory Exception: name, description, and embedding_model must be provided when params is not"
)


def test_index_model_splitter():
Expand Down