diff --git a/src/databricks/labs/mcp/servers/unity_catalog/cli.py b/src/databricks/labs/mcp/servers/unity_catalog/cli.py index 2d0b067..41897ad 100644 --- a/src/databricks/labs/mcp/servers/unity_catalog/cli.py +++ b/src/databricks/labs/mcp/servers/unity_catalog/cli.py @@ -1,11 +1,7 @@ -from typing import List, Optional -from pydantic import field_validator from functools import lru_cache -from pydantic import ( - Field, - AliasChoices, - model_validator, -) +from typing import List, Optional + +from pydantic import AliasChoices, Field, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -40,6 +36,12 @@ class CliSettings(BaseSettings): ), ) + enable_lakeview: bool = Field( + default=True, + description="Enable Lakeview dashboard management tools.", + validation_alias=AliasChoices("lv", "enable_lakeview"), + ) + def get_catalog_name(self): return self.schema_full_name.split(".")[0] if self.schema_full_name else None @@ -55,9 +57,13 @@ def split_genie_space_ids(cls, v): @model_validator(mode="after") def check_schema_name_or_genie_space_ids(self): - if not self.schema_full_name and not self.genie_space_ids: + if ( + not self.schema_full_name + and not self.genie_space_ids + and not self.enable_lakeview + ): raise ValueError( - "At least one of --schema (-s) or --genie-space-ids (-g) must be provided." + "At least one of --schema (-s), --genie-space-ids (-g), or --enable-lakeview must be provided." ) return self diff --git a/src/databricks/labs/mcp/servers/unity_catalog/tools/__init__.py b/src/databricks/labs/mcp/servers/unity_catalog/tools/__init__.py index 39b9af0..c8a8e56 100644 --- a/src/databricks/labs/mcp/servers/unity_catalog/tools/__init__.py +++ b/src/databricks/labs/mcp/servers/unity_catalog/tools/__init__.py @@ -1,21 +1,22 @@ import collections from typing import TypeAlias, Union + from mcp.server.fastmcp import FastMCP -from mcp.types import ( - TextContent, - ImageContent, - EmbeddedResource, -) +from mcp.types import EmbeddedResource, ImageContent, TextContent from databricks.labs.mcp._version import __version__ as VERSION from databricks.labs.mcp.servers.unity_catalog.cli import get_settings +from databricks.labs.mcp.servers.unity_catalog.tools.functions import ( + UCFunctionTool, + list_uc_function_tools, +) from databricks.labs.mcp.servers.unity_catalog.tools.genie import ( GenieTool, list_genie_tools, ) -from databricks.labs.mcp.servers.unity_catalog.tools.functions import ( - UCFunctionTool, - list_uc_function_tools, +from databricks.labs.mcp.servers.unity_catalog.tools.lakeview import ( + LakeviewTool, + list_lakeview_tools, ) from databricks.labs.mcp.servers.unity_catalog.tools.vector_search import ( VectorSearchTool, @@ -24,20 +25,21 @@ from databricks.labs.mcp.utils import logger Content: TypeAlias = Union[TextContent, ImageContent, EmbeddedResource] -AvailableTool = UCFunctionTool | VectorSearchTool | GenieTool +AvailableTool = UCFunctionTool | VectorSearchTool | GenieTool | LakeviewTool def list_all_tools(settings) -> list[AvailableTool]: """ - Returns a list of all available tools, including Genie tools, UC functions, and vector search tools. + Returns a list of all available tools, including Genie, Lakeview, UC functions, and vector search tools. This function aggregates tools from different sources and returns them in a single list. """ - - return ( - list_genie_tools(settings) - + list_vector_search_tools(settings) - + list_uc_function_tools(settings) - ) + tools = [] + tools += list_genie_tools(settings) + if getattr(settings, "enable_lakeview", True): + tools += list_lakeview_tools(settings) + tools += list_vector_search_tools(settings) + tools += list_uc_function_tools(settings) + return tools def _warn_if_duplicate_tool_names(tools: list[AvailableTool]): diff --git a/src/databricks/labs/mcp/servers/unity_catalog/tools/lakeview.py b/src/databricks/labs/mcp/servers/unity_catalog/tools/lakeview.py new file mode 100644 index 0000000..17ccde7 --- /dev/null +++ b/src/databricks/labs/mcp/servers/unity_catalog/tools/lakeview.py @@ -0,0 +1,346 @@ +import json +import logging + +from databricks.sdk import WorkspaceClient +from mcp.types import TextContent +from mcp.types import Tool as ToolSpec +from pydantic import BaseModel, Field + +from databricks.labs.mcp.servers.unity_catalog.tools.base_tool import BaseTool + +LOGGER = logging.getLogger(__name__) + + +# --- Input Schemas --- + + +class ListDashboardsInput(BaseModel): + page_size: int = Field( + default=25, description="Number of dashboards per page (1-100).", ge=1, le=100 + ) + page_token: str = Field( + default="", description="Pagination token from a previous response." + ) + + +class GetDashboardInput(BaseModel): + dashboard_id: str = Field(..., description="The ID of the dashboard to retrieve.") + + +class CreateDashboardInput(BaseModel): + display_name: str = Field(..., description="Display name of the new dashboard.") + warehouse_id: str = Field( + default="", + description="SQL warehouse ID for the dashboard. If empty, uses serverless compute.", + ) + parent_path: str = Field( + default="", + description="Workspace path where the dashboard will be created (e.g. /Users/user@example.com).", + ) + serialized_dashboard: str = Field( + default="", + description="JSON-encoded Lakeview dashboard definition. If empty, creates a blank dashboard.", + ) + + +class UpdateDashboardInput(BaseModel): + dashboard_id: str = Field(..., description="The ID of the dashboard to update.") + display_name: str = Field( + default="", description="New display name. Leave empty to keep current." + ) + warehouse_id: str = Field( + default="", + description="New SQL warehouse ID. Leave empty to keep current.", + ) + serialized_dashboard: str = Field( + default="", + description="Updated JSON-encoded Lakeview dashboard definition. Leave empty to keep current.", + ) + + +class DeleteDashboardInput(BaseModel): + dashboard_id: str = Field( + ..., description="The ID of the dashboard to move to trash." + ) + + +class PublishDashboardInput(BaseModel): + dashboard_id: str = Field(..., description="The ID of the dashboard to publish.") + warehouse_id: str = Field( + default="", + description="SQL warehouse ID for the published dashboard. If empty, uses serverless compute.", + ) + embed_credentials: bool = Field( + default=True, + description="Whether viewers can run queries using the publisher's credentials.", + ) + + +class UnpublishDashboardInput(BaseModel): + dashboard_id: str = Field(..., description="The ID of the dashboard to unpublish.") + + +class GetPublishedDashboardInput(BaseModel): + dashboard_id: str = Field( + ..., description="The ID of the dashboard whose published version to retrieve." + ) + + +# --- Tool Implementations --- + + +def _list_dashboards(client: WorkspaceClient, args) -> list[TextContent]: + model = ListDashboardsInput.model_validate(args) + response = client.lakeview.list( + page_size=model.page_size, + page_token=model.page_token or None, + ) + dashboards = [] + for d in response: + dashboards.append( + { + "dashboard_id": d.dashboard_id, + "display_name": d.display_name, + "parent_path": d.parent_path, + "lifecycle_state": ( + d.lifecycle_state.value if d.lifecycle_state else None + ), + "create_time": str(d.create_time) if d.create_time else None, + "update_time": str(d.update_time) if d.update_time else None, + } + ) + if len(dashboards) >= model.page_size: + break + return [ + TextContent( + type="text", text=json.dumps(dashboards, default=str, separators=(",", ":")) + ) + ] + + +def _get_dashboard(client: WorkspaceClient, args) -> list[TextContent]: + model = GetDashboardInput.model_validate(args) + dashboard = client.lakeview.get(model.dashboard_id) + result = { + "dashboard_id": dashboard.dashboard_id, + "display_name": dashboard.display_name, + "parent_path": dashboard.parent_path, + "warehouse_id": dashboard.warehouse_id, + "lifecycle_state": ( + dashboard.lifecycle_state.value if dashboard.lifecycle_state else None + ), + "serialized_dashboard": dashboard.serialized_dashboard, + "create_time": str(dashboard.create_time) if dashboard.create_time else None, + "update_time": str(dashboard.update_time) if dashboard.update_time else None, + } + return [ + TextContent( + type="text", text=json.dumps(result, default=str, separators=(",", ":")) + ) + ] + + +def _create_dashboard(client: WorkspaceClient, args) -> list[TextContent]: + model = CreateDashboardInput.model_validate(args) + kwargs = {"display_name": model.display_name} + if model.warehouse_id: + kwargs["warehouse_id"] = model.warehouse_id + if model.parent_path: + kwargs["parent_path"] = model.parent_path + if model.serialized_dashboard: + kwargs["serialized_dashboard"] = model.serialized_dashboard + + dashboard = client.lakeview.create(**kwargs) + result = { + "dashboard_id": dashboard.dashboard_id, + "display_name": dashboard.display_name, + "parent_path": dashboard.parent_path, + "lifecycle_state": ( + dashboard.lifecycle_state.value if dashboard.lifecycle_state else None + ), + } + return [ + TextContent( + type="text", text=json.dumps(result, default=str, separators=(",", ":")) + ) + ] + + +def _update_dashboard(client: WorkspaceClient, args) -> list[TextContent]: + model = UpdateDashboardInput.model_validate(args) + kwargs = {"dashboard_id": model.dashboard_id} + if model.display_name: + kwargs["display_name"] = model.display_name + if model.warehouse_id: + kwargs["warehouse_id"] = model.warehouse_id + if model.serialized_dashboard: + kwargs["serialized_dashboard"] = model.serialized_dashboard + + dashboard = client.lakeview.update(**kwargs) + result = { + "dashboard_id": dashboard.dashboard_id, + "display_name": dashboard.display_name, + "lifecycle_state": ( + dashboard.lifecycle_state.value if dashboard.lifecycle_state else None + ), + "update_time": str(dashboard.update_time) if dashboard.update_time else None, + } + return [ + TextContent( + type="text", text=json.dumps(result, default=str, separators=(",", ":")) + ) + ] + + +def _delete_dashboard(client: WorkspaceClient, args) -> list[TextContent]: + model = DeleteDashboardInput.model_validate(args) + client.lakeview.trash(model.dashboard_id) + return [ + TextContent( + type="text", + text=json.dumps( + {"dashboard_id": model.dashboard_id, "status": "trashed"}, + separators=(",", ":"), + ), + ) + ] + + +def _publish_dashboard(client: WorkspaceClient, args) -> list[TextContent]: + model = PublishDashboardInput.model_validate(args) + kwargs = { + "dashboard_id": model.dashboard_id, + "embed_credentials": model.embed_credentials, + } + if model.warehouse_id: + kwargs["warehouse_id"] = model.warehouse_id + + published = client.lakeview.publish(**kwargs) + result = { + "dashboard_id": model.dashboard_id, + "warehouse_id": published.warehouse_id if published.warehouse_id else None, + "embed_credentials": ( + published.embed_credentials + if hasattr(published, "embed_credentials") + else None + ), + "revision_create_time": ( + str(published.revision_create_time) + if published.revision_create_time + else None + ), + } + return [ + TextContent( + type="text", text=json.dumps(result, default=str, separators=(",", ":")) + ) + ] + + +def _unpublish_dashboard(client: WorkspaceClient, args) -> list[TextContent]: + model = UnpublishDashboardInput.model_validate(args) + client.lakeview.unpublish(model.dashboard_id) + return [ + TextContent( + type="text", + text=json.dumps( + {"dashboard_id": model.dashboard_id, "status": "unpublished"}, + separators=(",", ":"), + ), + ) + ] + + +def _get_published_dashboard(client: WorkspaceClient, args) -> list[TextContent]: + model = GetPublishedDashboardInput.model_validate(args) + published = client.lakeview.get_published(model.dashboard_id) + result = { + "dashboard_id": model.dashboard_id, + "warehouse_id": published.warehouse_id if published.warehouse_id else None, + "embed_credentials": ( + published.embed_credentials + if hasattr(published, "embed_credentials") + else None + ), + "revision_create_time": ( + str(published.revision_create_time) + if published.revision_create_time + else None + ), + } + return [ + TextContent( + type="text", text=json.dumps(result, default=str, separators=(",", ":")) + ) + ] + + +# --- Tool Registry --- + + +class LakeviewTool(BaseTool): + + def __init__(self, name, description, input_schema, func): + self.func = func + tool_spec = ToolSpec( + name=name, + description=description, + inputSchema=input_schema, + ) + super().__init__(tool_spec) + + def execute(self, **kwargs): + return self.func(client=WorkspaceClient(), args=kwargs) + + +def list_lakeview_tools(settings) -> list[LakeviewTool]: + return [ + LakeviewTool( + name="lakeview_list_dashboards", + description="List AI/BI (Lakeview) dashboards in the workspace with pagination support.", + input_schema=ListDashboardsInput.model_json_schema(), + func=_list_dashboards, + ), + LakeviewTool( + name="lakeview_get_dashboard", + description="Get full metadata and serialized definition of an AI/BI dashboard.", + input_schema=GetDashboardInput.model_json_schema(), + func=_get_dashboard, + ), + LakeviewTool( + name="lakeview_create_dashboard", + description="Create a new AI/BI (Lakeview) dashboard in the workspace.", + input_schema=CreateDashboardInput.model_json_schema(), + func=_create_dashboard, + ), + LakeviewTool( + name="lakeview_update_dashboard", + description="Update an existing AI/BI dashboard's display name, warehouse, or definition.", + input_schema=UpdateDashboardInput.model_json_schema(), + func=_update_dashboard, + ), + LakeviewTool( + name="lakeview_delete_dashboard", + description="Move an AI/BI dashboard to trash. This is recoverable.", + input_schema=DeleteDashboardInput.model_json_schema(), + func=_delete_dashboard, + ), + LakeviewTool( + name="lakeview_publish_dashboard", + description="Publish an AI/BI dashboard, making it accessible via its published URL.", + input_schema=PublishDashboardInput.model_json_schema(), + func=_publish_dashboard, + ), + LakeviewTool( + name="lakeview_unpublish_dashboard", + description="Remove the published version of an AI/BI dashboard.", + input_schema=UnpublishDashboardInput.model_json_schema(), + func=_unpublish_dashboard, + ), + LakeviewTool( + name="lakeview_get_published_dashboard", + description="Get details of the published version of an AI/BI dashboard.", + input_schema=GetPublishedDashboardInput.model_json_schema(), + func=_get_published_dashboard, + ), + ] diff --git a/tests/test_cli.py b/tests/test_cli.py index 2c00795..82fe7d7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,7 +1,9 @@ import sys +from unittest.mock import patch + import pytest from pydantic import ValidationError -from unittest.mock import patch + from databricks.labs.mcp.servers.unity_catalog.cli import get_settings @@ -37,14 +39,8 @@ def test_arguments(catalog: str, schema: str) -> None: assert settings.get_schema_name() == schema -@pytest.mark.parametrize( - "argv", - [ - ["unitycatalog-mcp"], # neither -s nor -g - ["unitycatalog-mcp", "-s", "schema_no_catalog"], - ], -) -def test_required_arguments(argv) -> None: +def test_required_arguments_invalid_schema() -> None: + argv = ["unitycatalog-mcp", "-s", "schema_no_catalog"] with patch.object(sys, "argv", argv): with pytest.raises(ValidationError) as exc_info: get_settings() @@ -55,6 +51,26 @@ def test_required_arguments(argv) -> None: ) +def test_required_arguments_all_disabled() -> None: + from databricks.labs.mcp.servers.unity_catalog.cli import CliSettings + + with pytest.raises(ValidationError): + CliSettings( + schema_full_name=None, + genie_space_ids=[], + enable_lakeview=False, + _env_file=None, + _cli_parse_args=False, + ) + + +def test_lakeview_enabled_by_default() -> None: + argv = ["unitycatalog-mcp"] + with patch.object(sys, "argv", argv): + settings = get_settings() + assert settings.enable_lakeview is True + + @pytest.mark.parametrize( "argv,expected_catalog,expected_schema,expected_genie_space_ids", [ diff --git a/tests/test_lakeview.py b/tests/test_lakeview.py new file mode 100644 index 0000000..d00a153 --- /dev/null +++ b/tests/test_lakeview.py @@ -0,0 +1,272 @@ +import json +from unittest import mock + +from databricks.labs.mcp.servers.unity_catalog.tools.lakeview import ( + LakeviewTool, + _create_dashboard, + _delete_dashboard, + _get_dashboard, + _get_published_dashboard, + _list_dashboards, + _publish_dashboard, + _unpublish_dashboard, + _update_dashboard, + list_lakeview_tools, +) + + +class DummySettings: + schema_full_name = "main.default" + genie_space_ids = [] + + +class DummyLifecycleState: + value = "ACTIVE" + + +class DummyDashboard: + def __init__(self, dashboard_id="d1", display_name="Test Dashboard"): + self.dashboard_id = dashboard_id + self.display_name = display_name + self.parent_path = "/Users/test@example.com" + self.warehouse_id = "wh-123" + self.lifecycle_state = DummyLifecycleState() + self.serialized_dashboard = '{"pages":[]}' + self.create_time = "2024-01-01T00:00:00Z" + self.update_time = "2024-01-02T00:00:00Z" + + +class DummyPublishedDashboard: + def __init__(self): + self.warehouse_id = "wh-123" + self.embed_credentials = True + self.revision_create_time = "2024-01-03T00:00:00Z" + + +class DummyLakeviewAPI: + def list(self, page_size=None, page_token=None): + return [ + DummyDashboard("d1", "Dashboard 1"), + DummyDashboard("d2", "Dashboard 2"), + ] + + def get(self, dashboard_id): + return DummyDashboard(dashboard_id, "Retrieved Dashboard") + + def create(self, **kwargs): + d = DummyDashboard("new-id", kwargs.get("display_name", "New")) + return d + + def update(self, **kwargs): + d = DummyDashboard( + kwargs["dashboard_id"], kwargs.get("display_name", "Updated") + ) + return d + + def trash(self, dashboard_id): + return None + + def publish(self, **kwargs): + return DummyPublishedDashboard() + + def unpublish(self, dashboard_id): + return None + + def get_published(self, dashboard_id): + return DummyPublishedDashboard() + + +class DummyWorkspaceClient: + def __init__(self): + self.lakeview = DummyLakeviewAPI() + + +# --- Tool Listing Tests --- + + +def test_list_lakeview_tools_count_and_types(): + tools = list_lakeview_tools(DummySettings()) + assert len(tools) == 8 + assert all(isinstance(t, LakeviewTool) for t in tools) + + +def test_list_lakeview_tools_names(): + tools = list_lakeview_tools(DummySettings()) + names = {t.tool_spec.name for t in tools} + expected = { + "lakeview_list_dashboards", + "lakeview_get_dashboard", + "lakeview_create_dashboard", + "lakeview_update_dashboard", + "lakeview_delete_dashboard", + "lakeview_publish_dashboard", + "lakeview_unpublish_dashboard", + "lakeview_get_published_dashboard", + } + assert names == expected + + +def test_list_lakeview_tools_have_descriptions(): + tools = list_lakeview_tools(DummySettings()) + for tool in tools: + assert tool.tool_spec.description + assert len(tool.tool_spec.description) > 10 + + +def test_list_lakeview_tools_have_input_schemas(): + tools = list_lakeview_tools(DummySettings()) + for tool in tools: + assert tool.tool_spec.inputSchema is not None + assert "properties" in tool.tool_spec.inputSchema + + +# --- Execution Tests --- + + +def test_list_dashboards(): + client = DummyWorkspaceClient() + result = _list_dashboards(client, {}) + assert isinstance(result, list) + data = json.loads(result[0].text) + assert len(data) == 2 + assert data[0]["dashboard_id"] == "d1" + assert data[1]["display_name"] == "Dashboard 2" + + +def test_list_dashboards_with_page_size(): + client = DummyWorkspaceClient() + result = _list_dashboards(client, {"page_size": 1}) + data = json.loads(result[0].text) + assert len(data) == 1 + + +def test_get_dashboard(): + client = DummyWorkspaceClient() + result = _get_dashboard(client, {"dashboard_id": "d1"}) + data = json.loads(result[0].text) + assert data["dashboard_id"] == "d1" + assert data["display_name"] == "Retrieved Dashboard" + assert data["serialized_dashboard"] == '{"pages":[]}' + assert data["warehouse_id"] == "wh-123" + + +def test_create_dashboard(): + client = DummyWorkspaceClient() + result = _create_dashboard(client, {"display_name": "My New Dashboard"}) + data = json.loads(result[0].text) + assert data["dashboard_id"] == "new-id" + assert data["display_name"] == "My New Dashboard" + + +def test_create_dashboard_with_all_fields(): + client = DummyWorkspaceClient() + result = _create_dashboard( + client, + { + "display_name": "Full Dashboard", + "warehouse_id": "wh-456", + "parent_path": "/Users/test@example.com", + "serialized_dashboard": '{"pages":[{"name":"p1"}]}', + }, + ) + data = json.loads(result[0].text) + assert data["dashboard_id"] == "new-id" + + +def test_update_dashboard(): + client = DummyWorkspaceClient() + result = _update_dashboard( + client, {"dashboard_id": "d1", "display_name": "Renamed"} + ) + data = json.loads(result[0].text) + assert data["dashboard_id"] == "d1" + assert data["display_name"] == "Renamed" + + +def test_delete_dashboard(): + client = DummyWorkspaceClient() + result = _delete_dashboard(client, {"dashboard_id": "d1"}) + data = json.loads(result[0].text) + assert data["dashboard_id"] == "d1" + assert data["status"] == "trashed" + + +def test_publish_dashboard(): + client = DummyWorkspaceClient() + result = _publish_dashboard( + client, {"dashboard_id": "d1", "embed_credentials": True} + ) + data = json.loads(result[0].text) + assert data["dashboard_id"] == "d1" + assert data["warehouse_id"] == "wh-123" + + +def test_publish_dashboard_with_warehouse(): + client = DummyWorkspaceClient() + result = _publish_dashboard( + client, + { + "dashboard_id": "d1", + "warehouse_id": "wh-override", + "embed_credentials": False, + }, + ) + data = json.loads(result[0].text) + assert data["dashboard_id"] == "d1" + + +def test_unpublish_dashboard(): + client = DummyWorkspaceClient() + result = _unpublish_dashboard(client, {"dashboard_id": "d1"}) + data = json.loads(result[0].text) + assert data["dashboard_id"] == "d1" + assert data["status"] == "unpublished" + + +def test_get_published_dashboard(): + client = DummyWorkspaceClient() + result = _get_published_dashboard(client, {"dashboard_id": "d1"}) + data = json.loads(result[0].text) + assert data["dashboard_id"] == "d1" + assert data["warehouse_id"] == "wh-123" + assert data["embed_credentials"] is True + + +# --- LakeviewTool.execute() integration --- + + +@mock.patch( + "databricks.labs.mcp.servers.unity_catalog.tools.lakeview.WorkspaceClient", + new=DummyWorkspaceClient, +) +def test_lakeview_tool_execute_list(): + tools = list_lakeview_tools(DummySettings()) + list_tool = next(t for t in tools if t.tool_spec.name == "lakeview_list_dashboards") + result = list_tool.execute() + assert isinstance(result, list) + data = json.loads(result[0].text) + assert len(data) == 2 + + +@mock.patch( + "databricks.labs.mcp.servers.unity_catalog.tools.lakeview.WorkspaceClient", + new=DummyWorkspaceClient, +) +def test_lakeview_tool_execute_get(): + tools = list_lakeview_tools(DummySettings()) + get_tool = next(t for t in tools if t.tool_spec.name == "lakeview_get_dashboard") + result = get_tool.execute(dashboard_id="abc-123") + data = json.loads(result[0].text) + assert data["dashboard_id"] == "abc-123" + + +@mock.patch( + "databricks.labs.mcp.servers.unity_catalog.tools.lakeview.WorkspaceClient", + new=DummyWorkspaceClient, +) +def test_lakeview_tool_execute_delete(): + tools = list_lakeview_tools(DummySettings()) + del_tool = next(t for t in tools if t.tool_spec.name == "lakeview_delete_dashboard") + result = del_tool.execute(dashboard_id="abc-123") + data = json.loads(result[0].text) + assert data["status"] == "trashed"