From 5f06929cf8806144cd6ef08493118b505a73ef2c Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 20 May 2026 22:46:29 +0000 Subject: [PATCH 01/13] feat(mcp): add list and get tools for action log and tasks Co-Authored-By: Claude Sonnet 4.6 --- superset/mcp_service/action_log/__init__.py | 16 ++ superset/mcp_service/action_log/schemas.py | 205 +++++++++++++++ .../mcp_service/action_log/tool/__init__.py | 24 ++ .../action_log/tool/get_action_log_info.py | 97 +++++++ .../action_log/tool/list_action_logs.py | 145 +++++++++++ superset/mcp_service/app.py | 8 + superset/mcp_service/task/__init__.py | 16 ++ superset/mcp_service/task/schemas.py | 195 ++++++++++++++ superset/mcp_service/task/tool/__init__.py | 24 ++ .../mcp_service/task/tool/get_task_info.py | 108 ++++++++ superset/mcp_service/task/tool/list_tasks.py | 129 +++++++++ .../mcp_service/action_log/__init__.py | 16 ++ .../mcp_service/action_log/tool/__init__.py | 16 ++ .../action_log/tool/test_action_log_tools.py | 213 +++++++++++++++ tests/unit_tests/mcp_service/task/__init__.py | 16 ++ .../mcp_service/task/tool/__init__.py | 16 ++ .../mcp_service/task/tool/test_task_tools.py | 245 ++++++++++++++++++ 17 files changed, 1489 insertions(+) create mode 100644 superset/mcp_service/action_log/__init__.py create mode 100644 superset/mcp_service/action_log/schemas.py create mode 100644 superset/mcp_service/action_log/tool/__init__.py create mode 100644 superset/mcp_service/action_log/tool/get_action_log_info.py create mode 100644 superset/mcp_service/action_log/tool/list_action_logs.py create mode 100644 superset/mcp_service/task/__init__.py create mode 100644 superset/mcp_service/task/schemas.py create mode 100644 superset/mcp_service/task/tool/__init__.py create mode 100644 superset/mcp_service/task/tool/get_task_info.py create mode 100644 superset/mcp_service/task/tool/list_tasks.py create mode 100644 tests/unit_tests/mcp_service/action_log/__init__.py create mode 100644 tests/unit_tests/mcp_service/action_log/tool/__init__.py create mode 100644 tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py create mode 100644 tests/unit_tests/mcp_service/task/__init__.py create mode 100644 tests/unit_tests/mcp_service/task/tool/__init__.py create mode 100644 tests/unit_tests/mcp_service/task/tool/test_task_tools.py diff --git a/superset/mcp_service/action_log/__init__.py b/superset/mcp_service/action_log/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/superset/mcp_service/action_log/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/mcp_service/action_log/schemas.py b/superset/mcp_service/action_log/schemas.py new file mode 100644 index 000000000000..c6697cc70cbf --- /dev/null +++ b/superset/mcp_service/action_log/schemas.py @@ -0,0 +1,205 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pydantic schemas for action-log MCP tools.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Annotated, Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, field_validator, PositiveInt + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE +from superset.mcp_service.system.schemas import PaginationInfo +from superset.mcp_service.utils.schema_utils import ( + parse_json_or_list, + parse_json_or_model_list, +) + +DEFAULT_LOG_COLUMNS: list[str] = ["id", "action", "user_id", "dttm"] +ALL_LOG_COLUMNS: list[str] = [ + "id", + "action", + "user_id", + "dttm", + "dashboard_id", + "slice_id", + "json", +] +LOG_SORTABLE_COLUMNS: list[str] = ["id", "dttm"] + + +class ActionLogFilter(ColumnOperator): + """Filter object for action-log listing. + + col: Column to filter on. + opr: Operator to use. + value: Value to filter by. + """ + + col: Literal["action", "user_id", "dashboard_id", "slice_id", "dttm"] = Field( + ..., + description="Column to filter on.", + ) + opr: ColumnOperatorEnum = Field(..., description="Operator to use.") + value: str | int | float | bool | list[str | int | float | bool] = Field( + ..., description="Value to filter by" + ) + + +class ActionLogInfo(BaseModel): + id: int | None = Field(None, description="Log entry ID") + action: str | None = Field(None, description="Action name") + user_id: int | None = Field( + None, description="ID of the user who performed the action" + ) + dttm: str | datetime | None = Field(None, description="Timestamp of the action") + dashboard_id: int | None = Field(None, description="Associated dashboard ID") + slice_id: int | None = Field(None, description="Associated chart/slice ID") + json: str | None = Field(None, description="JSON payload of the action") + + model_config = ConfigDict( + from_attributes=True, + ser_json_timedelta="iso8601", + populate_by_name=True, + ) + + def model_post_init(self, __context: Any) -> None: + if isinstance(self.dttm, datetime) and self.dttm.tzinfo is None: + from datetime import timezone + + object.__setattr__(self, "dttm", self.dttm.replace(tzinfo=timezone.utc)) + + +class ActionLogList(BaseModel): + action_logs: list[ActionLogInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: list[str] = Field(default_factory=list) + columns_loaded: list[str] = Field(default_factory=list) + columns_available: list[str] = Field(default_factory=list) + sortable_columns: list[str] = Field(default_factory=list) + filters_applied: list[ActionLogFilter] = Field(default_factory=list) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ListActionLogsRequest(BaseModel): + """Request schema for list_action_logs.""" + + filters: Annotated[ + list[ActionLogFilter], + Field( + default_factory=list, + description=( + "List of filter objects (col, opr, value). " + "Filter columns: action, user_id, dashboard_id, slice_id, dttm. " + "Cannot be used with 'search'." + ), + ), + ] + select_columns: Annotated[ + list[str], + Field( + default_factory=list, + description="Columns to return. Defaults to common columns.", + ), + ] + order_column: Annotated[ + str | None, + Field(default=None, description="Column to sort by (default: dttm)"), + ] + order_direction: Annotated[ + Literal["asc", "desc"], + Field(default="desc", description="Sort direction ('asc' or 'desc')"), + ] + page: Annotated[ + PositiveInt, + Field(default=1, description="Page number (1-based)"), + ] + page_size: Annotated[ + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Items per page (max {MAX_PAGE_SIZE})", + ), + ] + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v: Any) -> list[ActionLogFilter]: + return parse_json_or_model_list(v, ActionLogFilter, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_columns(cls, v: Any) -> list[str]: + return parse_json_or_list(v, "select_columns") + + +class ActionLogError(BaseModel): + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Error type") + timestamp: str | datetime | None = Field(None, description="Error timestamp") + model_config = ConfigDict(ser_json_timedelta="iso8601") + + @classmethod + def create(cls, error: str, error_type: str) -> "ActionLogError": + from datetime import timezone + + return cls( + error=error, + error_type=error_type, + timestamp=datetime.now(timezone.utc), + ) + + +class GetActionLogInfoRequest(BaseModel): + """Request schema for get_action_log_info (ID-only lookup).""" + + identifier: Annotated[ + int, + Field(description="Log entry ID (integer)"), + ] + + +def serialize_action_log_object(log: Any) -> ActionLogInfo | None: + if not log: + return None + dttm = getattr(log, "dttm", None) + if isinstance(dttm, datetime) and dttm.tzinfo is None: + from datetime import timezone + + dttm = dttm.replace(tzinfo=timezone.utc) + return ActionLogInfo( + id=getattr(log, "id", None), + action=getattr(log, "action", None), + user_id=getattr(log, "user_id", None), + dttm=dttm, + dashboard_id=getattr(log, "dashboard_id", None), + slice_id=getattr(log, "slice_id", None), + json=getattr(log, "json", None), + ) diff --git a/superset/mcp_service/action_log/tool/__init__.py b/superset/mcp_service/action_log/tool/__init__.py new file mode 100644 index 000000000000..086da42c8658 --- /dev/null +++ b/superset/mcp_service/action_log/tool/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .get_action_log_info import get_action_log_info +from .list_action_logs import list_action_logs + +__all__ = [ + "list_action_logs", + "get_action_log_info", +] diff --git a/superset/mcp_service/action_log/tool/get_action_log_info.py b/superset/mcp_service/action_log/tool/get_action_log_info.py new file mode 100644 index 000000000000..46ac6b8fde9c --- /dev/null +++ b/superset/mcp_service/action_log/tool/get_action_log_info.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Get action log info MCP tool.""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.action_log.schemas import ( + ActionLogError, + ActionLogInfo, + GetActionLogInfoRequest, + serialize_action_log_object, +) +from superset.mcp_service.mcp_core import ModelGetInfoCore + +logger = logging.getLogger(__name__) + + +@tool( + tags=["discovery"], + class_permission_name="Log", + annotations=ToolAnnotations( + title="Get action log info", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def get_action_log_info( + request: GetActionLogInfoRequest, + ctx: Context, +) -> ActionLogInfo | ActionLogError: + """Get a single action log entry by its integer ID. + + Returns the action, user_id, timestamp (dttm), dashboard_id, slice_id, + and JSON payload for the specified log record. + + ADMIN-ONLY: This tool requires admin privileges. + + Use list_action_logs to discover log IDs. + """ + await ctx.info("Retrieving action log: identifier=%s" % (request.identifier,)) + + try: + from superset.daos.log import LogDAO + + with event_logger.log_context(action="mcp.get_action_log_info.lookup"): + get_tool = ModelGetInfoCore( + dao_class=LogDAO, + output_schema=ActionLogInfo, + error_schema=ActionLogError, + serializer=serialize_action_log_object, + supports_slug=False, + logger=logger, + ) + result = get_tool.run_tool(request.identifier) + + if isinstance(result, ActionLogInfo): + await ctx.info( + "Action log retrieved: id=%s, action=%s" % (result.id, result.action) + ) + else: + await ctx.warning( + "Action log retrieval failed: error_type=%s, error=%s" + % (result.error_type, result.error) + ) + + return result + + except Exception as e: + await ctx.error( + "Action log retrieval failed: identifier=%s, error=%s, error_type=%s" + % (request.identifier, str(e), type(e).__name__) + ) + return ActionLogError( + error=f"Failed to get action log info: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/action_log/tool/list_action_logs.py b/superset/mcp_service/action_log/tool/list_action_logs.py new file mode 100644 index 000000000000..a1ec8b07c8fc --- /dev/null +++ b/superset/mcp_service/action_log/tool/list_action_logs.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""List action logs MCP tool.""" + +import logging +from datetime import datetime, timedelta, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.extensions import event_logger +from superset.mcp_service.action_log.schemas import ( + ActionLogError, + ActionLogFilter, + ActionLogInfo, + ActionLogList, + ALL_LOG_COLUMNS, + DEFAULT_LOG_COLUMNS, + ListActionLogsRequest, + LOG_SORTABLE_COLUMNS, + serialize_action_log_object, +) +from superset.mcp_service.mcp_core import ModelListCore + +logger = logging.getLogger(__name__) + +_DEFAULT_LIST_ACTION_LOGS_REQUEST = ListActionLogsRequest() + + +@tool( + tags=["core"], + class_permission_name="Log", + annotations=ToolAnnotations( + title="List action logs", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def list_action_logs( + request: ListActionLogsRequest | None = None, + ctx: Context | None = None, +) -> ActionLogList | ActionLogError: + """List Superset action logs with filtering and pagination. + + Returns audit log entries recording user interactions with dashboards and + charts. Defaults to the last 7 days to avoid pulling large result sets. + + ADMIN-ONLY: This tool requires admin privileges. Non-admin users will + receive a permission error. + + Sortable columns for order_column: id, dttm + Filter columns: action, user_id, dashboard_id, slice_id, dttm + + When no dttm filter is provided the tool automatically applies + dttm >= (now - 7 days). Add an explicit dttm filter to override. + """ + if ctx is None: + raise RuntimeError("FastMCP context is required for list_action_logs") + + request = request or _DEFAULT_LIST_ACTION_LOGS_REQUEST.model_copy(deep=True) + + await ctx.info( + "Listing action logs: page=%s, page_size=%s" % (request.page, request.page_size) + ) + await ctx.debug( + "Action log parameters: filters=%s, order_column=%s, order_direction=%s" + % (request.filters, request.order_column, request.order_direction) + ) + + try: + from superset.daos.log import LogDAO + + # Inject default 7-day dttm filter unless caller already provides one + filters: list[ColumnOperator] = list(request.filters) + has_dttm_filter = any(getattr(f, "col", None) == "dttm" for f in filters) + if not has_dttm_filter: + cutoff = datetime.now(timezone.utc) - timedelta(days=7) + default_filter = ColumnOperator( + col="dttm", + opr=ColumnOperatorEnum.gte, + value=cutoff, + ) + filters = [default_filter] + filters + await ctx.debug( + "Applied default 7-day dttm filter: cutoff=%s" % (cutoff.isoformat(),) + ) + + def _serialize(obj: object, cols: list[str] | None) -> ActionLogInfo | None: + return serialize_action_log_object(obj) + + list_tool = ModelListCore( + dao_class=LogDAO, + output_schema=ActionLogInfo, + item_serializer=_serialize, + filter_type=ActionLogFilter, + default_columns=DEFAULT_LOG_COLUMNS, + search_columns=["action"], + list_field_name="action_logs", + output_list_schema=ActionLogList, + all_columns=ALL_LOG_COLUMNS, + sortable_columns=LOG_SORTABLE_COLUMNS, + logger=logger, + ) + + with event_logger.log_context(action="mcp.list_action_logs.query"): + result = list_tool.run_tool( + filters=filters, + select_columns=request.select_columns, + order_column=request.order_column or "dttm", + order_direction=request.order_direction, + page=max(request.page - 1, 0), + page_size=request.page_size, + ) + + await ctx.info( + "Action logs listed: count=%s, total_count=%s" + % ( + len(result.action_logs) if hasattr(result, "action_logs") else 0, + getattr(result, "total_count", None), + ) + ) + return result + + except Exception as e: + await ctx.error( + "Action log listing failed: page=%s, error=%s, error_type=%s" + % (request.page, str(e), type(e).__name__) + ) + raise diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 01566b364569..851ed0d76e82 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -620,6 +620,10 @@ def create_mcp_app( # NOTE: Always add new prompt/resource imports here when creating new prompts/resources. # Prompts use @mcp.prompt decorators and resources use @mcp.resource decorators. # They register automatically on import, similar to tools. +from superset.mcp_service.action_log.tool import ( # noqa: F401, E402 + get_action_log_info, + list_action_logs, +) from superset.mcp_service.chart import ( # noqa: F401, E402 prompts as chart_prompts, resources as chart_resources, @@ -670,6 +674,10 @@ def create_mcp_app( get_schema, health_check, ) +from superset.mcp_service.task.tool import ( # noqa: F401, E402 + get_task_info, + list_tasks, +) def _remove_disabled_tools(disabled_tools: set[str]) -> None: diff --git a/superset/mcp_service/task/__init__.py b/superset/mcp_service/task/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/superset/mcp_service/task/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/mcp_service/task/schemas.py b/superset/mcp_service/task/schemas.py new file mode 100644 index 000000000000..7f41bc99263d --- /dev/null +++ b/superset/mcp_service/task/schemas.py @@ -0,0 +1,195 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pydantic schemas for task MCP tools.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Annotated, Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, field_validator, PositiveInt + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE +from superset.mcp_service.system.schemas import PaginationInfo +from superset.mcp_service.utils.schema_utils import ( + parse_json_or_list, + parse_json_or_model_list, +) + +DEFAULT_TASK_COLUMNS: list[str] = ["id", "uuid", "task_type", "status", "changed_on"] +ALL_TASK_COLUMNS: list[str] = [ + "id", + "uuid", + "task_type", + "status", + "scope", + "changed_on", + "created_on", +] +TASK_SORTABLE_COLUMNS: list[str] = ["id", "changed_on", "created_on", "status"] + + +class TaskColumnFilter(ColumnOperator): + """Filter object for task listing. + + col: Column to filter on. + opr: Operator to use. + value: Value to filter by. + """ + + col: Literal["task_type", "status", "scope"] = Field( + ..., + description="Column to filter on.", + ) + opr: ColumnOperatorEnum = Field(..., description="Operator to use.") + value: str | int | float | bool | list[str | int | float | bool] = Field( + ..., description="Value to filter by" + ) + + +class TaskInfo(BaseModel): + id: int | None = Field(None, description="Task ID") + uuid: str | None = Field(None, description="Task UUID") + task_type: str | None = Field(None, description="Task type (e.g., sql_execution)") + status: str | None = Field(None, description="Task status") + scope: str | None = Field(None, description="Task scope (private/shared/system)") + changed_on: str | datetime | None = Field( + None, description="Last modification timestamp" + ) + created_on: str | datetime | None = Field(None, description="Creation timestamp") + + model_config = ConfigDict( + from_attributes=True, + ser_json_timedelta="iso8601", + populate_by_name=True, + ) + + +class TaskList(BaseModel): + tasks: list[TaskInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: list[str] = Field(default_factory=list) + columns_loaded: list[str] = Field(default_factory=list) + columns_available: list[str] = Field(default_factory=list) + sortable_columns: list[str] = Field(default_factory=list) + filters_applied: list[TaskColumnFilter] = Field(default_factory=list) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ListTasksRequest(BaseModel): + """Request schema for list_tasks.""" + + filters: Annotated[ + list[TaskColumnFilter], + Field( + default_factory=list, + description=( + "List of filter objects (col, opr, value). " + "Filter columns: task_type, status, scope. " + "Cannot be used with 'search'." + ), + ), + ] + select_columns: Annotated[ + list[str], + Field( + default_factory=list, + description="Columns to return. Defaults to common columns.", + ), + ] + order_column: Annotated[ + str | None, + Field(default=None, description="Column to sort by (default: changed_on)"), + ] + order_direction: Annotated[ + Literal["asc", "desc"], + Field(default="desc", description="Sort direction ('asc' or 'desc')"), + ] + page: Annotated[ + PositiveInt, + Field(default=1, description="Page number (1-based)"), + ] + page_size: Annotated[ + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Items per page (max {MAX_PAGE_SIZE})", + ), + ] + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v: Any) -> list[TaskColumnFilter]: + return parse_json_or_model_list(v, TaskColumnFilter, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_columns(cls, v: Any) -> list[str]: + return parse_json_or_list(v, "select_columns") + + +class TaskError(BaseModel): + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Error type") + timestamp: str | datetime | None = Field(None, description="Error timestamp") + model_config = ConfigDict(ser_json_timedelta="iso8601") + + @classmethod + def create(cls, error: str, error_type: str) -> "TaskError": + from datetime import timezone + + return cls( + error=error, + error_type=error_type, + timestamp=datetime.now(timezone.utc), + ) + + +class GetTaskInfoRequest(BaseModel): + """Request schema for get_task_info (ID or UUID lookup).""" + + identifier: Annotated[ + int | str, + Field(description="Task identifier — numeric ID or UUID string"), + ] + + +def serialize_task_object(task: Any) -> TaskInfo | None: + if not task: + return None + uuid_val = getattr(task, "uuid", None) + return TaskInfo( + id=getattr(task, "id", None), + uuid=str(uuid_val) if uuid_val is not None else None, + task_type=getattr(task, "task_type", None), + status=getattr(task, "status", None), + scope=getattr(task, "scope", None), + changed_on=getattr(task, "changed_on", None), + created_on=getattr(task, "created_on", None), + ) diff --git a/superset/mcp_service/task/tool/__init__.py b/superset/mcp_service/task/tool/__init__.py new file mode 100644 index 000000000000..acf3da684292 --- /dev/null +++ b/superset/mcp_service/task/tool/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .get_task_info import get_task_info +from .list_tasks import list_tasks + +__all__ = [ + "list_tasks", + "get_task_info", +] diff --git a/superset/mcp_service/task/tool/get_task_info.py b/superset/mcp_service/task/tool/get_task_info.py new file mode 100644 index 000000000000..1cbc38d2e44d --- /dev/null +++ b/superset/mcp_service/task/tool/get_task_info.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Get task info MCP tool.""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.mcp_core import ModelGetInfoCore +from superset.mcp_service.task.schemas import ( + GetTaskInfoRequest, + serialize_task_object, + TaskError, + TaskInfo, +) + +logger = logging.getLogger(__name__) + + +@tool( + tags=["discovery"], + class_permission_name="Task", + annotations=ToolAnnotations( + title="Get task info", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def get_task_info( + request: GetTaskInfoRequest, + ctx: Context, +) -> TaskInfo | TaskError: + """Get details for a single async task by ID or UUID. + + Returns task_type, status, scope, and timestamps for the specified task. + Non-admin users can only retrieve tasks they are subscribed to. + + Use list_tasks to discover task IDs and UUIDs. + + Example usage: + ```json + {"identifier": 42} + ``` + + Or with UUID: + ```json + {"identifier": "a1b2c3d4-5678-90ab-cdef-1234567890ab"} + ``` + """ + await ctx.info("Retrieving task: identifier=%s" % (request.identifier,)) + + try: + from superset.daos.tasks import TaskDAO + + with event_logger.log_context(action="mcp.get_task_info.lookup"): + # ModelGetInfoCore handles int ID and UUID string automatically. + # TaskDAO.base_filter (TaskFilter) enforces subscription-based access. + get_tool = ModelGetInfoCore( + dao_class=TaskDAO, + output_schema=TaskInfo, + error_schema=TaskError, + serializer=serialize_task_object, + supports_slug=False, + logger=logger, + ) + result = get_tool.run_tool(request.identifier) + + if isinstance(result, TaskInfo): + await ctx.info( + "Task retrieved: id=%s, task_type=%s, status=%s" + % (result.id, result.task_type, result.status) + ) + else: + await ctx.warning( + "Task retrieval failed: error_type=%s, error=%s" + % (result.error_type, result.error) + ) + + return result + + except Exception as e: + await ctx.error( + "Task retrieval failed: identifier=%s, error=%s, error_type=%s" + % (request.identifier, str(e), type(e).__name__) + ) + return TaskError( + error=f"Failed to get task info: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/task/tool/list_tasks.py b/superset/mcp_service/task/tool/list_tasks.py new file mode 100644 index 000000000000..21a0aea5e674 --- /dev/null +++ b/superset/mcp_service/task/tool/list_tasks.py @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""List tasks MCP tool.""" + +import logging + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.mcp_core import ModelListCore +from superset.mcp_service.task.schemas import ( + ALL_TASK_COLUMNS, + DEFAULT_TASK_COLUMNS, + ListTasksRequest, + serialize_task_object, + TASK_SORTABLE_COLUMNS, + TaskColumnFilter, + TaskError, + TaskInfo, + TaskList, +) + +logger = logging.getLogger(__name__) + +_DEFAULT_LIST_TASKS_REQUEST = ListTasksRequest() + + +@tool( + tags=["core"], + class_permission_name="Task", + annotations=ToolAnnotations( + title="List tasks", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def list_tasks( + request: ListTasksRequest | None = None, + ctx: Context | None = None, +) -> TaskList | TaskError: + """List async tasks with filtering and pagination. + + Returns tasks visible to the current user. Non-admin users only see tasks + they are subscribed to (task creators are auto-subscribed). Admins see all + tasks. + + Sortable columns for order_column: id, changed_on, created_on, status + Filter columns: task_type, status, scope + + Common task_type values: sql_execution, thumbnail, report + Common status values: pending, in_progress, success, failure, aborted + Common scope values: private, shared, system + """ + if ctx is None: + raise RuntimeError("FastMCP context is required for list_tasks") + + request = request or _DEFAULT_LIST_TASKS_REQUEST.model_copy(deep=True) + + await ctx.info( + "Listing tasks: page=%s, page_size=%s" % (request.page, request.page_size) + ) + await ctx.debug( + "Task parameters: filters=%s, order_column=%s, order_direction=%s" + % (request.filters, request.order_column, request.order_direction) + ) + + try: + from superset.daos.tasks import TaskDAO + + def _serialize(obj: object, cols: list[str] | None) -> TaskInfo | None: + return serialize_task_object(obj) + + # TaskDAO.base_filter = TaskFilter automatically scopes results: + # non-admins only see their subscribed tasks; admins see all. + list_tool = ModelListCore( + dao_class=TaskDAO, + output_schema=TaskInfo, + item_serializer=_serialize, + filter_type=TaskColumnFilter, + default_columns=DEFAULT_TASK_COLUMNS, + search_columns=["task_type", "status", "scope"], + list_field_name="tasks", + output_list_schema=TaskList, + all_columns=ALL_TASK_COLUMNS, + sortable_columns=TASK_SORTABLE_COLUMNS, + logger=logger, + ) + + with event_logger.log_context(action="mcp.list_tasks.query"): + result = list_tool.run_tool( + filters=request.filters, + select_columns=request.select_columns, + order_column=request.order_column, + order_direction=request.order_direction, + page=max(request.page - 1, 0), + page_size=request.page_size, + ) + + await ctx.info( + "Tasks listed: count=%s, total_count=%s" + % ( + len(result.tasks) if hasattr(result, "tasks") else 0, + getattr(result, "total_count", None), + ) + ) + return result + + except Exception as e: + await ctx.error( + "Task listing failed: page=%s, error=%s, error_type=%s" + % (request.page, str(e), type(e).__name__) + ) + raise diff --git a/tests/unit_tests/mcp_service/action_log/__init__.py b/tests/unit_tests/mcp_service/action_log/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/action_log/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/action_log/tool/__init__.py b/tests/unit_tests/mcp_service/action_log/tool/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/action_log/tool/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py b/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py new file mode 100644 index 000000000000..0f57cdf56ffd --- /dev/null +++ b/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py @@ -0,0 +1,213 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Unit tests for list_action_logs and get_action_log_info MCP tools.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest +from fastmcp import Client +from pydantic import ValidationError + +from superset.mcp_service.action_log.schemas import ( + ActionLogFilter, + ListActionLogsRequest, +) +from superset.mcp_service.app import mcp +from superset.utils import json + + +def create_mock_log( + log_id: int = 1, + action: str = "log", + user_id: int = 1, + dashboard_id: int | None = None, + slice_id: int | None = None, + json_payload: str | None = None, + dttm: datetime | None = None, +) -> MagicMock: + log = MagicMock() + log.id = log_id + log.action = action + log.user_id = user_id + log.dashboard_id = dashboard_id + log.slice_id = slice_id + log.json = json_payload or '{"event_name": "explore_slice"}' + log.dttm = dttm or datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + log.duration_ms = None + log.referrer = None + return log + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + from unittest.mock import Mock + + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +class TestActionLogFilterSchema: + def test_valid_filter_columns_accepted(self): + for col in ("action", "user_id", "dashboard_id", "slice_id", "dttm"): + f = ActionLogFilter(col=col, opr="eq", value="test") + assert f.col == col + + def test_invalid_filter_column_rejected(self): + with pytest.raises(ValidationError): + ActionLogFilter(col="not_a_column", opr="eq", value="x") + + def test_created_by_fk_rejected(self): + with pytest.raises(ValidationError): + ActionLogFilter(col="created_by_fk", opr="eq", value=1) + + +@patch("superset.daos.log.LogDAO.list") +@pytest.mark.asyncio +async def test_list_action_logs_basic(mock_list, mcp_server): + """Basic listing returns structured response.""" + log = create_mock_log() + mock_list.return_value = ([log], 1) + + async with Client(mcp_server) as client: + result = await client.call_tool("list_action_logs", {}) + + data = json.loads(result.content[0].text) + assert data["action_logs"] is not None + assert len(data["action_logs"]) == 1 + assert data["action_logs"][0]["id"] == 1 + assert data["action_logs"][0]["action"] == "log" + assert data["action_logs"][0]["user_id"] == 1 + + +@patch("superset.daos.log.LogDAO.list") +@pytest.mark.asyncio +async def test_list_action_logs_default_7day_filter_applied(mock_list, mcp_server): + """When no dttm filter is provided, a 7-day filter is injected automatically.""" + mock_list.return_value = ([], 0) + + async with Client(mcp_server) as client: + await client.call_tool("list_action_logs", {}) + + # Verify list() was called with a dttm filter in column_operators + call_kwargs = mock_list.call_args.kwargs + col_operators = call_kwargs.get("column_operators", []) + dttm_filters = [f for f in col_operators if getattr(f, "col", None) == "dttm"] + assert len(dttm_filters) == 1 + assert dttm_filters[0].opr == "gte" + + +@patch("superset.daos.log.LogDAO.list") +@pytest.mark.asyncio +async def test_list_action_logs_explicit_dttm_filter_skips_default( + mock_list, mcp_server +): + """When a dttm filter is provided, the default 7-day filter is NOT injected.""" + mock_list.return_value = ([], 0) + + request = ListActionLogsRequest( + filters=[{"col": "dttm", "opr": "gte", "value": "2020-01-01T00:00:00"}] + ) + + async with Client(mcp_server) as client: + await client.call_tool("list_action_logs", {"request": request.model_dump()}) + + call_kwargs = mock_list.call_args.kwargs + col_operators = call_kwargs.get("column_operators", []) + dttm_filters = [f for f in col_operators if getattr(f, "col", None) == "dttm"] + # Only the user-provided filter, not the injected default + assert len(dttm_filters) == 1 + assert dttm_filters[0].value == "2020-01-01T00:00:00" + + +@patch("superset.daos.log.LogDAO.list") +@pytest.mark.asyncio +async def test_list_action_logs_default_sort_is_dttm_desc(mock_list, mcp_server): + """Default sort is dttm descending (most recent first).""" + mock_list.return_value = ([], 0) + + async with Client(mcp_server) as client: + await client.call_tool("list_action_logs", {}) + + call_kwargs = mock_list.call_args.kwargs + assert call_kwargs.get("order_column") == "dttm" + assert call_kwargs.get("order_direction") == "desc" + + +@patch("superset.daos.log.LogDAO.list") +@pytest.mark.asyncio +async def test_list_action_logs_pagination(mock_list, mcp_server): + """Pagination metadata is correct.""" + logs = [create_mock_log(log_id=i) for i in range(1, 6)] + mock_list.return_value = (logs, 20) + + async with Client(mcp_server) as client: + request = ListActionLogsRequest(page=1, page_size=5) + result = await client.call_tool( + "list_action_logs", {"request": request.model_dump()} + ) + + data = json.loads(result.content[0].text) + assert data["count"] == 5 + assert data["total_count"] == 20 + assert data["page"] == 1 + assert data["page_size"] == 5 + assert data["has_next"] is True + assert data["has_previous"] is False + + +@patch("superset.daos.log.LogDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_action_log_info_basic(mock_find, mcp_server): + """get_action_log_info returns log details by integer ID.""" + log = create_mock_log(log_id=42, action="explore_chart", user_id=7) + mock_find.return_value = log + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_action_log_info", {"request": {"identifier": 42}} + ) + + data = json.loads(result.content[0].text) + assert data["id"] == 42 + assert data["action"] == "explore_chart" + assert data["user_id"] == 7 + + +@patch("superset.daos.log.LogDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_action_log_info_not_found(mock_find, mcp_server): + """get_action_log_info returns error when log does not exist.""" + mock_find.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_action_log_info", {"request": {"identifier": 9999}} + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "not_found" diff --git a/tests/unit_tests/mcp_service/task/__init__.py b/tests/unit_tests/mcp_service/task/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/task/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/task/tool/__init__.py b/tests/unit_tests/mcp_service/task/tool/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/task/tool/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/task/tool/test_task_tools.py b/tests/unit_tests/mcp_service/task/tool/test_task_tools.py new file mode 100644 index 000000000000..8cb616ff435b --- /dev/null +++ b/tests/unit_tests/mcp_service/task/tool/test_task_tools.py @@ -0,0 +1,245 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Unit tests for list_tasks and get_task_info MCP tools.""" + +import uuid +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest +from fastmcp import Client +from pydantic import ValidationError + +from superset.mcp_service.app import mcp +from superset.mcp_service.task.schemas import ListTasksRequest, TaskColumnFilter +from superset.utils import json + +SAMPLE_UUID = str(uuid.uuid4()) + + +def create_mock_task( + task_id: int = 1, + task_uuid: str | None = None, + task_type: str = "sql_execution", + status: str = "success", + scope: str = "private", + changed_on: datetime | None = None, + created_on: datetime | None = None, +) -> MagicMock: + task = MagicMock() + task.id = task_id + task.uuid = task_uuid or SAMPLE_UUID + task.task_type = task_type + task.status = status + task.scope = scope + task.changed_on = changed_on or datetime(2024, 1, 2, 10, 0, 0, tzinfo=timezone.utc) + task.created_on = created_on or datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc) + return task + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + from unittest.mock import Mock + + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "testuser" + mock_get_user.return_value = mock_user + yield mock_get_user + + +class TestTaskColumnFilterSchema: + def test_valid_filter_columns_accepted(self): + for col in ("task_type", "status", "scope"): + f = TaskColumnFilter(col=col, opr="eq", value="test") + assert f.col == col + + def test_invalid_filter_column_rejected(self): + with pytest.raises(ValidationError): + TaskColumnFilter(col="user_id", opr="eq", value=1) + + def test_uuid_column_rejected(self): + with pytest.raises(ValidationError): + TaskColumnFilter(col="uuid", opr="eq", value="some-uuid") + + +@patch("superset.daos.tasks.TaskDAO.list") +@pytest.mark.asyncio +async def test_list_tasks_basic(mock_list, mcp_server): + """Basic task listing returns structured response.""" + task = create_mock_task() + mock_list.return_value = ([task], 1) + + async with Client(mcp_server) as client: + result = await client.call_tool("list_tasks", {}) + + data = json.loads(result.content[0].text) + assert data["tasks"] is not None + assert len(data["tasks"]) == 1 + assert data["tasks"][0]["id"] == 1 + assert data["tasks"][0]["task_type"] == "sql_execution" + assert data["tasks"][0]["status"] == "success" + + +@patch("superset.daos.tasks.TaskDAO.list") +@pytest.mark.asyncio +async def test_list_tasks_with_status_filter(mock_list, mcp_server): + """Status filter is passed through to the DAO correctly.""" + task = create_mock_task(status="pending") + mock_list.return_value = ([task], 1) + + async with Client(mcp_server) as client: + request = ListTasksRequest( + filters=[{"col": "status", "opr": "eq", "value": "pending"}] + ) + result = await client.call_tool("list_tasks", {"request": request.model_dump()}) + + data = json.loads(result.content[0].text) + assert len(data["tasks"]) == 1 + assert data["tasks"][0]["status"] == "pending" + + +@patch("superset.daos.tasks.TaskDAO.list") +@pytest.mark.asyncio +async def test_list_tasks_taskfilter_scoping_is_applied(mock_list, mcp_server): + """TaskDAO.list is called with base_filter (TaskFilter) applied via DAO class.""" + mock_list.return_value = ([], 0) + + async with Client(mcp_server) as client: + await client.call_tool("list_tasks", {}) + + # Verify the DAO's list() is called — the TaskFilter scoping is enforced + # by TaskDAO.base_filter = TaskFilter which BaseDAO applies automatically. + assert mock_list.called + + +@patch("superset.daos.tasks.TaskDAO.list") +@pytest.mark.asyncio +async def test_list_tasks_pagination(mock_list, mcp_server): + """Pagination metadata is correct.""" + tasks = [create_mock_task(task_id=i) for i in range(1, 4)] + mock_list.return_value = (tasks, 30) + + async with Client(mcp_server) as client: + request = ListTasksRequest(page=2, page_size=3) + result = await client.call_tool("list_tasks", {"request": request.model_dump()}) + + data = json.loads(result.content[0].text) + assert data["count"] == 3 + assert data["total_count"] == 30 + assert data["page"] == 2 + assert data["page_size"] == 3 + assert data["has_previous"] is True + assert data["has_next"] is True + + +@patch("superset.daos.tasks.TaskDAO.list") +@pytest.mark.asyncio +async def test_list_tasks_uuid_in_response(mock_list, mcp_server): + """Task UUID is returned as a string in the response.""" + task_uuid = str(uuid.uuid4()) + task = create_mock_task(task_uuid=task_uuid) + mock_list.return_value = ([task], 1) + + async with Client(mcp_server) as client: + result = await client.call_tool("list_tasks", {}) + + data = json.loads(result.content[0].text) + assert data["tasks"][0]["uuid"] == task_uuid + + +@patch("superset.daos.tasks.TaskDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_task_info_by_integer_id(mock_find, mcp_server): + """get_task_info resolves a task by integer ID.""" + task = create_mock_task(task_id=5, task_type="thumbnail", status="in_progress") + mock_find.return_value = task + + async with Client(mcp_server) as client: + result = await client.call_tool("get_task_info", {"request": {"identifier": 5}}) + + data = json.loads(result.content[0].text) + assert data["id"] == 5 + assert data["task_type"] == "thumbnail" + assert data["status"] == "in_progress" + + +@patch("superset.daos.tasks.TaskDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_task_info_by_uuid(mock_find, mcp_server): + """get_task_info resolves a task by UUID string.""" + task_uuid = str(uuid.uuid4()) + task = create_mock_task(task_id=10, task_uuid=task_uuid, status="success") + mock_find.return_value = task + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_task_info", {"request": {"identifier": task_uuid}} + ) + + data = json.loads(result.content[0].text) + assert data["id"] == 10 + assert data["status"] == "success" + + +@patch("superset.daos.tasks.TaskDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_task_info_not_found(mock_find, mcp_server): + """get_task_info returns error when task does not exist.""" + mock_find.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_task_info", {"request": {"identifier": 9999}} + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "not_found" + + +@patch("superset.daos.tasks.TaskDAO.list") +@pytest.mark.asyncio +async def test_list_tasks_non_admin_sees_only_subscribed(mock_list, mcp_server): + """Non-admin users only receive tasks their subscriptions allow. + + The subscription scoping is enforced by TaskDAO.base_filter = TaskFilter, + which BaseDAO._apply_base_filter injects before each query. The MCP tool + itself adds no extra filtering — it just delegates to TaskDAO.list(), which + carries the filter class. This test confirms that: + + 1. list_tasks calls TaskDAO.list() (so the base_filter runs) + 2. Only items returned by that call appear in the response + """ + # Simulate DAO returning only the subscribed task + subscribed_task = create_mock_task(task_id=42, status="pending") + mock_list.return_value = ([subscribed_task], 1) + + async with Client(mcp_server) as client: + result = await client.call_tool("list_tasks", {}) + + data = json.loads(result.content[0].text) + assert len(data["tasks"]) == 1 + assert data["tasks"][0]["id"] == 42 + # TaskDAO.list was called exactly once — base_filter is applied inside + assert mock_list.call_count == 1 From 58bbbd784967e0d16016eaa5bcf8356faba9f782 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 20 May 2026 23:14:47 +0000 Subject: [PATCH 02/13] fix(mcp): convert dttm cutoff to ISO string so filters_applied validates The injected 7-day default filter used a datetime object as the value, but ActionLogFilter.value only allows str|int|float|bool|list. Pydantic rejects the datetime when building the filters_applied list in ActionLogList, causing a ValidationError on every call that triggered the default filter. Co-Authored-By: Claude Sonnet 4.6 --- superset/mcp_service/action_log/tool/list_action_logs.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/superset/mcp_service/action_log/tool/list_action_logs.py b/superset/mcp_service/action_log/tool/list_action_logs.py index a1ec8b07c8fc..eee80e20886a 100644 --- a/superset/mcp_service/action_log/tool/list_action_logs.py +++ b/superset/mcp_service/action_log/tool/list_action_logs.py @@ -90,16 +90,14 @@ async def list_action_logs( filters: list[ColumnOperator] = list(request.filters) has_dttm_filter = any(getattr(f, "col", None) == "dttm" for f in filters) if not has_dttm_filter: - cutoff = datetime.now(timezone.utc) - timedelta(days=7) + cutoff = (datetime.now(timezone.utc) - timedelta(days=7)).isoformat() default_filter = ColumnOperator( col="dttm", opr=ColumnOperatorEnum.gte, value=cutoff, ) filters = [default_filter] + filters - await ctx.debug( - "Applied default 7-day dttm filter: cutoff=%s" % (cutoff.isoformat(),) - ) + await ctx.debug("Applied default 7-day dttm filter: cutoff=%s" % (cutoff,)) def _serialize(obj: object, cols: list[str] | None) -> ActionLogInfo | None: return serialize_action_log_object(obj) From f75e448ffe262cac0cd4bd909d9f4e4aee9bb5eb Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 20 May 2026 23:26:38 +0000 Subject: [PATCH 03/13] fix(mcp): field filtering and search for action-log and task list tools - Add model_serializer to ActionLogInfo and TaskInfo that drops non-requested fields from output when select_columns context is set, matching the DatabaseInfo pattern - Switch list_action_logs and list_tasks to return model_dump with serialization context so only requested columns appear in responses - Add search field + search-XOR-filters validator to ListActionLogsRequest and ListTasksRequest - Pass search=request.search through to ModelListCore.run_tool() Co-Authored-By: Claude Sonnet 4.6 --- superset/mcp_service/action_log/schemas.py | 40 ++++++++++++++++++- .../action_log/tool/list_action_logs.py | 12 +++++- superset/mcp_service/task/schemas.py | 40 ++++++++++++++++++- superset/mcp_service/task/tool/list_tasks.py | 12 +++++- 4 files changed, 100 insertions(+), 4 deletions(-) diff --git a/superset/mcp_service/action_log/schemas.py b/superset/mcp_service/action_log/schemas.py index c6697cc70cbf..9eb54b8fc620 100644 --- a/superset/mcp_service/action_log/schemas.py +++ b/superset/mcp_service/action_log/schemas.py @@ -22,7 +22,15 @@ from datetime import datetime from typing import Annotated, Any, Literal -from pydantic import BaseModel, ConfigDict, Field, field_validator, PositiveInt +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_serializer, + model_validator, + PositiveInt, +) from superset.daos.base import ColumnOperator, ColumnOperatorEnum from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE @@ -86,6 +94,16 @@ def model_post_init(self, __context: Any) -> None: object.__setattr__(self, "dttm", self.dttm.replace(tzinfo=timezone.utc)) + @model_serializer(mode="wrap") + def _filter_fields_by_context(self, serializer: Any, info: Any) -> dict[str, Any]: + data = serializer(self) + if info.context and isinstance(info.context, dict): + select_columns = info.context.get("select_columns") + if select_columns: + requested_fields = set(select_columns) + return {k: v for k, v in data.items() if k in requested_fields} + return data + class ActionLogList(BaseModel): action_logs: list[ActionLogInfo] @@ -127,6 +145,16 @@ class ListActionLogsRequest(BaseModel): description="Columns to return. Defaults to common columns.", ), ] + search: Annotated[ + str | None, + Field( + default=None, + description=( + "Text search string matched against action. " + "Cannot be used together with 'filters'." + ), + ), + ] order_column: Annotated[ str | None, Field(default=None, description="Column to sort by (default: dttm)"), @@ -159,6 +187,16 @@ def parse_filters(cls, v: Any) -> list[ActionLogFilter]: def parse_columns(cls, v: Any) -> list[str]: return parse_json_or_list(v, "select_columns") + @model_validator(mode="after") + def validate_search_and_filters(self) -> "ListActionLogsRequest": + if self.search and self.filters: + raise ValueError( + "Cannot use both 'search' and 'filters' simultaneously. " + "Use 'search' for text matching on action, or 'filters' for " + "column-based filtering, but not both." + ) + return self + class ActionLogError(BaseModel): error: str = Field(..., description="Error message") diff --git a/superset/mcp_service/action_log/tool/list_action_logs.py b/superset/mcp_service/action_log/tool/list_action_logs.py index eee80e20886a..199f0190e628 100644 --- a/superset/mcp_service/action_log/tool/list_action_logs.py +++ b/superset/mcp_service/action_log/tool/list_action_logs.py @@ -119,6 +119,7 @@ def _serialize(obj: object, cols: list[str] | None) -> ActionLogInfo | None: with event_logger.log_context(action="mcp.list_action_logs.query"): result = list_tool.run_tool( filters=filters, + search=request.search, select_columns=request.select_columns, order_column=request.order_column or "dttm", order_direction=request.order_direction, @@ -133,7 +134,16 @@ def _serialize(obj: object, cols: list[str] | None) -> ActionLogInfo | None: getattr(result, "total_count", None), ) ) - return result + columns_to_filter = result.columns_requested + await ctx.debug( + "Applying field filtering via serialization context: columns=%s" + % (columns_to_filter,) + ) + with event_logger.log_context(action="mcp.list_action_logs.serialization"): + return result.model_dump( + mode="json", + context={"select_columns": columns_to_filter}, + ) except Exception as e: await ctx.error( diff --git a/superset/mcp_service/task/schemas.py b/superset/mcp_service/task/schemas.py index 7f41bc99263d..af5e7f6662d3 100644 --- a/superset/mcp_service/task/schemas.py +++ b/superset/mcp_service/task/schemas.py @@ -22,7 +22,15 @@ from datetime import datetime from typing import Annotated, Any, Literal -from pydantic import BaseModel, ConfigDict, Field, field_validator, PositiveInt +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_serializer, + model_validator, + PositiveInt, +) from superset.daos.base import ColumnOperator, ColumnOperatorEnum from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE @@ -80,6 +88,16 @@ class TaskInfo(BaseModel): populate_by_name=True, ) + @model_serializer(mode="wrap") + def _filter_fields_by_context(self, serializer: Any, info: Any) -> dict[str, Any]: + data = serializer(self) + if info.context and isinstance(info.context, dict): + select_columns = info.context.get("select_columns") + if select_columns: + requested_fields = set(select_columns) + return {k: v for k, v in data.items() if k in requested_fields} + return data + class TaskList(BaseModel): tasks: list[TaskInfo] @@ -121,6 +139,16 @@ class ListTasksRequest(BaseModel): description="Columns to return. Defaults to common columns.", ), ] + search: Annotated[ + str | None, + Field( + default=None, + description=( + "Text search string matched against task_type, status, and scope. " + "Cannot be used together with 'filters'." + ), + ), + ] order_column: Annotated[ str | None, Field(default=None, description="Column to sort by (default: changed_on)"), @@ -153,6 +181,16 @@ def parse_filters(cls, v: Any) -> list[TaskColumnFilter]: def parse_columns(cls, v: Any) -> list[str]: return parse_json_or_list(v, "select_columns") + @model_validator(mode="after") + def validate_search_and_filters(self) -> "ListTasksRequest": + if self.search and self.filters: + raise ValueError( + "Cannot use both 'search' and 'filters' simultaneously. " + "Use 'search' for text matching on task_type/status/scope, or " + "'filters' for column-based filtering, but not both." + ) + return self + class TaskError(BaseModel): error: str = Field(..., description="Error message") diff --git a/superset/mcp_service/task/tool/list_tasks.py b/superset/mcp_service/task/tool/list_tasks.py index 21a0aea5e674..a9ca6329145e 100644 --- a/superset/mcp_service/task/tool/list_tasks.py +++ b/superset/mcp_service/task/tool/list_tasks.py @@ -105,6 +105,7 @@ def _serialize(obj: object, cols: list[str] | None) -> TaskInfo | None: with event_logger.log_context(action="mcp.list_tasks.query"): result = list_tool.run_tool( filters=request.filters, + search=request.search, select_columns=request.select_columns, order_column=request.order_column, order_direction=request.order_direction, @@ -119,7 +120,16 @@ def _serialize(obj: object, cols: list[str] | None) -> TaskInfo | None: getattr(result, "total_count", None), ) ) - return result + columns_to_filter = result.columns_requested + await ctx.debug( + "Applying field filtering via serialization context: columns=%s" + % (columns_to_filter,) + ) + with event_logger.log_context(action="mcp.list_tasks.serialization"): + return result.model_dump( + mode="json", + context={"select_columns": columns_to_filter}, + ) except Exception as e: await ctx.error( From 293707123c7ff090ae0e4ad764585cf9757b8273 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 20 May 2026 23:30:59 +0000 Subject: [PATCH 04/13] fix(mcp): add task_key/task_name to TaskInfo and strengthen test coverage - Add task_key and task_name fields to TaskInfo schema and ALL_TASK_COLUMNS; these are real Task model columns present in the REST API search_columns - Expand search_columns in list_tasks to include task_key and task_name - Strengthen test_list_action_logs_default_7day_filter_applied to also assert the injected filter appears in filters_applied with an ISO string value Co-Authored-By: Claude Sonnet 4.6 --- superset/mcp_service/task/schemas.py | 9 ++++++++- superset/mcp_service/task/tool/list_tasks.py | 3 ++- .../action_log/tool/test_action_log_tools.py | 10 +++++++++- .../mcp_service/task/tool/test_task_tools.py | 4 ++++ 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/superset/mcp_service/task/schemas.py b/superset/mcp_service/task/schemas.py index af5e7f6662d3..c908abf19e52 100644 --- a/superset/mcp_service/task/schemas.py +++ b/superset/mcp_service/task/schemas.py @@ -45,6 +45,8 @@ "id", "uuid", "task_type", + "task_key", + "task_name", "status", "scope", "changed_on", @@ -75,6 +77,8 @@ class TaskInfo(BaseModel): id: int | None = Field(None, description="Task ID") uuid: str | None = Field(None, description="Task UUID") task_type: str | None = Field(None, description="Task type (e.g., sql_execution)") + task_key: str | None = Field(None, description="Task deduplication key") + task_name: str | None = Field(None, description="Human-readable task name") status: str | None = Field(None, description="Task status") scope: str | None = Field(None, description="Task scope (private/shared/system)") changed_on: str | datetime | None = Field( @@ -144,7 +148,8 @@ class ListTasksRequest(BaseModel): Field( default=None, description=( - "Text search string matched against task_type, status, and scope. " + "Text search string matched against task_type, task_key, " + "task_name, status, and scope. " "Cannot be used together with 'filters'." ), ), @@ -226,6 +231,8 @@ def serialize_task_object(task: Any) -> TaskInfo | None: id=getattr(task, "id", None), uuid=str(uuid_val) if uuid_val is not None else None, task_type=getattr(task, "task_type", None), + task_key=getattr(task, "task_key", None), + task_name=getattr(task, "task_name", None), status=getattr(task, "status", None), scope=getattr(task, "scope", None), changed_on=getattr(task, "changed_on", None), diff --git a/superset/mcp_service/task/tool/list_tasks.py b/superset/mcp_service/task/tool/list_tasks.py index a9ca6329145e..1aa701dd4aac 100644 --- a/superset/mcp_service/task/tool/list_tasks.py +++ b/superset/mcp_service/task/tool/list_tasks.py @@ -62,6 +62,7 @@ async def list_tasks( Sortable columns for order_column: id, changed_on, created_on, status Filter columns: task_type, status, scope + Search columns (via search=): task_type, task_key, task_name, status, scope Common task_type values: sql_execution, thumbnail, report Common status values: pending, in_progress, success, failure, aborted @@ -94,7 +95,7 @@ def _serialize(obj: object, cols: list[str] | None) -> TaskInfo | None: item_serializer=_serialize, filter_type=TaskColumnFilter, default_columns=DEFAULT_TASK_COLUMNS, - search_columns=["task_type", "status", "scope"], + search_columns=["task_type", "task_key", "task_name", "status", "scope"], list_field_name="tasks", output_list_schema=TaskList, all_columns=ALL_TASK_COLUMNS, diff --git a/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py b/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py index 0f57cdf56ffd..89496ace78c8 100644 --- a/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py +++ b/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py @@ -111,7 +111,7 @@ async def test_list_action_logs_default_7day_filter_applied(mock_list, mcp_serve mock_list.return_value = ([], 0) async with Client(mcp_server) as client: - await client.call_tool("list_action_logs", {}) + result = await client.call_tool("list_action_logs", {}) # Verify list() was called with a dttm filter in column_operators call_kwargs = mock_list.call_args.kwargs @@ -120,6 +120,14 @@ async def test_list_action_logs_default_7day_filter_applied(mock_list, mcp_serve assert len(dttm_filters) == 1 assert dttm_filters[0].opr == "gte" + # Verify the injected filter appears in the serialized filters_applied + data = json.loads(result.content[0].text) + filters_applied = data.get("filters_applied", []) + dttm_applied = [f for f in filters_applied if f.get("col") == "dttm"] + assert len(dttm_applied) == 1 + assert dttm_applied[0]["opr"] == "gte" + assert isinstance(dttm_applied[0]["value"], str) # ISO string, not datetime + @patch("superset.daos.log.LogDAO.list") @pytest.mark.asyncio diff --git a/tests/unit_tests/mcp_service/task/tool/test_task_tools.py b/tests/unit_tests/mcp_service/task/tool/test_task_tools.py index 8cb616ff435b..9257b2c002dd 100644 --- a/tests/unit_tests/mcp_service/task/tool/test_task_tools.py +++ b/tests/unit_tests/mcp_service/task/tool/test_task_tools.py @@ -36,6 +36,8 @@ def create_mock_task( task_id: int = 1, task_uuid: str | None = None, task_type: str = "sql_execution", + task_key: str = "default-key", + task_name: str | None = None, status: str = "success", scope: str = "private", changed_on: datetime | None = None, @@ -45,6 +47,8 @@ def create_mock_task( task.id = task_id task.uuid = task_uuid or SAMPLE_UUID task.task_type = task_type + task.task_key = task_key + task.task_name = task_name task.status = status task.scope = scope task.changed_on = changed_on or datetime(2024, 1, 2, 10, 0, 0, tzinfo=timezone.utc) From b6566312a7569550b30ebd8c51532ed99ae9afe8 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 21 May 2026 01:37:26 +0000 Subject: [PATCH 05/13] fix(mcp): use ActionLogFilter for injected default dttm filter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pydantic v2 rejects a ColumnOperator instance when validating list[ActionLogFilter] — it requires an exact instance or dict, not a parent-class instance. The injected 7-day default dttm filter was created as a plain ColumnOperator, causing every test_list_action_logs_* call to fail with '1 validation error for ActionLogList'. Fix: construct the default filter as ActionLogFilter (which is a subclass of ColumnOperator), so it passes pydantic validation for ActionLogList.filters_applied: list[ActionLogFilter] and is still accepted everywhere ColumnOperator is expected. --- superset/mcp_service/action_log/tool/list_action_logs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/mcp_service/action_log/tool/list_action_logs.py b/superset/mcp_service/action_log/tool/list_action_logs.py index 199f0190e628..d0a8e97580cd 100644 --- a/superset/mcp_service/action_log/tool/list_action_logs.py +++ b/superset/mcp_service/action_log/tool/list_action_logs.py @@ -91,7 +91,7 @@ async def list_action_logs( has_dttm_filter = any(getattr(f, "col", None) == "dttm" for f in filters) if not has_dttm_filter: cutoff = (datetime.now(timezone.utc) - timedelta(days=7)).isoformat() - default_filter = ColumnOperator( + default_filter = ActionLogFilter( col="dttm", opr=ColumnOperatorEnum.gte, value=cutoff, From d5b95daa6b203b5b56a52db89e3baccfa194b893 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 21 May 2026 03:12:13 +0000 Subject: [PATCH 06/13] ci: trigger CI for fix From 28461c7591d3dad9f479d8ed16a1ed6cbb4976ff Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 27 May 2026 15:28:15 +0000 Subject: [PATCH 07/13] fix(mcp): normalize naive datetimes to UTC and strengthen task tool tests - Normalize changed_on/created_on naive datetimes in serialize_task_object (mirrors serialize_action_log_object pattern for dttm) - Add filter-forwarding assertion to test_list_tasks_with_status_filter - Add id_column="uuid" assertion to test_get_task_info_by_uuid Co-Authored-By: Claude Sonnet 4.6 --- superset/mcp_service/task/schemas.py | 12 ++++++++++-- .../mcp_service/task/tool/test_task_tools.py | 11 +++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/superset/mcp_service/task/schemas.py b/superset/mcp_service/task/schemas.py index c908abf19e52..3836ca068192 100644 --- a/superset/mcp_service/task/schemas.py +++ b/superset/mcp_service/task/schemas.py @@ -226,7 +226,15 @@ class GetTaskInfoRequest(BaseModel): def serialize_task_object(task: Any) -> TaskInfo | None: if not task: return None + from datetime import timezone + uuid_val = getattr(task, "uuid", None) + changed_on = getattr(task, "changed_on", None) + if isinstance(changed_on, datetime) and changed_on.tzinfo is None: + changed_on = changed_on.replace(tzinfo=timezone.utc) + created_on = getattr(task, "created_on", None) + if isinstance(created_on, datetime) and created_on.tzinfo is None: + created_on = created_on.replace(tzinfo=timezone.utc) return TaskInfo( id=getattr(task, "id", None), uuid=str(uuid_val) if uuid_val is not None else None, @@ -235,6 +243,6 @@ def serialize_task_object(task: Any) -> TaskInfo | None: task_name=getattr(task, "task_name", None), status=getattr(task, "status", None), scope=getattr(task, "scope", None), - changed_on=getattr(task, "changed_on", None), - created_on=getattr(task, "created_on", None), + changed_on=changed_on, + created_on=created_on, ) diff --git a/tests/unit_tests/mcp_service/task/tool/test_task_tools.py b/tests/unit_tests/mcp_service/task/tool/test_task_tools.py index 9257b2c002dd..a35576b34cb2 100644 --- a/tests/unit_tests/mcp_service/task/tool/test_task_tools.py +++ b/tests/unit_tests/mcp_service/task/tool/test_task_tools.py @@ -123,6 +123,14 @@ async def test_list_tasks_with_status_filter(mock_list, mcp_server): assert len(data["tasks"]) == 1 assert data["tasks"][0]["status"] == "pending" + # Verify the filter was forwarded to the DAO + call_kwargs = mock_list.call_args.kwargs + col_operators = call_kwargs.get("column_operators", []) + status_filters = [f for f in col_operators if getattr(f, "col", None) == "status"] + assert len(status_filters) == 1 + assert status_filters[0].opr.value == "eq" + assert status_filters[0].value == "pending" + @patch("superset.daos.tasks.TaskDAO.list") @pytest.mark.asyncio @@ -206,6 +214,9 @@ async def test_get_task_info_by_uuid(mock_find, mcp_server): assert data["id"] == 10 assert data["status"] == "success" + # Verify the DAO was called with id_column="uuid" for UUID-style identifiers + mock_find.assert_called_once_with(task_uuid, id_column="uuid", query_options=None) + @patch("superset.daos.tasks.TaskDAO.find_by_id") @pytest.mark.asyncio From 946bb37adb159eafe86ea2138e93bd3bc1ad395d Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 27 May 2026 18:23:43 +0000 Subject: [PATCH 08/13] fix(mcp): add config guards and fix dttm filter type for action/task tools --- superset/mcp_service/action_log/schemas.py | 2 +- .../action_log/tool/list_action_logs.py | 2 +- superset/mcp_service/app.py | 31 +++++++ .../action_log/tool/test_action_log_tools.py | 4 +- .../mcp_service/test_mcp_tool_registration.py | 86 ++++++++++++++++++- 5 files changed, 118 insertions(+), 7 deletions(-) diff --git a/superset/mcp_service/action_log/schemas.py b/superset/mcp_service/action_log/schemas.py index 9eb54b8fc620..74fb083bcd97 100644 --- a/superset/mcp_service/action_log/schemas.py +++ b/superset/mcp_service/action_log/schemas.py @@ -66,7 +66,7 @@ class ActionLogFilter(ColumnOperator): description="Column to filter on.", ) opr: ColumnOperatorEnum = Field(..., description="Operator to use.") - value: str | int | float | bool | list[str | int | float | bool] = Field( + value: str | int | float | bool | datetime | list[str | int | float | bool] = Field( ..., description="Value to filter by" ) diff --git a/superset/mcp_service/action_log/tool/list_action_logs.py b/superset/mcp_service/action_log/tool/list_action_logs.py index d0a8e97580cd..a7df565989fe 100644 --- a/superset/mcp_service/action_log/tool/list_action_logs.py +++ b/superset/mcp_service/action_log/tool/list_action_logs.py @@ -90,7 +90,7 @@ async def list_action_logs( filters: list[ColumnOperator] = list(request.filters) has_dttm_filter = any(getattr(f, "col", None) == "dttm" for f in filters) if not has_dttm_filter: - cutoff = (datetime.now(timezone.utc) - timedelta(days=7)).isoformat() + cutoff = datetime.now(timezone.utc) - timedelta(days=7) default_filter = ActionLogFilter( col="dttm", opr=ColumnOperatorEnum.gte, diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 851ed0d76e82..eeefcb8d69e9 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -699,6 +699,36 @@ def _remove_disabled_tools(disabled_tools: set[str]) -> None: ) +def _remove_tool_quietly(tool_name: str, reason: str) -> None: + """Remove a single tool from the global MCP instance, ignoring missing-tool errors.""" + try: + mcp.local_provider.remove_tool(tool_name) + logger.info("Disabled MCP tool: %s (%s)", tool_name, reason) + except KeyError: + pass + + +def _apply_config_guards(flask_app: Any) -> None: + """Remove tools whose backing features are administratively disabled. + + - Action-log tools: mirrors LogRestApi.is_enabled() which checks + FAB_ADD_SECURITY_VIEWS and SUPERSET_LOG_VIEW. + - Task tools: mirrors TaskRestApi conditional registration which checks + the GLOBAL_TASK_FRAMEWORK feature flag. + """ + if not ( + flask_app.config.get("FAB_ADD_SECURITY_VIEWS", True) + and flask_app.config.get("SUPERSET_LOG_VIEW", True) + ): + for tool_name in ("list_action_logs", "get_action_log_info"): + _remove_tool_quietly(tool_name, "logging disabled by config flags") + + feature_flags: dict[str, Any] = flask_app.config.get("FEATURE_FLAGS", {}) + if not feature_flags.get("GLOBAL_TASK_FRAMEWORK", False): + for tool_name in ("list_tasks", "get_task_info"): + _remove_tool_quietly(tool_name, "GLOBAL_TASK_FRAMEWORK not enabled") + + def init_fastmcp_server( name: str | None = None, instructions: str | None = None, @@ -743,6 +773,7 @@ def init_fastmcp_server( # instructions never advertise tools that clients cannot actually call. disabled_tools: set[str] = flask_app.config.get("MCP_DISABLED_TOOLS", set()) _remove_disabled_tools(disabled_tools) + _apply_config_guards(flask_app) if instructions is None: instructions = get_default_instructions(branding, disabled_tools) diff --git a/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py b/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py index 89496ace78c8..3ad4dd234fa1 100644 --- a/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py +++ b/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py @@ -126,7 +126,9 @@ async def test_list_action_logs_default_7day_filter_applied(mock_list, mcp_serve dttm_applied = [f for f in filters_applied if f.get("col") == "dttm"] assert len(dttm_applied) == 1 assert dttm_applied[0]["opr"] == "gte" - assert isinstance(dttm_applied[0]["value"], str) # ISO string, not datetime + assert isinstance( + dttm_applied[0]["value"], str + ) # serialized to ISO string in JSON output @patch("superset.daos.log.LogDAO.list") diff --git a/tests/unit_tests/mcp_service/test_mcp_tool_registration.py b/tests/unit_tests/mcp_service/test_mcp_tool_registration.py index 00a94fa78e95..c20fbb4d721b 100644 --- a/tests/unit_tests/mcp_service/test_mcp_tool_registration.py +++ b/tests/unit_tests/mcp_service/test_mcp_tool_registration.py @@ -106,11 +106,30 @@ def test_mcp_packages_discoverable_by_setuptools(): # --------------------------------------------------------------------------- -def _make_flask_app_mock(disabled_tools: set[str]) -> MagicMock: - """Return a minimal Flask app mock with MCP_DISABLED_TOOLS configured.""" +def _make_flask_app_mock( + disabled_tools: set[str], + feature_flags: dict[str, object] | None = None, + fab_security_views: bool = True, + log_view: bool = True, +) -> MagicMock: + """Return a minimal Flask app mock with MCP config set to safe defaults. + + Defaults enable all feature flags and logging so that tests focused on + MCP_DISABLED_TOOLS are not affected by the config guards added for action-log + and task tools. + """ + _feature_flags: dict[str, object] = ( + {"GLOBAL_TASK_FRAMEWORK": True} if feature_flags is None else feature_flags + ) + _config: dict[str, object] = { + "MCP_DISABLED_TOOLS": disabled_tools, + "FAB_ADD_SECURITY_VIEWS": fab_security_views, + "SUPERSET_LOG_VIEW": log_view, + "FEATURE_FLAGS": _feature_flags, + } flask_app = MagicMock() - flask_app.config.get.side_effect = lambda key, default=None: ( - disabled_tools if key == "MCP_DISABLED_TOOLS" else default + flask_app.config.get.side_effect = lambda key, default=None: _config.get( + key, default ) return flask_app @@ -257,6 +276,65 @@ def test_no_disabled_tools_returns_full_instructions() -> None: assert full == also_full +# --------------------------------------------------------------------------- +# Config-guard tests: action-log tools and task tools +# --------------------------------------------------------------------------- + + +def test_action_log_tools_removed_when_superset_log_view_disabled() -> None: + """Action-log tools removed when SUPERSET_LOG_VIEW=False. + + Mirrors LogRestApi.is_enabled() which checks FAB_ADD_SECURITY_VIEWS and + SUPERSET_LOG_VIEW. + """ + flask_app = _make_flask_app_mock(set(), log_view=False) + + with ( + patch("superset.mcp_service.flask_singleton.app", flask_app), + patch.object(mcp.local_provider, "remove_tool") as mock_remove, + ): + init_fastmcp_server() + + removed = {call.args[0] for call in mock_remove.call_args_list} + assert "list_action_logs" in removed + assert "get_action_log_info" in removed + + +def test_action_log_tools_removed_when_fab_security_views_disabled() -> None: + """Action-log tools removed when FAB_ADD_SECURITY_VIEWS=False.""" + flask_app = _make_flask_app_mock(set(), fab_security_views=False) + + with ( + patch("superset.mcp_service.flask_singleton.app", flask_app), + patch.object(mcp.local_provider, "remove_tool") as mock_remove, + ): + init_fastmcp_server() + + removed = {call.args[0] for call in mock_remove.call_args_list} + assert "list_action_logs" in removed + assert "get_action_log_info" in removed + + +def test_task_tools_removed_when_global_task_framework_disabled() -> None: + """Task tools removed when GLOBAL_TASK_FRAMEWORK=False. + + Mirrors TaskRestApi conditional registration in initialization/__init__.py. + """ + flask_app = _make_flask_app_mock( + set(), feature_flags={"GLOBAL_TASK_FRAMEWORK": False} + ) + + with ( + patch("superset.mcp_service.flask_singleton.app", flask_app), + patch.object(mcp.local_provider, "remove_tool") as mock_remove, + ): + init_fastmcp_server() + + removed = {call.args[0] for call in mock_remove.call_args_list} + assert "list_tasks" in removed + assert "get_task_info" in removed + + def test_instructions_generated_after_disabled_tools_removed() -> None: """init_fastmcp_server generates instructions AFTER removing disabled tools, so the instructions never advertise tools that clients cannot call.""" From e3ba2d170e9c068d76b9e27a37e5a65a5127a28d Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 27 May 2026 18:27:35 +0000 Subject: [PATCH 09/13] docs(mcp): correct action log tool docstrings to reflect actual RBAC permissions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The ADMIN-ONLY label was misleading — access is gated by the Log permission in Superset's RBAC, not a hard admin check. Updated both docstrings to describe the actual permission model. --- superset/mcp_service/action_log/tool/get_action_log_info.py | 3 ++- superset/mcp_service/action_log/tool/list_action_logs.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/superset/mcp_service/action_log/tool/get_action_log_info.py b/superset/mcp_service/action_log/tool/get_action_log_info.py index 46ac6b8fde9c..db5c9329fc94 100644 --- a/superset/mcp_service/action_log/tool/get_action_log_info.py +++ b/superset/mcp_service/action_log/tool/get_action_log_info.py @@ -53,7 +53,8 @@ async def get_action_log_info( Returns the action, user_id, timestamp (dttm), dashboard_id, slice_id, and JSON payload for the specified log record. - ADMIN-ONLY: This tool requires admin privileges. + Requires the Log permission (controlled by Superset's RBAC). Users without + that permission will receive a permission error. Use list_action_logs to discover log IDs. """ diff --git a/superset/mcp_service/action_log/tool/list_action_logs.py b/superset/mcp_service/action_log/tool/list_action_logs.py index a7df565989fe..f4d85fdc8859 100644 --- a/superset/mcp_service/action_log/tool/list_action_logs.py +++ b/superset/mcp_service/action_log/tool/list_action_logs.py @@ -61,8 +61,8 @@ async def list_action_logs( Returns audit log entries recording user interactions with dashboards and charts. Defaults to the last 7 days to avoid pulling large result sets. - ADMIN-ONLY: This tool requires admin privileges. Non-admin users will - receive a permission error. + Requires the Log permission (controlled by Superset's RBAC). Users without + that permission will receive a permission error. Sortable columns for order_column: id, dttm Filter columns: action, user_id, dashboard_id, slice_id, dttm From 8a84b91f89e021d8d0248542c9075eb152f25a25 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 28 May 2026 00:28:49 +0000 Subject: [PATCH 10/13] fix(mcp): use feature_flag_manager for GLOBAL_TASK_FRAMEWORK guard in app.py --- superset/mcp_service/app.py | 9 ++-- .../mcp_service/test_mcp_tool_registration.py | 42 ++++++++++++------- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index eeefcb8d69e9..e29fa320e6db 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -714,7 +714,9 @@ def _apply_config_guards(flask_app: Any) -> None: - Action-log tools: mirrors LogRestApi.is_enabled() which checks FAB_ADD_SECURITY_VIEWS and SUPERSET_LOG_VIEW. - Task tools: mirrors TaskRestApi conditional registration which checks - the GLOBAL_TASK_FRAMEWORK feature flag. + the GLOBAL_TASK_FRAMEWORK feature flag via feature_flag_manager so that + all Superset enablement paths (DEFAULT_FEATURE_FLAGS, GET_FEATURE_FLAGS_FUNC, + IS_FEATURE_ENABLED_FUNC, etc.) are respected. """ if not ( flask_app.config.get("FAB_ADD_SECURITY_VIEWS", True) @@ -723,8 +725,9 @@ def _apply_config_guards(flask_app: Any) -> None: for tool_name in ("list_action_logs", "get_action_log_info"): _remove_tool_quietly(tool_name, "logging disabled by config flags") - feature_flags: dict[str, Any] = flask_app.config.get("FEATURE_FLAGS", {}) - if not feature_flags.get("GLOBAL_TASK_FRAMEWORK", False): + from superset.extensions import feature_flag_manager # noqa: PLC0415 + + if not feature_flag_manager.is_feature_enabled("GLOBAL_TASK_FRAMEWORK"): for tool_name in ("list_tasks", "get_task_info"): _remove_tool_quietly(tool_name, "GLOBAL_TASK_FRAMEWORK not enabled") diff --git a/tests/unit_tests/mcp_service/test_mcp_tool_registration.py b/tests/unit_tests/mcp_service/test_mcp_tool_registration.py index c20fbb4d721b..268e8aa8f1b7 100644 --- a/tests/unit_tests/mcp_service/test_mcp_tool_registration.py +++ b/tests/unit_tests/mcp_service/test_mcp_tool_registration.py @@ -21,8 +21,25 @@ import logging from unittest.mock import MagicMock, patch +import pytest + from superset.mcp_service.app import get_default_instructions, init_fastmcp_server, mcp +# Patch target for the feature_flag_manager imported inside _apply_config_guards +_FFM_PATH = "superset.extensions.feature_flag_manager" + + +@pytest.fixture(autouse=True) +def gtf_ffm(): + """Default for this module: GLOBAL_TASK_FRAMEWORK is enabled. + + Tests that need to verify the disabled path override is_feature_enabled + after requesting this fixture by name. + """ + with patch(_FFM_PATH) as mock_ffm: + mock_ffm.is_feature_enabled.return_value = True + yield mock_ffm + def _run(coro): """Run an async coroutine synchronously.""" @@ -108,24 +125,14 @@ def test_mcp_packages_discoverable_by_setuptools(): def _make_flask_app_mock( disabled_tools: set[str], - feature_flags: dict[str, object] | None = None, fab_security_views: bool = True, log_view: bool = True, ) -> MagicMock: - """Return a minimal Flask app mock with MCP config set to safe defaults. - - Defaults enable all feature flags and logging so that tests focused on - MCP_DISABLED_TOOLS are not affected by the config guards added for action-log - and task tools. - """ - _feature_flags: dict[str, object] = ( - {"GLOBAL_TASK_FRAMEWORK": True} if feature_flags is None else feature_flags - ) + """Return a minimal Flask app mock with MCP config set to safe defaults.""" _config: dict[str, object] = { "MCP_DISABLED_TOOLS": disabled_tools, "FAB_ADD_SECURITY_VIEWS": fab_security_views, "SUPERSET_LOG_VIEW": log_view, - "FEATURE_FLAGS": _feature_flags, } flask_app = MagicMock() flask_app.config.get.side_effect = lambda key, default=None: _config.get( @@ -315,14 +322,17 @@ def test_action_log_tools_removed_when_fab_security_views_disabled() -> None: assert "get_action_log_info" in removed -def test_task_tools_removed_when_global_task_framework_disabled() -> None: +def test_task_tools_removed_when_global_task_framework_disabled( + gtf_ffm: MagicMock, +) -> None: """Task tools removed when GLOBAL_TASK_FRAMEWORK=False. - Mirrors TaskRestApi conditional registration in initialization/__init__.py. + Uses feature_flag_manager.is_feature_enabled(), mirroring TaskRestApi + conditional registration in initialization/__init__.py. """ - flask_app = _make_flask_app_mock( - set(), feature_flags={"GLOBAL_TASK_FRAMEWORK": False} - ) + gtf_ffm.is_feature_enabled.return_value = False + + flask_app = _make_flask_app_mock(set()) with ( patch("superset.mcp_service.flask_singleton.app", flask_app), From eeb4f0fe6032e7a9edeb57d428f7f93edfdab8e7 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 28 May 2026 15:00:42 +0000 Subject: [PATCH 11/13] fix(mcp): normalize dttm filter strings and add action-log/task tool docs - Add model_validator to ActionLogFilter to parse ISO string dttm values to timezone-aware datetime objects, preventing VARCHAR bind mismatch on Postgres TIMESTAMP columns (Pydantic's left-to-right union keeps strings as str when str precedes datetime in the union) - Feed config-guard removed tools (action-log, task) into disabled_tools before calling get_default_instructions so removed tools are never advertised in LLM instructions - Add Action Logs and Task Management sections to get_default_instructions output; existing per-line filtering strips them when tools are disabled - Update test assertions and add new coverage for both behaviors --- superset/mcp_service/action_log/schemas.py | 20 ++++++++++- superset/mcp_service/app.py | 26 +++++++++++++-- .../action_log/tool/test_action_log_tools.py | 3 +- .../mcp_service/test_mcp_tool_registration.py | 33 +++++++++++++++++++ 4 files changed, 77 insertions(+), 5 deletions(-) diff --git a/superset/mcp_service/action_log/schemas.py b/superset/mcp_service/action_log/schemas.py index 74fb083bcd97..8f8ccff9c093 100644 --- a/superset/mcp_service/action_log/schemas.py +++ b/superset/mcp_service/action_log/schemas.py @@ -19,7 +19,7 @@ from __future__ import annotations -from datetime import datetime +from datetime import datetime, timezone from typing import Annotated, Any, Literal from pydantic import ( @@ -70,6 +70,24 @@ class ActionLogFilter(ColumnOperator): ..., description="Value to filter by" ) + @model_validator(mode="after") + def normalize_dttm_value(self) -> "ActionLogFilter": + """Normalize string dttm values to datetime to avoid VARCHAR bind mismatch. + + Pydantic's left-to-right union matching keeps ISO strings as str when + str appears before datetime in the union. This validator parses them so + the DAO always receives a typed datetime for TIMESTAMP column comparisons. + """ + if self.col == "dttm" and isinstance(self.value, str): + try: + parsed = datetime.fromisoformat(self.value) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + self.value = parsed + except ValueError: + pass + return self + class ActionLogInfo(BaseModel): id: int | None = Field(None, description="Log entry ID") diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index e29fa320e6db..64f816e969d6 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -159,6 +159,14 @@ def get_default_instructions( Schema Discovery: - get_schema: Get schema metadata for chart/dataset/dashboard (columns, filters) +Action Logs (requires SUPERSET_LOG_VIEW and FAB_ADD_SECURITY_VIEWS): +- list_action_logs: List user action logs with filtering and pagination (defaults to last 7 days) +- get_action_log_info: Get a single action log entry by integer ID + +Task Management (requires GLOBAL_TASK_FRAMEWORK feature flag): +- list_tasks: List background tasks with status filtering and pagination +- get_task_info: Get task details by integer ID or UUID + System Information: - get_instance_info: Get instance-wide statistics, metadata, and current user identity - find_users: Resolve a person's name to user IDs for use as a filter value @@ -708,9 +716,12 @@ def _remove_tool_quietly(tool_name: str, reason: str) -> None: pass -def _apply_config_guards(flask_app: Any) -> None: +def _apply_config_guards(flask_app: Any) -> set[str]: """Remove tools whose backing features are administratively disabled. + Returns the set of tool names that were removed so that callers can exclude + them from generated instructions. + - Action-log tools: mirrors LogRestApi.is_enabled() which checks FAB_ADD_SECURITY_VIEWS and SUPERSET_LOG_VIEW. - Task tools: mirrors TaskRestApi conditional registration which checks @@ -718,18 +729,24 @@ def _apply_config_guards(flask_app: Any) -> None: all Superset enablement paths (DEFAULT_FEATURE_FLAGS, GET_FEATURE_FLAGS_FUNC, IS_FEATURE_ENABLED_FUNC, etc.) are respected. """ + removed: set[str] = set() + if not ( flask_app.config.get("FAB_ADD_SECURITY_VIEWS", True) and flask_app.config.get("SUPERSET_LOG_VIEW", True) ): for tool_name in ("list_action_logs", "get_action_log_info"): _remove_tool_quietly(tool_name, "logging disabled by config flags") + removed.add(tool_name) from superset.extensions import feature_flag_manager # noqa: PLC0415 if not feature_flag_manager.is_feature_enabled("GLOBAL_TASK_FRAMEWORK"): for tool_name in ("list_tasks", "get_task_info"): _remove_tool_quietly(tool_name, "GLOBAL_TASK_FRAMEWORK not enabled") + removed.add(tool_name) + + return removed def init_fastmcp_server( @@ -776,10 +793,13 @@ def init_fastmcp_server( # instructions never advertise tools that clients cannot actually call. disabled_tools: set[str] = flask_app.config.get("MCP_DISABLED_TOOLS", set()) _remove_disabled_tools(disabled_tools) - _apply_config_guards(flask_app) + config_guard_removed = _apply_config_guards(flask_app) if instructions is None: - instructions = get_default_instructions(branding, disabled_tools) + # Merge MCP_DISABLED_TOOLS with config-guard removals so the instructions + # never advertise tools that have been suppressed by either mechanism. + all_disabled = disabled_tools | config_guard_removed + instructions = get_default_instructions(branding, all_disabled) # Configure the global mcp instance with provided settings. # Tools are already registered on this instance via @tool decorator imports above. diff --git a/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py b/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py index 3ad4dd234fa1..06d576ec5e2a 100644 --- a/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py +++ b/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py @@ -151,7 +151,8 @@ async def test_list_action_logs_explicit_dttm_filter_skips_default( dttm_filters = [f for f in col_operators if getattr(f, "col", None) == "dttm"] # Only the user-provided filter, not the injected default assert len(dttm_filters) == 1 - assert dttm_filters[0].value == "2020-01-01T00:00:00" + # model_validator normalizes ISO strings to timezone-aware datetime objects + assert dttm_filters[0].value == datetime(2020, 1, 1, 0, 0, 0, tzinfo=timezone.utc) @patch("superset.daos.log.LogDAO.list") diff --git a/tests/unit_tests/mcp_service/test_mcp_tool_registration.py b/tests/unit_tests/mcp_service/test_mcp_tool_registration.py index 268e8aa8f1b7..3ce4f64547a7 100644 --- a/tests/unit_tests/mcp_service/test_mcp_tool_registration.py +++ b/tests/unit_tests/mcp_service/test_mcp_tool_registration.py @@ -280,6 +280,10 @@ def test_no_disabled_tools_returns_full_instructions() -> None: assert "- execute_sql:" in full assert "- health_check:" in full + assert "- list_action_logs:" in full + assert "- get_action_log_info:" in full + assert "- list_tasks:" in full + assert "- get_task_info:" in full assert full == also_full @@ -345,6 +349,35 @@ def test_task_tools_removed_when_global_task_framework_disabled( assert "get_task_info" in removed +def test_config_guard_tools_excluded_from_instructions() -> None: + """Config-guard removed tools must be passed to get_default_instructions so + the instructions never advertise tools that are disabled by config flags.""" + flask_app = _make_flask_app_mock(set(), log_view=False) + + captured: list[str] = [] + + def fake_get_instructions( + branding: str = "Apache Superset", + disabled_tools: set[str] | None = None, + ) -> str: + captured.append(str(disabled_tools)) + return f"instructions for {branding}" + + with ( + patch("superset.mcp_service.flask_singleton.app", flask_app), + patch.object(mcp.local_provider, "remove_tool"), + patch( + "superset.mcp_service.app.get_default_instructions", + fake_get_instructions, + ), + ): + init_fastmcp_server() + + assert len(captured) == 1 + assert "list_action_logs" in captured[0] + assert "get_action_log_info" in captured[0] + + def test_instructions_generated_after_disabled_tools_removed() -> None: """init_fastmcp_server generates instructions AFTER removing disabled tools, so the instructions never advertise tools that clients cannot call.""" From 1fd7ecf4a67b725e9f3ca06cc7cc7e9a25dae8aa Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 28 May 2026 15:44:35 +0000 Subject: [PATCH 12/13] fix(mcp): handle Z suffix in dttm filter parser for Python 3.10 compat --- superset/mcp_service/action_log/schemas.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/superset/mcp_service/action_log/schemas.py b/superset/mcp_service/action_log/schemas.py index 8f8ccff9c093..1263709fa8be 100644 --- a/superset/mcp_service/action_log/schemas.py +++ b/superset/mcp_service/action_log/schemas.py @@ -77,10 +77,16 @@ def normalize_dttm_value(self) -> "ActionLogFilter": Pydantic's left-to-right union matching keeps ISO strings as str when str appears before datetime in the union. This validator parses them so the DAO always receives a typed datetime for TIMESTAMP column comparisons. + + Replaces a trailing 'Z' with '+00:00' before parsing because + datetime.fromisoformat does not accept the 'Z' suffix on Python < 3.11. """ if self.col == "dttm" and isinstance(self.value, str): try: - parsed = datetime.fromisoformat(self.value) + val = self.value + if val.endswith("Z"): + val = val[:-1] + "+00:00" + parsed = datetime.fromisoformat(val) if parsed.tzinfo is None: parsed = parsed.replace(tzinfo=timezone.utc) self.value = parsed From d90456758091948c0b5227faca7d89f800921fdb Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 28 May 2026 16:07:05 +0000 Subject: [PATCH 13/13] fix(mcp): sanitize action-log json payload before placing in LLM context The stored log `json` field is user-controlled data. Parse it and run each string leaf through `sanitize_for_llm_context` so the payload cannot masquerade as instructions when placed in an LLM context. Preserves the JSON shape (dict/list structure) so callers can still inspect individual fields; only string leaves are wrapped in UNTRUSTED-CONTENT delimiters. Falls back to sanitizing the raw string when the payload is not valid JSON. Addresses review feedback from richardfogaca. --- superset/mcp_service/action_log/schemas.py | 29 +++++++++- .../action_log/tool/test_action_log_tools.py | 53 +++++++++++++++++++ 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/superset/mcp_service/action_log/schemas.py b/superset/mcp_service/action_log/schemas.py index 1263709fa8be..fc82489a5c45 100644 --- a/superset/mcp_service/action_log/schemas.py +++ b/superset/mcp_service/action_log/schemas.py @@ -35,6 +35,7 @@ from superset.daos.base import ColumnOperator, ColumnOperatorEnum from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE from superset.mcp_service.system.schemas import PaginationInfo +from superset.mcp_service.utils import sanitize_for_llm_context from superset.mcp_service.utils.schema_utils import ( parse_json_or_list, parse_json_or_model_list, @@ -104,7 +105,9 @@ class ActionLogInfo(BaseModel): dttm: str | datetime | None = Field(None, description="Timestamp of the action") dashboard_id: int | None = Field(None, description="Associated dashboard ID") slice_id: int | None = Field(None, description="Associated chart/slice ID") - json: str | None = Field(None, description="JSON payload of the action") + json: Any = Field( + None, description="JSON payload of the action (user-controlled, sanitized)" + ) model_config = ConfigDict( from_attributes=True, @@ -248,6 +251,28 @@ class GetActionLogInfoRequest(BaseModel): ] +def _sanitize_log_json(raw: Any) -> Any: + """Parse the stored log JSON string and sanitize string leaves. + + Preserves the JSON shape so callers can inspect individual fields; wraps + every string leaf in UNTRUSTED-CONTENT delimiters so the payload cannot + inject instructions into the LLM context. Falls back to sanitizing the + raw string when it is not valid JSON. + """ + if raw is None: + return None + if isinstance(raw, str): + try: + from superset.utils import json as json_utils # noqa: PLC0415 + + parsed = json_utils.loads(raw) + except (ValueError, TypeError): + parsed = raw + else: + parsed = raw + return sanitize_for_llm_context(parsed, field_path=("json",)) + + def serialize_action_log_object(log: Any) -> ActionLogInfo | None: if not log: return None @@ -263,5 +288,5 @@ def serialize_action_log_object(log: Any) -> ActionLogInfo | None: dttm=dttm, dashboard_id=getattr(log, "dashboard_id", None), slice_id=getattr(log, "slice_id", None), - json=getattr(log, "json", None), + json=_sanitize_log_json(getattr(log, "json", None)), ) diff --git a/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py b/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py index 06d576ec5e2a..6e7d2bf5c18f 100644 --- a/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py +++ b/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py @@ -222,3 +222,56 @@ async def test_get_action_log_info_not_found(mock_find, mcp_server): data = json.loads(result.content[0].text) assert data["error_type"] == "not_found" + + +@patch("superset.daos.log.LogDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_action_log_info_json_payload_sanitized(mock_find, mcp_server): + """The json field is sanitized: string leaves are wrapped in UNTRUSTED-CONTENT.""" + log = create_mock_log( + log_id=1, + json_payload=( + '{"event_name": "explore_slice",' + ' "filters": [{"col": "name", "val": "inject me"}]}' + ), + ) + mock_find.return_value = log + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_action_log_info", {"request": {"identifier": 1}} + ) + + data = json.loads(result.content[0].text) + payload = data.get("json") + # Parsed JSON shape is preserved (dict, not raw string) + assert isinstance(payload, dict) + # String leaves are wrapped in UNTRUSTED-CONTENT delimiters + event_name = payload.get("event_name", "") + assert "" in event_name + assert "explore_slice" in event_name + assert "" in event_name + + +@patch("superset.daos.log.LogDAO.list") +@pytest.mark.asyncio +async def test_list_action_logs_json_payload_sanitized(mock_list, mcp_server): + """list_action_logs also sanitizes the json field in each entry.""" + log = create_mock_log( + log_id=5, + json_payload='{"event_name": "dashboard_load", "user_input": "inject me"}', + ) + mock_list.return_value = ([log], 1) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_action_logs", + {"request": {"select_columns": ["id", "action", "json"]}}, + ) + + data = json.loads(result.content[0].text) + payload = data["action_logs"][0].get("json") + assert isinstance(payload, dict) + event_name_value = payload.get("event_name", "") + assert "" in event_name_value + assert "dashboard_load" in event_name_value