Skip to content

Commit d28b85c

Browse files
committed
Add ability to query.count
1 parent 8d5a9ea commit d28b85c

4 files changed

Lines changed: 193 additions & 8 deletions

File tree

src/superannotate/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import sys
44

5-
__version__ = "4.5.4dev1"
5+
__version__ = "4.5.5dev2"
66

77

88
os.environ.update({"sa_version": __version__})
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from __future__ import annotations
2+
3+
from typing import Callable
4+
from typing import Generic
5+
from typing import Iterator
6+
from typing import TypeVar
7+
from typing import overload
8+
9+
T = TypeVar("T")
10+
11+
12+
class BaseResult(Generic[T]):
13+
"""A generic list-like wrapper for results with lazy loading support.
14+
15+
This class wraps a list of results while maintaining full backward
16+
compatibility with list-like operations (iteration, indexing, len()).
17+
Data is fetched lazily on first access.
18+
"""
19+
20+
def __init__(self, data_fetcher: Callable[[], list[T]]) -> None:
21+
self._data: list[T] | None = None
22+
self._data_fetcher = data_fetcher
23+
24+
def _ensure_data(self) -> list[T]:
25+
"""Lazily fetch data if not already loaded."""
26+
if self._data is None:
27+
self._data = self._data_fetcher()
28+
return self._data
29+
30+
def __iter__(self) -> Iterator[T]:
31+
return iter(self._ensure_data())
32+
33+
def __len__(self) -> int:
34+
return len(self._ensure_data())
35+
36+
@overload
37+
def __getitem__(self, index: int) -> T: ...
38+
39+
@overload
40+
def __getitem__(self, index: slice) -> list[T]: ...
41+
42+
def __getitem__(self, index: int | slice) -> T | list[T]:
43+
return self._ensure_data()[index]
44+
45+
def __repr__(self) -> str:
46+
return repr(self._ensure_data())
47+
48+
def __bool__(self) -> bool:
49+
return bool(self._ensure_data())
50+
51+
def __contains__(self, item: T) -> bool:
52+
return item in self._ensure_data()
53+
54+
55+
class QueryResult(BaseResult[dict]):
56+
"""A list-like wrapper for query results that supports .count() method.
57+
58+
This class wraps a list of query results while maintaining full backward
59+
compatibility with list-like operations (iteration, indexing, len()).
60+
Data is fetched lazily - only when accessed. Calling .count() does not
61+
trigger data fetching.
62+
"""
63+
64+
def __init__(
65+
self,
66+
data_fetcher: Callable[[], list[dict]],
67+
count_fetcher: Callable[[], int],
68+
) -> None:
69+
super().__init__(data_fetcher)
70+
self._count_fetcher = count_fetcher
71+
72+
def count(self) -> int:
73+
"""Return the count of items matching the query from the server.
74+
75+
This method does not trigger data fetching - it makes a separate
76+
lightweight API call to get only the count.
77+
"""
78+
return self._count_fetcher()

src/superannotate/lib/app/interface/sdk_interface.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import warnings
1212
from collections.abc import Callable
1313
from collections.abc import Iterable
14+
from functools import partial
1415
from pathlib import Path
1516
from typing import Annotated
1617
from typing import Any
@@ -80,6 +81,8 @@
8081
from lib.infrastructure.query_builder import QueryBuilderChain
8182
from lib.infrastructure.query_builder import FieldValidationHandler
8283

84+
from lib.app.interface.responses import QueryResult
85+
8386
logger = logging.getLogger("sa")
8487

8588
NotEmptyStr = Annotated[str, StringConstraints(strict=True, min_length=1)]
@@ -4267,10 +4270,14 @@ def query(
42674270
project: NotEmptyStr | int | tuple[int, int] | tuple[str, str],
42684271
query: NotEmptyStr | None = None,
42694272
subset: NotEmptyStr | None = None,
4270-
):
4273+
) -> QueryResult:
42714274
"""Return items that satisfy the given query.
42724275
Query syntax should be in SuperAnnotate query language(https://doc.superannotate.com/docs/explore-overview).
42734276
4277+
The returned object behaves like a list of dicts (supports iteration,
4278+
indexing, and ``len()``) and additionally exposes a ``.count()`` method
4279+
that returns the total number of matching items without fetching them.
4280+
42744281
:param project: Accepts a project as a string ("project" or "project/folder") or as a tuple (project_id, folder_id), where the folder is optional.”
42754282
:type project: Union[str, int, Tuple[int, int], Tuple[str, str]]
42764283
@@ -4283,13 +4290,39 @@ def query(
42834290
42844291
:return: queried items' metadata list
42854292
:rtype: list of dicts
4293+
4294+
Request Example:
4295+
::
4296+
4297+
client = SAClient()
4298+
4299+
# Iterate over queried items (fetches data)
4300+
queried_items = client.query(
4301+
project="Image Project",
4302+
query="instance(error = true)"
4303+
)
4304+
for item in queried_items:
4305+
print(item["name"])
4306+
4307+
# Get only the count without fetching all items
4308+
total = client.query(
4309+
project="Image Project",
4310+
query="instance(error = true)"
4311+
).count()
4312+
print(f"Total matching items: {total}")
42864313
"""
4287-
project, folder = self.controller.get_project_folder(project)
4288-
items = self.controller.query_entities(project, folder, query, subset)
4289-
exclude = {
4290-
"meta",
4291-
}
4292-
return BaseSerializer.serialize_iterable(items, exclude=exclude)
4314+
project_entity, folder = self.controller.get_project_folder(project)
4315+
fetch_entities = partial(
4316+
self.controller.query_entities, project_entity, folder, query, subset
4317+
)
4318+
return QueryResult(
4319+
data_fetcher=lambda: BaseSerializer.serialize_iterable(
4320+
fetch_entities(), exclude={"meta"}
4321+
),
4322+
count_fetcher=partial(
4323+
self.controller.query_items_count, project_entity.name, query
4324+
),
4325+
)
42934326

42944327
def get_item_metadata(
42954328
self,

tests/integration/items/test_saqul_query.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,80 @@ def test_query_on_100(self):
6666
== 100
6767
)
6868

69+
def test_query_result_list_like_behavior(self):
70+
"""Test that QueryResult behaves like a list for backward compatibility."""
71+
sa.attach_items(self.PROJECT_NAME, os.path.join(DATA_SET_PATH, "100_urls.csv"))
72+
result = sa.query(self.PROJECT_NAME, "metadata(status = NotStarted)")
73+
74+
# Test len()
75+
self.assertEqual(len(result), 100)
76+
77+
# Test indexing
78+
first_item = result[0]
79+
self.assertIsInstance(first_item, dict)
80+
self.assertIn("name", first_item)
81+
82+
# Test negative indexing
83+
last_item = result[-1]
84+
self.assertIsInstance(last_item, dict)
85+
86+
# Test slicing
87+
sliced = result[0:5]
88+
self.assertEqual(len(sliced), 5)
89+
90+
# Test iteration
91+
count = 0
92+
for item in result:
93+
self.assertIsInstance(item, dict)
94+
count += 1
95+
self.assertEqual(count, 100)
96+
97+
# Test list conversion
98+
as_list = list(result)
99+
self.assertEqual(len(as_list), 100)
100+
self.assertIsInstance(as_list, list)
101+
102+
def test_query_result_count_method(self):
103+
"""Test that QueryResult.count() returns the count from server."""
104+
sa.attach_items(self.PROJECT_NAME, os.path.join(DATA_SET_PATH, "100_urls.csv"))
105+
result = sa.query(self.PROJECT_NAME, "metadata(status = NotStarted)")
106+
107+
# Test .count() method
108+
count = result.count()
109+
self.assertEqual(count, 100)
110+
self.assertIsInstance(count, int)
111+
112+
# Verify count matches len
113+
self.assertEqual(count, len(result))
114+
115+
def test_query_result_lazy_loading(self):
116+
"""Test that QueryResult.count() does not trigger data fetching."""
117+
sa.attach_items(self.PROJECT_NAME, os.path.join(DATA_SET_PATH, "100_urls.csv"))
118+
result = sa.query(self.PROJECT_NAME, "metadata(status = NotStarted)")
119+
120+
# Data should not be loaded yet
121+
self.assertIsNone(result._data)
122+
123+
# Calling count() should not load data
124+
count = result.count()
125+
self.assertEqual(count, 100)
126+
self.assertIsNone(result._data)
127+
128+
# Accessing data should trigger loading
129+
first_item = result[0]
130+
self.assertIsNotNone(result._data)
131+
self.assertIsInstance(first_item, dict)
132+
133+
def test_query_result_repr(self):
134+
"""Test that QueryResult repr shows the underlying list."""
135+
sa.attach_items(self.PROJECT_NAME, os.path.join(DATA_SET_PATH, "100_urls.csv"))
136+
result = sa.query(self.PROJECT_NAME, "metadata(status = NotStarted)")
137+
138+
# Test __repr__
139+
repr_str = repr(result)
140+
self.assertIsInstance(repr_str, str)
141+
self.assertTrue(repr_str.startswith("["))
142+
69143
def test_validate_saqul_query(self):
70144
try:
71145
self.assertRaises(

0 commit comments

Comments
 (0)