diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f762fbee..a5a5da18 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: name: Style Guide Enforcement (flake8) args: - '--max-line-length=120' - - --ignore=D100,D203,D405,W503,E203,E501,F841,E126,E712,E123,E131,F821,E121,W605,E402 + - --ignore=D100,D203,D405,W503,E203,E501,F841,E126,E712,E123,E131,F821,E121,W605,E402,E704 - repo: 'https://github.com/asottile/pyupgrade' rev: v3.21.2 hooks: diff --git a/src/superannotate/__init__.py b/src/superannotate/__init__.py index c10018bb..ebed1f89 100644 --- a/src/superannotate/__init__.py +++ b/src/superannotate/__init__.py @@ -2,7 +2,7 @@ import os import sys -__version__ = "4.5.4dev1" +__version__ = "4.5.5dev2" os.environ.update({"sa_version": __version__}) diff --git a/src/superannotate/lib/app/interface/responses.py b/src/superannotate/lib/app/interface/responses.py new file mode 100644 index 00000000..a4bc1501 --- /dev/null +++ b/src/superannotate/lib/app/interface/responses.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from collections.abc import Callable +from collections.abc import Iterator +from typing import Generic +from typing import overload +from typing import TypeVar + +T = TypeVar("T") + + +class BaseResult(list, Generic[T]): + """A generic list-like wrapper for results with lazy loading support. + + Inherits from ``list`` for full backward compatibility with code that + expects a real list (``isinstance(x, list)``, JSON serializers, etc.). + Data is fetched lazily on first access. + """ + + def __init__(self, data_fetcher: Callable[[], list[T]]) -> None: + super().__init__() + self._data_fetcher = data_fetcher + self._loaded = False + + def _ensure_data(self) -> None: + """Lazily fetch data if not already loaded.""" + if not self._loaded: + list.extend(self, self._data_fetcher()) + self._loaded = True + + def data(self) -> list[T]: + self._ensure_data() + return list(self) + + def __iter__(self) -> Iterator[T]: + self._ensure_data() + return list.__iter__(self) + + def __len__(self) -> int: + self._ensure_data() + return list.__len__(self) + + @overload + def __getitem__(self, index: int) -> T: ... + + @overload + def __getitem__(self, index: slice) -> list[T]: ... + + def __getitem__(self, index: int | slice) -> T | list[T]: + self._ensure_data() + return list.__getitem__(self, index) + + def __repr__(self) -> str: + self._ensure_data() + return list.__repr__(self) + + def __bool__(self) -> bool: + self._ensure_data() + return list.__len__(self) > 0 + + def __contains__(self, item: object) -> bool: + self._ensure_data() + return list.__contains__(self, item) + + def __eq__(self, other: object) -> bool: + self._ensure_data() + return list.__eq__(self, other) + + __hash__ = None # type: ignore[assignment] + + +class QueryResult(BaseResult[dict]): + """A list-like wrapper for query results that supports .count() method. + + This class wraps a list of query results while maintaining full backward + compatibility with list-like operations (iteration, indexing, len()). + Data is fetched lazily - only when accessed. Calling .count() does not + trigger data fetching. + """ + + def __init__( + self, + data_fetcher: Callable[[], list[dict]], + count_fetcher: Callable[[], int], + ) -> None: + super().__init__(data_fetcher) + self._count_fetcher = count_fetcher + + def count(self) -> int: + """Return the count of items matching the query from the server. + + This method does not trigger data fetching - it makes a separate + lightweight API call to get only the count. + """ + return self._count_fetcher() diff --git a/src/superannotate/lib/app/interface/sdk_interface.py b/src/superannotate/lib/app/interface/sdk_interface.py index dae44d8a..2011f566 100644 --- a/src/superannotate/lib/app/interface/sdk_interface.py +++ b/src/superannotate/lib/app/interface/sdk_interface.py @@ -11,6 +11,7 @@ import warnings from collections.abc import Callable from collections.abc import Iterable +from functools import partial from pathlib import Path from typing import Annotated from typing import Any @@ -80,6 +81,8 @@ from lib.infrastructure.query_builder import QueryBuilderChain from lib.infrastructure.query_builder import FieldValidationHandler +from lib.app.interface.responses import QueryResult + logger = logging.getLogger("sa") NotEmptyStr = Annotated[str, StringConstraints(strict=True, min_length=1)] @@ -152,6 +155,7 @@ def __init__( self._annotation_adapter: BaseMultimodalAnnotationAdapter | None = None self._overwrite = overwrite self._annotation: dict | None = None + self._set_component_called = False def _set_small_annotation_adapter(self, annotation: dict | None = None): self._annotation_adapter = MultimodalSmallAnnotationAdapter( @@ -221,7 +225,9 @@ def save(self): self._set_large_annotation_adapter(self.annotation) else: self._set_small_annotation_adapter(self.annotation) - self._annotation_adapter.save() + if self._set_component_called: + self._annotation_adapter.save() + self._set_component_called = False def get_metadata(self): """ @@ -281,6 +287,7 @@ def set_component_value(self, component_id: str, value: Any): """ self.annotation_adapter.set_component_value(component_id, value) + self._set_component_called = True return self @@ -4267,10 +4274,12 @@ def query( project: NotEmptyStr | int | tuple[int, int] | tuple[str, str], query: NotEmptyStr | None = None, subset: NotEmptyStr | None = None, - ): + ) -> QueryResult: """Return items that satisfy the given query. Query syntax should be in SuperAnnotate query language(https://doc.superannotate.com/docs/explore-overview). + The returned QueryResult behaves like a list of dicts, and additionally exposes a .count() method. + :param project: Accepts a project as a string ("project" or "project/folder") or as a tuple (project_id, folder_id), where the folder is optional.” :type project: Union[str, int, Tuple[int, int], Tuple[str, str]] @@ -4282,14 +4291,54 @@ def query( :type subset: str :return: queried items' metadata list - :rtype: list of dicts + :rtype: QueryResult (list of dicts with .count() method) + + Request Example: + :: + + sa_client = SAClient() + + queried_items = sa_client.query( + project="Image Project", + query="metadata(lastAction.email = test@superannotate.com)" + ) + for item in queried_items: + print(item["name"]) + + .. py:method:: query.count() -> int + + Returns the total number of items matching the query. + + :return: total number of matching items + :rtype: int + + Request Example: + :: + + sa_client = SAClient() + + total = sa_client.query( + project="Image Project", + query="metadata(lastAction.email = test@superannotate.com)" + ).count() + print(f"Total matching items: {total}") """ project, folder = self.controller.get_project_folder(project) - items = self.controller.query_entities(project, folder, query, subset) - exclude = { - "meta", - } - return BaseSerializer.serialize_iterable(items, exclude=exclude) + fetch_entities = partial( + self.controller.query_entities, project, folder, query, subset + ) + return QueryResult( + data_fetcher=lambda: BaseSerializer.serialize_iterable( + fetch_entities(), exclude={"meta"} + ), + count_fetcher=partial( + self.controller.query_items_count, + project=project, + folder=folder, + query=query, + subset=subset, + ), + ) def get_item_metadata( self, diff --git a/src/superannotate/lib/core/serviceproviders.py b/src/superannotate/lib/core/serviceproviders.py index 6a28e04e..d92de4b9 100644 --- a/src/superannotate/lib/core/serviceproviders.py +++ b/src/superannotate/lib/core/serviceproviders.py @@ -750,7 +750,9 @@ def saqul_query( def query_item_count( self, project: entities.ProjectEntity, + folder: entities.FolderEntity = None, query: str = None, + subset_id: int = None, ) -> ServiceResponse: raise NotImplementedError diff --git a/src/superannotate/lib/core/usecases/items.py b/src/superannotate/lib/core/usecases/items.py index 5a1f2177..fbed9f53 100644 --- a/src/superannotate/lib/core/usecases/items.py +++ b/src/superannotate/lib/core/usecases/items.py @@ -170,13 +170,17 @@ def __init__( self, reporter: Reporter, project: ProjectEntity, + folder: FolderEntity, service_provider: BaseServiceProvider, query: str, + subset: str = None, ): super().__init__(reporter) self._project = project + self._folder = folder self._service_provider = service_provider self._query = query + self._subset = subset def validate_arguments(self): if self._query: @@ -197,9 +201,40 @@ def validate_arguments(self): if not response.ok: raise AppException(response.error) + if not any([self._query, self._subset]): + raise AppException( + "The query and subset params cannot have the value None at the same time." + ) + if self._subset and not self._folder.is_root: + raise AppException( + "The folder name should be specified in the query string." + ) + def execute(self) -> Response: if self.is_valid(): - query_kwargs = {"query": self._query} + query_kwargs = {} + if self._subset: + response = self._service_provider.explore.list_subsets(self._project) + if response.ok: + subset = next( + (_sub for _sub in response.data if _sub.name == self._subset), + None, + ) + else: + self._response.errors = response.error + return self._response + if not subset: + self._response.errors = AppException( + "Subset not found. Use the superannotate." + "get_subsets() function to get a list of the available subsets." + ) + return self._response + query_kwargs["subset_id"] = subset.id + if self._query: + query_kwargs["query"] = self._query + query_kwargs["folder"] = ( + None if self._folder.name == "root" else self._folder + ) service_response = self._service_provider.explore.query_item_count( self._project, **query_kwargs, @@ -862,7 +897,7 @@ def execute(self): item_names=self._item_names[i : i + self.CHUNK_SIZE], # noqa: E203, annotation_status=self._annotation_status_code, ) - if not status_changed: + if not status_changed.ok: self._response.errors = AppException(self.ERROR_MESSAGE) break return self._response diff --git a/src/superannotate/lib/infrastructure/controller.py b/src/superannotate/lib/infrastructure/controller.py index febb40bf..985444b5 100644 --- a/src/superannotate/lib/infrastructure/controller.py +++ b/src/superannotate/lib/infrastructure/controller.py @@ -2039,13 +2039,20 @@ def query_entities( self.service_provider, items, project, folder, map_fields=False ) - def query_items_count(self, project_name: str, query: str = None) -> int: - project = self.get_project(project_name) + def query_items_count( + self, + project: ProjectEntity, + folder: FolderEntity, + query: str = None, + subset: str = None, + ) -> int: use_case = usecases.QueryEntitiesCountUseCase( reporter=self.get_default_reporter(), project=project, + folder=folder, query=query, + subset=subset, service_provider=self.service_provider, ) response = use_case.execute() diff --git a/src/superannotate/lib/infrastructure/services/explore.py b/src/superannotate/lib/infrastructure/services/explore.py index 4afbd544..341fae49 100644 --- a/src/superannotate/lib/infrastructure/services/explore.py +++ b/src/superannotate/lib/infrastructure/services/explore.py @@ -196,13 +196,19 @@ def saqul_query( def query_item_count( self, project: entities.ProjectEntity, + folder: entities.FolderEntity = None, query: str = None, + subset_id: int = None, ) -> ServiceResponse: params = { "project_id": project.id, "includeFolderNames": True, } + if folder: + params["folder_id"] = folder.id + if subset_id: + params["subset_id"] = subset_id data = {"query": query} response = self.client.request( urljoin(self.explore_service_url, self.URL_QUERY_COUNT), diff --git a/tests/integration/items/test_item_context.py b/tests/integration/items/test_item_context.py index 454e4879..7bcd7ede 100644 --- a/tests/integration/items/test_item_context.py +++ b/tests/integration/items/test_item_context.py @@ -1,8 +1,12 @@ import json import os from pathlib import Path +from unittest import TestCase +from unittest.mock import MagicMock +from unittest.mock import patch from src.superannotate import FileChangedError +from src.superannotate import ItemContext from src.superannotate import SAClient from tests.integration.base import BaseTestCase @@ -135,3 +139,63 @@ def tearDown(self) -> None: sa.delete_project(self.PROJECT_NAME) except Exception: ... + + +class TestItemContextSetComponentCalledFlag(TestCase): + def _make_context(self): + ic = ItemContext( + controller=MagicMock(), + project=MagicMock(), + folder=MagicMock(), + item=MagicMock(), + overwrite=True, + ) + ic._annotation_adapter = MagicMock() + ic._annotation_adapter.annotation = {"metadata": {}, "data": {}} + return ic + + def test_dirty_flag_initial_state(self): + ic = self._make_context() + self.assertFalse(ic._set_component_called) + + def test_set_component_value_marks_dirty(self): + ic = self._make_context() + ic.set_component_value("component_id", "value") + self.assertTrue(ic._set_component_called) + + def test_save_called_on_exit_after_set_component_value(self): + ic = self._make_context() + with patch.object(ItemContext, "save", autospec=True) as save_mock: + with ic: + ic.set_component_value("component_id", "value") + save_mock.assert_called_once_with(ic) + + def test_dirty_flag_reset_after_save(self): + ic = self._make_context() + with patch.object(ic, "_set_small_annotation_adapter"), patch.object( + ic, "_set_large_annotation_adapter" + ): + ic.set_component_value("component_id", "value") + self.assertTrue(ic._set_component_called) + ic.save() + self.assertFalse(ic._set_component_called) + + def test_no_double_save_on_exit_after_manual_save(self): + ic = self._make_context() + with patch.object(ic, "_set_small_annotation_adapter"), patch.object( + ic, "_set_large_annotation_adapter" + ): + with ic: + ic.set_component_value("component_id", "value") + ic.save() + self.assertEqual(ic._annotation_adapter.save.call_count, 1) + self.assertEqual(ic._annotation_adapter.save.call_count, 1) + + def test_save_not_called_when_exception_raised(self): + ic = self._make_context() + with patch.object(ItemContext, "save", autospec=True) as save_mock: + with self.assertRaises(RuntimeError): + with ic: + ic.set_component_value("component_id", "value") + raise RuntimeError("boom") + save_mock.assert_not_called() diff --git a/tests/integration/items/test_saqul_query.py b/tests/integration/items/test_saqul_query.py index 1e964c3b..4a2a2412 100644 --- a/tests/integration/items/test_saqul_query.py +++ b/tests/integration/items/test_saqul_query.py @@ -59,13 +59,51 @@ def test_query_on_100(self): sa.attach_items(self.PROJECT_NAME, os.path.join(DATA_SET_PATH, "100_urls.csv")) entities = sa.query(self.PROJECT_NAME, "metadata(status = NotStarted)") assert len(entities) == 100 - assert ( - sa.controller.query_items_count( - self.PROJECT_NAME, "metadata(status = NotStarted)" - ) - == 100 + assert entities.count() == len(entities) + + def test_query_result_list_like_behavior(self): + sa.attach_items(self.PROJECT_NAME, os.path.join(DATA_SET_PATH, "100_urls.csv")) + result = sa.query(self.PROJECT_NAME, "metadata(status = NotStarted)") + + self.assertEqual(len(result), 100) + self.assertIsInstance(result[0], dict) + self.assertIn("name", result[0]) + self.assertIsInstance(result[-1], dict) + self.assertEqual(len(result[0:5]), 5) + + items = [item for item in result] + self.assertEqual(len(items), 100) + self.assertIsInstance(list(result), list) + + def test_query_result_lazy_count(self): + sa.attach_items(self.PROJECT_NAME, os.path.join(DATA_SET_PATH, "100_urls.csv")) + result = sa.query(self.PROJECT_NAME, "metadata(status = NotStarted)") + + self.assertFalse(result._loaded) + self.assertEqual(result.count(), 100) + self.assertFalse(result._loaded) + + _ = result[0] + self.assertTrue(result._loaded) + + def test_query_result_count_respects_subset(self): + subset_name = "subset_a" + sa.attach_items(self.PROJECT_NAME, os.path.join(DATA_SET_PATH, "100_urls.csv")) + all_items = sa.query(self.PROJECT_NAME, "metadata(status = NotStarted)") + subset_items = [ + {"name": item["name"], "path": self.PROJECT_NAME} for item in all_items[:30] + ] + sa.add_items_to_subset(self.PROJECT_NAME, subset_name, subset_items) + + result = sa.query( + self.PROJECT_NAME, + "metadata(status = NotStarted)", + subset=subset_name, ) + self.assertEqual(result.count(), len(subset_items)) + self.assertEqual(result.count(), len(list(result))) + def test_validate_saqul_query(self): try: self.assertRaises(