From 10ea6da4213ab83ed500c69af06af53d49493eda Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 20 May 2026 22:48:20 +0000 Subject: [PATCH 1/9] feat(mcp): add list and get tools for row level security and plugins Co-Authored-By: Claude Sonnet 4.6 --- superset/mcp_service/app.py | 16 ++ superset/mcp_service/plugin/__init__.py | 16 ++ superset/mcp_service/plugin/dao.py | 23 ++ superset/mcp_service/plugin/schemas.py | 213 +++++++++++++++ superset/mcp_service/plugin/tool/__init__.py | 24 ++ .../plugin/tool/get_plugin_info.py | 101 +++++++ .../mcp_service/plugin/tool/list_plugins.py | 123 +++++++++ superset/mcp_service/rls/__init__.py | 16 ++ superset/mcp_service/rls/schemas.py | 255 ++++++++++++++++++ superset/mcp_service/rls/tool/__init__.py | 24 ++ .../rls/tool/get_rls_filter_info.py | 101 +++++++ .../mcp_service/rls/tool/list_rls_filters.py | 123 +++++++++ .../unit_tests/mcp_service/plugin/__init__.py | 16 ++ .../mcp_service/plugin/tool/__init__.py | 16 ++ .../plugin/tool/test_plugin_tools.py | 172 ++++++++++++ tests/unit_tests/mcp_service/rls/__init__.py | 16 ++ .../mcp_service/rls/tool/__init__.py | 16 ++ .../mcp_service/rls/tool/test_rls_tools.py | 222 +++++++++++++++ 18 files changed, 1493 insertions(+) create mode 100644 superset/mcp_service/plugin/__init__.py create mode 100644 superset/mcp_service/plugin/dao.py create mode 100644 superset/mcp_service/plugin/schemas.py create mode 100644 superset/mcp_service/plugin/tool/__init__.py create mode 100644 superset/mcp_service/plugin/tool/get_plugin_info.py create mode 100644 superset/mcp_service/plugin/tool/list_plugins.py create mode 100644 superset/mcp_service/rls/__init__.py create mode 100644 superset/mcp_service/rls/schemas.py create mode 100644 superset/mcp_service/rls/tool/__init__.py create mode 100644 superset/mcp_service/rls/tool/get_rls_filter_info.py create mode 100644 superset/mcp_service/rls/tool/list_rls_filters.py create mode 100644 tests/unit_tests/mcp_service/plugin/__init__.py create mode 100644 tests/unit_tests/mcp_service/plugin/tool/__init__.py create mode 100644 tests/unit_tests/mcp_service/plugin/tool/test_plugin_tools.py create mode 100644 tests/unit_tests/mcp_service/rls/__init__.py create mode 100644 tests/unit_tests/mcp_service/rls/tool/__init__.py create mode 100644 tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 81c6bd1f0886..a444834194a2 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -123,6 +123,14 @@ def get_default_instructions( - list_databases: List database connections with advanced filters (1-based pagination) - get_database_info: Get detailed database connection info by ID (backend, capabilities) +Row Level Security (Admin only): +- list_rls_filters: List RLS filters with filtering and search (1-based pagination) +- get_rls_filter_info: Get detailed RLS filter info by ID (tables, roles, clause) + +Plugins (Admin only): +- list_plugins: List dynamic plugins with filtering and search (1-based pagination) +- get_plugin_info: Get detailed plugin info by ID (name, key, bundle URL) + Dataset Management: - list_datasets: List datasets with advanced filters (1-based pagination) - get_dataset_info: Get detailed dataset information by ID (includes columns/metrics) @@ -636,6 +644,14 @@ def create_mcp_app( from superset.mcp_service.explore.tool import ( # noqa: F401, E402 generate_explore_link, ) +from superset.mcp_service.plugin.tool import ( # noqa: F401, E402 + get_plugin_info, + list_plugins, +) +from superset.mcp_service.rls.tool import ( # noqa: F401, E402 + get_rls_filter_info, + list_rls_filters, +) from superset.mcp_service.sql_lab.tool import ( # noqa: F401, E402 execute_sql, open_sql_lab_with_context, diff --git a/superset/mcp_service/plugin/__init__.py b/superset/mcp_service/plugin/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/superset/mcp_service/plugin/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/mcp_service/plugin/dao.py b/superset/mcp_service/plugin/dao.py new file mode 100644 index 000000000000..c5eb3e7f597c --- /dev/null +++ b/superset/mcp_service/plugin/dao.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from superset.daos.base import BaseDAO +from superset.models.dynamic_plugins import DynamicPlugin + + +class DynamicPluginDAO(BaseDAO[DynamicPlugin]): + pass diff --git a/superset/mcp_service/plugin/schemas.py b/superset/mcp_service/plugin/schemas.py new file mode 100644 index 000000000000..6283eff42da8 --- /dev/null +++ b/superset/mcp_service/plugin/schemas.py @@ -0,0 +1,213 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Pydantic schemas for dynamic plugin responses. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Annotated, Any, Dict, List, Literal + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_serializer, + model_validator, + PositiveInt, +) + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE +from superset.mcp_service.system.schemas import PaginationInfo +from superset.mcp_service.utils.schema_utils import ( + parse_json_or_list, + parse_json_or_model_list, +) + +DEFAULT_PLUGIN_COLUMNS = ["id", "name", "key", "bundle_url"] + +ALL_PLUGIN_COLUMNS = [ + "id", + "name", + "key", + "bundle_url", + "changed_on", + "created_on", +] + +SORTABLE_PLUGIN_COLUMNS = ["id", "name", "key", "changed_on", "created_on"] + + +class PluginColumnFilter(ColumnOperator): + """Filter object for plugin listing.""" + + col: Literal["name", "key"] = Field(..., description="Column to filter on.") + opr: ColumnOperatorEnum = Field(..., description="Operator to use.") + value: str | int | float | bool | List[str | int | float | bool] = Field( + ..., description="Value to filter by" + ) + + +class PluginInfo(BaseModel): + id: int | None = Field(None, description="Plugin ID") + name: str | None = Field(None, description="Plugin display name") + key: str | None = Field(None, description="Plugin key (corresponds to viz_type)") + bundle_url: str | None = Field(None, description="URL to the plugin bundle") + changed_on: str | datetime | None = Field( + None, description="Last modification timestamp" + ) + created_on: str | datetime | None = Field(None, description="Creation timestamp") + model_config = ConfigDict( + from_attributes=True, + ser_json_timedelta="iso8601", + populate_by_name=True, + ) + + @model_serializer(mode="wrap") + def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]: + data = serializer(self) + if info.context and isinstance(info.context, dict): + select_columns = info.context.get("select_columns") + if select_columns: + requested_fields = set(select_columns) + return {k: v for k, v in data.items() if k in requested_fields} + return data + + +class PluginList(BaseModel): + plugins: List[PluginInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: List[str] = Field(default_factory=list) + columns_loaded: List[str] = Field(default_factory=list) + columns_available: List[str] = Field(default_factory=list) + sortable_columns: List[str] = Field(default_factory=list) + filters_applied: List[PluginColumnFilter] = Field(default_factory=list) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ListPluginsRequest(BaseModel): + """Request schema for list_plugins.""" + + filters: Annotated[ + List[PluginColumnFilter], + Field( + default_factory=list, + description="List of filter objects (col, opr, value). " + "Cannot be used with search.", + ), + ] + select_columns: Annotated[ + List[str], + Field( + default_factory=list, + description="Columns to include in response. Defaults to common columns.", + ), + ] + search: Annotated[ + str | None, + Field( + default=None, + description="Text search on plugin name or key. " + "Cannot be used with filters.", + ), + ] + order_column: Annotated[ + str | None, Field(default=None, description="Column to order results by") + ] + order_direction: Annotated[ + Literal["asc", "desc"], + Field(default="desc", description="Sort direction"), + ] + page: Annotated[ + PositiveInt, + Field(default=1, description="Page number (1-based)"), + ] + page_size: Annotated[ + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Items per page (max {MAX_PAGE_SIZE})", + ), + ] + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v: Any) -> List[PluginColumnFilter]: + return parse_json_or_model_list(v, PluginColumnFilter, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_columns(cls, v: Any) -> List[str]: + return parse_json_or_list(v, "select_columns") + + @model_validator(mode="after") + def validate_search_and_filters(self) -> "ListPluginsRequest": + if self.search and self.filters: + raise ValueError("Cannot use both 'search' and 'filters' simultaneously.") + return self + + +class PluginError(BaseModel): + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Type of error") + timestamp: str | datetime | None = Field(None, description="Error timestamp") + model_config = ConfigDict(ser_json_timedelta="iso8601") + + @classmethod + def create(cls, error: str, error_type: str) -> "PluginError": + from datetime import timezone + + return cls( + error=error, error_type=error_type, timestamp=datetime.now(timezone.utc) + ) + + +class GetPluginInfoRequest(BaseModel): + """Request schema for get_plugin_info.""" + + identifier: Annotated[ + int, + Field(description="Plugin ID"), + ] + + +def serialize_plugin_object(plugin: Any) -> PluginInfo | None: + if not plugin: + return None + + return PluginInfo( + id=getattr(plugin, "id", None), + name=getattr(plugin, "name", None), + key=getattr(plugin, "key", None), + bundle_url=getattr(plugin, "bundle_url", None), + changed_on=getattr(plugin, "changed_on", None), + created_on=getattr(plugin, "created_on", None), + ) diff --git a/superset/mcp_service/plugin/tool/__init__.py b/superset/mcp_service/plugin/tool/__init__.py new file mode 100644 index 000000000000..4f2781fe9dfb --- /dev/null +++ b/superset/mcp_service/plugin/tool/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .get_plugin_info import get_plugin_info +from .list_plugins import list_plugins + +__all__ = [ + "list_plugins", + "get_plugin_info", +] diff --git a/superset/mcp_service/plugin/tool/get_plugin_info.py b/superset/mcp_service/plugin/tool/get_plugin_info.py new file mode 100644 index 000000000000..6c77a298b168 --- /dev/null +++ b/superset/mcp_service/plugin/tool/get_plugin_info.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Get plugin info FastMCP tool. +""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.mcp_core import ModelGetInfoCore +from superset.mcp_service.plugin.schemas import ( + GetPluginInfoRequest, + PluginError, + PluginInfo, + serialize_plugin_object, +) + +logger = logging.getLogger(__name__) + + +@tool( + tags=["discovery"], + class_permission_name="DynamicPlugin", + annotations=ToolAnnotations( + title="Get plugin info", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def get_plugin_info( + request: GetPluginInfoRequest, ctx: Context +) -> PluginInfo | PluginError: + """Get dynamic plugin details by ID. Requires admin access. + + Returns full plugin configuration including name, key, and bundle URL. + + Example usage: + ```json + {"identifier": 1} + ``` + """ + await ctx.info( + "Retrieving plugin information: identifier=%s" % (request.identifier,) + ) + + try: + from superset.mcp_service.plugin.dao import DynamicPluginDAO + + with event_logger.log_context(action="mcp.get_plugin_info.lookup"): + get_tool = ModelGetInfoCore( + dao_class=DynamicPluginDAO, + output_schema=PluginInfo, + error_schema=PluginError, + serializer=serialize_plugin_object, + supports_slug=False, + logger=logger, + ) + result = get_tool.run_tool(request.identifier) + + if isinstance(result, PluginInfo): + await ctx.info( + "Plugin retrieved: id=%s, name=%s, key=%s" + % (result.id, result.name, result.key) + ) + else: + await ctx.warning( + "Plugin retrieval failed: error_type=%s, error=%s" + % (result.error_type, result.error) + ) + + return result + + except Exception as e: + await ctx.error( + "Plugin info retrieval failed: identifier=%s, error=%s" + % (request.identifier, str(e)) + ) + return PluginError( + error=f"Failed to get plugin info: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/plugin/tool/list_plugins.py b/superset/mcp_service/plugin/tool/list_plugins.py new file mode 100644 index 000000000000..8c5de9ec22ff --- /dev/null +++ b/superset/mcp_service/plugin/tool/list_plugins.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +List plugins FastMCP tool. +""" + +import logging + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.mcp_core import ModelListCore +from superset.mcp_service.plugin.schemas import ( + ALL_PLUGIN_COLUMNS, + DEFAULT_PLUGIN_COLUMNS, + ListPluginsRequest, + PluginColumnFilter, + PluginError, + PluginInfo, + PluginList, + serialize_plugin_object, + SORTABLE_PLUGIN_COLUMNS, +) + +logger = logging.getLogger(__name__) + +_DEFAULT_LIST_PLUGINS_REQUEST = ListPluginsRequest() + + +@tool( + tags=["core"], + class_permission_name="DynamicPlugin", + annotations=ToolAnnotations( + title="List plugins", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def list_plugins( + request: ListPluginsRequest | None = None, + ctx: Context | None = None, +) -> PluginList | PluginError: + """List dynamic plugins registered in this Superset instance. Requires admin access. + + Returns plugin metadata including name, key, and bundle URL. + + Sortable columns for order_column: id, name, key, changed_on, created_on + """ + if ctx is None: + raise RuntimeError("FastMCP context is required for list_plugins") + + request = request or _DEFAULT_LIST_PLUGINS_REQUEST.model_copy(deep=True) + + await ctx.info( + "Listing plugins: page=%s, page_size=%s, search=%s" + % (request.page, request.page_size, request.search) + ) + + try: + from superset.mcp_service.plugin.dao import DynamicPluginDAO + + def _serialize(obj: object, cols: list[str] | None) -> PluginInfo | None: + return serialize_plugin_object(obj) + + list_tool = ModelListCore( + dao_class=DynamicPluginDAO, + output_schema=PluginInfo, + item_serializer=_serialize, + filter_type=PluginColumnFilter, + default_columns=DEFAULT_PLUGIN_COLUMNS, + search_columns=["name", "key"], + list_field_name="plugins", + output_list_schema=PluginList, + all_columns=ALL_PLUGIN_COLUMNS, + sortable_columns=SORTABLE_PLUGIN_COLUMNS, + logger=logger, + ) + + with event_logger.log_context(action="mcp.list_plugins.query"): + result = list_tool.run_tool( + filters=request.filters, + search=request.search, + select_columns=request.select_columns, + order_column=request.order_column, + order_direction=request.order_direction, + page=max(request.page - 1, 0), + page_size=request.page_size, + ) + + await ctx.info( + "Plugins listed: count=%s, total_count=%s" + % (len(result.plugins), result.total_count) + ) + + columns_to_filter = result.columns_requested + with event_logger.log_context(action="mcp.list_plugins.serialization"): + return result.model_dump( + mode="json", + context={"select_columns": columns_to_filter}, + ) + + except Exception as e: + await ctx.error( + "Plugin listing failed: error=%s, error_type=%s" + % (str(e), type(e).__name__) + ) + raise diff --git a/superset/mcp_service/rls/__init__.py b/superset/mcp_service/rls/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/superset/mcp_service/rls/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/mcp_service/rls/schemas.py b/superset/mcp_service/rls/schemas.py new file mode 100644 index 000000000000..37a7734f01de --- /dev/null +++ b/superset/mcp_service/rls/schemas.py @@ -0,0 +1,255 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Pydantic schemas for row level security filter responses. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Annotated, Any, Dict, List, Literal + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_serializer, + model_validator, + PositiveInt, +) + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE +from superset.mcp_service.system.schemas import PaginationInfo +from superset.mcp_service.utils.schema_utils import ( + parse_json_or_list, + parse_json_or_model_list, +) + +DEFAULT_RLS_COLUMNS = ["id", "name", "filter_type", "clause"] + +ALL_RLS_COLUMNS = [ + "id", + "name", + "filter_type", + "tables", + "roles", + "clause", + "group_key", + "changed_on", +] + +SORTABLE_RLS_COLUMNS = ["id", "name", "filter_type", "changed_on"] + + +class RlsColumnFilter(ColumnOperator): + """Filter object for RLS filter listing.""" + + col: Literal["name", "filter_type"] = Field( + ..., + description="Column to filter on.", + ) + opr: ColumnOperatorEnum = Field(..., description="Operator to use.") + value: str | int | float | bool | List[str | int | float | bool] = Field( + ..., description="Value to filter by" + ) + + +class RlsTableRef(BaseModel): + id: int | None = Field(None, description="Table ID") + table_name: str | None = Field(None, description="Table name") + model_config = ConfigDict(from_attributes=True) + + +class RlsRoleRef(BaseModel): + id: int | None = Field(None, description="Role ID") + name: str | None = Field(None, description="Role name") + model_config = ConfigDict(from_attributes=True) + + +class RlsFilterInfo(BaseModel): + id: int | None = Field(None, description="RLS filter ID") + name: str | None = Field(None, description="RLS filter name") + filter_type: str | None = Field(None, description="Filter type: Regular or Base") + tables: List[RlsTableRef] | None = Field( + None, description="Tables this filter applies to" + ) + roles: List[RlsRoleRef] | None = Field( + None, description="Roles this filter applies to" + ) + clause: str | None = Field(None, description="SQL WHERE clause") + group_key: str | None = Field( + None, description="Group key for Base filter grouping" + ) + changed_on: str | datetime | None = Field( + None, description="Last modification timestamp" + ) + model_config = ConfigDict( + from_attributes=True, + ser_json_timedelta="iso8601", + populate_by_name=True, + ) + + @model_serializer(mode="wrap") + def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]: + data = serializer(self) + if info.context and isinstance(info.context, dict): + select_columns = info.context.get("select_columns") + if select_columns: + requested_fields = set(select_columns) + return {k: v for k, v in data.items() if k in requested_fields} + return data + + +class RlsFilterList(BaseModel): + rls_filters: List[RlsFilterInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: List[str] = Field(default_factory=list) + columns_loaded: List[str] = Field(default_factory=list) + columns_available: List[str] = Field(default_factory=list) + sortable_columns: List[str] = Field(default_factory=list) + filters_applied: List[RlsColumnFilter] = Field(default_factory=list) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ListRlsFiltersRequest(BaseModel): + """Request schema for list_rls_filters.""" + + filters: Annotated[ + List[RlsColumnFilter], + Field( + default_factory=list, + description="List of filter objects (col, opr, value). " + "Cannot be used with search.", + ), + ] + select_columns: Annotated[ + List[str], + Field( + default_factory=list, + description="Columns to include in response. Defaults to common columns.", + ), + ] + search: Annotated[ + str | None, + Field( + default=None, + description="Text search on filter name. Cannot be used with filters.", + ), + ] + order_column: Annotated[ + str | None, Field(default=None, description="Column to order results by") + ] + order_direction: Annotated[ + Literal["asc", "desc"], + Field(default="desc", description="Sort direction"), + ] + page: Annotated[ + PositiveInt, + Field(default=1, description="Page number (1-based)"), + ] + page_size: Annotated[ + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Items per page (max {MAX_PAGE_SIZE})", + ), + ] + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v: Any) -> List[RlsColumnFilter]: + return parse_json_or_model_list(v, RlsColumnFilter, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_columns(cls, v: Any) -> List[str]: + return parse_json_or_list(v, "select_columns") + + @model_validator(mode="after") + def validate_search_and_filters(self) -> "ListRlsFiltersRequest": + if self.search and self.filters: + raise ValueError("Cannot use both 'search' and 'filters' simultaneously.") + return self + + +class RlsFilterError(BaseModel): + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Type of error") + timestamp: str | datetime | None = Field(None, description="Error timestamp") + model_config = ConfigDict(ser_json_timedelta="iso8601") + + @classmethod + def create(cls, error: str, error_type: str) -> "RlsFilterError": + from datetime import timezone + + return cls( + error=error, error_type=error_type, timestamp=datetime.now(timezone.utc) + ) + + +class GetRlsFilterInfoRequest(BaseModel): + """Request schema for get_rls_filter_info.""" + + identifier: Annotated[ + int, + Field(description="RLS filter ID"), + ] + + +def serialize_rls_filter_object(rls_filter: Any) -> RlsFilterInfo | None: + if not rls_filter: + return None + + tables = [ + RlsTableRef( + id=getattr(t, "id", None), + table_name=getattr(t, "table_name", None), + ) + for t in (getattr(rls_filter, "tables", None) or []) + ] + + roles = [ + RlsRoleRef( + id=getattr(r, "id", None), + name=getattr(r, "name", None), + ) + for r in (getattr(rls_filter, "roles", None) or []) + ] + + return RlsFilterInfo( + id=getattr(rls_filter, "id", None), + name=getattr(rls_filter, "name", None), + filter_type=getattr(rls_filter, "filter_type", None), + tables=tables, + roles=roles, + clause=getattr(rls_filter, "clause", None), + group_key=getattr(rls_filter, "group_key", None), + changed_on=getattr(rls_filter, "changed_on", None), + ) diff --git a/superset/mcp_service/rls/tool/__init__.py b/superset/mcp_service/rls/tool/__init__.py new file mode 100644 index 000000000000..c05033569f12 --- /dev/null +++ b/superset/mcp_service/rls/tool/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .get_rls_filter_info import get_rls_filter_info +from .list_rls_filters import list_rls_filters + +__all__ = [ + "list_rls_filters", + "get_rls_filter_info", +] diff --git a/superset/mcp_service/rls/tool/get_rls_filter_info.py b/superset/mcp_service/rls/tool/get_rls_filter_info.py new file mode 100644 index 000000000000..31c828689ac0 --- /dev/null +++ b/superset/mcp_service/rls/tool/get_rls_filter_info.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Get RLS filter info FastMCP tool. +""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.mcp_core import ModelGetInfoCore +from superset.mcp_service.rls.schemas import ( + GetRlsFilterInfoRequest, + RlsFilterError, + RlsFilterInfo, + serialize_rls_filter_object, +) + +logger = logging.getLogger(__name__) + + +@tool( + tags=["discovery"], + class_permission_name="Row Level Security", + annotations=ToolAnnotations( + title="Get RLS filter info", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def get_rls_filter_info( + request: GetRlsFilterInfoRequest, ctx: Context +) -> RlsFilterInfo | RlsFilterError: + """Get row level security filter details by ID. Requires admin access. + + Returns full RLS filter configuration including name, type, tables, roles, + and clause. + + Example usage: + ```json + {"identifier": 1} + ``` + """ + await ctx.info( + "Retrieving RLS filter information: identifier=%s" % (request.identifier,) + ) + + try: + from superset.daos.security import RLSDAO + + with event_logger.log_context(action="mcp.get_rls_filter_info.lookup"): + get_tool = ModelGetInfoCore( + dao_class=RLSDAO, + output_schema=RlsFilterInfo, + error_schema=RlsFilterError, + serializer=serialize_rls_filter_object, + supports_slug=False, + logger=logger, + ) + result = get_tool.run_tool(request.identifier) + + if isinstance(result, RlsFilterInfo): + await ctx.info( + "RLS filter retrieved: id=%s, name=%s" % (result.id, result.name) + ) + else: + await ctx.warning( + "RLS filter retrieval failed: error_type=%s, error=%s" + % (result.error_type, result.error) + ) + + return result + + except Exception as e: + await ctx.error( + "RLS filter info retrieval failed: identifier=%s, error=%s" + % (request.identifier, str(e)) + ) + return RlsFilterError( + error=f"Failed to get RLS filter info: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/rls/tool/list_rls_filters.py b/superset/mcp_service/rls/tool/list_rls_filters.py new file mode 100644 index 000000000000..b08b9bc32d31 --- /dev/null +++ b/superset/mcp_service/rls/tool/list_rls_filters.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +List RLS filters FastMCP tool. +""" + +import logging + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.mcp_core import ModelListCore +from superset.mcp_service.rls.schemas import ( + ALL_RLS_COLUMNS, + DEFAULT_RLS_COLUMNS, + ListRlsFiltersRequest, + RlsColumnFilter, + RlsFilterError, + RlsFilterInfo, + RlsFilterList, + serialize_rls_filter_object, + SORTABLE_RLS_COLUMNS, +) + +logger = logging.getLogger(__name__) + +_DEFAULT_LIST_RLS_FILTERS_REQUEST = ListRlsFiltersRequest() + + +@tool( + tags=["core"], + class_permission_name="Row Level Security", + annotations=ToolAnnotations( + title="List RLS filters", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def list_rls_filters( + request: ListRlsFiltersRequest | None = None, + ctx: Context | None = None, +) -> RlsFilterList | RlsFilterError: + """List row level security filters. Requires admin access. + + Returns RLS filter metadata including name, filter type, tables, roles, and clause. + + Sortable columns for order_column: id, name, filter_type, changed_on + """ + if ctx is None: + raise RuntimeError("FastMCP context is required for list_rls_filters") + + request = request or _DEFAULT_LIST_RLS_FILTERS_REQUEST.model_copy(deep=True) + + await ctx.info( + "Listing RLS filters: page=%s, page_size=%s, search=%s" + % (request.page, request.page_size, request.search) + ) + + try: + from superset.daos.security import RLSDAO + + def _serialize(obj: object, cols: list[str] | None) -> RlsFilterInfo | None: + return serialize_rls_filter_object(obj) + + list_tool = ModelListCore( + dao_class=RLSDAO, + output_schema=RlsFilterInfo, + item_serializer=_serialize, + filter_type=RlsColumnFilter, + default_columns=DEFAULT_RLS_COLUMNS, + search_columns=["name"], + list_field_name="rls_filters", + output_list_schema=RlsFilterList, + all_columns=ALL_RLS_COLUMNS, + sortable_columns=SORTABLE_RLS_COLUMNS, + logger=logger, + ) + + with event_logger.log_context(action="mcp.list_rls_filters.query"): + result = list_tool.run_tool( + filters=request.filters, + search=request.search, + select_columns=request.select_columns, + order_column=request.order_column, + order_direction=request.order_direction, + page=max(request.page - 1, 0), + page_size=request.page_size, + ) + + await ctx.info( + "RLS filters listed: count=%s, total_count=%s" + % (len(result.rls_filters), result.total_count) + ) + + columns_to_filter = result.columns_requested + with event_logger.log_context(action="mcp.list_rls_filters.serialization"): + return result.model_dump( + mode="json", + context={"select_columns": columns_to_filter}, + ) + + except Exception as e: + await ctx.error( + "RLS filter listing failed: error=%s, error_type=%s" + % (str(e), type(e).__name__) + ) + raise diff --git a/tests/unit_tests/mcp_service/plugin/__init__.py b/tests/unit_tests/mcp_service/plugin/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/plugin/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/plugin/tool/__init__.py b/tests/unit_tests/mcp_service/plugin/tool/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/plugin/tool/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/plugin/tool/test_plugin_tools.py b/tests/unit_tests/mcp_service/plugin/tool/test_plugin_tools.py new file mode 100644 index 000000000000..3afbf3ffbdf7 --- /dev/null +++ b/tests/unit_tests/mcp_service/plugin/tool/test_plugin_tools.py @@ -0,0 +1,172 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from unittest.mock import MagicMock, Mock, patch + +import pytest +from fastmcp import Client +from pydantic import ValidationError + +from superset.mcp_service.app import mcp +from superset.mcp_service.plugin.schemas import ListPluginsRequest, PluginColumnFilter +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +def create_mock_plugin( + plugin_id: int = 1, + name: str = "My Plugin", + key: str = "my_plugin", + bundle_url: str = "https://example.com/plugin.js", +) -> MagicMock: + plugin = MagicMock() + plugin.id = plugin_id + plugin.name = name + plugin.key = key + plugin.bundle_url = bundle_url + plugin.changed_on = None + plugin.created_on = None + return plugin + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +class TestPluginColumnFilterSchema: + def test_invalid_filter_column_rejected(self): + with pytest.raises(ValidationError): + PluginColumnFilter(col="bundle_url", opr="eq", value="test") + + def test_valid_name_filter(self): + f = PluginColumnFilter(col="name", opr="eq", value="test") + assert f.col == "name" + + def test_valid_key_filter(self): + f = PluginColumnFilter(col="key", opr="eq", value="my_plugin") + assert f.col == "key" + + +@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list") +@pytest.mark.asyncio +async def test_list_plugins_basic(mock_list, mcp_server): + plugin = create_mock_plugin() + mock_list.return_value = ([plugin], 1) + + async with Client(mcp_server) as client: + result = await client.call_tool("list_plugins", {}) + data = json.loads(result.content[0].text) + assert "plugins" in data + assert len(data["plugins"]) == 1 + assert data["plugins"][0]["id"] == 1 + assert data["plugins"][0]["name"] == "My Plugin" + assert data["plugins"][0]["key"] == "my_plugin" + + +@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list") +@pytest.mark.asyncio +async def test_list_plugins_with_request(mock_list, mcp_server): + plugin = create_mock_plugin() + mock_list.return_value = ([plugin], 1) + + async with Client(mcp_server) as client: + request = ListPluginsRequest(page=1, page_size=10) + result = await client.call_tool( + "list_plugins", {"request": request.model_dump()} + ) + data = json.loads(result.content[0].text) + assert data["count"] == 1 + assert data["total_count"] == 1 + + +@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list") +@pytest.mark.asyncio +async def test_list_plugins_with_search(mock_list, mcp_server): + plugin = create_mock_plugin(name="Custom Chart") + mock_list.return_value = ([plugin], 1) + + async with Client(mcp_server) as client: + request = ListPluginsRequest(page=1, page_size=10, search="custom") + result = await client.call_tool( + "list_plugins", {"request": request.model_dump()} + ) + data = json.loads(result.content[0].text) + assert data["plugins"][0]["name"] == "Custom Chart" + + +@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list") +@pytest.mark.asyncio +async def test_list_plugins_empty(mock_list, mcp_server): + mock_list.return_value = ([], 0) + + async with Client(mcp_server) as client: + result = await client.call_tool("list_plugins", {}) + data = json.loads(result.content[0].text) + assert data["count"] == 0 + assert data["plugins"] == [] + + +@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_plugin_info_basic(mock_find, mcp_server): + plugin = create_mock_plugin() + mock_find.return_value = plugin + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_plugin_info", {"request": {"identifier": 1}} + ) + data = json.loads(result.content[0].text) + assert data["id"] == 1 + assert data["name"] == "My Plugin" + assert data["key"] == "my_plugin" + assert data["bundle_url"] == "https://example.com/plugin.js" + + +@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_plugin_info_not_found(mock_find, mcp_server): + mock_find.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_plugin_info", {"request": {"identifier": 999}} + ) + data = json.loads(result.content[0].text) + assert data["error_type"] == "not_found" + + +def test_list_plugins_request_rejects_search_and_filters(): + with pytest.raises(ValidationError): + ListPluginsRequest( + search="test", + filters=[{"col": "name", "opr": "eq", "value": "x"}], + ) diff --git a/tests/unit_tests/mcp_service/rls/__init__.py b/tests/unit_tests/mcp_service/rls/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/rls/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/rls/tool/__init__.py b/tests/unit_tests/mcp_service/rls/tool/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/rls/tool/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py b/tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py new file mode 100644 index 000000000000..ab45317fab45 --- /dev/null +++ b/tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py @@ -0,0 +1,222 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from unittest.mock import MagicMock, Mock, patch + +import pytest +from fastmcp import Client +from pydantic import ValidationError + +from superset.mcp_service.app import mcp +from superset.mcp_service.rls.schemas import ListRlsFiltersRequest, RlsColumnFilter +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +def create_mock_rls_filter( + filter_id: int = 1, + name: str = "test_filter", + filter_type: str = "Regular", + clause: str = "user_id = {{current_user_id()}}", + group_key: str | None = None, +) -> MagicMock: + rls_filter = MagicMock() + rls_filter.id = filter_id + rls_filter.name = name + rls_filter.filter_type = filter_type + rls_filter.clause = clause + rls_filter.group_key = group_key + rls_filter.changed_on = None + + table = MagicMock() + table.id = 1 + table.table_name = "sales" + rls_filter.tables = [table] + + role = MagicMock() + role.id = 1 + role.name = "Alpha" + rls_filter.roles = [role] + + return rls_filter + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +class TestRlsColumnFilterSchema: + def test_invalid_filter_column_rejected(self): + with pytest.raises(ValidationError): + RlsColumnFilter(col="clause", opr="eq", value="test") + + def test_valid_name_filter(self): + f = RlsColumnFilter(col="name", opr="eq", value="test") + assert f.col == "name" + + def test_valid_filter_type_filter(self): + f = RlsColumnFilter(col="filter_type", opr="eq", value="Regular") + assert f.col == "filter_type" + + +@patch("superset.daos.security.RLSDAO.list") +@pytest.mark.asyncio +async def test_list_rls_filters_basic(mock_list, mcp_server): + rls_filter = create_mock_rls_filter() + mock_list.return_value = ([rls_filter], 1) + + async with Client(mcp_server) as client: + result = await client.call_tool("list_rls_filters", {}) + assert result.content is not None + data = json.loads(result.content[0].text) + assert "rls_filters" in data + assert len(data["rls_filters"]) == 1 + assert data["rls_filters"][0]["id"] == 1 + assert data["rls_filters"][0]["name"] == "test_filter" + + +@patch("superset.daos.security.RLSDAO.list") +@pytest.mark.asyncio +async def test_list_rls_filters_with_request(mock_list, mcp_server): + rls_filter = create_mock_rls_filter() + mock_list.return_value = ([rls_filter], 1) + + async with Client(mcp_server) as client: + request = ListRlsFiltersRequest(page=1, page_size=10) + result = await client.call_tool( + "list_rls_filters", {"request": request.model_dump()} + ) + data = json.loads(result.content[0].text) + assert data["count"] == 1 + assert data["total_count"] == 1 + + +@patch("superset.daos.security.RLSDAO.list") +@pytest.mark.asyncio +async def test_list_rls_filters_with_search(mock_list, mcp_server): + rls_filter = create_mock_rls_filter(name="user_filter") + mock_list.return_value = ([rls_filter], 1) + + async with Client(mcp_server) as client: + request = ListRlsFiltersRequest(page=1, page_size=10, search="user") + result = await client.call_tool( + "list_rls_filters", {"request": request.model_dump()} + ) + data = json.loads(result.content[0].text) + assert data["rls_filters"][0]["name"] == "user_filter" + + +@patch("superset.daos.security.RLSDAO.list") +@pytest.mark.asyncio +async def test_list_rls_filters_returns_tables_and_roles(mock_list, mcp_server): + rls_filter = create_mock_rls_filter() + mock_list.return_value = ([rls_filter], 1) + + async with Client(mcp_server) as client: + request = ListRlsFiltersRequest( + page=1, + page_size=10, + select_columns=["id", "name", "tables", "roles"], + ) + result = await client.call_tool( + "list_rls_filters", {"request": request.model_dump()} + ) + data = json.loads(result.content[0].text) + item = data["rls_filters"][0] + assert "tables" in item + assert item["tables"][0]["table_name"] == "sales" + assert "roles" in item + assert item["roles"][0]["name"] == "Alpha" + + +@patch("superset.daos.security.RLSDAO.list") +@pytest.mark.asyncio +async def test_list_rls_filters_empty(mock_list, mcp_server): + mock_list.return_value = ([], 0) + + async with Client(mcp_server) as client: + result = await client.call_tool("list_rls_filters", {}) + data = json.loads(result.content[0].text) + assert data["count"] == 0 + assert data["rls_filters"] == [] + + +@patch("superset.daos.security.RLSDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_rls_filter_info_basic(mock_find, mcp_server): + rls_filter = create_mock_rls_filter() + mock_find.return_value = rls_filter + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_rls_filter_info", {"request": {"identifier": 1}} + ) + data = json.loads(result.content[0].text) + assert data["id"] == 1 + assert data["name"] == "test_filter" + assert data["filter_type"] == "Regular" + assert data["clause"] == "user_id = {{current_user_id()}}" + + +@patch("superset.daos.security.RLSDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_rls_filter_info_not_found(mock_find, mcp_server): + mock_find.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_rls_filter_info", {"request": {"identifier": 999}} + ) + data = json.loads(result.content[0].text) + assert data["error_type"] == "not_found" + + +@patch("superset.daos.security.RLSDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_rls_filter_info_includes_tables_and_roles(mock_find, mcp_server): + rls_filter = create_mock_rls_filter() + mock_find.return_value = rls_filter + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_rls_filter_info", {"request": {"identifier": 1}} + ) + data = json.loads(result.content[0].text) + assert data["tables"][0]["table_name"] == "sales" + assert data["roles"][0]["name"] == "Alpha" + + +def test_list_rls_filters_request_rejects_search_and_filters(): + with pytest.raises(ValidationError): + ListRlsFiltersRequest( + search="test", + filters=[{"col": "name", "opr": "eq", "value": "x"}], + ) From 80faffb32a882a879b43edae4e09cff105025f94 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 21 May 2026 01:30:46 +0000 Subject: [PATCH 2/9] fix(mcp): remove 'roles' from USER_DIRECTORY_FIELDS to allow RLS filter roles to be returned RLS filter `roles` (which roles a filter applies to) are core RLS data, not user-directory metadata. Including 'roles' in USER_DIRECTORY_FIELDS caused filter_user_directory_columns() to strip it from any requested select_columns list, making it impossible to retrieve via list_rls_filters. No dashboard/chart/dataset schema defines a 'roles' field, so removing it from the block set has no privacy impact on other tools. Fixes test_list_rls_filters_returns_tables_and_roles. --- superset/mcp_service/privacy.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/superset/mcp_service/privacy.py b/superset/mcp_service/privacy.py index 0b64ec14e0ac..02d4982d23f7 100644 --- a/superset/mcp_service/privacy.py +++ b/superset/mcp_service/privacy.py @@ -40,7 +40,6 @@ "last_saved_by_name", "owner", "owners", - "roles", } ) @@ -140,7 +139,7 @@ def user_can_view_data_model_metadata() -> bool: def filter_user_directory_fields(data: dict[str, Any]) -> dict[str, Any]: - """Remove fields that expose users, roles, owners, or access metadata.""" + """Remove fields that expose users, owners, or access metadata.""" return { key: value for key, value in data.items() if key not in USER_DIRECTORY_FIELDS } From df6b03026916f453641d48c3751bc69cea091066 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 21 May 2026 02:12:18 +0000 Subject: [PATCH 3/9] ci: trigger CI for fix From f08a5103c7b2b3a3ffcd2fc44292e9c7edd1a146 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 21 May 2026 04:25:19 +0000 Subject: [PATCH 4/9] fix(mcp): restore 'roles' to USER_DIRECTORY_FIELDS and bypass filter in RLS list tool MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 'roles' on a dashboard/chart exposes who has access to the resource and should be stripped by the USER_DIRECTORY_FIELDS privacy filter. 'roles' in an RLS filter is which roles the filter applies to — it is core filter data, not user-directory metadata. The RLS list tool now derives its column selection directly from ALL_RLS_COLUMNS (bypassing ModelListCore's USER_DIRECTORY_FIELDS filtering) so that RLS roles are selectable while dashboard roles remain hidden. Fixes three failing unit tests: - test_list_dashboards_omits_requested_user_directory_fields - test_get_allowed_fields_always_denies_user_directory_fields - test_filter_sensitive_data_strips_user_directory_fields_even_if_allowed Co-Authored-By: Claude Sonnet 4.6 --- superset/mcp_service/privacy.py | 1 + superset/mcp_service/rls/tool/list_rls_filters.py | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/superset/mcp_service/privacy.py b/superset/mcp_service/privacy.py index 02d4982d23f7..86dc552e7890 100644 --- a/superset/mcp_service/privacy.py +++ b/superset/mcp_service/privacy.py @@ -40,6 +40,7 @@ "last_saved_by_name", "owner", "owners", + "roles", } ) diff --git a/superset/mcp_service/rls/tool/list_rls_filters.py b/superset/mcp_service/rls/tool/list_rls_filters.py index b08b9bc32d31..e72d032a96e6 100644 --- a/superset/mcp_service/rls/tool/list_rls_filters.py +++ b/superset/mcp_service/rls/tool/list_rls_filters.py @@ -108,7 +108,20 @@ def _serialize(obj: object, cols: list[str] | None) -> RlsFilterInfo | None: % (len(result.rls_filters), result.total_count) ) - columns_to_filter = result.columns_requested + # Build column selection using ALL_RLS_COLUMNS as the source of truth, + # bypassing the USER_DIRECTORY_FIELDS privacy filter applied by + # ModelListCore. 'roles' in an RLS filter is which roles the filter + # applies to — core filter data — not user-directory metadata (like + # dashboard.roles, which exposes who has access to the resource). + if request.select_columns: + columns_to_filter = [ + c for c in request.select_columns if c in ALL_RLS_COLUMNS + ] + if not columns_to_filter: + columns_to_filter = list(DEFAULT_RLS_COLUMNS) + else: + columns_to_filter = list(DEFAULT_RLS_COLUMNS) + with event_logger.log_context(action="mcp.list_rls_filters.serialization"): return result.model_dump( mode="json", From 6e3a32511dcc2bac3cc503e5cc76a3787a93119f Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 21 May 2026 08:13:49 +0000 Subject: [PATCH 5/9] docs(mcp): document that list_rls_filters and list_plugins have inline column docs --- superset/mcp_service/app.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index a444834194a2..1869c6eeea9e 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -361,7 +361,10 @@ def get_default_instructions( General usage tips: - All listing tools use 1-based pagination (first page is 1) -- Use get_schema to discover filterable columns, sortable columns, and default columns +- Use get_schema (chart/dataset/dashboard/database) to discover filterable columns, + sortable columns, and default columns for those resource types +- For list_rls_filters and list_plugins, filterable/sortable columns are listed + inline in each tool's docstring — get_schema does not cover these tools - Use 'filters' parameter for advanced queries with filter columns from get_schema - IDs can be integer or UUID format where supported - All tools return structured, Pydantic-typed responses From d5ac996d85af34721e407bd8080534f3e34782a8 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 27 May 2026 18:22:43 +0000 Subject: [PATCH 6/9] fix(mcp): fix serializer signature and update docstring per review - Fix _serialize cols parameter type from list[str] | None to list[str] in both list_plugins.py and list_rls_filters.py to match ModelListCore Callable[[T, List[str]], S | None] callback signature - Update filter_user_directory_fields docstring to mention roles Co-Authored-By: Claude Sonnet 4.6 --- superset/mcp_service/plugin/tool/list_plugins.py | 2 +- superset/mcp_service/privacy.py | 2 +- superset/mcp_service/rls/tool/list_rls_filters.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/superset/mcp_service/plugin/tool/list_plugins.py b/superset/mcp_service/plugin/tool/list_plugins.py index 8c5de9ec22ff..f93821e507e9 100644 --- a/superset/mcp_service/plugin/tool/list_plugins.py +++ b/superset/mcp_service/plugin/tool/list_plugins.py @@ -75,7 +75,7 @@ async def list_plugins( try: from superset.mcp_service.plugin.dao import DynamicPluginDAO - def _serialize(obj: object, cols: list[str] | None) -> PluginInfo | None: + def _serialize(obj: object, cols: list[str]) -> PluginInfo | None: return serialize_plugin_object(obj) list_tool = ModelListCore( diff --git a/superset/mcp_service/privacy.py b/superset/mcp_service/privacy.py index 86dc552e7890..0b64ec14e0ac 100644 --- a/superset/mcp_service/privacy.py +++ b/superset/mcp_service/privacy.py @@ -140,7 +140,7 @@ def user_can_view_data_model_metadata() -> bool: def filter_user_directory_fields(data: dict[str, Any]) -> dict[str, Any]: - """Remove fields that expose users, owners, or access metadata.""" + """Remove fields that expose users, roles, owners, or access metadata.""" return { key: value for key, value in data.items() if key not in USER_DIRECTORY_FIELDS } diff --git a/superset/mcp_service/rls/tool/list_rls_filters.py b/superset/mcp_service/rls/tool/list_rls_filters.py index e72d032a96e6..ce8dfc814389 100644 --- a/superset/mcp_service/rls/tool/list_rls_filters.py +++ b/superset/mcp_service/rls/tool/list_rls_filters.py @@ -75,7 +75,7 @@ async def list_rls_filters( try: from superset.daos.security import RLSDAO - def _serialize(obj: object, cols: list[str] | None) -> RlsFilterInfo | None: + def _serialize(obj: object, cols: list[str]) -> RlsFilterInfo | None: return serialize_rls_filter_object(obj) list_tool = ModelListCore( From ae0a6ec614d668c3a56c9a4b2beab7ec248d39b9 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 27 May 2026 18:35:17 +0000 Subject: [PATCH 7/9] fix(mcp): prevent ValueError when select_columns contains only USER_DIRECTORY_FIELDS in list_rls_filters When the caller passes select_columns that consists entirely of USER_DIRECTORY_FIELDS columns (e.g. ["roles"]), ModelListCore raises ValueError because its privacy filter strips all columns, leaving an empty list. Strip USER_DIRECTORY_FIELDS from select_columns before passing to run_tool (falling back to None/defaults when the filtered list is empty). The existing bypass mechanism already restores these fields in the final serialized output using ALL_RLS_COLUMNS. Adds a regression test for the ["roles"]-only select_columns edge case. --- .../mcp_service/rls/tool/list_rls_filters.py | 15 ++++++++++++- .../mcp_service/rls/tool/test_rls_tools.py | 22 +++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/superset/mcp_service/rls/tool/list_rls_filters.py b/superset/mcp_service/rls/tool/list_rls_filters.py index ce8dfc814389..ecb00b32aac0 100644 --- a/superset/mcp_service/rls/tool/list_rls_filters.py +++ b/superset/mcp_service/rls/tool/list_rls_filters.py @@ -26,6 +26,7 @@ from superset.extensions import event_logger from superset.mcp_service.mcp_core import ModelListCore +from superset.mcp_service.privacy import USER_DIRECTORY_FIELDS from superset.mcp_service.rls.schemas import ( ALL_RLS_COLUMNS, DEFAULT_RLS_COLUMNS, @@ -92,11 +93,23 @@ def _serialize(obj: object, cols: list[str]) -> RlsFilterInfo | None: logger=logger, ) + # RLS 'roles' is valid filter data but lives in USER_DIRECTORY_FIELDS, + # so ModelListCore would raise ValueError for a column list that reduces + # to empty after privacy filtering (e.g. select_columns=["roles"]). + # Strip directory-field columns here; the bypass below restores them in + # the final serialized output from ALL_RLS_COLUMNS. + run_tool_columns = None + if request.select_columns: + non_directory = [ + c for c in request.select_columns if c not in USER_DIRECTORY_FIELDS + ] + run_tool_columns = non_directory if non_directory else None + with event_logger.log_context(action="mcp.list_rls_filters.query"): result = list_tool.run_tool( filters=request.filters, search=request.search, - select_columns=request.select_columns, + select_columns=run_tool_columns, order_column=request.order_column, order_direction=request.order_direction, page=max(request.page - 1, 0), diff --git a/tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py b/tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py index ab45317fab45..1a05dddd37f8 100644 --- a/tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py +++ b/tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py @@ -157,6 +157,28 @@ async def test_list_rls_filters_returns_tables_and_roles(mock_list, mcp_server): assert item["roles"][0]["name"] == "Alpha" +@patch("superset.daos.security.RLSDAO.list") +@pytest.mark.asyncio +async def test_list_rls_filters_roles_only_select_columns(mock_list, mcp_server): + """Requesting only 'roles' must not raise ValueError from the privacy filter.""" + rls_filter = create_mock_rls_filter() + mock_list.return_value = ([rls_filter], 1) + + async with Client(mcp_server) as client: + request = ListRlsFiltersRequest( + page=1, + page_size=10, + select_columns=["roles"], + ) + result = await client.call_tool( + "list_rls_filters", {"request": request.model_dump()} + ) + data = json.loads(result.content[0].text) + item = data["rls_filters"][0] + assert "roles" in item + assert item["roles"][0]["name"] == "Alpha" + + @patch("superset.daos.security.RLSDAO.list") @pytest.mark.asyncio async def test_list_rls_filters_empty(mock_list, mcp_server): From cf832788b5f3f53c97749b6324e8b886380390a9 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 27 May 2026 18:37:26 +0000 Subject: [PATCH 8/9] refactor(mcp): rename _serialize to _serialize_rls_filter/_serialize_plugin for consistency Align with the naming convention used by all other list tools (list_charts, list_dashboards, list_databases, list_datasets), which use _serialize_ for the item serializer closure. Addresses bito additional suggestion: serializer naming inconsistency. --- superset/mcp_service/plugin/tool/list_plugins.py | 4 ++-- superset/mcp_service/rls/tool/list_rls_filters.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/superset/mcp_service/plugin/tool/list_plugins.py b/superset/mcp_service/plugin/tool/list_plugins.py index f93821e507e9..04373a98ed19 100644 --- a/superset/mcp_service/plugin/tool/list_plugins.py +++ b/superset/mcp_service/plugin/tool/list_plugins.py @@ -75,13 +75,13 @@ async def list_plugins( try: from superset.mcp_service.plugin.dao import DynamicPluginDAO - def _serialize(obj: object, cols: list[str]) -> PluginInfo | None: + def _serialize_plugin(obj: object, cols: list[str]) -> PluginInfo | None: return serialize_plugin_object(obj) list_tool = ModelListCore( dao_class=DynamicPluginDAO, output_schema=PluginInfo, - item_serializer=_serialize, + item_serializer=_serialize_plugin, filter_type=PluginColumnFilter, default_columns=DEFAULT_PLUGIN_COLUMNS, search_columns=["name", "key"], diff --git a/superset/mcp_service/rls/tool/list_rls_filters.py b/superset/mcp_service/rls/tool/list_rls_filters.py index ecb00b32aac0..ae92dcf56df3 100644 --- a/superset/mcp_service/rls/tool/list_rls_filters.py +++ b/superset/mcp_service/rls/tool/list_rls_filters.py @@ -76,13 +76,13 @@ async def list_rls_filters( try: from superset.daos.security import RLSDAO - def _serialize(obj: object, cols: list[str]) -> RlsFilterInfo | None: + def _serialize_rls_filter(obj: object, cols: list[str]) -> RlsFilterInfo | None: return serialize_rls_filter_object(obj) list_tool = ModelListCore( dao_class=RLSDAO, output_schema=RlsFilterInfo, - item_serializer=_serialize, + item_serializer=_serialize_rls_filter, filter_type=RlsColumnFilter, default_columns=DEFAULT_RLS_COLUMNS, search_columns=["name"], From 544df097ed4ff90bdfb0dd36f1cb6be6433ca165 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 28 May 2026 16:06:21 +0000 Subject: [PATCH 9/9] feat(mcp): add description/created_on to RlsFilterInfo; rename RlsColumnFilter to RlsFilter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add `description` field to `RlsFilterInfo`, `ALL_RLS_COLUMNS`, and `serialize_rls_filter_object` (maps to `RowLevelSecurityFilter.description`) - Add `created_on` field alongside existing `changed_on` for consistency with all sibling schemas in the MCP service suite - Rename `RlsColumnFilter` → `RlsFilter` to follow the `Filter` naming convention used by every other filter class in the suite - Update test mock and test class name accordingly --- superset/mcp_service/rls/schemas.py | 18 +++++++++++++----- .../rls/tool/get_rls_filter_info.py | 4 ++-- .../mcp_service/rls/tool/list_rls_filters.py | 6 +++--- .../mcp_service/rls/tool/test_rls_tools.py | 12 +++++++----- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/superset/mcp_service/rls/schemas.py b/superset/mcp_service/rls/schemas.py index 37a7734f01de..bfb715266b05 100644 --- a/superset/mcp_service/rls/schemas.py +++ b/superset/mcp_service/rls/schemas.py @@ -47,18 +47,20 @@ ALL_RLS_COLUMNS = [ "id", "name", + "description", "filter_type", "tables", "roles", "clause", "group_key", + "created_on", "changed_on", ] SORTABLE_RLS_COLUMNS = ["id", "name", "filter_type", "changed_on"] -class RlsColumnFilter(ColumnOperator): +class RlsFilter(ColumnOperator): """Filter object for RLS filter listing.""" col: Literal["name", "filter_type"] = Field( @@ -86,6 +88,9 @@ class RlsRoleRef(BaseModel): class RlsFilterInfo(BaseModel): id: int | None = Field(None, description="RLS filter ID") name: str | None = Field(None, description="RLS filter name") + description: str | None = Field( + None, description="Human-readable description of the filter's purpose" + ) filter_type: str | None = Field(None, description="Filter type: Regular or Base") tables: List[RlsTableRef] | None = Field( None, description="Tables this filter applies to" @@ -97,6 +102,7 @@ class RlsFilterInfo(BaseModel): group_key: str | None = Field( None, description="Group key for Base filter grouping" ) + created_on: str | datetime | None = Field(None, description="Creation timestamp") changed_on: str | datetime | None = Field( None, description="Last modification timestamp" ) @@ -130,7 +136,7 @@ class RlsFilterList(BaseModel): columns_loaded: List[str] = Field(default_factory=list) columns_available: List[str] = Field(default_factory=list) sortable_columns: List[str] = Field(default_factory=list) - filters_applied: List[RlsColumnFilter] = Field(default_factory=list) + filters_applied: List[RlsFilter] = Field(default_factory=list) pagination: PaginationInfo | None = None timestamp: datetime | None = None model_config = ConfigDict(ser_json_timedelta="iso8601") @@ -140,7 +146,7 @@ class ListRlsFiltersRequest(BaseModel): """Request schema for list_rls_filters.""" filters: Annotated[ - List[RlsColumnFilter], + List[RlsFilter], Field( default_factory=list, description="List of filter objects (col, opr, value). " @@ -184,8 +190,8 @@ class ListRlsFiltersRequest(BaseModel): @field_validator("filters", mode="before") @classmethod - def parse_filters(cls, v: Any) -> List[RlsColumnFilter]: - return parse_json_or_model_list(v, RlsColumnFilter, "filters") + def parse_filters(cls, v: Any) -> List[RlsFilter]: + return parse_json_or_model_list(v, RlsFilter, "filters") @field_validator("select_columns", mode="before") @classmethod @@ -246,10 +252,12 @@ def serialize_rls_filter_object(rls_filter: Any) -> RlsFilterInfo | None: return RlsFilterInfo( id=getattr(rls_filter, "id", None), name=getattr(rls_filter, "name", None), + description=getattr(rls_filter, "description", None), filter_type=getattr(rls_filter, "filter_type", None), tables=tables, roles=roles, clause=getattr(rls_filter, "clause", None), group_key=getattr(rls_filter, "group_key", None), + created_on=getattr(rls_filter, "created_on", None), changed_on=getattr(rls_filter, "changed_on", None), ) diff --git a/superset/mcp_service/rls/tool/get_rls_filter_info.py b/superset/mcp_service/rls/tool/get_rls_filter_info.py index 31c828689ac0..130c9303f3af 100644 --- a/superset/mcp_service/rls/tool/get_rls_filter_info.py +++ b/superset/mcp_service/rls/tool/get_rls_filter_info.py @@ -51,8 +51,8 @@ async def get_rls_filter_info( ) -> RlsFilterInfo | RlsFilterError: """Get row level security filter details by ID. Requires admin access. - Returns full RLS filter configuration including name, type, tables, roles, - and clause. + Returns full RLS filter configuration including name, description, type, + tables, roles, clause, created_on, and changed_on. Example usage: ```json diff --git a/superset/mcp_service/rls/tool/list_rls_filters.py b/superset/mcp_service/rls/tool/list_rls_filters.py index ae92dcf56df3..468c325a7585 100644 --- a/superset/mcp_service/rls/tool/list_rls_filters.py +++ b/superset/mcp_service/rls/tool/list_rls_filters.py @@ -31,7 +31,7 @@ ALL_RLS_COLUMNS, DEFAULT_RLS_COLUMNS, ListRlsFiltersRequest, - RlsColumnFilter, + RlsFilter, RlsFilterError, RlsFilterInfo, RlsFilterList, @@ -59,7 +59,7 @@ async def list_rls_filters( ) -> RlsFilterList | RlsFilterError: """List row level security filters. Requires admin access. - Returns RLS filter metadata including name, filter type, tables, roles, and clause. + Returns RLS filter metadata including name, description, filter type, tables, roles, clause, created_on, and changed_on. Sortable columns for order_column: id, name, filter_type, changed_on """ @@ -83,7 +83,7 @@ def _serialize_rls_filter(obj: object, cols: list[str]) -> RlsFilterInfo | None: dao_class=RLSDAO, output_schema=RlsFilterInfo, item_serializer=_serialize_rls_filter, - filter_type=RlsColumnFilter, + filter_type=RlsFilter, default_columns=DEFAULT_RLS_COLUMNS, search_columns=["name"], list_field_name="rls_filters", diff --git a/tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py b/tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py index 1a05dddd37f8..4da445ead160 100644 --- a/tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py +++ b/tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py @@ -23,7 +23,7 @@ from pydantic import ValidationError from superset.mcp_service.app import mcp -from superset.mcp_service.rls.schemas import ListRlsFiltersRequest, RlsColumnFilter +from superset.mcp_service.rls.schemas import ListRlsFiltersRequest, RlsFilter from superset.utils import json logging.basicConfig(level=logging.DEBUG) @@ -40,9 +40,11 @@ def create_mock_rls_filter( rls_filter = MagicMock() rls_filter.id = filter_id rls_filter.name = name + rls_filter.description = None rls_filter.filter_type = filter_type rls_filter.clause = clause rls_filter.group_key = group_key + rls_filter.created_on = None rls_filter.changed_on = None table = MagicMock() @@ -73,17 +75,17 @@ def mock_auth(): yield mock_get_user -class TestRlsColumnFilterSchema: +class TestRlsFilterSchema: def test_invalid_filter_column_rejected(self): with pytest.raises(ValidationError): - RlsColumnFilter(col="clause", opr="eq", value="test") + RlsFilter(col="clause", opr="eq", value="test") def test_valid_name_filter(self): - f = RlsColumnFilter(col="name", opr="eq", value="test") + f = RlsFilter(col="name", opr="eq", value="test") assert f.col == "name" def test_valid_filter_type_filter(self): - f = RlsColumnFilter(col="filter_type", opr="eq", value="Regular") + f = RlsFilter(col="filter_type", opr="eq", value="Regular") assert f.col == "filter_type"