From b09fbe01ecc4b909cd0a70d80ffdc21d4bd8ef9d Mon Sep 17 00:00:00 2001 From: Abdul Basit Anees Date: Thu, 11 Sep 2025 21:45:49 +0000 Subject: [PATCH 1/2] Add score threshold for searching in index --- aixplain/modules/model/index_model.py | 20 +++++++------- tests/functional/model/run_model_test.py | 34 +++++++++++++++++++----- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_model.py index 61488c12..a367da55 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_model.py @@ -14,6 +14,7 @@ DOCLING_MODEL_ID = "677bee6c6eb56331f9192a91" + class IndexFilterOperator(Enum): """Enumeration of operators available for filtering index records. @@ -30,6 +31,7 @@ class IndexFilterOperator(Enum): GREATER_THAN_OR_EQUALS (str): Greater than or equal to operator (">=") LESS_THAN_OR_EQUALS (str): Less than or equal to operator ("<=") """ + EQUALS = "==" NOT_EQUALS = "!=" CONTAINS = "in" @@ -197,7 +199,9 @@ def to_dict(self) -> Dict: data["collection_type"] = self.version.split("-", 1)[0] return data - def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) -> ModelResponse: + def search( + self, query: str, top_k: int = 10, filters: List[IndexFilter] = [], score_threshold: float = 0.0 + ) -> ModelResponse: """Search for documents in the index Args: @@ -226,7 +230,7 @@ def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) - "data": query or uri, "dataType": value_type, "filters": [filter.to_dict() for filter in filters], - "payload": {"uri": uri, "value_type": value_type, "top_k": top_k}, + "payload": {"uri": uri, "value_type": value_type, "top_k": top_k, "score_threshold": score_threshold}, } return self.run(data=data) @@ -295,8 +299,7 @@ def count(self) -> int: raise Exception(f"Failed to count documents: {response.error_message}") def get_record(self, record_id: Text) -> ModelResponse: - """ - Get a document from the index. + """Get a document from the index. Args: record_id (Text): ID of the document to retrieve. @@ -317,8 +320,7 @@ def get_record(self, record_id: Text) -> ModelResponse: raise Exception(f"Failed to get record: {response.error_message}") def delete_record(self, record_id: Text) -> ModelResponse: - """ - Delete a document from the index. + """Delete a document from the index. Args: record_id (Text): ID of the document to delete. @@ -396,8 +398,7 @@ def parse_file(file_path: str) -> ModelResponse: raise Exception(f"Failed to parse file: {e}") def retrieve_records_with_filter(self, filter: IndexFilter) -> ModelResponse: - """ - Retrieve records from the index that match the given filter. + """Retrieve records from the index that match the given filter. Args: filter (IndexFilter): The filter criteria to apply when retrieving records. @@ -420,8 +421,7 @@ def retrieve_records_with_filter(self, filter: IndexFilter) -> ModelResponse: raise Exception(f"Failed to retrieve records with filter: {response.error_message}") def delete_records_by_date(self, date: float) -> ModelResponse: - """ - Delete records from the index that match the given date. + """Delete records from the index that match the given date. Args: date (float): The date (as a timestamp) to match records for deletion. diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 9af2ccc3..49b885a1 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -15,6 +15,7 @@ CACHE_FOLDER = ".cache" + def pytest_generate_tests(metafunc): if "llm_model" in metafunc.fixturenames: four_weeks_ago = datetime.now(timezone.utc) - timedelta(weeks=4) @@ -41,7 +42,6 @@ def pytest_generate_tests(metafunc): def test_llm_run(llm_model): """Testing LLMs with history context""" - assert isinstance(llm_model, LLM) response = llm_model.run( data="What is my name?", @@ -166,7 +166,11 @@ def test_index_model_with_filter(embedding_model, supplier_params): for _ in range(retries): try: index_model.upsert( - [Record(value="The world is great", value_type="text", uri="", id="2", attributes={"category": "world"})] + [ + Record( + value="The world is great", value_type="text", uri="", id="2", attributes={"category": "world"} + ) + ] ) break except Exception: @@ -185,7 +189,6 @@ def test_index_model_with_filter(embedding_model, supplier_params): def test_llm_run_with_file(): """Testing LLM with local file input containing emoji""" - # Create test file path test_file_path = Path(__file__).parent / "data" / "test_input.txt" @@ -204,7 +207,6 @@ def test_llm_run_with_file(): def test_aixplain_model_cache_creation(): """Ensure AssetCache is triggered and cache is created.""" - cache_file = os.path.join(CACHE_FOLDER, "models.json") # Clean up cache before the test @@ -232,7 +234,9 @@ def test_index_model_air_with_image(): from aixplain.factories.index_factory.utils import AirParams params = AirParams( - name=f"Image Index {uuid4()}", description="Index for images", embedding_model=EmbeddingModel.JINA_CLIP_V2_MULTIMODAL + name=f"Image Index {uuid4()}", + description="Index for images", + embedding_model=EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, ) index_model = IndexFactory.create(params=params) @@ -334,7 +338,9 @@ def test_index_model_with_txt_file(): # Create index with OpenAI Ada 002 for text processing params = AirParams( - name=f"File Index {uuid4()}", description="Index for file processing", embedding_model=EmbeddingModel.OPENAI_ADA002 + name=f"File Index {uuid4()}", + description="Index for file processing", + embedding_model=EmbeddingModel.OPENAI_ADA002, ) index_model = IndexFactory.create(params=params) @@ -368,7 +374,9 @@ def test_index_model_with_pdf_file(): # Create index with OpenAI Ada 002 for text processing params = AirParams( - name=f"PDF Index {uuid4()}", description="Index for PDF processing", embedding_model=EmbeddingModel.OPENAI_ADA002 + name=f"PDF Index {uuid4()}", + description="Index for PDF processing", + embedding_model=EmbeddingModel.OPENAI_ADA002, ) index_model = IndexFactory.create(params=params) @@ -506,3 +514,15 @@ def test_delete_records_by_date(setup_index_with_test_records): response = index_model.retrieve_records_with_filter(filter_all) assert response.status == "SUCCESS" assert len(response.details) == 2 + + +def test_index_model_with_score_threshold(setup_index_with_test_records): + """Testing Index Model with score threshold""" + index_model = setup_index_with_test_records + response = index_model.search("technology", score_threshold=0.0) + assert response.status == "SUCCESS" + assert len(response.details) == 5 + + response = index_model.search("technology", score_threshold=1.0) + assert response.status == "SUCCESS" + assert len(response.details) == 0 From 1ed5b3441bb747ac4aa3e9396b5c8863f2ffcb50 Mon Sep 17 00:00:00 2001 From: aix-ahmet Date: Mon, 23 Feb 2026 22:01:51 +0300 Subject: [PATCH 2/2] Add score_threshold filtering option for index search --- aixplain/modules/model/index_model.py | 11 +- .../functional/model/test_score_threshold.py | 125 ++++++++++++++++++ tests/unit/index_model_test.py | 19 ++- 3 files changed, 148 insertions(+), 7 deletions(-) create mode 100644 tests/functional/model/test_score_threshold.py diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_model.py index a367da55..aed7934e 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_model.py @@ -1,3 +1,5 @@ +"""Index model module for document indexing and search operations.""" + import os import warnings from uuid import uuid4 @@ -122,6 +124,8 @@ def __init__( class IndexModel(Model): + """A model for indexing and searching documents using vector embeddings.""" + def __init__( self, id: Text, @@ -202,18 +206,21 @@ def to_dict(self) -> Dict: def search( self, query: str, top_k: int = 10, filters: List[IndexFilter] = [], score_threshold: float = 0.0 ) -> ModelResponse: - """Search for documents in the index + """Search for documents in the index. Args: query (str): Query to be searched top_k (int, optional): Number of results to be returned. Defaults to 10. filters (List[IndexFilter], optional): Filters to be applied. Defaults to []. + score_threshold (float, optional): Minimum score threshold for results. Results with + scores below this threshold will be filtered out. Defaults to 0.0. Returns: ModelResponse: Response from the indexing service Example: - index_model.search("Hello") + - index_model.search("Hello", score_threshold=0.5) - index_model.search("", filters=[IndexFilter(field="category", value="animate", operator=IndexFilterOperator.EQUALS)]) """ from aixplain.factories import FileFactory @@ -235,7 +242,7 @@ def search( return self.run(data=data) def upsert(self, documents: Union[List[Record], str], splitter: Optional[Splitter] = None) -> ModelResponse: - """Upsert documents into the index + """Upsert documents into the index. Args: documents (Union[List[Record], str]): List of documents to be upserted or a file path diff --git a/tests/functional/model/test_score_threshold.py b/tests/functional/model/test_score_threshold.py new file mode 100644 index 00000000..7f783e60 --- /dev/null +++ b/tests/functional/model/test_score_threshold.py @@ -0,0 +1,125 @@ +"""Test score threshold filtering for IndexModel.search()""" + +import pytest +import time +import uuid +from aixplain.factories import IndexFactory +from aixplain.modules.model.record import Record + + +@pytest.fixture(scope="module") +def setup_test_index(): + """Create an index with test records for score threshold testing.""" + unique_name = f"score_threshold_test_{uuid.uuid4().hex[:8]}" + index_model = IndexFactory.create(name=unique_name, description="Test index for score threshold") + + # Wait for index to be ready + time.sleep(5) + + # Add test records with varied content + test_records = [ + Record(id="1", value="Python programming language tutorial", value_type="text", attributes={"topic": "python"}), + Record( + id="2", value="JavaScript web development basics", value_type="text", attributes={"topic": "javascript"} + ), + Record( + id="3", value="Python data science and machine learning", value_type="text", attributes={"topic": "python"} + ), + Record(id="4", value="Cloud computing infrastructure", value_type="text", attributes={"topic": "cloud"}), + Record(id="5", value="Python Flask web framework", value_type="text", attributes={"topic": "python"}), + ] + + index_model.upsert(test_records) + time.sleep(10) # Wait for indexing + + yield index_model + + # Cleanup + index_model.delete() + + +def test_score_threshold_zero_returns_all(setup_test_index): + """With score_threshold=0.0, all matching results should be returned.""" + index_model = setup_test_index + response = index_model.search("Python", score_threshold=0.0) + + assert response.status == "SUCCESS" + print(f"\nResults with threshold 0.0: {len(response.details)} records") + for detail in response.details: + print(f" - Score: {detail.get('score', 'N/A')}") + + +def test_score_threshold_high_returns_none(setup_test_index): + """With score_threshold=1.0, no results should be returned.""" + index_model = setup_test_index + + response = index_model.search("Python", score_threshold=1.0) + assert response.status == "SUCCESS" + print(f"\nResults with threshold 1.0: {len(response.details)} records") + assert len(response.details) == 0, "No results should match with threshold=1.0" + + +def test_score_threshold_filters_correctly(setup_test_index): + """Verify that higher threshold returns fewer or equal results.""" + index_model = setup_test_index + + # Get all results first + response_all = index_model.search("Python", score_threshold=0.0) + all_count = len(response_all.details) + + # Get filtered results with medium threshold + response_filtered = index_model.search("Python", score_threshold=0.5) + filtered_count = len(response_filtered.details) + + print(f"\nAll results (threshold=0.0): {all_count}") + print(f"Filtered results (threshold=0.5): {filtered_count}") + + # Filtered count should be <= all count + assert filtered_count <= all_count + + +def test_score_threshold_with_filter(setup_test_index): + """Test score_threshold combined with filters.""" + from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator + + index_model = setup_test_index + + # Search with filter only (no threshold) + filter_ = IndexFilter(field="topic", value="python", operator=IndexFilterOperator.EQUALS) + response = index_model.search("Python", filters=[filter_], score_threshold=0.0) + assert response.status == "SUCCESS" + filtered_count = len(response.details) + print(f"\nResults with filter only (threshold=0.0): {filtered_count}") + + # Search with filter AND high threshold - should return 0 + response = index_model.search("Python", filters=[filter_], score_threshold=1.0) + assert response.status == "SUCCESS" + print(f"Results with filter + threshold=1.0: {len(response.details)}") + assert len(response.details) == 0 + + +def test_score_threshold_with_filter_medium(setup_test_index): + """Test score_threshold with filter at medium threshold.""" + from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator + + index_model = setup_test_index + + filter_ = IndexFilter(field="topic", value="python", operator=IndexFilterOperator.EQUALS) + + # Get all filtered results + response_all = index_model.search("Python", filters=[filter_], score_threshold=0.0) + all_count = len(response_all.details) + + # Get filtered results with medium threshold + response_filtered = index_model.search("Python", filters=[filter_], score_threshold=0.8) + filtered_count = len(response_filtered.details) + + print(f"\nFiltered results (threshold=0.0): {all_count}") + print(f"Filtered results (threshold=0.8): {filtered_count}") + + # Higher threshold should return fewer or equal results + assert filtered_count <= all_count + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/unit/index_model_test.py b/tests/unit/index_model_test.py index 5f5e13eb..faed73dd 100644 --- a/tests/unit/index_model_test.py +++ b/tests/unit/index_model_test.py @@ -183,10 +183,10 @@ def test_validate_record_failure_no_uri(mocker): def test_validate_record_failure_no_value(mocker): - record = Record(uri="test.jpg", value_type="text", id=0, attributes={}) + record = Record(value_type="text", id=0, attributes={}) with pytest.raises(Exception) as e: record.validate() - assert str(e.value) == "Index Upsert Error: Value is required for text records" + assert str(e.value) == "Index Upsert Error: Either value or uri is required for text records" def test_record_to_dict(): @@ -233,15 +233,24 @@ def test_index_factory_create_failure(): with pytest.raises(Exception) as e: IndexFactory.create(description="test") - assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + assert ( + str(e.value) + == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + ) with pytest.raises(Exception) as e: IndexFactory.create(name="test") - assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + assert ( + str(e.value) + == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + ) with pytest.raises(Exception) as e: IndexFactory.create(name="test", description="test", embedding_model=None) - assert str(e.value) == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + assert ( + str(e.value) + == "Index Factory Exception: name, description, and embedding_model must be provided when params is not" + ) def test_index_model_splitter():