diff --git a/src/databricks/labs/ucx/framework/utils.py b/src/databricks/labs/ucx/framework/utils.py index d428447911..9bec95bd80 100644 --- a/src/databricks/labs/ucx/framework/utils.py +++ b/src/databricks/labs/ucx/framework/utils.py @@ -1,5 +1,6 @@ import logging import subprocess +from collections.abc import Callable, Iterator logger = logging.getLogger(__name__) @@ -22,6 +23,59 @@ def escape_sql_identifier(path: str, *, maxsplit: int = 2) -> str: return ".".join(escaped) +def paginated_fetch_offset( + fetch_page: Callable[[int, int], dict], + items_key: str, + page_size: int, + start_index: int = 1, +) -> Iterator[dict]: + """Paginate a SCIM-style offset API (startIndex / count). + + Args: + fetch_page: Callable that takes (start_index, count) and returns a response dict. + items_key: Key in the response containing the list of items (e.g. "Resources"). + page_size: Number of items to request per page. + start_index: 1-based index to start from (SCIM default is 1). + + Yields: + Individual raw item dicts from each page. + """ + while True: + response = fetch_page(start_index, page_size) + items = response.get(items_key, []) + if not items: + break + yield from items + if len(items) < page_size: + break + start_index += len(items) + + +def paginated_fetch_cursor( + fetch_page: Callable[[str | None], dict], + items_key: str, + next_token_key: str = "next_page_token", +) -> Iterator[dict]: + """Paginate a cursor/token-based API. + + Args: + fetch_page: Callable that takes an optional page token and returns a response dict. + items_key: Key in the response containing the list of items. + next_token_key: Key in the response containing the next page token. + + Yields: + Individual raw item dicts from each page. + """ + token: str | None = None + while True: + response = fetch_page(token) + items = response.get(items_key, []) + yield from items + token = response.get(next_token_key) + if not token: + break + + def run_command(command: str | list[str]) -> tuple[int, str, str]: args = command.split() if isinstance(command, str) else command logger.info(f"Invoking command: {args!r}") diff --git a/src/databricks/labs/ucx/workspace_access/generic.py b/src/databricks/labs/ucx/workspace_access/generic.py index a646317ece..a9568ef92f 100644 --- a/src/databricks/labs/ucx/workspace_access/generic.py +++ b/src/databricks/labs/ucx/workspace_access/generic.py @@ -23,6 +23,8 @@ from databricks.sdk.service import iam, ml from databricks.sdk.service.iam import PermissionLevel +from databricks.labs.ucx.framework.utils import paginated_fetch_cursor + from databricks.labs.ucx.framework.crawlers import CrawlerBase from databricks.labs.ucx.framework.utils import escape_sql_identifier from databricks.labs.ucx.workspace_access.base import AclSupport, Permissions, StaticListing @@ -439,20 +441,16 @@ def inner() -> Iterator[ml.Experiment]: def feature_store_listing(ws: WorkspaceClient): def inner() -> list[GenericPermissionsInfo]: - feature_tables = [] - token = None - while True: + def fetch_feature_tables(token: str | None) -> dict: result = ws.api_client.do( "GET", "/api/2.0/feature-store/feature-tables/search", query={"page_token": token, "max_results": 200} ) assert isinstance(result, dict) - for table in result.get("feature_tables", []): - feature_tables.append(GenericPermissionsInfo(table["id"], "feature-tables")) - - if "next_page_token" not in result: - break - token = result["next_page_token"] # type: ignore[index] + return result + feature_tables = [] + for table in paginated_fetch_cursor(fetch_feature_tables, items_key="feature_tables"): + feature_tables.append(GenericPermissionsInfo(table["id"], "feature-tables")) return feature_tables return inner diff --git a/src/databricks/labs/ucx/workspace_access/groups.py b/src/databricks/labs/ucx/workspace_access/groups.py index 814043e1c6..8b4ff00e0d 100644 --- a/src/databricks/labs/ucx/workspace_access/groups.py +++ b/src/databricks/labs/ucx/workspace_access/groups.py @@ -24,7 +24,7 @@ from databricks.sdk.service.iam import Group, User from databricks.labs.ucx.framework.crawlers import CrawlerBase -from databricks.labs.ucx.framework.utils import escape_sql_identifier +from databricks.labs.ucx.framework.utils import escape_sql_identifier, paginated_fetch_offset logger = logging.getLogger(__name__) @@ -913,6 +913,8 @@ def _get_strategy( class AccountGroupLookup: + PAGE_SIZE = 10000 + def __init__(self, ws: WorkspaceClient): self._ws = ws @@ -964,13 +966,23 @@ def _list_account_groups(self, scim_attributes: str) -> list[iam.Group]: # TODO: we should avoid using this method, as it's not documented # get account-level groups even if they're not (yet) assigned to a workspace logger.info(f"Listing account groups with {scim_attributes}...") - account_groups = [] - raw = self._ws.api_client.do("GET", "/api/2.0/account/scim/v2/Groups", query={"attributes": scim_attributes}) - for resource in raw.get("Resources", []): # type: ignore[union-attr] + + def fetch_account_groups(start_index: int, count: int) -> dict: + query = {"startIndex": start_index, "count": count, "attributes": scim_attributes} + return self._ws.api_client.do("GET", "/api/2.0/account/scim/v2/Groups", query=query) # type: ignore[return-value] + + account_groups: list[iam.Group] = [] + seen: set[str] = set() + for resource in paginated_fetch_offset(fetch_account_groups, items_key="Resources", page_size=self.PAGE_SIZE): group = iam.Group.from_dict(resource) + if group.id and group.id in seen: + continue + if group.id: + seen.add(group.id) if group.display_name in SYSTEM_GROUPS: continue account_groups.append(group) + logger.info(f"Found {len(account_groups)} account groups") sorted_groups: list[iam.Group] = sorted( account_groups, key=lambda _: _.display_name if _.display_name else "" diff --git a/tests/unit/framework/test_utils.py b/tests/unit/framework/test_utils.py index 966a8c9945..a48ff7418b 100644 --- a/tests/unit/framework/test_utils.py +++ b/tests/unit/framework/test_utils.py @@ -1,6 +1,10 @@ import pytest -from databricks.labs.ucx.framework.utils import escape_sql_identifier +from databricks.labs.ucx.framework.utils import ( + escape_sql_identifier, + paginated_fetch_offset, + paginated_fetch_cursor, +) @pytest.mark.parametrize( @@ -34,3 +38,139 @@ def test_escaped_when_column_contains_period() -> None: expected = "`column.with.periods`" path = "column.with.periods" assert escape_sql_identifier(path, maxsplit=0) == expected + + +# --- paginated_fetch_offset tests --- + + +def test_offset_pagination_empty_first_page() -> None: + """No items returned on the first request.""" + pages: list[dict] = [{"Resources": []}] + + def fetch_page(_start_index: int, _count: int) -> dict: + return pages.pop(0) + + results = list(paginated_fetch_offset(fetch_page, items_key="Resources", page_size=10)) + assert not results + + +def test_offset_pagination_single_page() -> None: + """All items fit in a single page (fewer items than page_size).""" + items = [{"id": "1"}, {"id": "2"}] + pages = [{"Resources": items}] + + def fetch_page(_start_index: int, _count: int) -> dict: + return pages.pop(0) + + results = list(paginated_fetch_offset(fetch_page, items_key="Resources", page_size=10)) + assert results == items + + +def test_offset_pagination_multiple_pages() -> None: + """Items span multiple pages; pagination advances startIndex correctly.""" + page1 = [{"id": "1"}, {"id": "2"}] + page2 = [{"id": "3"}] + pages = [{"Resources": page1}, {"Resources": page2}] + captured_calls: list[tuple[int, int]] = [] + + def fetch_page(start_index: int, count: int) -> dict: + captured_calls.append((start_index, count)) + return pages.pop(0) + + results = list(paginated_fetch_offset(fetch_page, items_key="Resources", page_size=2)) + assert results == [{"id": "1"}, {"id": "2"}, {"id": "3"}] + assert captured_calls[0] == (1, 2) + assert captured_calls[1] == (3, 2) + + +def test_offset_pagination_terminates_on_missing_key() -> None: + """Stops when the items key is missing from the response entirely.""" + pages = [{"other_key": "value"}] + + def fetch_page(_start_index: int, _count: int) -> dict: + return pages.pop(0) + + results = list(paginated_fetch_offset(fetch_page, items_key="Resources", page_size=10)) + assert not results + + +def test_offset_pagination_respects_start_index() -> None: + """Custom start_index is passed to the first request.""" + pages = [{"Resources": [{"id": "5"}]}] + captured_calls: list[tuple[int, int]] = [] + + def fetch_page(start_index: int, count: int) -> dict: + captured_calls.append((start_index, count)) + return pages.pop(0) + + results = list(paginated_fetch_offset(fetch_page, items_key="Resources", page_size=10, start_index=5)) + assert results == [{"id": "5"}] + assert captured_calls[0] == (5, 10) + + +# --- paginated_fetch_cursor tests --- + + +def test_cursor_pagination_empty_first_page() -> None: + """No items returned on the first request.""" + pages: list[dict] = [{"feature_tables": []}] + + def fetch_page(_token: str | None) -> dict: + return pages.pop(0) + + results = list(paginated_fetch_cursor(fetch_page, items_key="feature_tables")) + assert not results + + +def test_cursor_pagination_single_page_no_token() -> None: + """All items in one page, no next_page_token in response.""" + items = [{"id": "t1"}, {"id": "t2"}] + pages = [{"feature_tables": items}] + + def fetch_page(_token: str | None) -> dict: + return pages.pop(0) + + results = list(paginated_fetch_cursor(fetch_page, items_key="feature_tables")) + assert results == items + + +def test_cursor_pagination_multiple_pages() -> None: + """Items span multiple pages with cursor tokens.""" + pages: list[dict] = [ + {"feature_tables": [{"id": "t1"}], "next_page_token": "token_abc"}, + {"feature_tables": [{"id": "t2"}, {"id": "t3"}]}, + ] + captured_tokens: list[str | None] = [] + + def fetch_page(token: str | None) -> dict: + captured_tokens.append(token) + return pages.pop(0) + + results = list(paginated_fetch_cursor(fetch_page, items_key="feature_tables")) + assert results == [{"id": "t1"}, {"id": "t2"}, {"id": "t3"}] + assert captured_tokens == [None, "token_abc"] + + +def test_cursor_pagination_custom_token_key() -> None: + """Supports a custom key name for the next page token.""" + pages: list[dict] = [ + {"items": [{"id": "1"}], "continuation": "xyz"}, + {"items": [{"id": "2"}]}, + ] + + def fetch_page(_token: str | None) -> dict: + return pages.pop(0) + + results = list(paginated_fetch_cursor(fetch_page, items_key="items", next_token_key="continuation")) + assert results == [{"id": "1"}, {"id": "2"}] + + +def test_cursor_pagination_terminates_on_missing_items_key() -> None: + """Stops when the items key is missing from the response.""" + pages = [{"other": "data"}] + + def fetch_page(_token: str | None) -> dict: + return pages.pop(0) + + results = list(paginated_fetch_cursor(fetch_page, items_key="feature_tables")) + assert not results diff --git a/tests/unit/workspace_access/test_generic.py b/tests/unit/workspace_access/test_generic.py index 860b17a9c5..2a3572a7bf 100644 --- a/tests/unit/workspace_access/test_generic.py +++ b/tests/unit/workspace_access/test_generic.py @@ -21,6 +21,7 @@ from databricks.sdk.service.workspace import Language, ObjectInfo, ObjectType from databricks.labs.ucx.workspace_access.generic import ( + GenericPermissionsInfo, GenericPermissionsSupport, Listing, Permissions, @@ -872,6 +873,25 @@ def do_api_side_effect(*_, query): assert result[0].request_type == "feature-tables" +def test_feature_store_listing_collects_all_pages(): + ws = create_autospec(WorkspaceClient) + page1_tables = [{"id": f"table{i}"} for i in range(200)] + page2_tables = [{"id": f"table{i}"} for i in range(200, 250)] + + def do_api_side_effect(*_, query): + if not query["page_token"]: + return {"feature_tables": page1_tables, "next_page_token": "token2"} + return {"feature_tables": page2_tables} + + ws.api_client.do.side_effect = do_api_side_effect + + result = feature_store_listing(ws)() + + assert len(result) == 250 + assert all(isinstance(item, GenericPermissionsInfo) for item in result) + assert all(item.request_type == "feature-tables" for item in result) + + def test_root_page_listing(): ws = create_autospec(WorkspaceClient) diff --git a/tests/unit/workspace_access/test_groups.py b/tests/unit/workspace_access/test_groups.py index 1adc823fc2..d4f07c47bc 100644 --- a/tests/unit/workspace_access/test_groups.py +++ b/tests/unit/workspace_access/test_groups.py @@ -14,6 +14,7 @@ from databricks.sdk.service.iam import ComplexValue, Group, ResourceMeta from databricks.labs.ucx.workspace_access.groups import ( + AccountGroupLookup, ConfigureGroups, GroupManager, MigratedGroup, @@ -568,6 +569,83 @@ def reflect_account_side_effect(method, *_, **__): ) +def test_list_account_groups_paginates_through_multiple_pages(): + """Verify _list_account_groups uses paginated_fetch_offset to retrieve all pages.""" + wsclient = create_autospec(WorkspaceClient) + + page1 = [ + Group(id="1", display_name="alpha").as_dict(), + Group(id="2", display_name="beta").as_dict(), + ] + page2 = [ + Group(id="3", display_name="gamma").as_dict(), + ] + responses = iter([{"Resources": page1}, {"Resources": page2}]) + + def do_side_effect(_method, *_args, **_kwargs): + return next(responses) + + wsclient.api_client.do.side_effect = do_side_effect + + with patch.object(AccountGroupLookup, "PAGE_SIZE", 2): + lookup = AccountGroupLookup(wsclient) + groups = lookup.get_mapping() + + assert len(groups) == 3 + assert "alpha" in groups + assert "beta" in groups + assert "gamma" in groups + assert wsclient.api_client.do.call_count == 2 + + +def test_list_account_groups_filters_system_groups_across_pages(): + """System groups should be filtered even when they appear on later pages.""" + wsclient = create_autospec(WorkspaceClient) + + page1 = [ + Group(id="1", display_name="real_group").as_dict(), + Group(id="2", display_name="users").as_dict(), + ] + responses = iter([{"Resources": page1}]) + + def do_side_effect(_method, *_args, **_kwargs): + return next(responses) + + wsclient.api_client.do.side_effect = do_side_effect + + lookup = AccountGroupLookup(wsclient) + groups = lookup.get_mapping() + + assert len(groups) == 1 + assert "real_group" in groups + assert "users" not in groups + + +def test_list_account_groups_deduplicates_across_pages(): + """Duplicate group IDs across pages should be counted only once.""" + wsclient = create_autospec(WorkspaceClient) + + duplicate = Group(id="1", display_name="alpha").as_dict() + # page1 fills PAGE_SIZE (2), so pagination continues; page2 is shorter, so it stops. + # alpha appears on both pages and must only be counted once. + page1 = [duplicate, Group(id="2", display_name="beta").as_dict()] + page2 = [duplicate] + responses = iter([{"Resources": page1}, {"Resources": page2}]) + + def do_side_effect(_method, *_args, **_kwargs): + return next(responses) + + wsclient.api_client.do.side_effect = do_side_effect + + with patch.object(AccountGroupLookup, "PAGE_SIZE", 2): + lookup = AccountGroupLookup(wsclient) + groups = lookup.get_mapping() + + assert len(groups) == 2 + assert "alpha" in groups + assert "beta" in groups + + def test_delete_original_workspace_groups_should_delete_reflected_acc_groups_in_workspace(fake_sleep: Mock) -> None: account_id = "11" ws_id = "1"