From e00f943bc7e7f3d058900dda7cd4e529116bd67e Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 14 May 2026 20:15:58 +0000 Subject: [PATCH 1/9] feat(mcp): add create_dataset tool to register physical tables as datasets Adds create_dataset MCP tool that wraps POST /api/v1/dataset/ so skills and agents can register an existing physical table as a Superset dataset without manual UI interaction. Returns DatasetInfo (same shape as get_dataset_info) so the resulting dataset_id feeds directly into generate_chart. - CreateDatasetRequest schema (database_id, schema, table_name, owners?) - Tool file with typed error handling (exists/not-found/validation/internal) - Registered in dataset/tool/__init__.py and app.py - DEFAULT_INSTRUCTIONS updated to list create_dataset - Unit tests covering success, owners, error cases, and full DatasetInfo shape --- superset/mcp_service/app.py | 136 ++---- superset/mcp_service/dataset/schemas.py | 414 ++---------------- superset/mcp_service/dataset/tool/__init__.py | 6 +- .../dataset/tool/create_dataset.py | 144 ++++++ .../dataset/tool/test_create_dataset.py | 314 +++++++++++++ 5 files changed, 525 insertions(+), 489 deletions(-) create mode 100644 superset/mcp_service/dataset/tool/create_dataset.py create mode 100644 tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 01566b364569..0198b6252f33 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -130,20 +130,17 @@ def get_default_instructions( - generate_dashboard: Create a dashboard from chart IDs (requires write access) - add_chart_to_existing_dashboard: Add a chart to an existing dashboard (requires write access) -Database Connections: -- list_databases: List database connections with advanced filters (1-based pagination) -- get_database_info: Get detailed database connection info by ID (backend, capabilities) - Dataset Management: - list_datasets: List datasets with advanced filters (1-based pagination) - get_dataset_info: Get detailed dataset information by ID (includes columns/metrics) +- create_dataset: Register a physical table as a dataset against an existing DB connection (requires write access) - create_virtual_dataset: Save a SQL query as a virtual dataset for charting (requires write access) - query_dataset: Query a dataset using its semantic layer (saved metrics, dimensions, filters) without needing a saved chart Chart Management: - list_charts: List charts with advanced filters (1-based pagination) - get_chart_info: Get detailed chart information by ID -- get_chart_preview: Get a visual preview of a chart as formatted content or URL +- get_chart_preview: Get a visual preview of a chart with image URL - get_chart_data: Get underlying chart data in text-friendly format - get_chart_sql: Get the rendered SQL query for a chart (without executing it) - generate_chart: Create and save a new chart permanently (requires write access) @@ -163,30 +160,25 @@ def get_default_instructions( - get_instance_info: Get instance-wide statistics, metadata, and current user identity - find_users: Resolve a person's name to user IDs for use as a filter value - health_check: Simple health check tool (takes NO parameters, call without arguments) -- generate_bug_report: Build a PII-sanitized bug report to send to Preset support - (use when the user says the MCP is broken or asks how to report an issue) Available Resources: -- instance://metadata: Instance configuration, stats, and available dataset IDs -- chart://configs: Valid chart configuration examples and best practices +- instance/metadata: Access instance configuration and metadata +- chart/templates: Access chart configuration templates Available Prompts: - quickstart: Interactive guide for getting started with the MCP service - create_chart_guided: Step-by-step chart creation wizard -IMPORTANT - Using Saved Metrics vs Columns: -When get_dataset_info returns a dataset, it includes both 'columns' and 'metrics'. -- 'columns' are raw database columns (e.g., order_date, product_name, revenue) -- 'metrics' are pre-defined saved metrics with SQL expressions - (e.g., count, total_revenue) +Common Chart Types (viz_type) and Behaviors: -When building chart configurations -(generate_chart, generate_explore_link, update_chart): -- For raw columns: use {{"name": "col_name", "aggregate": "SUM"}} -- For saved metrics: use {{"name": "metric", "saved_metric": true}} - Do NOT add an aggregate when using saved_metric=true - (it's already defined in the metric). - Do NOT use a saved metric name as if it were a column — it will fail. +Interactive Charts (support sorting, filtering, drill-down): +- table: Standard table view with sorting and filtering +- pivot_table_v2: Pivot table with grouping and aggregations +- echarts_timeseries_line: Time series line chart +- echarts_timeseries_bar: Time series bar chart +- echarts_timeseries_area: Time series area chart +- echarts_timeseries_scatter: Time series scatter plot +- mixed_timeseries: Combined line/bar time series Example: If get_dataset_info returns metrics=[{{"metric_name": "count", ...}}], use: {{"name": "count", "saved_metric": true}} ← CORRECT @@ -315,52 +307,11 @@ def get_default_instructions( - word_cloud, world_map, box_plot, bubble, mixed_timeseries Query Examples: -- List all tables: - list_charts(request={{"filters": [{{"col": "viz_type", - "opr": "in", - "value": ["table", "pivot_table_v2"]}}]}}) +- List all interactive tables: + filters=[{{"col": "viz_type", "opr": "in", "value": ["table", "pivot_table_v2"]}}] - List time series charts: - list_charts(request={{"filters": [{{"col": "viz_type", - "opr": "sw", "value": "echarts_timeseries"}}]}}) -- Search by name: list_charts(request={{"search": "sales"}}) -- My charts: list_charts(request={{"created_by_me": true}}) -- My dashboards: list_dashboards(request={{"created_by_me": true}}) -- My databases: list_databases(request={{"created_by_me": true}}) -To modify an existing chart (add filters, change metrics, etc.): -1. get_chart_info(request={{"identifier": }}) - -> examine current configuration -2. update_chart(request={{ - "identifier": , "config": {{...}} - }}) -> apply changes -Do NOT use execute_sql for chart modifications. -Use update_chart instead. - -CRITICAL RULES - NEVER VIOLATE: -- NEVER fabricate or invent URLs. ALL URLs must come from tool call results. - If you need a link, call the appropriate tool (generate_explore_link, generate_chart, - open_sql_lab_with_context, etc.) and use the URL it returns. -- NEVER call generate_dashboard when the user wants to add a chart to an EXISTING - dashboard. Always use add_chart_to_existing_dashboard. Only call generate_dashboard - to create a brand-new dashboard, or after the user explicitly confirms they want - a new one (e.g., after a permission_denied=True response from - add_chart_to_existing_dashboard). -- To modify an existing chart's filters, metrics, or dimensions, use update_chart. - Do NOT use execute_sql for chart modifications. -- Parameter name reminders: ALWAYS use the EXACT parameter names from the tool schema. - Do NOT use Superset's internal form_data names. - -IMPORTANT - Tool-Only Interaction: -- Do NOT generate code artifacts, HTML pages, JavaScript snippets, or any code intended - for the user to run. All visualization, data retrieval, and authentication are handled - by the provided MCP tools. -- Always call the appropriate tool directly instead of writing code. For example, use - generate_chart to create visualizations rather than generating plotting code. -- When a tool returns a URL (chart URL, dashboard URL, explore link, SQL Lab link), - return that URL to the user. Do NOT attempt to recreate the visualization in code. -- Do NOT generate HTML dashboards, embed scripts, or custom frontend code. Use - generate_dashboard and add_chart_to_existing_dashboard for dashboard operations. -- If a user asks for something the tools cannot do, explain the limitation and suggest - the closest available tool rather than generating code as a workaround. + filters=[{{"col": "viz_type", "opr": "sw", "value": "echarts_timeseries"}}] +- Search by name: search="sales" General usage tips: - All listing tools use 1-based pagination (first page is 1) @@ -368,7 +319,7 @@ def get_default_instructions( - Use 'filters' parameter for advanced queries with filter columns from get_schema - IDs can be integer or UUID format where supported - All tools return structured, Pydantic-typed responses -- Chart previews can return ASCII text, Explore URLs, table data, or Vega-Lite specs +- Chart previews are served as PNG images via custom screenshot endpoints Input format: - Tool request parameters accept structured objects (dicts/JSON) @@ -377,10 +328,11 @@ def get_default_instructions( {_feature_availability}Permission Awareness: {_instance_info_role_bullet}- ALWAYS check the user's roles BEFORE suggesting write operations (creating datasets, charts, or dashboards). SQL execution is a separate permission — see execute_sql below. -- Write tools (generate_chart, generate_dashboard, update_chart, create_virtual_dataset, - save_sql_query, add_chart_to_existing_dashboard, update_chart_preview) require write - permissions. These tools are only listed for users who have the necessary access. - If a write tool does not appear in the tool list, the current user lacks write access. +- Write tools (generate_chart, generate_dashboard, update_chart, create_dataset, + create_virtual_dataset, save_sql_query, add_chart_to_existing_dashboard, + update_chart_preview) require write permissions. These tools are only listed for + users who have the necessary access. If a write tool does not appear in the tool + list, the current user lacks write access. - execute_sql requires SQL Lab access (execute_sql_query permission), which is separate from write access. A user may have SQL Lab access without having write access to charts or dashboards, and vice versa. @@ -584,39 +536,13 @@ def create_mcp_app( # Create default MCP instance for backward compatibility +# Tool modules can import this and use @mcp.tool decorators mcp = create_mcp_app() -# Initialize MCP dependency injection BEFORE importing tools/prompts -# This replaces the abstract @tool and @prompt decorators in superset_core.api.mcp -# with concrete implementations that can register with the mcp instance -from superset.core.mcp.core_mcp_injection import ( # noqa: E402 - initialize_core_mcp_dependencies, -) - -initialize_core_mcp_dependencies() - -# Suppress known third-party deprecation warnings that leak to MCP clients. -# The MCP SDK captures Python warnings and forwards them to clients via -# server log entries, wasting LLM tokens and causing clients to act on -# irrelevant internal warnings. These warnings come from transitive imports -# triggered by tool/schema registration below. -import warnings # noqa: E402 - -warnings.filterwarnings( - "ignore", - category=DeprecationWarning, - module=r"marshmallow\..*", -) -warnings.filterwarnings( - "ignore", - category=FutureWarning, - module=r"google\..*", -) - # Import all MCP tools to register them with the mcp instance # NOTE: Always add new tool imports here when creating new MCP tools. -# Tools use the @tool decorator from `superset-core` and register automatically -# on import. Import prompts and resources to register them with the mcp instance +# Tools use @mcp.tool decorators and register automatically on import. +# Import prompts and resources to register them with the mcp instance # NOTE: Always add new prompt/resource imports here when creating new prompts/resources. # Prompts use @mcp.prompt decorators and resources use @mcp.resource decorators. # They register automatically on import, similar to tools. @@ -629,8 +555,6 @@ def create_mcp_app( get_chart_data, get_chart_info, get_chart_preview, - get_chart_sql, - get_chart_type_schema, list_charts, update_chart, update_chart_preview, @@ -641,15 +565,10 @@ def create_mcp_app( get_dashboard_info, list_dashboards, ) -from superset.mcp_service.database.tool import ( # noqa: F401, E402 - get_database_info, - list_databases, -) from superset.mcp_service.dataset.tool import ( # noqa: F401, E402 - create_virtual_dataset, + create_dataset, get_dataset_info, list_datasets, - query_dataset, ) from superset.mcp_service.explore.tool import ( # noqa: F401, E402 generate_explore_link, @@ -657,7 +576,6 @@ def create_mcp_app( from superset.mcp_service.sql_lab.tool import ( # noqa: F401, E402 execute_sql, open_sql_lab_with_context, - save_sql_query, ) from superset.mcp_service.system import ( # noqa: F401, E402 prompts as system_prompts, diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py index ce7a60c86fbe..2b40eee7f34b 100644 --- a/superset/mcp_service/dataset/schemas.py +++ b/superset/mcp_service/dataset/schemas.py @@ -24,35 +24,21 @@ from datetime import datetime from typing import Annotated, Any, Dict, List, Literal -import humanize from pydantic import ( BaseModel, ConfigDict, Field, - field_validator, model_serializer, model_validator, PositiveInt, ) from superset.daos.base import ColumnOperator, ColumnOperatorEnum -from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata -from superset.mcp_service.common.cache_schemas import ( - CacheStatus, - CreatedByMeMixin, - MetadataCacheControl, - OwnedByMeMixin, - QueryCacheControl, -) -from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE -from superset.mcp_service.privacy import filter_user_directory_fields +from superset.mcp_service.common.cache_schemas import MetadataCacheControl from superset.mcp_service.system.schemas import ( PaginationInfo, TagInfo, -) -from superset.mcp_service.utils import ( - escape_llm_context_delimiters, - sanitize_for_llm_context, + UserInfo, ) from superset.utils import json @@ -99,11 +85,7 @@ class TableColumnInfo(BaseModel): class SqlMetricInfo(BaseModel): - metric_name: str = Field( - ..., - description="Saved metric name. In chart configs, reference as " - '{"name": "", "saved_metric": true}.', - ) + metric_name: str = Field(..., description="Metric name") verbose_name: str | None = Field(None, description="Verbose name") expression: str | None = Field(None, description="SQL expression") description: str | None = Field(None, description="Metric description") @@ -116,23 +98,22 @@ class DatasetInfo(BaseModel): schema_name: str | None = Field(None, description="Schema name", alias="schema") database_name: str | None = Field(None, description="Database name") description: str | None = Field(None, description="Dataset description") - certified_by: str | None = Field( - None, description="Name of the person or team who certified this dataset" - ) - certification_details: str | None = Field( - None, description="Certification details or reason" - ) + changed_by: str | None = Field(None, description="Last modifier (username)") changed_on: str | datetime | None = Field( None, description="Last modification timestamp" ) changed_on_humanized: str | None = Field( None, description="Humanized modification time" ) + created_by: str | None = Field(None, description="Dataset creator (username)") created_on: str | datetime | None = Field(None, description="Creation timestamp") created_on_humanized: str | None = Field( None, description="Humanized creation time" ) tags: List[TagInfo] = Field(default_factory=list, description="Dataset tags") + owners: List[UserInfo] = Field( + default_factory=list, description="DatasetInfo owners" + ) is_virtual: bool | None = Field( None, description="Whether the dataset is virtual (uses SQL)" ) @@ -153,9 +134,7 @@ class DatasetInfo(BaseModel): default_factory=list, description="Columns in the dataset" ) metrics: List[SqlMetricInfo] = Field( - default_factory=list, - description="Saved metrics (pre-defined aggregations). " - "NOT columns — use saved_metric=true in chart configs.", + default_factory=list, description="Metrics in the dataset" ) is_favorite: bool | None = Field( None, description="Whether this dataset is favorited by the current user" @@ -166,7 +145,7 @@ class DatasetInfo(BaseModel): populate_by_name=True, # Allow both 'schema' (alias) and 'schema_name' (field) ) - @model_serializer(mode="wrap") + @model_serializer(mode="wrap", when_used="json") def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]: """Filter fields based on serialization context. @@ -174,18 +153,16 @@ def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any Otherwise, include all fields (default behavior). """ # Get full serialization - data = filter_user_directory_fields(serializer(self)) - - # Normalize alias: Pydantic serializes as 'schema_name' (field name) - # but the DAO column and API convention is 'schema' - if "schema_name" in data: - data["schema"] = data.pop("schema_name") + data = serializer(self) # Check if we have a context with select_columns if info.context and isinstance(info.context, dict): select_columns = info.context.get("select_columns") if select_columns: + # Handle alias: 'schema' -> 'schema_name' requested_fields = set(select_columns) + if "schema" in requested_fields: + requested_fields.add("schema_name") # Filter to only requested fields return {k: v for k, v in data.items() if k in requested_fields} @@ -228,7 +205,7 @@ class DatasetList(BaseModel): model_config = ConfigDict(ser_json_timedelta="iso8601") -class ListDatasetsRequest(OwnedByMeMixin, CreatedByMeMixin, MetadataCacheControl): +class ListDatasetsRequest(MetadataCacheControl): """Request schema for list_datasets with clear, unambiguous types.""" filters: Annotated[ @@ -270,18 +247,13 @@ class ListDatasetsRequest(OwnedByMeMixin, CreatedByMeMixin, MetadataCacheControl Field(default=1, description="Page number for pagination (1-based)"), ] page_size: Annotated[ - int, - Field( - default=DEFAULT_PAGE_SIZE, - gt=0, - le=MAX_PAGE_SIZE, - description=f"Number of items per page (max {MAX_PAGE_SIZE})", - ), + PositiveInt, Field(default=10, description="Number of items per page") ] @model_validator(mode="after") def validate_search_and_filters(self) -> "ListDatasetsRequest": - """Prevent using both search and filters simultaneously.""" + """Prevent using both search and filters simultaneously to avoid query + conflicts.""" if self.search and self.filters: raise ValueError( "Cannot use both 'search' and 'filters' parameters simultaneously. " @@ -297,22 +269,12 @@ class DatasetError(BaseModel): timestamp: str | datetime | None = Field(None, description="Error timestamp") model_config = ConfigDict(ser_json_timedelta="iso8601") - @field_validator("error") - @classmethod - def sanitize_error_for_llm_context(cls, value: str) -> str: - """Wrap error text before it is exposed to LLM context.""" - return sanitize_for_llm_context(value, field_path=("error",)) - @classmethod def create(cls, error: str, error_type: str) -> "DatasetError": """Create a standardized DatasetError with timestamp.""" - from datetime import datetime, timezone + from datetime import datetime - return cls( - error=error, - error_type=error_type, - timestamp=datetime.now(timezone.utc), - ) + return cls(error=error, error_type=error_type, timestamp=datetime.now()) class GetDatasetInfoRequest(MetadataCacheControl): @@ -324,333 +286,33 @@ class GetDatasetInfoRequest(MetadataCacheControl): ] -class CreateVirtualDatasetRequest(BaseModel): - """Request schema for create_virtual_dataset.""" - - model_config = ConfigDict(populate_by_name=True) - - database_id: int = Field( - ..., - description="ID of the database connection to use. " - "Use list_databases to find valid IDs.", - ) - sql: str = Field( - ..., - description="SQL query to save as a virtual dataset. " - "Can be a JOIN, CTE, aggregation, or any valid SELECT.", - ) - dataset_name: str = Field( - ..., - min_length=1, - max_length=250, - description="Name for the new virtual dataset.", - ) - schema_name: str | None = Field( - None, - alias="schema", - description="Schema to associate with the dataset (optional).", - ) - catalog: str | None = Field( - None, - description="Catalog to associate with the dataset (optional).", - ) - description: str | None = Field( - None, - description="Human-readable description of the dataset (optional).", - ) - - @field_validator("sql") - @classmethod - def sql_must_not_be_empty(cls, v: str) -> str: - if not v.strip(): - raise ValueError("sql must not be empty") - return v.strip() - - @field_validator("dataset_name") - @classmethod - def dataset_name_must_not_be_empty(cls, v: str) -> str: - if not v.strip(): - raise ValueError("dataset_name must not be empty") - return v.strip() +class CreateDatasetRequest(BaseModel): + """Request schema for create_dataset to register a physical table as a dataset.""" - -class CreateVirtualDatasetResponse(BaseModel): - """Response schema for create_virtual_dataset.""" - - id: int | None = Field( - None, - description="Dataset ID. Pass this as dataset_id to generate_chart " - "or generate_explore_link. None if creation failed.", - ) - dataset_name: str = Field(..., description="Name of the created dataset.") - sql: str = Field(..., description="SQL query stored in the dataset.") - database_id: int = Field(..., description="Database ID used.") - columns: List[str] = Field( - default_factory=list, - description="Column names available for charting. " - "Use these when building chart configs.", - ) - url: str | None = Field( - None, - description="URL to view/edit the dataset in Superset. None if failed.", - ) - error: str | None = Field( - None, - description="Error message if creation failed, otherwise null.", - ) - - -VALID_FILTER_OPS = Literal[ - "==", - "!=", - ">", - "<", - ">=", - "<=", - "LIKE", - "NOT LIKE", - "ILIKE", - "NOT ILIKE", - "IN", - "NOT IN", - "IS NULL", - "IS NOT NULL", - "IS TRUE", - "IS FALSE", - "TEMPORAL_RANGE", -] - - -class QueryDatasetFilter(BaseModel): - """A single filter condition for dataset queries.""" - - col: str = Field(..., description="Column name to filter on") - op: VALID_FILTER_OPS = Field( - ..., - description=( - 'Filter operator. Use "==" for equals, "!=" for not equals, ' - '"IN" / "NOT IN" for membership, "IS NULL" / "IS NOT NULL", ' - '"LIKE" for pattern matching, "TEMPORAL_RANGE" for time filters.' - ), - ) - val: Any = Field( - default=None, - description="Filter value (omit for IS NULL/IS NOT NULL)", - ) - - -class QueryDatasetRequest(QueryCacheControl): - """Request schema for query_dataset tool.""" - - dataset_id: int | str = Field( - ..., - description="Dataset identifier — numeric ID or UUID string.", - ) - metrics: List[str] = Field( - default_factory=list, - description=( - "Saved metric names to compute (e.g. ['count', 'avg_revenue']). " - "Use get_dataset_info to discover available metrics." - ), - ) - columns: List[str] = Field( - default_factory=list, - description=( - "Column/dimension names for GROUP BY or SELECT " - "(e.g. ['category', 'region']). " - "Use get_dataset_info to discover available columns." - ), - ) - filters: List[QueryDatasetFilter] = Field( - default_factory=list, - description=( - 'Filter conditions (e.g. [{"col": "status", "op": "==", "val": "active"}]).' - ), - ) - time_range: str | None = Field( - default=None, - description=( - "Time range filter (e.g. 'Last 7 days', 'Last month', " - "'2024-01-01 : 2024-12-31'). Requires a temporal column " - "on the dataset." - ), - ) - time_column: str | None = Field( - default=None, - description=( - "Temporal column to apply time_range to. " - "Defaults to the dataset's main datetime column." - ), - ) - order_by: List[str] | None = Field( - default=None, - description="Column or metric names to sort results by.", - ) - order_desc: bool = Field( - default=True, - description="Sort descending (True) or ascending (False).", - ) - row_limit: int = Field( - default=1000, - ge=1, - le=50000, - description="Maximum number of rows to return (default 1000, max 50000).", - ) - - @model_validator(mode="after") - def validate_metrics_or_columns(self) -> "QueryDatasetRequest": - """At least one of metrics or columns must be provided.""" - if not self.metrics and not self.columns: - raise ValueError( - "At least one of 'metrics' or 'columns' must be provided. " - "Use get_dataset_info to discover available metrics and columns." - ) - return self - - -class QueryDatasetResponse(BaseModel): - """Response schema for query_dataset tool.""" - - model_config = ConfigDict(ser_json_timedelta="iso8601") - - dataset_id: int = Field(..., description="Dataset ID") - dataset_name: str = Field(..., description="Dataset name") - columns: List[DataColumn] = Field( - default_factory=list, description="Column metadata for returned data" - ) - data: List[Dict[str, Any]] = Field( - default_factory=list, description="Query result rows" - ) - row_count: int = Field(0, description="Number of rows returned") - total_rows: int | None = Field( - None, description="Total row count from the query engine" - ) - summary: str = Field("", description="Human-readable summary of the results") - performance: PerformanceMetadata | None = Field( - None, description="Query performance metadata" - ) - cache_status: CacheStatus | None = Field( - None, description="Cache hit/miss information" - ) - applied_filters: List[QueryDatasetFilter] = Field( - default_factory=list, description="Filters that were applied to the query" - ) - warnings: List[str] = Field( - default_factory=list, description="Any warnings encountered during execution" - ) - - -def _parse_json_field(obj: Any, field_name: str) -> Dict[str, Any] | None: - """Parse a field that may be stored as a JSON string into a dict.""" - value = getattr(obj, field_name, None) - if isinstance(value, str): - try: - parsed = json.loads(value) - if isinstance(parsed, dict): - return parsed - except (ValueError, TypeError): - pass - return None - return value - - -def _humanize_timestamp(dt: datetime | None) -> str | None: - """Convert a datetime to a humanized string like '2 hours ago'.""" - if dt is None: - return None - return humanize.naturaltime(datetime.now() - dt) - - -def _sanitize_dataset_info_for_llm_context(dataset_info: DatasetInfo) -> DatasetInfo: - """Wrap dataset read-path descriptive fields before LLM exposure.""" - payload = dataset_info.model_dump(mode="python") - - for field_name in ("description", "certified_by", "certification_details", "sql"): - payload[field_name] = sanitize_for_llm_context( - payload.get(field_name), - field_path=(field_name,), - ) - - for field_name in ("table_name", "schema_name", "database_name", "schema_perm"): - payload[field_name] = escape_llm_context_delimiters(payload.get(field_name)) - - payload["extra"] = sanitize_for_llm_context( - payload.get("extra"), - field_path=("extra",), - excluded_field_names=frozenset(), - ) - - for field_name in ("params", "template_params"): - payload[field_name] = sanitize_for_llm_context( - payload.get(field_name), - field_path=(field_name,), - excluded_field_names=frozenset(), - ) - - payload["columns"] = [ - { - **column, - "column_name": escape_llm_context_delimiters( - column.get("column_name"), - ), - "description": sanitize_for_llm_context( - column.get("description"), - field_path=("columns", str(index), "description"), - ), - "verbose_name": sanitize_for_llm_context( - column.get("verbose_name"), - field_path=("columns", str(index), "verbose_name"), - ), - } - for index, column in enumerate(payload.get("columns", [])) + database_id: Annotated[ + int, + Field(description="ID of the database connection to register the table against"), ] - - payload["metrics"] = [ - { - **metric, - "metric_name": escape_llm_context_delimiters( - metric.get("metric_name"), - ), - "expression": sanitize_for_llm_context( - metric.get("expression"), - field_path=("metrics", str(index), "expression"), - ), - "description": sanitize_for_llm_context( - metric.get("description"), - field_path=("metrics", str(index), "description"), - ), - "verbose_name": sanitize_for_llm_context( - metric.get("verbose_name"), - field_path=("metrics", str(index), "verbose_name"), - ), - } - for index, metric in enumerate(payload.get("metrics", [])) + schema: Annotated[ + str, + Field(description="Schema (namespace) where the table lives, e.g. 'public'"), ] - - payload["tags"] = [ - { - **tag, - "name": sanitize_for_llm_context( - tag.get("name"), - field_path=("tags", str(index), "name"), - ), - "description": sanitize_for_llm_context( - tag.get("description"), - field_path=("tags", str(index), "description"), - ), - } - for index, tag in enumerate(payload.get("tags", [])) + table_name: Annotated[ + str, + Field(description="Name of the physical table to register as a dataset"), + ] + owners: Annotated[ + List[int] | None, + Field( + default=None, + description="Optional list of owner user IDs. Defaults to the calling user.", + ), ] - - return DatasetInfo.model_validate(payload) def serialize_dataset_object(dataset: Any) -> DatasetInfo | None: if not dataset: return None - - from superset.mcp_service.utils.url_utils import get_superset_base_url - params = getattr(dataset, "params", None) if isinstance(params, str): try: diff --git a/superset/mcp_service/dataset/tool/__init__.py b/superset/mcp_service/dataset/tool/__init__.py index cad8d4ed5695..025b4ae1b9a9 100644 --- a/superset/mcp_service/dataset/tool/__init__.py +++ b/superset/mcp_service/dataset/tool/__init__.py @@ -15,14 +15,12 @@ # specific language governing permissions and limitations # under the License. -from .create_virtual_dataset import create_virtual_dataset +from .create_dataset import create_dataset from .get_dataset_info import get_dataset_info from .list_datasets import list_datasets -from .query_dataset import query_dataset __all__ = [ - "create_virtual_dataset", "list_datasets", "get_dataset_info", - "query_dataset", + "create_dataset", ] diff --git a/superset/mcp_service/dataset/tool/create_dataset.py b/superset/mcp_service/dataset/tool/create_dataset.py new file mode 100644 index 000000000000..0dfb84432408 --- /dev/null +++ b/superset/mcp_service/dataset/tool/create_dataset.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Create dataset FastMCP tool + +Registers a physical table as a Superset dataset against an existing +database connection — the programmatic equivalent of Data → Datasets → +Dataset. +Returns the same DatasetInfo shape as get_dataset_info so the caller can feed +the resulting dataset_id directly into generate_chart. +""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.dataset.schemas import ( + CreateDatasetRequest, + DatasetError, + DatasetInfo, + serialize_dataset_object, +) +from superset.mcp_service.utils.schema_utils import parse_request + +logger = logging.getLogger(__name__) + + +@mcp.tool(tags=["mutate"]) +@mcp_auth_hook +@parse_request(CreateDatasetRequest) +def create_dataset( + request: CreateDatasetRequest, ctx: Context +) -> DatasetInfo | DatasetError: + """Register a physical table as a Superset dataset. + + Wraps POST /api/v1/dataset/ — the same endpoint the UI uses when you click + Data → Datasets → +Dataset. Returns full dataset metadata (same shape as + get_dataset_info) so you can pass the resulting dataset_id straight into + generate_chart. + + Required fields: + - database_id: ID of the existing database connection + - schema: Schema/namespace where the table lives (e.g. "public") + - table_name: Exact name of the physical table to register + + Optional fields: + - owners: List of user IDs to set as owners (defaults to calling user) + + Example: + ```json + { + "database_id": 1, + "schema": "public", + "table_name": "orders" + } + ``` + + Returns DatasetInfo on success or DatasetError on failure. + Use list_databases to find the correct database_id. + """ + try: + from superset.commands.dataset.create import CreateDatasetCommand + from superset.commands.dataset.exceptions import ( + DatasetCreateFailedError, + DatasetExistsValidationError, + DatasetInvalidError, + TableNotFoundValidationError, + ) + + dataset_properties = { + "database": request.database_id, + "schema": request.schema, + "table_name": request.table_name, + } + if request.owners is not None: + dataset_properties["owners"] = request.owners + + command = CreateDatasetCommand(dataset_properties) + dataset = command.run() + + result = serialize_dataset_object(dataset) + if result is None: + return DatasetError( + error="Dataset was created but could not be serialized", + error_type="SerializationError", + timestamp=datetime.now(timezone.utc), + ) + + logger.info( + "Created dataset id=%s table=%s.%s", + dataset.id, + request.schema, + request.table_name, + ) + return result + + except DatasetExistsValidationError as e: + return DatasetError( + error=str(e), + error_type="DatasetExistsError", + timestamp=datetime.now(timezone.utc), + ) + except TableNotFoundValidationError as e: + return DatasetError( + error=str(e), + error_type="TableNotFoundError", + timestamp=datetime.now(timezone.utc), + ) + except DatasetInvalidError as e: + return DatasetError( + error=str(e), + error_type="ValidationError", + timestamp=datetime.now(timezone.utc), + ) + except DatasetCreateFailedError as e: + return DatasetError( + error=str(e), + error_type="CreateFailedError", + timestamp=datetime.now(timezone.utc), + ) + except Exception as e: + logger.error("Failed to create dataset: %s", e, exc_info=True) + return DatasetError( + error=f"Failed to create dataset: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py new file mode 100644 index 000000000000..d129519f3f77 --- /dev/null +++ b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py @@ -0,0 +1,314 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Unit tests for create_dataset MCP tool.""" + +import logging +from unittest.mock import MagicMock, Mock, patch + +import pytest +from fastmcp import Client +from fastmcp.exceptions import ToolError + +from superset.mcp_service.app import mcp +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +def _make_mock_dataset( + dataset_id: int = 42, + table_name: str = "orders", + schema: str = "public", + database_name: str = "main_db", +) -> MagicMock: + dataset = MagicMock() + dataset.id = dataset_id + dataset.table_name = table_name + dataset.schema = schema + dataset.description = None + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = f"[{database_name}].[{schema}]" + dataset.url = f"/tablemodelview/edit/{dataset_id}" + dataset.database = MagicMock() + dataset.database.database_name = database_name + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + dataset.uuid = f"dataset-uuid-{dataset_id}" + dataset.columns = [] + dataset.metrics = [] + return dataset + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +class TestCreateDataset: + """Tests for the create_dataset MCP tool.""" + + @patch( + "superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand" + ) + @pytest.mark.asyncio + async def test_create_dataset_success(self, mock_command_class, mcp_server): + """Happy path: tool creates dataset and returns DatasetInfo.""" + mock_dataset = _make_mock_dataset() + mock_command = MagicMock() + mock_command.run.return_value = mock_dataset + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "public", + "table_name": "orders", + } + }, + ) + + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["id"] == 42 + assert data["table_name"] == "orders" + assert data["schema_name"] == "public" + + # Verify the command was called with the right properties + call_kwargs = mock_command_class.call_args[0][0] + assert call_kwargs["database"] == 1 + assert call_kwargs["schema"] == "public" + assert call_kwargs["table_name"] == "orders" + assert "owners" not in call_kwargs + + @patch( + "superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand" + ) + @pytest.mark.asyncio + async def test_create_dataset_with_owners(self, mock_command_class, mcp_server): + """Owners list is forwarded to the command when supplied.""" + mock_dataset = _make_mock_dataset() + mock_command = MagicMock() + mock_command.run.return_value = mock_dataset + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 2, + "schema": "sales", + "table_name": "transactions", + "owners": [5, 10], + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["id"] == 42 + + call_kwargs = mock_command_class.call_args[0][0] + assert call_kwargs["owners"] == [5, 10] + + @patch( + "superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand" + ) + @pytest.mark.asyncio + async def test_create_dataset_already_exists(self, mock_command_class, mcp_server): + """Returns DatasetError when a dataset for the table already exists.""" + from superset.commands.dataset.exceptions import DatasetExistsValidationError + from superset.sql.parse import Table + + mock_command = MagicMock() + mock_command.run.side_effect = DatasetExistsValidationError( + Table("orders", "public", None) + ) + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "public", + "table_name": "orders", + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "DatasetExistsError" + assert "error" in data + + @patch( + "superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand" + ) + @pytest.mark.asyncio + async def test_create_dataset_table_not_found( + self, mock_command_class, mcp_server + ): + """Returns DatasetError when the physical table does not exist in the DB.""" + from superset.commands.dataset.exceptions import TableNotFoundValidationError + from superset.sql.parse import Table + + mock_command = MagicMock() + mock_command.run.side_effect = TableNotFoundValidationError( + Table("missing_table", "public", None) + ) + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "public", + "table_name": "missing_table", + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "TableNotFoundError" + + @patch( + "superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand" + ) + @pytest.mark.asyncio + async def test_create_dataset_unexpected_error( + self, mock_command_class, mcp_server + ): + """Unexpected exceptions are caught and returned as InternalError.""" + mock_command = MagicMock() + mock_command.run.side_effect = RuntimeError("DB connection lost") + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "public", + "table_name": "orders", + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "InternalError" + assert "DB connection lost" in data["error"] + + @pytest.mark.asyncio + async def test_create_dataset_missing_required_fields(self, mcp_server): + """Missing required fields raise a validation error before the tool runs.""" + async with Client(mcp_server) as client: + with pytest.raises(ToolError): + await client.call_tool( + "create_dataset", + { + "request": { + # database_id and table_name are omitted intentionally + "schema": "public", + } + }, + ) + + @patch( + "superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand" + ) + @pytest.mark.asyncio + async def test_create_dataset_returns_full_dataset_info( + self, mock_command_class, mcp_server + ): + """The returned DatasetInfo includes columns, metrics, and all core fields.""" + mock_dataset = _make_mock_dataset(dataset_id=99, table_name="sales", schema="dw") + + col = MagicMock() + col.column_name = "amount" + col.verbose_name = "Amount" + col.type = "NUMERIC" + col.is_dttm = False + col.groupby = True + col.filterable = True + col.description = "Sale amount" + mock_dataset.columns = [col] + + metric = MagicMock() + metric.metric_name = "total_sales" + metric.verbose_name = "Total Sales" + metric.expression = "SUM(amount)" + metric.description = "Sum of amounts" + metric.d3format = None + mock_dataset.metrics = [metric] + + mock_command = MagicMock() + mock_command.run.return_value = mock_dataset + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "dw", + "table_name": "sales", + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["id"] == 99 + assert data["table_name"] == "sales" + assert data["schema_name"] == "dw" + assert data["is_virtual"] is False + assert len(data["columns"]) == 1 + assert data["columns"][0]["column_name"] == "amount" + assert len(data["metrics"]) == 1 + assert data["metrics"][0]["metric_name"] == "total_sales" From 26784693715e8e039ee4b6ec279b0f12521e5ea3 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 14 May 2026 20:19:17 +0000 Subject: [PATCH 2/9] style: ruff format create_dataset tool files --- superset/mcp_service/dataset/schemas.py | 4 ++- .../dataset/tool/test_create_dataset.py | 32 ++++++------------- 2 files changed, 13 insertions(+), 23 deletions(-) diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py index 2b40eee7f34b..d14b2fb4a112 100644 --- a/superset/mcp_service/dataset/schemas.py +++ b/superset/mcp_service/dataset/schemas.py @@ -291,7 +291,9 @@ class CreateDatasetRequest(BaseModel): database_id: Annotated[ int, - Field(description="ID of the database connection to register the table against"), + Field( + description="ID of the database connection to register the table against" + ), ] schema: Annotated[ str, diff --git a/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py index d129519f3f77..ae062f923460 100644 --- a/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py +++ b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py @@ -87,9 +87,7 @@ def mock_auth(): class TestCreateDataset: """Tests for the create_dataset MCP tool.""" - @patch( - "superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand" - ) + @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") @pytest.mark.asyncio async def test_create_dataset_success(self, mock_command_class, mcp_server): """Happy path: tool creates dataset and returns DatasetInfo.""" @@ -123,9 +121,7 @@ async def test_create_dataset_success(self, mock_command_class, mcp_server): assert call_kwargs["table_name"] == "orders" assert "owners" not in call_kwargs - @patch( - "superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand" - ) + @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") @pytest.mark.asyncio async def test_create_dataset_with_owners(self, mock_command_class, mcp_server): """Owners list is forwarded to the command when supplied.""" @@ -153,9 +149,7 @@ async def test_create_dataset_with_owners(self, mock_command_class, mcp_server): call_kwargs = mock_command_class.call_args[0][0] assert call_kwargs["owners"] == [5, 10] - @patch( - "superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand" - ) + @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") @pytest.mark.asyncio async def test_create_dataset_already_exists(self, mock_command_class, mcp_server): """Returns DatasetError when a dataset for the table already exists.""" @@ -184,13 +178,9 @@ async def test_create_dataset_already_exists(self, mock_command_class, mcp_serve assert data["error_type"] == "DatasetExistsError" assert "error" in data - @patch( - "superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand" - ) + @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") @pytest.mark.asyncio - async def test_create_dataset_table_not_found( - self, mock_command_class, mcp_server - ): + async def test_create_dataset_table_not_found(self, mock_command_class, mcp_server): """Returns DatasetError when the physical table does not exist in the DB.""" from superset.commands.dataset.exceptions import TableNotFoundValidationError from superset.sql.parse import Table @@ -216,9 +206,7 @@ async def test_create_dataset_table_not_found( data = json.loads(result.content[0].text) assert data["error_type"] == "TableNotFoundError" - @patch( - "superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand" - ) + @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") @pytest.mark.asyncio async def test_create_dataset_unexpected_error( self, mock_command_class, mcp_server @@ -259,15 +247,15 @@ async def test_create_dataset_missing_required_fields(self, mcp_server): }, ) - @patch( - "superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand" - ) + @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") @pytest.mark.asyncio async def test_create_dataset_returns_full_dataset_info( self, mock_command_class, mcp_server ): """The returned DatasetInfo includes columns, metrics, and all core fields.""" - mock_dataset = _make_mock_dataset(dataset_id=99, table_name="sales", schema="dw") + mock_dataset = _make_mock_dataset( + dataset_id=99, table_name="sales", schema="dw" + ) col = MagicMock() col.column_name = "amount" From ef3c2035458a4a2426eebd317f87f1ed5c2a1dae Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Fri, 22 May 2026 01:44:41 +0000 Subject: [PATCH 3/9] fix(mcp): fix create_dataset CI failures - schemas.py: restore full apache/master version and add CreateDatasetRequest (previous cherry-pick used an older shorter version missing helper functions _sanitize_dataset_info_for_llm_context, _humanize_timestamp, etc.) - create_dataset.py: remove parse_request decorator (not in apache/master yet) --- superset/mcp_service/dataset/schemas.py | 405 +++++++++++++++++- .../dataset/tool/create_dataset.py | 2 - 2 files changed, 384 insertions(+), 23 deletions(-) diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py index d14b2fb4a112..0bbc4061f8c1 100644 --- a/superset/mcp_service/dataset/schemas.py +++ b/superset/mcp_service/dataset/schemas.py @@ -24,21 +24,35 @@ from datetime import datetime from typing import Annotated, Any, Dict, List, Literal +import humanize from pydantic import ( BaseModel, ConfigDict, Field, + field_validator, model_serializer, model_validator, PositiveInt, ) from superset.daos.base import ColumnOperator, ColumnOperatorEnum -from superset.mcp_service.common.cache_schemas import MetadataCacheControl +from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata +from superset.mcp_service.common.cache_schemas import ( + CacheStatus, + CreatedByMeMixin, + MetadataCacheControl, + OwnedByMeMixin, + QueryCacheControl, +) +from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE +from superset.mcp_service.privacy import filter_user_directory_fields from superset.mcp_service.system.schemas import ( PaginationInfo, TagInfo, - UserInfo, +) +from superset.mcp_service.utils import ( + escape_llm_context_delimiters, + sanitize_for_llm_context, ) from superset.utils import json @@ -85,7 +99,11 @@ class TableColumnInfo(BaseModel): class SqlMetricInfo(BaseModel): - metric_name: str = Field(..., description="Metric name") + metric_name: str = Field( + ..., + description="Saved metric name. In chart configs, reference as " + '{"name": "", "saved_metric": true}.', + ) verbose_name: str | None = Field(None, description="Verbose name") expression: str | None = Field(None, description="SQL expression") description: str | None = Field(None, description="Metric description") @@ -98,22 +116,23 @@ class DatasetInfo(BaseModel): schema_name: str | None = Field(None, description="Schema name", alias="schema") database_name: str | None = Field(None, description="Database name") description: str | None = Field(None, description="Dataset description") - changed_by: str | None = Field(None, description="Last modifier (username)") + certified_by: str | None = Field( + None, description="Name of the person or team who certified this dataset" + ) + certification_details: str | None = Field( + None, description="Certification details or reason" + ) changed_on: str | datetime | None = Field( None, description="Last modification timestamp" ) changed_on_humanized: str | None = Field( None, description="Humanized modification time" ) - created_by: str | None = Field(None, description="Dataset creator (username)") created_on: str | datetime | None = Field(None, description="Creation timestamp") created_on_humanized: str | None = Field( None, description="Humanized creation time" ) tags: List[TagInfo] = Field(default_factory=list, description="Dataset tags") - owners: List[UserInfo] = Field( - default_factory=list, description="DatasetInfo owners" - ) is_virtual: bool | None = Field( None, description="Whether the dataset is virtual (uses SQL)" ) @@ -134,7 +153,9 @@ class DatasetInfo(BaseModel): default_factory=list, description="Columns in the dataset" ) metrics: List[SqlMetricInfo] = Field( - default_factory=list, description="Metrics in the dataset" + default_factory=list, + description="Saved metrics (pre-defined aggregations). " + "NOT columns — use saved_metric=true in chart configs.", ) is_favorite: bool | None = Field( None, description="Whether this dataset is favorited by the current user" @@ -145,7 +166,7 @@ class DatasetInfo(BaseModel): populate_by_name=True, # Allow both 'schema' (alias) and 'schema_name' (field) ) - @model_serializer(mode="wrap", when_used="json") + @model_serializer(mode="wrap") def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]: """Filter fields based on serialization context. @@ -153,16 +174,18 @@ def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any Otherwise, include all fields (default behavior). """ # Get full serialization - data = serializer(self) + data = filter_user_directory_fields(serializer(self)) + + # Normalize alias: Pydantic serializes as 'schema_name' (field name) + # but the DAO column and API convention is 'schema' + if "schema_name" in data: + data["schema"] = data.pop("schema_name") # Check if we have a context with select_columns if info.context and isinstance(info.context, dict): select_columns = info.context.get("select_columns") if select_columns: - # Handle alias: 'schema' -> 'schema_name' requested_fields = set(select_columns) - if "schema" in requested_fields: - requested_fields.add("schema_name") # Filter to only requested fields return {k: v for k, v in data.items() if k in requested_fields} @@ -205,7 +228,7 @@ class DatasetList(BaseModel): model_config = ConfigDict(ser_json_timedelta="iso8601") -class ListDatasetsRequest(MetadataCacheControl): +class ListDatasetsRequest(OwnedByMeMixin, CreatedByMeMixin, MetadataCacheControl): """Request schema for list_datasets with clear, unambiguous types.""" filters: Annotated[ @@ -247,13 +270,18 @@ class ListDatasetsRequest(MetadataCacheControl): Field(default=1, description="Page number for pagination (1-based)"), ] page_size: Annotated[ - PositiveInt, Field(default=10, description="Number of items per page") + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Number of items per page (max {MAX_PAGE_SIZE})", + ), ] @model_validator(mode="after") def validate_search_and_filters(self) -> "ListDatasetsRequest": - """Prevent using both search and filters simultaneously to avoid query - conflicts.""" + """Prevent using both search and filters simultaneously.""" if self.search and self.filters: raise ValueError( "Cannot use both 'search' and 'filters' parameters simultaneously. " @@ -269,12 +297,22 @@ class DatasetError(BaseModel): timestamp: str | datetime | None = Field(None, description="Error timestamp") model_config = ConfigDict(ser_json_timedelta="iso8601") + @field_validator("error") + @classmethod + def sanitize_error_for_llm_context(cls, value: str) -> str: + """Wrap error text before it is exposed to LLM context.""" + return sanitize_for_llm_context(value, field_path=("error",)) + @classmethod def create(cls, error: str, error_type: str) -> "DatasetError": """Create a standardized DatasetError with timestamp.""" - from datetime import datetime + from datetime import datetime, timezone - return cls(error=error, error_type=error_type, timestamp=datetime.now()) + return cls( + error=error, + error_type=error_type, + timestamp=datetime.now(timezone.utc), + ) class GetDatasetInfoRequest(MetadataCacheControl): @@ -307,14 +345,339 @@ class CreateDatasetRequest(BaseModel): List[int] | None, Field( default=None, - description="Optional list of owner user IDs. Defaults to the calling user.", + description="Optional list of owner user IDs. " + "Defaults to the calling user.", + ), + ] + + +class CreateVirtualDatasetRequest(BaseModel): + """Request schema for create_virtual_dataset.""" + + model_config = ConfigDict(populate_by_name=True) + + database_id: int = Field( + ..., + description="ID of the database connection to use. " + "Use list_databases to find valid IDs.", + ) + sql: str = Field( + ..., + description="SQL query to save as a virtual dataset. " + "Can be a JOIN, CTE, aggregation, or any valid SELECT.", + ) + dataset_name: str = Field( + ..., + min_length=1, + max_length=250, + description="Name for the new virtual dataset.", + ) + schema_name: str | None = Field( + None, + alias="schema", + description="Schema to associate with the dataset (optional).", + ) + catalog: str | None = Field( + None, + description="Catalog to associate with the dataset (optional).", + ) + description: str | None = Field( + None, + description="Human-readable description of the dataset (optional).", + ) + + @field_validator("sql") + @classmethod + def sql_must_not_be_empty(cls, v: str) -> str: + if not v.strip(): + raise ValueError("sql must not be empty") + return v.strip() + + @field_validator("dataset_name") + @classmethod + def dataset_name_must_not_be_empty(cls, v: str) -> str: + if not v.strip(): + raise ValueError("dataset_name must not be empty") + return v.strip() + + +class CreateVirtualDatasetResponse(BaseModel): + """Response schema for create_virtual_dataset.""" + + id: int | None = Field( + None, + description="Dataset ID. Pass this as dataset_id to generate_chart " + "or generate_explore_link. None if creation failed.", + ) + dataset_name: str = Field(..., description="Name of the created dataset.") + sql: str = Field(..., description="SQL query stored in the dataset.") + database_id: int = Field(..., description="Database ID used.") + columns: List[str] = Field( + default_factory=list, + description="Column names available for charting. " + "Use these when building chart configs.", + ) + url: str | None = Field( + None, + description="URL to view/edit the dataset in Superset. None if failed.", + ) + error: str | None = Field( + None, + description="Error message if creation failed, otherwise null.", + ) + + +VALID_FILTER_OPS = Literal[ + "==", + "!=", + ">", + "<", + ">=", + "<=", + "LIKE", + "NOT LIKE", + "ILIKE", + "NOT ILIKE", + "IN", + "NOT IN", + "IS NULL", + "IS NOT NULL", + "IS TRUE", + "IS FALSE", + "TEMPORAL_RANGE", +] + + +class QueryDatasetFilter(BaseModel): + """A single filter condition for dataset queries.""" + + col: str = Field(..., description="Column name to filter on") + op: VALID_FILTER_OPS = Field( + ..., + description=( + 'Filter operator. Use "==" for equals, "!=" for not equals, ' + '"IN" / "NOT IN" for membership, "IS NULL" / "IS NOT NULL", ' + '"LIKE" for pattern matching, "TEMPORAL_RANGE" for time filters.' + ), + ) + val: Any = Field( + default=None, + description="Filter value (omit for IS NULL/IS NOT NULL)", + ) + + +class QueryDatasetRequest(QueryCacheControl): + """Request schema for query_dataset tool.""" + + dataset_id: int | str = Field( + ..., + description="Dataset identifier — numeric ID or UUID string.", + ) + metrics: List[str] = Field( + default_factory=list, + description=( + "Saved metric names to compute (e.g. ['count', 'avg_revenue']). " + "Use get_dataset_info to discover available metrics." + ), + ) + columns: List[str] = Field( + default_factory=list, + description=( + "Column/dimension names for GROUP BY or SELECT " + "(e.g. ['category', 'region']). " + "Use get_dataset_info to discover available columns." + ), + ) + filters: List[QueryDatasetFilter] = Field( + default_factory=list, + description=( + 'Filter conditions (e.g. [{"col": "status", "op": "==", "val": "active"}]).' + ), + ) + time_range: str | None = Field( + default=None, + description=( + "Time range filter (e.g. 'Last 7 days', 'Last month', " + "'2024-01-01 : 2024-12-31'). Requires a temporal column " + "on the dataset." + ), + ) + time_column: str | None = Field( + default=None, + description=( + "Temporal column to apply time_range to. " + "Defaults to the dataset's main datetime column." ), + ) + order_by: List[str] | None = Field( + default=None, + description="Column or metric names to sort results by.", + ) + order_desc: bool = Field( + default=True, + description="Sort descending (True) or ascending (False).", + ) + row_limit: int = Field( + default=1000, + ge=1, + le=50000, + description="Maximum number of rows to return (default 1000, max 50000).", + ) + + @model_validator(mode="after") + def validate_metrics_or_columns(self) -> "QueryDatasetRequest": + """At least one of metrics or columns must be provided.""" + if not self.metrics and not self.columns: + raise ValueError( + "At least one of 'metrics' or 'columns' must be provided. " + "Use get_dataset_info to discover available metrics and columns." + ) + return self + + +class QueryDatasetResponse(BaseModel): + """Response schema for query_dataset tool.""" + + model_config = ConfigDict(ser_json_timedelta="iso8601") + + dataset_id: int = Field(..., description="Dataset ID") + dataset_name: str = Field(..., description="Dataset name") + columns: List[DataColumn] = Field( + default_factory=list, description="Column metadata for returned data" + ) + data: List[Dict[str, Any]] = Field( + default_factory=list, description="Query result rows" + ) + row_count: int = Field(0, description="Number of rows returned") + total_rows: int | None = Field( + None, description="Total row count from the query engine" + ) + summary: str = Field("", description="Human-readable summary of the results") + performance: PerformanceMetadata | None = Field( + None, description="Query performance metadata" + ) + cache_status: CacheStatus | None = Field( + None, description="Cache hit/miss information" + ) + applied_filters: List[QueryDatasetFilter] = Field( + default_factory=list, description="Filters that were applied to the query" + ) + warnings: List[str] = Field( + default_factory=list, description="Any warnings encountered during execution" + ) + + +def _parse_json_field(obj: Any, field_name: str) -> Dict[str, Any] | None: + """Parse a field that may be stored as a JSON string into a dict.""" + value = getattr(obj, field_name, None) + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, dict): + return parsed + except (ValueError, TypeError): + pass + return None + return value + + +def _humanize_timestamp(dt: datetime | None) -> str | None: + """Convert a datetime to a humanized string like '2 hours ago'.""" + if dt is None: + return None + return humanize.naturaltime(datetime.now() - dt) + + +def _sanitize_dataset_info_for_llm_context(dataset_info: DatasetInfo) -> DatasetInfo: + """Wrap dataset read-path descriptive fields before LLM exposure.""" + payload = dataset_info.model_dump(mode="python") + + for field_name in ("description", "certified_by", "certification_details", "sql"): + payload[field_name] = sanitize_for_llm_context( + payload.get(field_name), + field_path=(field_name,), + ) + + for field_name in ("table_name", "schema_name", "database_name", "schema_perm"): + payload[field_name] = escape_llm_context_delimiters(payload.get(field_name)) + + payload["extra"] = sanitize_for_llm_context( + payload.get("extra"), + field_path=("extra",), + excluded_field_names=frozenset(), + ) + + for field_name in ("params", "template_params"): + payload[field_name] = sanitize_for_llm_context( + payload.get(field_name), + field_path=(field_name,), + excluded_field_names=frozenset(), + ) + + payload["columns"] = [ + { + **column, + "column_name": escape_llm_context_delimiters( + column.get("column_name"), + ), + "description": sanitize_for_llm_context( + column.get("description"), + field_path=("columns", str(index), "description"), + ), + "verbose_name": sanitize_for_llm_context( + column.get("verbose_name"), + field_path=("columns", str(index), "verbose_name"), + ), + } + for index, column in enumerate(payload.get("columns", [])) + ] + + payload["metrics"] = [ + { + **metric, + "metric_name": escape_llm_context_delimiters( + metric.get("metric_name"), + ), + "expression": sanitize_for_llm_context( + metric.get("expression"), + field_path=("metrics", str(index), "expression"), + ), + "description": sanitize_for_llm_context( + metric.get("description"), + field_path=("metrics", str(index), "description"), + ), + "verbose_name": sanitize_for_llm_context( + metric.get("verbose_name"), + field_path=("metrics", str(index), "verbose_name"), + ), + } + for index, metric in enumerate(payload.get("metrics", [])) ] + payload["tags"] = [ + { + **tag, + "name": sanitize_for_llm_context( + tag.get("name"), + field_path=("tags", str(index), "name"), + ), + "description": sanitize_for_llm_context( + tag.get("description"), + field_path=("tags", str(index), "description"), + ), + } + for index, tag in enumerate(payload.get("tags", [])) + ] + + return DatasetInfo.model_validate(payload) + def serialize_dataset_object(dataset: Any) -> DatasetInfo | None: if not dataset: return None + + from superset.mcp_service.utils.url_utils import get_superset_base_url + params = getattr(dataset, "params", None) if isinstance(params, str): try: diff --git a/superset/mcp_service/dataset/tool/create_dataset.py b/superset/mcp_service/dataset/tool/create_dataset.py index 0dfb84432408..6484f8ab36de 100644 --- a/superset/mcp_service/dataset/tool/create_dataset.py +++ b/superset/mcp_service/dataset/tool/create_dataset.py @@ -37,14 +37,12 @@ DatasetInfo, serialize_dataset_object, ) -from superset.mcp_service.utils.schema_utils import parse_request logger = logging.getLogger(__name__) @mcp.tool(tags=["mutate"]) @mcp_auth_hook -@parse_request(CreateDatasetRequest) def create_dataset( request: CreateDatasetRequest, ctx: Context ) -> DatasetInfo | DatasetError: From e2d60960a514a45d8a2fc6c864fdcbd20fe53aff Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Tue, 26 May 2026 19:17:05 +0000 Subject: [PATCH 4/9] =?UTF-8?q?test(mcp):=20fix=20patch=20paths=20in=20tes?= =?UTF-8?q?t=5Fcreate=5Fdataset=20=E2=80=94=20CreateDatasetCommand=20is=20?= =?UTF-8?q?a=20lazy=20import?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CreateDatasetCommand is imported inside the function body, so patching at superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand fails with AttributeError. Patch at the source module instead. Also fix data["schema_name"] assertions: DatasetInfo.model_serializer renames the field to "schema" in the serialized output. --- .../dataset/tool/test_create_dataset.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py index ae062f923460..ba2898eb6603 100644 --- a/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py +++ b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py @@ -87,7 +87,7 @@ def mock_auth(): class TestCreateDataset: """Tests for the create_dataset MCP tool.""" - @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") + @patch("superset.commands.dataset.create.CreateDatasetCommand") @pytest.mark.asyncio async def test_create_dataset_success(self, mock_command_class, mcp_server): """Happy path: tool creates dataset and returns DatasetInfo.""" @@ -112,7 +112,7 @@ async def test_create_dataset_success(self, mock_command_class, mcp_server): data = json.loads(result.content[0].text) assert data["id"] == 42 assert data["table_name"] == "orders" - assert data["schema_name"] == "public" + assert data["schema"] == "public" # Verify the command was called with the right properties call_kwargs = mock_command_class.call_args[0][0] @@ -121,7 +121,7 @@ async def test_create_dataset_success(self, mock_command_class, mcp_server): assert call_kwargs["table_name"] == "orders" assert "owners" not in call_kwargs - @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") + @patch("superset.commands.dataset.create.CreateDatasetCommand") @pytest.mark.asyncio async def test_create_dataset_with_owners(self, mock_command_class, mcp_server): """Owners list is forwarded to the command when supplied.""" @@ -149,7 +149,7 @@ async def test_create_dataset_with_owners(self, mock_command_class, mcp_server): call_kwargs = mock_command_class.call_args[0][0] assert call_kwargs["owners"] == [5, 10] - @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") + @patch("superset.commands.dataset.create.CreateDatasetCommand") @pytest.mark.asyncio async def test_create_dataset_already_exists(self, mock_command_class, mcp_server): """Returns DatasetError when a dataset for the table already exists.""" @@ -178,7 +178,7 @@ async def test_create_dataset_already_exists(self, mock_command_class, mcp_serve assert data["error_type"] == "DatasetExistsError" assert "error" in data - @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") + @patch("superset.commands.dataset.create.CreateDatasetCommand") @pytest.mark.asyncio async def test_create_dataset_table_not_found(self, mock_command_class, mcp_server): """Returns DatasetError when the physical table does not exist in the DB.""" @@ -206,7 +206,7 @@ async def test_create_dataset_table_not_found(self, mock_command_class, mcp_serv data = json.loads(result.content[0].text) assert data["error_type"] == "TableNotFoundError" - @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") + @patch("superset.commands.dataset.create.CreateDatasetCommand") @pytest.mark.asyncio async def test_create_dataset_unexpected_error( self, mock_command_class, mcp_server @@ -247,7 +247,7 @@ async def test_create_dataset_missing_required_fields(self, mcp_server): }, ) - @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") + @patch("superset.commands.dataset.create.CreateDatasetCommand") @pytest.mark.asyncio async def test_create_dataset_returns_full_dataset_info( self, mock_command_class, mcp_server @@ -294,7 +294,7 @@ async def test_create_dataset_returns_full_dataset_info( data = json.loads(result.content[0].text) assert data["id"] == 99 assert data["table_name"] == "sales" - assert data["schema_name"] == "dw" + assert data["schema"] == "dw" assert data["is_virtual"] is False assert len(data["columns"]) == 1 assert data["columns"][0]["column_name"] == "amount" From f1245e34d78867339f0c65ca6dc88d79d1d724a8 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 28 May 2026 04:47:09 +0000 Subject: [PATCH 5/9] fix(mcp): restore missing tool imports and fix test mock for create_dataset Restores tool imports that were accidentally dropped from app.py: create_virtual_dataset, query_dataset, get_chart_sql, get_chart_type_schema, get_database_info, list_databases, save_sql_query. Exports create_virtual_dataset and query_dataset from dataset/tool/__init__.py. Fixes KeyError in test_create_dataset by setting is_favorite=None on the mock dataset to avoid Pydantic bool|None validation errors from MagicMock auto-attributes. Co-Authored-By: Claude Sonnet 4.6 --- superset/mcp_service/app.py | 9 +++++++++ superset/mcp_service/dataset/tool/__init__.py | 8 ++++++-- .../mcp_service/dataset/tool/test_create_dataset.py | 3 +++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 0198b6252f33..71d831a2ef92 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -555,6 +555,8 @@ def create_mcp_app( get_chart_data, get_chart_info, get_chart_preview, + get_chart_sql, + get_chart_type_schema, list_charts, update_chart, update_chart_preview, @@ -565,10 +567,16 @@ def create_mcp_app( get_dashboard_info, list_dashboards, ) +from superset.mcp_service.database.tool import ( # noqa: F401, E402 + get_database_info, + list_databases, +) from superset.mcp_service.dataset.tool import ( # noqa: F401, E402 create_dataset, + create_virtual_dataset, get_dataset_info, list_datasets, + query_dataset, ) from superset.mcp_service.explore.tool import ( # noqa: F401, E402 generate_explore_link, @@ -576,6 +584,7 @@ def create_mcp_app( from superset.mcp_service.sql_lab.tool import ( # noqa: F401, E402 execute_sql, open_sql_lab_with_context, + save_sql_query, ) from superset.mcp_service.system import ( # noqa: F401, E402 prompts as system_prompts, diff --git a/superset/mcp_service/dataset/tool/__init__.py b/superset/mcp_service/dataset/tool/__init__.py index 025b4ae1b9a9..6fd3c12133c2 100644 --- a/superset/mcp_service/dataset/tool/__init__.py +++ b/superset/mcp_service/dataset/tool/__init__.py @@ -16,11 +16,15 @@ # under the License. from .create_dataset import create_dataset +from .create_virtual_dataset import create_virtual_dataset from .get_dataset_info import get_dataset_info from .list_datasets import list_datasets +from .query_dataset import query_dataset __all__ = [ - "list_datasets", - "get_dataset_info", "create_dataset", + "create_virtual_dataset", + "get_dataset_info", + "list_datasets", + "query_dataset", ] diff --git a/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py index ba2898eb6603..b8438da26d5b 100644 --- a/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py +++ b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py @@ -42,6 +42,8 @@ def _make_mock_dataset( dataset.table_name = table_name dataset.schema = schema dataset.description = None + dataset.certified_by = None + dataset.certification_details = None dataset.changed_by_name = "admin" dataset.changed_on = None dataset.changed_on_humanized = None @@ -51,6 +53,7 @@ def _make_mock_dataset( dataset.tags = [] dataset.owners = [] dataset.is_virtual = False + dataset.is_favorite = None dataset.database_id = 1 dataset.schema_perm = f"[{database_name}].[{schema}]" dataset.url = f"/tablemodelview/edit/{dataset_id}" From 39a77095e03ac95be4068b32a4b592c79f855faf Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 28 May 2026 06:05:52 +0000 Subject: [PATCH 6/9] =?UTF-8?q?fix(mcp):=20address=20bot=20review=20?= =?UTF-8?q?=E2=80=94=20make=20schema=20optional,=20switch=20to=20@tool=20d?= =?UTF-8?q?ecorator,=20normalize=20whitespace?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - CreateDatasetRequest.schema is now str | None (default None) so databases without schema namespaces (e.g. SQLite) can register tables without error - create_dataset switches from @mcp.tool/@mcp_auth_hook to the standard @tool decorator from superset_core.mcp.decorators, adding Dataset write RBAC and ToolAnnotations consistent with create_virtual_dataset - Blank/whitespace-only schema values are normalized to None before forwarding to CreateDatasetCommand, avoiding spurious table-not-found failures - Unexpected exceptions now re-raise (middleware handles them) instead of being swallowed into an InternalError response; test updated accordingly - Uses DatasetError.create() factory and event_logger/ctx instrumentation Co-Authored-By: Claude Sonnet 4.6 --- superset/mcp_service/dataset/schemas.py | 8 +- .../dataset/tool/create_dataset.py | 97 ++++++++++--------- .../dataset/tool/test_create_dataset.py | 27 +++--- 3 files changed, 67 insertions(+), 65 deletions(-) diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py index 0bbc4061f8c1..88eb9c3361b6 100644 --- a/superset/mcp_service/dataset/schemas.py +++ b/superset/mcp_service/dataset/schemas.py @@ -334,8 +334,12 @@ class CreateDatasetRequest(BaseModel): ), ] schema: Annotated[ - str, - Field(description="Schema (namespace) where the table lives, e.g. 'public'"), + str | None, + Field( + default=None, + description="Schema (namespace) where the table lives, e.g. 'public'. " + "Omit or pass None for databases without schema namespaces (e.g. SQLite).", + ), ] table_name: Annotated[ str, diff --git a/superset/mcp_service/dataset/tool/create_dataset.py b/superset/mcp_service/dataset/tool/create_dataset.py index 6484f8ab36de..975327e27658 100644 --- a/superset/mcp_service/dataset/tool/create_dataset.py +++ b/superset/mcp_service/dataset/tool/create_dataset.py @@ -25,12 +25,11 @@ """ import logging -from datetime import datetime, timezone from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations -from superset.mcp_service.app import mcp -from superset.mcp_service.auth import mcp_auth_hook +from superset.extensions import event_logger from superset.mcp_service.dataset.schemas import ( CreateDatasetRequest, DatasetError, @@ -41,9 +40,17 @@ logger = logging.getLogger(__name__) -@mcp.tool(tags=["mutate"]) -@mcp_auth_hook -def create_dataset( +@tool( + tags=["mutate"], + class_permission_name="Dataset", + method_permission_name="write", + annotations=ToolAnnotations( + title="Register physical table as dataset", + readOnlyHint=False, + destructiveHint=False, + ), +) +async def create_dataset( request: CreateDatasetRequest, ctx: Context ) -> DatasetInfo | DatasetError: """Register a physical table as a Superset dataset. @@ -55,10 +62,11 @@ def create_dataset( Required fields: - database_id: ID of the existing database connection - - schema: Schema/namespace where the table lives (e.g. "public") - table_name: Exact name of the physical table to register Optional fields: + - schema: Schema/namespace where the table lives (e.g. "public"). Omit for + databases without schema namespaces (e.g. SQLite). - owners: List of user IDs to set as owners (defaults to calling user) Example: @@ -73,6 +81,14 @@ def create_dataset( Returns DatasetInfo on success or DatasetError on failure. Use list_databases to find the correct database_id. """ + # Normalize schema: strip whitespace and treat blank strings as None + schema = request.schema.strip() if request.schema else None + + await ctx.info( + "Registering physical table as dataset: database_id=%s, table=%s.%s" + % (request.database_id, schema, request.table_name) + ) + try: from superset.commands.dataset.create import CreateDatasetCommand from superset.commands.dataset.exceptions import ( @@ -82,61 +98,46 @@ def create_dataset( TableNotFoundValidationError, ) - dataset_properties = { + dataset_properties: dict[str, object] = { "database": request.database_id, - "schema": request.schema, "table_name": request.table_name, } + if schema is not None: + dataset_properties["schema"] = schema if request.owners is not None: dataset_properties["owners"] = request.owners - command = CreateDatasetCommand(dataset_properties) - dataset = command.run() + with event_logger.log_context(action="mcp.create_dataset.create"): + dataset = CreateDatasetCommand(dataset_properties).run() result = serialize_dataset_object(dataset) if result is None: - return DatasetError( + return DatasetError.create( error="Dataset was created but could not be serialized", error_type="SerializationError", - timestamp=datetime.now(timezone.utc), ) - logger.info( - "Created dataset id=%s table=%s.%s", - dataset.id, - request.schema, - request.table_name, + await ctx.info( + "Dataset registered: id=%s, table=%s.%s" + % (dataset.id, schema, request.table_name) ) return result - except DatasetExistsValidationError as e: - return DatasetError( - error=str(e), - error_type="DatasetExistsError", - timestamp=datetime.now(timezone.utc), - ) - except TableNotFoundValidationError as e: - return DatasetError( - error=str(e), - error_type="TableNotFoundError", - timestamp=datetime.now(timezone.utc), - ) - except DatasetInvalidError as e: - return DatasetError( - error=str(e), - error_type="ValidationError", - timestamp=datetime.now(timezone.utc), - ) - except DatasetCreateFailedError as e: - return DatasetError( - error=str(e), - error_type="CreateFailedError", - timestamp=datetime.now(timezone.utc), - ) - except Exception as e: - logger.error("Failed to create dataset: %s", e, exc_info=True) - return DatasetError( - error=f"Failed to create dataset: {str(e)}", - error_type="InternalError", - timestamp=datetime.now(timezone.utc), + except DatasetExistsValidationError as exc: + await ctx.warning("Dataset already exists: %s" % (str(exc),)) + return DatasetError.create(error=str(exc), error_type="DatasetExistsError") + except TableNotFoundValidationError as exc: + await ctx.warning("Table not found: %s" % (str(exc),)) + return DatasetError.create(error=str(exc), error_type="TableNotFoundError") + except DatasetInvalidError as exc: + messages = exc.normalized_messages() + await ctx.warning("Dataset validation failed: %s" % (messages,)) + return DatasetError.create(error=str(messages), error_type="ValidationError") + except DatasetCreateFailedError as exc: + await ctx.error("Dataset creation failed: %s" % (str(exc),)) + return DatasetError.create(error=str(exc), error_type="CreateFailedError") + except Exception as exc: + await ctx.error( + "Unexpected error creating dataset: %s: %s" % (type(exc).__name__, str(exc)) ) + raise diff --git a/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py index b8438da26d5b..9123abaf53d8 100644 --- a/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py +++ b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py @@ -214,26 +214,23 @@ async def test_create_dataset_table_not_found(self, mock_command_class, mcp_serv async def test_create_dataset_unexpected_error( self, mock_command_class, mcp_server ): - """Unexpected exceptions are caught and returned as InternalError.""" + """Unexpected exceptions are re-raised as ToolError (handled by middleware).""" mock_command = MagicMock() mock_command.run.side_effect = RuntimeError("DB connection lost") mock_command_class.return_value = mock_command async with Client(mcp_server) as client: - result = await client.call_tool( - "create_dataset", - { - "request": { - "database_id": 1, - "schema": "public", - "table_name": "orders", - } - }, - ) - - data = json.loads(result.content[0].text) - assert data["error_type"] == "InternalError" - assert "DB connection lost" in data["error"] + with pytest.raises(ToolError): + await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "public", + "table_name": "orders", + } + }, + ) @pytest.mark.asyncio async def test_create_dataset_missing_required_fields(self, mcp_server): From 65476671e853466f85af7350a3aeb51c16453d7c Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 28 May 2026 08:00:58 +0000 Subject: [PATCH 7/9] test(mcp): add DatasetInvalidError test; add min_length to table_name - Add test_create_dataset_invalid_error to cover the DatasetInvalidError handler in create_dataset (previously untested path) - Add min_length=1 to CreateDatasetRequest.table_name to reject empty strings at the schema layer, consistent with CreateVirtualDatasetRequest Co-Authored-By: Claude Sonnet 4.6 --- superset/mcp_service/dataset/schemas.py | 5 +++- .../dataset/tool/test_create_dataset.py | 26 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py index 88eb9c3361b6..fa39e4ca3647 100644 --- a/superset/mcp_service/dataset/schemas.py +++ b/superset/mcp_service/dataset/schemas.py @@ -343,7 +343,10 @@ class CreateDatasetRequest(BaseModel): ] table_name: Annotated[ str, - Field(description="Name of the physical table to register as a dataset"), + Field( + min_length=1, + description="Name of the physical table to register as a dataset", + ), ] owners: Annotated[ List[int] | None, diff --git a/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py index 9123abaf53d8..c04aea881394 100644 --- a/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py +++ b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py @@ -209,6 +209,32 @@ async def test_create_dataset_table_not_found(self, mock_command_class, mcp_serv data = json.loads(result.content[0].text) assert data["error_type"] == "TableNotFoundError" + @patch("superset.commands.dataset.create.CreateDatasetCommand") + @pytest.mark.asyncio + async def test_create_dataset_invalid_error(self, mock_command_class, mcp_server): + """DatasetInvalidError is returned as ValidationError type.""" + from superset.commands.dataset.exceptions import DatasetInvalidError + + mock_command = MagicMock() + mock_command.run.side_effect = DatasetInvalidError() + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "public", + "table_name": "orders", + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "ValidationError" + assert "error" in data + @patch("superset.commands.dataset.create.CreateDatasetCommand") @pytest.mark.asyncio async def test_create_dataset_unexpected_error( From 35c8c7c74f74fc92f399065c9b1bf6b91206d7a5 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 28 May 2026 11:12:59 +0000 Subject: [PATCH 8/9] fix(mcp): strip whitespace from table_name in create_dataset MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirrors schema normalization — table_name is now stripped before being forwarded to CreateDatasetCommand, preventing whitespace-only strings from reaching the database layer. Co-Authored-By: Claude Sonnet 4.6 --- superset/mcp_service/dataset/tool/create_dataset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/superset/mcp_service/dataset/tool/create_dataset.py b/superset/mcp_service/dataset/tool/create_dataset.py index 975327e27658..fffc769b31dd 100644 --- a/superset/mcp_service/dataset/tool/create_dataset.py +++ b/superset/mcp_service/dataset/tool/create_dataset.py @@ -81,12 +81,13 @@ async def create_dataset( Returns DatasetInfo on success or DatasetError on failure. Use list_databases to find the correct database_id. """ - # Normalize schema: strip whitespace and treat blank strings as None + # Normalize schema and table_name: strip whitespace, treat blank schema as None schema = request.schema.strip() if request.schema else None + table_name = request.table_name.strip() await ctx.info( "Registering physical table as dataset: database_id=%s, table=%s.%s" - % (request.database_id, schema, request.table_name) + % (request.database_id, schema, table_name) ) try: @@ -100,7 +101,7 @@ async def create_dataset( dataset_properties: dict[str, object] = { "database": request.database_id, - "table_name": request.table_name, + "table_name": table_name, } if schema is not None: dataset_properties["schema"] = schema @@ -118,8 +119,7 @@ async def create_dataset( ) await ctx.info( - "Dataset registered: id=%s, table=%s.%s" - % (dataset.id, schema, request.table_name) + "Dataset registered: id=%s, table=%s.%s" % (dataset.id, schema, table_name) ) return result From b2fb339517123e28887741593b415f0d635c899f Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 28 May 2026 14:59:14 +0000 Subject: [PATCH 9/9] ci: retrigger CI run