diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_model.py index 61488c12..aed7934e 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_model.py @@ -1,3 +1,5 @@ +"""Index model module for document indexing and search operations.""" + import os import warnings from uuid import uuid4 @@ -14,6 +16,7 @@ DOCLING_MODEL_ID = "677bee6c6eb56331f9192a91" + class IndexFilterOperator(Enum): """Enumeration of operators available for filtering index records. @@ -30,6 +33,7 @@ class IndexFilterOperator(Enum): GREATER_THAN_OR_EQUALS (str): Greater than or equal to operator (">=") LESS_THAN_OR_EQUALS (str): Less than or equal to operator ("<=") """ + EQUALS = "==" NOT_EQUALS = "!=" CONTAINS = "in" @@ -120,6 +124,8 @@ def __init__( class IndexModel(Model): + """A model for indexing and searching documents using vector embeddings.""" + def __init__( self, id: Text, @@ -197,19 +203,24 @@ def to_dict(self) -> Dict: data["collection_type"] = self.version.split("-", 1)[0] return data - def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) -> ModelResponse: - """Search for documents in the index + def search( + self, query: str, top_k: int = 10, filters: List[IndexFilter] = [], score_threshold: float = 0.0 + ) -> ModelResponse: + """Search for documents in the index. Args: query (str): Query to be searched top_k (int, optional): Number of results to be returned. Defaults to 10. filters (List[IndexFilter], optional): Filters to be applied. Defaults to []. + score_threshold (float, optional): Minimum score threshold for results. Results with + scores below this threshold will be filtered out. Defaults to 0.0. Returns: ModelResponse: Response from the indexing service Example: - index_model.search("Hello") + - index_model.search("Hello", score_threshold=0.5) - index_model.search("", filters=[IndexFilter(field="category", value="animate", operator=IndexFilterOperator.EQUALS)]) """ from aixplain.factories import FileFactory @@ -226,12 +237,12 @@ def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) - "data": query or uri, "dataType": value_type, "filters": [filter.to_dict() for filter in filters], - "payload": {"uri": uri, "value_type": value_type, "top_k": top_k}, + "payload": {"uri": uri, "value_type": value_type, "top_k": top_k, "score_threshold": score_threshold}, } return self.run(data=data) def upsert(self, documents: Union[List[Record], str], splitter: Optional[Splitter] = None) -> ModelResponse: - """Upsert documents into the index + """Upsert documents into the index. Args: documents (Union[List[Record], str]): List of documents to be upserted or a file path @@ -295,8 +306,7 @@ def count(self) -> int: raise Exception(f"Failed to count documents: {response.error_message}") def get_record(self, record_id: Text) -> ModelResponse: - """ - Get a document from the index. + """Get a document from the index. Args: record_id (Text): ID of the document to retrieve. @@ -317,8 +327,7 @@ def get_record(self, record_id: Text) -> ModelResponse: raise Exception(f"Failed to get record: {response.error_message}") def delete_record(self, record_id: Text) -> ModelResponse: - """ - Delete a document from the index. + """Delete a document from the index. Args: record_id (Text): ID of the document to delete. @@ -396,8 +405,7 @@ def parse_file(file_path: str) -> ModelResponse: raise Exception(f"Failed to parse file: {e}") def retrieve_records_with_filter(self, filter: IndexFilter) -> ModelResponse: - """ - Retrieve records from the index that match the given filter. + """Retrieve records from the index that match the given filter. Args: filter (IndexFilter): The filter criteria to apply when retrieving records. @@ -420,8 +428,7 @@ def retrieve_records_with_filter(self, filter: IndexFilter) -> ModelResponse: raise Exception(f"Failed to retrieve records with filter: {response.error_message}") def delete_records_by_date(self, date: float) -> ModelResponse: - """ - Delete records from the index that match the given date. + """Delete records from the index that match the given date. Args: date (float): The date (as a timestamp) to match records for deletion. diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index b36971b0..33764e3f 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -46,7 +46,6 @@ def pytest_generate_tests(metafunc): def test_llm_run(llm_model): """Testing LLMs with history context""" - assert isinstance(llm_model, LLM) response = llm_model.run( data="What is my name?", @@ -225,7 +224,6 @@ def test_index_model_with_filter(embedding_model, supplier_params): def test_llm_run_with_file(): """Testing LLM with local file input containing emoji""" - # Create test file path test_file_path = Path(__file__).parent / "data" / "test_input.txt" @@ -607,3 +605,15 @@ def test_delete_records_by_date(setup_index_with_test_records): response = index_model.retrieve_records_with_filter(filter_all) assert response.status == "SUCCESS" assert len(response.details) == 2 + + +def test_index_model_with_score_threshold(setup_index_with_test_records): + """Testing Index Model with score threshold""" + index_model = setup_index_with_test_records + response = index_model.search("technology", score_threshold=0.0) + assert response.status == "SUCCESS" + assert len(response.details) == 5 + + response = index_model.search("technology", score_threshold=1.0) + assert response.status == "SUCCESS" + assert len(response.details) == 0 diff --git a/tests/functional/model/test_score_threshold.py b/tests/functional/model/test_score_threshold.py new file mode 100644 index 00000000..7f783e60 --- /dev/null +++ b/tests/functional/model/test_score_threshold.py @@ -0,0 +1,125 @@ +"""Test score threshold filtering for IndexModel.search()""" + +import pytest +import time +import uuid +from aixplain.factories import IndexFactory +from aixplain.modules.model.record import Record + + +@pytest.fixture(scope="module") +def setup_test_index(): + """Create an index with test records for score threshold testing.""" + unique_name = f"score_threshold_test_{uuid.uuid4().hex[:8]}" + index_model = IndexFactory.create(name=unique_name, description="Test index for score threshold") + + # Wait for index to be ready + time.sleep(5) + + # Add test records with varied content + test_records = [ + Record(id="1", value="Python programming language tutorial", value_type="text", attributes={"topic": "python"}), + Record( + id="2", value="JavaScript web development basics", value_type="text", attributes={"topic": "javascript"} + ), + Record( + id="3", value="Python data science and machine learning", value_type="text", attributes={"topic": "python"} + ), + Record(id="4", value="Cloud computing infrastructure", value_type="text", attributes={"topic": "cloud"}), + Record(id="5", value="Python Flask web framework", value_type="text", attributes={"topic": "python"}), + ] + + index_model.upsert(test_records) + time.sleep(10) # Wait for indexing + + yield index_model + + # Cleanup + index_model.delete() + + +def test_score_threshold_zero_returns_all(setup_test_index): + """With score_threshold=0.0, all matching results should be returned.""" + index_model = setup_test_index + response = index_model.search("Python", score_threshold=0.0) + + assert response.status == "SUCCESS" + print(f"\nResults with threshold 0.0: {len(response.details)} records") + for detail in response.details: + print(f" - Score: {detail.get('score', 'N/A')}") + + +def test_score_threshold_high_returns_none(setup_test_index): + """With score_threshold=1.0, no results should be returned.""" + index_model = setup_test_index + + response = index_model.search("Python", score_threshold=1.0) + assert response.status == "SUCCESS" + print(f"\nResults with threshold 1.0: {len(response.details)} records") + assert len(response.details) == 0, "No results should match with threshold=1.0" + + +def test_score_threshold_filters_correctly(setup_test_index): + """Verify that higher threshold returns fewer or equal results.""" + index_model = setup_test_index + + # Get all results first + response_all = index_model.search("Python", score_threshold=0.0) + all_count = len(response_all.details) + + # Get filtered results with medium threshold + response_filtered = index_model.search("Python", score_threshold=0.5) + filtered_count = len(response_filtered.details) + + print(f"\nAll results (threshold=0.0): {all_count}") + print(f"Filtered results (threshold=0.5): {filtered_count}") + + # Filtered count should be <= all count + assert filtered_count <= all_count + + +def test_score_threshold_with_filter(setup_test_index): + """Test score_threshold combined with filters.""" + from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator + + index_model = setup_test_index + + # Search with filter only (no threshold) + filter_ = IndexFilter(field="topic", value="python", operator=IndexFilterOperator.EQUALS) + response = index_model.search("Python", filters=[filter_], score_threshold=0.0) + assert response.status == "SUCCESS" + filtered_count = len(response.details) + print(f"\nResults with filter only (threshold=0.0): {filtered_count}") + + # Search with filter AND high threshold - should return 0 + response = index_model.search("Python", filters=[filter_], score_threshold=1.0) + assert response.status == "SUCCESS" + print(f"Results with filter + threshold=1.0: {len(response.details)}") + assert len(response.details) == 0 + + +def test_score_threshold_with_filter_medium(setup_test_index): + """Test score_threshold with filter at medium threshold.""" + from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator + + index_model = setup_test_index + + filter_ = IndexFilter(field="topic", value="python", operator=IndexFilterOperator.EQUALS) + + # Get all filtered results + response_all = index_model.search("Python", filters=[filter_], score_threshold=0.0) + all_count = len(response_all.details) + + # Get filtered results with medium threshold + response_filtered = index_model.search("Python", filters=[filter_], score_threshold=0.8) + filtered_count = len(response_filtered.details) + + print(f"\nFiltered results (threshold=0.0): {all_count}") + print(f"Filtered results (threshold=0.8): {filtered_count}") + + # Higher threshold should return fewer or equal results + assert filtered_count <= all_count + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/unit/index_model_test.py b/tests/unit/index_model_test.py index 5f5e13eb..faed73dd 100644 --- a/tests/unit/index_model_test.py +++ b/tests/unit/index_model_test.py @@ -183,10 +183,10 @@ def test_validate_record_failure_no_uri(mocker): def test_validate_record_failure_no_value(mocker): - record = Record(uri="test.jpg", value_type="text", id=0, attributes={}) + record = Record(value_type="text", id=0, attributes={}) with pytest.raises(Exception) as e: record.validate() - assert str(e.value) == "Index Upsert Error: Value is required for text records" + assert str(e.value) == "Index Upsert Error: Either value or uri is required for text records" def test_record_to_dict(): @@ -233,15 +233,24 @@ def test_index_factory_create_failure(): with pytest.raises(Exception) as e: IndexFactory.create(description="test") - assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + assert ( + str(e.value) + == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + ) with pytest.raises(Exception) as e: IndexFactory.create(name="test") - assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + assert ( + str(e.value) + == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + ) with pytest.raises(Exception) as e: IndexFactory.create(name="test", description="test", embedding_model=None) - assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + assert ( + str(e.value) + == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + ) def test_index_model_splitter():