Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions src/databricks/labs/ucx/framework/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import subprocess
from collections.abc import Callable, Iterator

logger = logging.getLogger(__name__)

Expand All @@ -22,6 +23,59 @@ def escape_sql_identifier(path: str, *, maxsplit: int = 2) -> str:
return ".".join(escaped)


def paginated_fetch_offset(
fetch_page: Callable[[int, int], dict],
items_key: str,
page_size: int,
start_index: int = 1,
) -> Iterator[dict]:
"""Paginate a SCIM-style offset API (startIndex / count).

Args:
fetch_page: Callable that takes (start_index, count) and returns a response dict.
items_key: Key in the response containing the list of items (e.g. "Resources").
page_size: Number of items to request per page.
start_index: 1-based index to start from (SCIM default is 1).

Yields:
Individual raw item dicts from each page.
"""
while True:
response = fetch_page(start_index, page_size)
items = response.get(items_key, [])
if not items:
break
yield from items
if len(items) < page_size:
break
start_index += len(items)


def paginated_fetch_cursor(
fetch_page: Callable[[str | None], dict],
items_key: str,
next_token_key: str = "next_page_token",
) -> Iterator[dict]:
"""Paginate a cursor/token-based API.

Args:
fetch_page: Callable that takes an optional page token and returns a response dict.
items_key: Key in the response containing the list of items.
next_token_key: Key in the response containing the next page token.

Yields:
Individual raw item dicts from each page.
"""
token: str | None = None
while True:
response = fetch_page(token)
items = response.get(items_key, [])
yield from items
token = response.get(next_token_key)
if not token:
break


def run_command(command: str | list[str]) -> tuple[int, str, str]:
args = command.split() if isinstance(command, str) else command
logger.info(f"Invoking command: {args!r}")
Expand Down
16 changes: 7 additions & 9 deletions src/databricks/labs/ucx/workspace_access/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from databricks.sdk.service import iam, ml
from databricks.sdk.service.iam import PermissionLevel

from databricks.labs.ucx.framework.utils import paginated_fetch_cursor

from databricks.labs.ucx.framework.crawlers import CrawlerBase
from databricks.labs.ucx.framework.utils import escape_sql_identifier
from databricks.labs.ucx.workspace_access.base import AclSupport, Permissions, StaticListing
Expand Down Expand Up @@ -439,20 +441,16 @@ def inner() -> Iterator[ml.Experiment]:

def feature_store_listing(ws: WorkspaceClient):
def inner() -> list[GenericPermissionsInfo]:
feature_tables = []
token = None
while True:
def fetch_feature_tables(token: str | None) -> dict:
result = ws.api_client.do(
"GET", "/api/2.0/feature-store/feature-tables/search", query={"page_token": token, "max_results": 200}
)
assert isinstance(result, dict)
for table in result.get("feature_tables", []):
feature_tables.append(GenericPermissionsInfo(table["id"], "feature-tables"))

if "next_page_token" not in result:
break
token = result["next_page_token"] # type: ignore[index]
return result

feature_tables = []
for table in paginated_fetch_cursor(fetch_feature_tables, items_key="feature_tables"):
feature_tables.append(GenericPermissionsInfo(table["id"], "feature-tables"))
return feature_tables

return inner
Expand Down
20 changes: 16 additions & 4 deletions src/databricks/labs/ucx/workspace_access/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from databricks.sdk.service.iam import Group, User

from databricks.labs.ucx.framework.crawlers import CrawlerBase
from databricks.labs.ucx.framework.utils import escape_sql_identifier
from databricks.labs.ucx.framework.utils import escape_sql_identifier, paginated_fetch_offset

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -913,6 +913,8 @@ def _get_strategy(


class AccountGroupLookup:
PAGE_SIZE = 10000

def __init__(self, ws: WorkspaceClient):
self._ws = ws

Expand Down Expand Up @@ -964,13 +966,23 @@ def _list_account_groups(self, scim_attributes: str) -> list[iam.Group]:
# TODO: we should avoid using this method, as it's not documented
# get account-level groups even if they're not (yet) assigned to a workspace
logger.info(f"Listing account groups with {scim_attributes}...")
account_groups = []
raw = self._ws.api_client.do("GET", "/api/2.0/account/scim/v2/Groups", query={"attributes": scim_attributes})
for resource in raw.get("Resources", []): # type: ignore[union-attr]

def fetch_account_groups(start_index: int, count: int) -> dict:
query = {"startIndex": start_index, "count": count, "attributes": scim_attributes}
return self._ws.api_client.do("GET", "/api/2.0/account/scim/v2/Groups", query=query) # type: ignore[return-value]

account_groups: list[iam.Group] = []
seen: set[str] = set()
for resource in paginated_fetch_offset(fetch_account_groups, items_key="Resources", page_size=self.PAGE_SIZE):
group = iam.Group.from_dict(resource)
if group.id and group.id in seen:
continue
if group.id:
seen.add(group.id)
if group.display_name in SYSTEM_GROUPS:
continue
account_groups.append(group)

logger.info(f"Found {len(account_groups)} account groups")
sorted_groups: list[iam.Group] = sorted(
account_groups, key=lambda _: _.display_name if _.display_name else ""
Expand Down
142 changes: 141 additions & 1 deletion tests/unit/framework/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import pytest

from databricks.labs.ucx.framework.utils import escape_sql_identifier
from databricks.labs.ucx.framework.utils import (
escape_sql_identifier,
paginated_fetch_offset,
paginated_fetch_cursor,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -34,3 +38,139 @@ def test_escaped_when_column_contains_period() -> None:
expected = "`column.with.periods`"
path = "column.with.periods"
assert escape_sql_identifier(path, maxsplit=0) == expected


# --- paginated_fetch_offset tests ---


def test_offset_pagination_empty_first_page() -> None:
"""No items returned on the first request."""
pages: list[dict] = [{"Resources": []}]

def fetch_page(_start_index: int, _count: int) -> dict:
return pages.pop(0)

results = list(paginated_fetch_offset(fetch_page, items_key="Resources", page_size=10))
assert not results


def test_offset_pagination_single_page() -> None:
"""All items fit in a single page (fewer items than page_size)."""
items = [{"id": "1"}, {"id": "2"}]
pages = [{"Resources": items}]

def fetch_page(_start_index: int, _count: int) -> dict:
return pages.pop(0)

results = list(paginated_fetch_offset(fetch_page, items_key="Resources", page_size=10))
assert results == items


def test_offset_pagination_multiple_pages() -> None:
"""Items span multiple pages; pagination advances startIndex correctly."""
page1 = [{"id": "1"}, {"id": "2"}]
page2 = [{"id": "3"}]
pages = [{"Resources": page1}, {"Resources": page2}]
captured_calls: list[tuple[int, int]] = []

def fetch_page(start_index: int, count: int) -> dict:
captured_calls.append((start_index, count))
return pages.pop(0)

results = list(paginated_fetch_offset(fetch_page, items_key="Resources", page_size=2))
assert results == [{"id": "1"}, {"id": "2"}, {"id": "3"}]
assert captured_calls[0] == (1, 2)
assert captured_calls[1] == (3, 2)


def test_offset_pagination_terminates_on_missing_key() -> None:
"""Stops when the items key is missing from the response entirely."""
pages = [{"other_key": "value"}]

def fetch_page(_start_index: int, _count: int) -> dict:
return pages.pop(0)

results = list(paginated_fetch_offset(fetch_page, items_key="Resources", page_size=10))
assert not results


def test_offset_pagination_respects_start_index() -> None:
"""Custom start_index is passed to the first request."""
pages = [{"Resources": [{"id": "5"}]}]
captured_calls: list[tuple[int, int]] = []

def fetch_page(start_index: int, count: int) -> dict:
captured_calls.append((start_index, count))
return pages.pop(0)

results = list(paginated_fetch_offset(fetch_page, items_key="Resources", page_size=10, start_index=5))
assert results == [{"id": "5"}]
assert captured_calls[0] == (5, 10)


# --- paginated_fetch_cursor tests ---


def test_cursor_pagination_empty_first_page() -> None:
"""No items returned on the first request."""
pages: list[dict] = [{"feature_tables": []}]

def fetch_page(_token: str | None) -> dict:
return pages.pop(0)

results = list(paginated_fetch_cursor(fetch_page, items_key="feature_tables"))
assert not results


def test_cursor_pagination_single_page_no_token() -> None:
"""All items in one page, no next_page_token in response."""
items = [{"id": "t1"}, {"id": "t2"}]
pages = [{"feature_tables": items}]

def fetch_page(_token: str | None) -> dict:
return pages.pop(0)

results = list(paginated_fetch_cursor(fetch_page, items_key="feature_tables"))
assert results == items


def test_cursor_pagination_multiple_pages() -> None:
"""Items span multiple pages with cursor tokens."""
pages: list[dict] = [
Comment thread
pritishpai marked this conversation as resolved.
{"feature_tables": [{"id": "t1"}], "next_page_token": "token_abc"},
{"feature_tables": [{"id": "t2"}, {"id": "t3"}]},
]
captured_tokens: list[str | None] = []

def fetch_page(token: str | None) -> dict:
captured_tokens.append(token)
return pages.pop(0)

results = list(paginated_fetch_cursor(fetch_page, items_key="feature_tables"))
assert results == [{"id": "t1"}, {"id": "t2"}, {"id": "t3"}]
assert captured_tokens == [None, "token_abc"]


def test_cursor_pagination_custom_token_key() -> None:
"""Supports a custom key name for the next page token."""
pages: list[dict] = [
{"items": [{"id": "1"}], "continuation": "xyz"},
{"items": [{"id": "2"}]},
]

def fetch_page(_token: str | None) -> dict:
return pages.pop(0)

results = list(paginated_fetch_cursor(fetch_page, items_key="items", next_token_key="continuation"))
assert results == [{"id": "1"}, {"id": "2"}]


def test_cursor_pagination_terminates_on_missing_items_key() -> None:
"""Stops when the items key is missing from the response."""
pages = [{"other": "data"}]

def fetch_page(_token: str | None) -> dict:
return pages.pop(0)

results = list(paginated_fetch_cursor(fetch_page, items_key="feature_tables"))
assert not results
20 changes: 20 additions & 0 deletions tests/unit/workspace_access/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from databricks.sdk.service.workspace import Language, ObjectInfo, ObjectType

from databricks.labs.ucx.workspace_access.generic import (
GenericPermissionsInfo,
GenericPermissionsSupport,
Listing,
Permissions,
Expand Down Expand Up @@ -872,6 +873,25 @@ def do_api_side_effect(*_, query):
assert result[0].request_type == "feature-tables"


def test_feature_store_listing_collects_all_pages():
ws = create_autospec(WorkspaceClient)
page1_tables = [{"id": f"table{i}"} for i in range(200)]
page2_tables = [{"id": f"table{i}"} for i in range(200, 250)]

def do_api_side_effect(*_, query):
if not query["page_token"]:
return {"feature_tables": page1_tables, "next_page_token": "token2"}
return {"feature_tables": page2_tables}

ws.api_client.do.side_effect = do_api_side_effect

result = feature_store_listing(ws)()

assert len(result) == 250
assert all(isinstance(item, GenericPermissionsInfo) for item in result)
assert all(item.request_type == "feature-tables" for item in result)


def test_root_page_listing():
ws = create_autospec(WorkspaceClient)

Expand Down
Loading
Loading